In [1]:
import sys
sys.path.append('../')

In [2]:
import copy
import json
import os
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast, AdamW

from dataset.dataset import collate_fn, DuEEEventDataset
from metric.metric import ChunkEvaluator
from model.model import DuEEEvent_model
from utils.finetuning_argparse import get_argparse
from utils.utils import init_logger, seed_everything, logger, ProgressBar


def evaluate(args, eval_iter, model, metric):
    """evaluate"""
    metric.reset()
    batch_loss = 0
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(args.device)

    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(eval_iter):
            
            for key in batch.keys():
                batch[key] = batch[key].to(args.device)
            logits = model(
                input_ids=batch['all_input_ids'],
                attention_mask=batch['all_attention_mask'],
                token_type_ids=batch['all_token_type_ids']
            )
            #loss = criterion(logits.view(-1, args.num_classes),batch["all_labels"].view(-1))
            #batch_loss += loss.item()

            
            #preds = torch.argmax(logits, axis=-1)
            preds=torch.tensor(model.crf.decode(logits),dtype=torch.int)
            n_infer, n_label, n_correct = metric.compute(batch["all_seq_lens"], preds, batch['all_labels'])
            metric.update(n_infer, n_label, n_correct)

    precision, recall, f1_score = metric.accumulate()

    return precision, recall, f1_score, batch_loss / (step + 1)



def main():
    args = get_argparse().parse_args()
    print(json.dumps(vars(args), sort_keys=True, indent=4, separators=(', ', ': '), ensure_ascii=False))
    init_logger(log_file="./log/{}.log".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())))
    seed_everything(args.seed)

    args.output_model_path = os.path.join(args.output_dir, args.dataset, args.event_type, "best_model.pkl")
    # 设置保存目录
    if not os.path.exists(os.path.dirname(args.output_model_path)):
        os.makedirs(os.path.dirname(args.output_model_path))

    # device
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # tokenizer
    tokenizer = BertTokenizerFast.from_pretrained(args.model_name_or_path)

    # dataset & dataloader
    args.train_data = "./data/{}/{}/train.tsv".format(args.dataset, args.event_type)
    args.dev_data = "./data/{}/{}/dev.tsv".format(args.dataset, args.event_type)
    args.tag_path = "./conf/{}/{}_tag.dict".format(args.dataset, args.event_type)
    train_dataset = DuEEEventDataset(args,
                                     args.train_data,
                                     args.tag_path,
                                     tokenizer)
    eval_dataset = DuEEEventDataset(args,
                                    args.dev_data,
                                    args.tag_path,
                                    tokenizer)
    logger.info("The nums of the train_dataset features is {}".format(len(train_dataset)))
    logger.info("The nums of the eval_dataset features is {}".format(len(eval_dataset)))
    train_iter = DataLoader(train_dataset,
                            shuffle=True,
                            batch_size=args.per_gpu_train_batch_size,
                            collate_fn=collate_fn,
                            num_workers=20)
    eval_iter = DataLoader(eval_dataset,
                           shuffle=False,
                           batch_size=args.per_gpu_eval_batch_size,
                           collate_fn=collate_fn,
                           num_workers=20)

    # 用于evaluate
    args.id2label = train_dataset.label_vocab
    args.num_classes = len(args.id2label)
    metric = ChunkEvaluator(label_list=args.id2label.keys(), suffix=False)

    # model
    model = DuEEEvent_model(args.model_name_or_path, num_classes=args.num_classes)
    model.to(args.device)

    best_f1 = 0
    early_stop = 0
    for epoch, _ in enumerate(range(int(args.num_train_epochs))):
        model.train()
        train(args, train_iter, model)
        eval_p, eval_r, eval_f1, eval_loss = evaluate(args, eval_iter, model, metric)
        logger.info(
            "The F1-score is {}".format(eval_f1)
        )
        if eval_f1 > best_f1:
            early_stop = 0
            best_f1 = eval_f1
            logger.info("the best eval f1 is {:.4f}, saving model !!".format(best_f1))
            best_model = copy.deepcopy(model.module if hasattr(model, "module") else model)
            torch.save(best_model.state_dict(), args.output_model_path)
        else:
            early_stop += 1
            if early_stop == args.early_stop:
                logger.info("Early stop in {} epoch!".format(epoch))
                break

In [3]:
def train(args, train_iter, model):
    logger.info("***** Running train *****")
    # 优化器
    no_decay = ["bias", "LayerNorm.weight"]
    bert_param_optimizer = list(model.bert.named_parameters())
    linear_param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [
        {'params': [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay,
         'lr': args.learning_rate},
        {'params': [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0,
         'lr': args.learning_rate},
        {'params': [p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay,
         'lr': args.linear_learning_rate},
        {'params': [p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0,
         'lr': args.linear_learning_rate},
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    # 损失函数
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(args.device)
    batch_loss = 0
    pbar = ProgressBar(n_total=len(train_iter), desc='Training')
    print("****" * 20)
    for step, batch in enumerate(train_iter):
        for key in batch.keys():
            batch[key] = batch[key].to(args.device)
        logits = model(
            input_ids=batch['all_input_ids'],
            attention_mask=batch['all_attention_mask'],
            token_type_ids=batch['all_token_type_ids'],
            labels=batch['all_labels']
        )
        #logits = logits.view(-1, args.num_classes)
        # 正常训练
        #loss = criterion(logits, batch["all_labels"].view(-1))
        loss=logits
        loss.backward()
        #
        batch_loss += loss.item()
        pbar(step,
             {
                 'batch_loss': batch_loss / (step + 1),
             })
        optimizer.step()
        model.zero_grad()

In [4]:
class CFG:
    def __init__(self):
        self.dataset='DuEE1.0'
        self.event_type='role'
        self.max_len=200
        self.per_gpu_train_batch_size=16
        self.per_gpu_eval_batch_size=32
        #self.model_name_or_path='F:/prev_trained_model/rbt3'
        self.model_name_or_path='F:/prev_trained_model/chinese_wwm_pytorch'
        self.linear_learning_rate=1e-4
        self.early_stop=5
        self.seed=1
        self.output_dir='../output'
        self.num_train_epochs=50
        self.weight_decay=0.01
        self.learning_rate=1e-5
        self.adam_epsilon=1e-8

In [5]:
args=CFG()

print(json.dumps(vars(args), sort_keys=True, indent=4, separators=(', ', ': '), ensure_ascii=False))
init_logger(log_file=".././log/{}.log".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())))
seed_everything(args.seed)

args.output_model_path = os.path.join(args.output_dir, args.dataset, args.event_type, "best_model.pkl")
# 设置保存目录
if not os.path.exists(os.path.dirname(args.output_model_path)):
    os.makedirs(os.path.dirname(args.output_model_path))

# device
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# tokenizer
tokenizer = BertTokenizerFast.from_pretrained(args.model_name_or_path)

# dataset & dataloader
args.train_data = "../data/{}/{}/train.tsv".format(args.dataset, args.event_type)
args.dev_data = "../data/{}/{}/dev.tsv".format(args.dataset, args.event_type)
args.tag_path = "../conf/{}/{}_tag.dict".format(args.dataset, args.event_type)
train_dataset = DuEEEventDataset(args,
                                 args.train_data,
                                 args.tag_path,
                                 tokenizer)
eval_dataset = DuEEEventDataset(args,
                                args.dev_data,
                                args.tag_path,
                                tokenizer)
logger.info("The nums of the train_dataset features is {}".format(len(train_dataset)))
logger.info("The nums of the eval_dataset features is {}".format(len(eval_dataset)))
train_iter = DataLoader(train_dataset,
                        shuffle=True,
                        batch_size=args.per_gpu_train_batch_size,
                        collate_fn=collate_fn,
                        num_workers=0)
eval_iter = DataLoader(eval_dataset,
                       shuffle=False,
                       batch_size=args.per_gpu_eval_batch_size,
                       collate_fn=collate_fn,
                       num_workers=0)

# 用于evaluate
args.id2label = train_dataset.label_vocab
args.num_classes = len(args.id2label)
metric = ChunkEvaluator(label_list=args.id2label.keys(), suffix=False)



{
    "adam_epsilon": 1e-08, 
    "dataset": "DuEE1.0", 
    "early_stop": 5, 
    "event_type": "role", 
    "learning_rate": 1e-05, 
    "linear_learning_rate": 0.0001, 
    "max_len": 200, 
    "model_name_or_path": "F:/prev_trained_model/chinese_wwm_pytorch", 
    "num_train_epochs": 50, 
    "output_dir": "../output", 
    "per_gpu_eval_batch_size": 32, 
    "per_gpu_train_batch_size": 16, 
    "seed": 1, 
    "weight_decay": 0.01
}


tokenizing...: 100%|███████████████████████████████████████████████████████████| 13915/13915 [00:09<00:00, 1504.69it/s]
tokenizing...: 100%|█████████████████████████████████████████████████████████████| 1790/1790 [00:01<00:00, 1444.51it/s]
05/17/2021 14:42:41 - INFO - root -   The nums of the train_dataset features is 13915
05/17/2021 14:42:41 - INFO - root -   The nums of the eval_dataset features is 1790


In [6]:
from torch import nn
from transformers import BertModel
from torchcrf import CRF


class DuEEEvent_crf_model(nn.Module):
    def __init__(self, pretrained_model_path, num_classes):
        super(DuEEEvent_crf_model, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_path)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        self.crf = CRF(num_tags=num_classes, batch_first=True)
        
    def forward(self,
                input_ids=None,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        output = self.bert(input_ids,
                           token_type_ids=token_type_ids,
                           attention_mask=attention_mask)
        sequence_output, pooled_output = output[0], output[1]
        logits = self.classifier(sequence_output)

        if labels is not None:
            loss = self.crf(emissions=logits, tags=labels, mask=attention_mask.to(torch.uint8))
            return -1 * loss
        return logits

In [7]:
# model
model = DuEEEvent_crf_model(args.model_name_or_path, num_classes=args.num_classes)
_=model.to(args.device)

In [8]:
best_f1 = 0
early_stop = 0
for epoch, _ in enumerate(range(int(args.num_train_epochs))):
    model.train()
    train(args, train_iter, model)
    eval_p, eval_r, eval_f1, eval_loss = evaluate(args, eval_iter, model, metric)
    logger.info(
        "The F1-score is {}".format(eval_f1)
    )
    if eval_f1 > best_f1:
        early_stop = 0
        best_f1 = eval_f1
        logger.info("the best eval f1 is {:.4f}, saving model !!".format(best_f1))
        best_model = copy.deepcopy(model.module if hasattr(model, "module") else model)
        torch.save(best_model.state_dict(), args.output_model_path)
    else:
        early_stop += 1
        if early_stop == args.early_stop:
            logger.info("Early stop in {} epoch!".format(epoch))
            break

05/17/2021 14:42:43 - INFO - root -   ***** Running train *****


********************************************************************************
torch.Size([16, 94, 243]) torch.Size([16, 94])
[Training] 1/870 [..............................] - ETA: 9:37  batch_loss: 3570.5664 

RuntimeError: CUDA out of memory. Tried to allocate 26.00 MiB (GPU 0; 8.00 GiB total capacity; 3.01 GiB already allocated; 3.09 MiB free; 3.17 GiB reserved in total by PyTorch)