In [1]:
import os
import sys
import math
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm_notebook as tqdm, tnrange
from transformers import BertModel, BertTokenizer
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

w_dir = %pwd
work_dir = os.path.dirname(w_dir)
work_dir

I1213 03:58:40.047722 140007454709568 file_utils.py:39] PyTorch version 1.1.0 available.


'/work'

In [2]:
from fgc_support_retri import config
from fgc_support_retri.sup_model import BertContextSupModel_V3
from fgc_support_retri.fgc_preprocess import BertV3Idx, SerContextDataset, bert_collate_v3
from fgc_support_retri.utils import read_fgc
from fgc_support_retri.eval import evalaluate_f1

I1213 03:58:40.492198 140007454709568 corenlp.py:42] Using an existing server http://140.109.19.191:9000
I1213 03:58:41.495132 140007454709568 corenlp.py:118] The server is available.


In [3]:
bert_model_name = config.BERT_EMBEDDING

train_items = read_fgc(config.FGC_TRAIN, eval=True)
train_items.sort(key=lambda item: len(item['SENTS']), reverse=True)
dev_items = read_fgc(config.FGC_DEV, eval=True)
dev_items.sort(key=lambda item: len(item['SENTS']), reverse=True)

tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_set = SerContextDataset(train_items, transform=torchvision.transforms.Compose([BertV3Idx(tokenizer, 50)]))
dev_set = SerContextDataset(dev_items, transform=torchvision.transforms.Compose([BertV3Idx(tokenizer, 50)]))

dataloader_train = DataLoader(train_set, batch_size=2, shuffle=False, collate_fn=bert_collate_v3)
dataloader_dev = DataLoader(dev_set, batch_size=64, shuffle=False, collate_fn=bert_collate_v3)

no gold supporting evidence
{'QID': 'D001Q11', 'QTYPE': '申论', 'QTEXT': '蘇東坡為何被後人認為是文學藝術史上的通才?', 'SENTS': [{'text': '苏东坡为何被后人认为是文学艺术史上的通才?', 'start': 0, 'end': 21}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Event', 'AMODE': 'Single-Span-Extraction', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '苏东坡为何被后人认为是文学艺术史上的通才?'}
no gold supporting evidence
{'QID': 'D006Q02', 'QTYPE': '申论', 'QTEXT': '「阿拉伯之春」運動中，走上街頭的民眾的訴求為何?', 'SENTS': [{'text': '「阿拉伯之春」运动中，', 'start': 0, 'end': 11}, {'text': '走上街头的民众的诉求为何?', 'start': 11, 'end': 24}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Object', 'AMODE': 'Single-Span-Extraction', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '「阿拉伯之春」运动中，走上街头的民众的诉求为何?'}
no gold supporting evidence
{'QID': 'D048Q09', 'QTYPE': '申论', 'QTEXT': '聊天機器人仰賴哪些方法讓回答愈來愈準確?', 'SENTS': [{'text': '聊天机器人仰赖哪些方法让回答愈来愈准确?', 'start': 0, 'end': 20}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'A

I1213 03:58:42.465002 140007454709568 tokenization_utils.py:375] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt from cache at /root/.cache/torch/transformers/8a0c070123c1f794c42a29c6904beb7c1b8715741e235bee04aca2c7636fc83f.9b42061518a39ca00b8b52059fd2bede8daa613f8a8671500e518a8c29de8c00


In [4]:
# for batch in dataloader_train:
#     print(batch['sentences']['input_ids'].shape)
#     import pdb; pdb.set_trace()

In [5]:
def train_BertContextSupModel_V3(num_epochs, batch_size, model_file_name):
        
    torch.manual_seed(12)
    bert_model_name = config.BERT_EMBEDDING
    warmup_proportion = 0.1
    learning_rate = 5e-5
    eval_frequency = 5
    
    trained_model_path = config.TRAINED_MODELS / model_file_name
    if not os.path.exists(trained_model_path):
        os.mkdir(trained_model_path)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    
    bert_encoder = BertModel.from_pretrained(bert_model_name)
    model = BertContextSupModel_V3(bert_encoder, device)
    
    model.to(device)
    if n_gpu > 1:
        model = nn.DataParallel(model)
    
    param_optimizer = list(model.named_parameters())
    
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    
    # read data
    train_items = read_fgc(config.FGC_TRAIN, eval=True)
    train_items.sort(key=lambda item: len(item['SENTS']), reverse=True)
    dev_items = read_fgc(config.FGC_DEV, eval=True)
    dev_items.sort(key=lambda item: len(item['SENTS']), reverse=True)

    tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    train_set = SerContextDataset(train_items, transform=torchvision.transforms.Compose([BertV3Idx(tokenizer, 50)]))
    dev_set = SerContextDataset(dev_items, transform=torchvision.transforms.Compose([BertV3Idx(tokenizer, 50)]))

    dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=False, collate_fn=bert_collate_v3)
    dataloader_dev = DataLoader(dev_set, batch_size=64, shuffle=False, collate_fn=bert_collate_v3)
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
    num_train_optimization_steps = int(math.ceil(len(train_set) / batch_size)) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=int(num_train_optimization_steps * warmup_proportion),
                                                num_training_steps=num_train_optimization_steps)
    
    print('start training ... ')
    for epoch_i in range(num_epochs + 1):
        model.train()
        running_loss = 0.0
        for batch_i, batch in enumerate(tqdm(dataloader_train)):
            optimizer.zero_grad()
            question = {key: tensor.to(dtype=torch.int64, device=device) for key, tensor in batch['question'].items()}
            sentences = {key: tensor.to(dtype=torch.int64, device=device) for key, tensor in batch['sentences'].items()}
            labels = batch['label'].to(dtype=torch.float, device=device)
            
            loss, _ = model(question, sentences, batch['batch_config'], labels=labels)
            
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            
            loss.backward()
            optimizer.step()
            scheduler.step()
            running_loss += loss.item()
        
        print('epoch %d train_loss: %.3f' % (epoch_i, running_loss / len(dataloader_train)))
        
        # evaluate
        if epoch_i % eval_frequency == 0:
            model.eval()       
            accum_loss = 0
            with torch.no_grad():
                
                score_list = []
                for batch in dataloader_dev:
                    question = {key: tensor.to(dtype=torch.int64, device=device) for key, tensor in batch['question'].items()}
                    sentences = {key: tensor.to(dtype=torch.int64, device=device) for key, tensor in batch['sentences'].items()}
                    
                    score = model(question, sentences, batch['batch_config'], mode=BertContextSupModel_V3.ForwardMode.EVAL)
                    score_list += score.cpu().numpy().tolist()
                predictions = []
                for score in score_list:
                    prediction = []
                    for s_i, s in enumerate(score):
                        if s >= 0.2:
                            prediction.append(s_i)
                        predictions.append(prediction)
                
            precision, recall, f1 = evalaluate_f1(dev_items, predictions)
            print('epoch %d eval_recall: %.3f eval_f1: %.3f' % (epoch_i, recall, f1))
                  
            model_to_save = model.module if hasattr(model, 'module') else model
            torch.save(model_to_save.state_dict(),
                       str(trained_model_path / "model_epoch{0}_eval_recall_{1:.3f}_f1_{2:.3f}.m".format(epoch_i, recall, f1)))

In [None]:
train_BertContextSupModel_V3(100, 6, '02191212_BertContextSupModel_V3_mul_test2')

I1213 03:58:43.427174 140007454709568 configuration_utils.py:152] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json from cache at /root/.cache/torch/transformers/8a3b1cfe5da58286e12a0f5d7d182b8d6eca88c08e26c332ee3817548cf7e60a.0c16faba8be66db3f02805c912e4cf94d3c9cffc1f12fa1a39906f9270f76d33
I1213 03:58:43.430310 140007454709568 configuration_utils.py:169] Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "po

no gold supporting evidence
{'QID': 'D001Q11', 'QTYPE': '申论', 'QTEXT': '蘇東坡為何被後人認為是文學藝術史上的通才?', 'SENTS': [{'text': '苏东坡为何被后人认为是文学艺术史上的通才?', 'start': 0, 'end': 21}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Event', 'AMODE': 'Single-Span-Extraction', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '苏东坡为何被后人认为是文学艺术史上的通才?'}
no gold supporting evidence
{'QID': 'D006Q02', 'QTYPE': '申论', 'QTEXT': '「阿拉伯之春」運動中，走上街頭的民眾的訴求為何?', 'SENTS': [{'text': '「阿拉伯之春」运动中，', 'start': 0, 'end': 11}, {'text': '走上街头的民众的诉求为何?', 'start': 11, 'end': 24}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Object', 'AMODE': 'Single-Span-Extraction', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '「阿拉伯之春」运动中，走上街头的民众的诉求为何?'}
no gold supporting evidence
{'QID': 'D048Q09', 'QTYPE': '申论', 'QTEXT': '聊天機器人仰賴哪些方法讓回答愈來愈準確?', 'SENTS': [{'text': '聊天机器人仰赖哪些方法让回答愈来愈准确?', 'start': 0, 'end': 20}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'A

I1213 03:58:51.591800 140007454709568 tokenization_utils.py:375] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt from cache at /root/.cache/torch/transformers/8a0c070123c1f794c42a29c6904beb7c1b8715741e235bee04aca2c7636fc83f.9b42061518a39ca00b8b52059fd2bede8daa613f8a8671500e518a8c29de8c00


start training ... 


HBox(children=(IntProgress(value=0, max=123), HTML(value='')))