In [1]:
import sys
sys.path.append('../')

In [2]:
class CFG:
    def __init__(self):
        self.dataset='DuEE1.0'
        self.event_type='trigger'
        self.max_len=200
        self.per_gpu_train_batch_size=16
        self.per_gpu_eval_batch_size=32
        self.model_name_or_path='F:/prev_trained_model/rbt3'
        self.linear_learning_rate=1e-4
        self.early_stop=5
        self.seed=1
        self.output_dir='../output'
        self.num_train_epochs=50
        self.weight_decay=0.01
        self.learning_rate=1e-5
        self.adam_epsilon=1e-8

In [3]:
import copy
import json
import os
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast, AdamW

from dataset.dataset import joint_collate_fn,DuEEJointDataset
from metric.metric import ChunkEvaluator
from model.model import DuEEEvent_model
from utils.finetuning_argparse import get_argparse
from utils.utils import init_logger, seed_everything, logger, ProgressBar





def main():
    args = get_argparse().parse_args()
    print(json.dumps(vars(args), sort_keys=True, indent=4, separators=(', ', ': '), ensure_ascii=False))
    init_logger(log_file="./log/{}.log".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())))
    seed_everything(args.seed)

    args.output_model_path = os.path.join(args.output_dir, args.dataset, args.event_type, "best_model.pkl")
    # 设置保存目录
    if not os.path.exists(os.path.dirname(args.output_model_path)):
        os.makedirs(os.path.dirname(args.output_model_path))

    # device
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # tokenizer
    tokenizer = BertTokenizerFast.from_pretrained(args.model_name_or_path)

    # dataset & dataloader
    args.train_data = "./data/{}/{}/train.tsv".format(args.dataset, args.event_type)
    args.dev_data = "./data/{}/{}/dev.tsv".format(args.dataset, args.event_type)
    args.tag_path = "./conf/{}/{}_tag.dict".format(args.dataset, args.event_type)
    train_dataset = DuEEEventDataset(args,
                                     args.train_data,
                                     args.tag_path,
                                     tokenizer)
    eval_dataset = DuEEEventDataset(args,
                                    args.dev_data,
                                    args.tag_path,
                                    tokenizer)
    logger.info("The nums of the train_dataset features is {}".format(len(train_dataset)))
    logger.info("The nums of the eval_dataset features is {}".format(len(eval_dataset)))
    train_iter = DataLoader(train_dataset,
                            shuffle=True,
                            batch_size=args.per_gpu_train_batch_size,
                            collate_fn=collate_fn,
                            num_workers=20)
    eval_iter = DataLoader(eval_dataset,
                           shuffle=False,
                           batch_size=args.per_gpu_eval_batch_size,
                           collate_fn=collate_fn,
                           num_workers=20)

    # 用于evaluate
    args.id2label = train_dataset.label_vocab
    args.num_classes = len(args.id2label)
    metric = ChunkEvaluator(label_list=args.id2label.keys(), suffix=False)

    # model
    model = DuEEEvent_model(args.model_name_or_path, num_classes=args.num_classes)
    model.to(args.device)

    best_f1 = 0
    early_stop = 0
    for epoch, _ in enumerate(range(int(args.num_train_epochs))):
        model.train()
        train(args, train_iter, model)
        eval_p, eval_r, eval_f1, eval_loss = evaluate(args, eval_iter, model, metric)
        logger.info(
            "The F1-score is {}".format(eval_f1)
        )
        if eval_f1 > best_f1:
            early_stop = 0
            best_f1 = eval_f1
            logger.info("the best eval f1 is {:.4f}, saving model !!".format(best_f1))
            best_model = copy.deepcopy(model.module if hasattr(model, "module") else model)
            torch.save(best_model.state_dict(), args.output_model_path)
        else:
            early_stop += 1
            if early_stop == args.early_stop:
                logger.info("Early stop in {} epoch!".format(epoch))
                break

In [4]:
args=CFG()

print(json.dumps(vars(args), sort_keys=True, indent=4, separators=(', ', ': '), ensure_ascii=False))
init_logger(log_file=".././log/{}.log".format(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())))
seed_everything(args.seed)

args.output_model_path = os.path.join(args.output_dir, args.dataset, args.event_type, "best_model.pkl")
# 设置保存目录
if not os.path.exists(os.path.dirname(args.output_model_path)):
    os.makedirs(os.path.dirname(args.output_model_path))

# device
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# tokenizer
tokenizer = BertTokenizerFast.from_pretrained(args.model_name_or_path)

# dataset & dataloader
# args.train_data = "../data/{}/{}/train.tsv".format(args.dataset, args.event_type)
# args.dev_data = "../data/{}/{}/dev.tsv".format(args.dataset, args.event_type)
# args.tag_path = "../conf/{}/{}_tag.dict".format(args.dataset, args.event_type)

#trigger dataset
args.trigger_train_data = "../data/DuEE1.0/trigger/train.tsv"
args.trigger_dev_data = "../data/DuEE1.0/trigger/dev.tsv"
args.trigger_tag_path = "../conf/DuEE1.0/trigger_tag.dict"

#role dataset
args.role_train_data = "../data/DuEE1.0/role/train.tsv"
args.role_dev_data = "../data/DuEE1.0/role/dev.tsv"
args.role_tag_path = "../conf/DuEE1.0/role_tag.dict"

train_dataset = DuEEJointDataset(args,
                                 args.trigger_train_data,
                                 args.role_train_data,
                                 args.trigger_tag_path,
                                 args.role_tag_path,
                                 tokenizer)

eval_dataset = DuEEJointDataset(args,
                                args.trigger_dev_data,
                                args.role_dev_data,
                                args.trigger_tag_path,
                                args.role_tag_path,
                                tokenizer)

logger.info("The nums of the train_dataset features is {}".format(len(train_dataset)))
logger.info("The nums of the eval_dataset features is {}".format(len(eval_dataset)))
train_iter = DataLoader(train_dataset,
                        shuffle=True,
                        batch_size=args.per_gpu_train_batch_size,
                        collate_fn=joint_collate_fn,
                        num_workers=0)
eval_iter = DataLoader(eval_dataset,
                       shuffle=False,
                       batch_size=args.per_gpu_eval_batch_size,
                       collate_fn=joint_collate_fn,
                       num_workers=0)

# 用于evaluate
args.tagger_id2label = train_dataset.tagger_dataset.label_vocab
args.num_tagger_classes = len(args.tagger_id2label)
tagegr_metric = ChunkEvaluator(label_list=args.tagger_id2label.keys(), suffix=False)

args.role_id2label = train_dataset.role_dataset.label_vocab
args.num_role_classes = len(args.role_id2label)
role_metric = ChunkEvaluator(label_list=args.role_id2label.keys(), suffix=False)

{
    "adam_epsilon": 1e-08, 
    "dataset": "DuEE1.0", 
    "early_stop": 5, 
    "event_type": "trigger", 
    "learning_rate": 1e-05, 
    "linear_learning_rate": 0.0001, 
    "max_len": 200, 
    "model_name_or_path": "F:/prev_trained_model/rbt3", 
    "num_train_epochs": 50, 
    "output_dir": "../output", 
    "per_gpu_eval_batch_size": 32, 
    "per_gpu_train_batch_size": 16, 
    "seed": 1, 
    "weight_decay": 0.01
}


tokenizing...: 100%|███████████████████████████████████████████████████████████| 11958/11958 [00:07<00:00, 1629.01it/s]
tokenizing...: 100%|███████████████████████████████████████████████████████████| 13915/13915 [00:09<00:00, 1530.08it/s]
tokenizing...: 100%|█████████████████████████████████████████████████████████████| 1498/1498 [00:00<00:00, 1615.91it/s]
tokenizing...: 100%|█████████████████████████████████████████████████████████████| 1790/1790 [00:01<00:00, 1507.97it/s]
05/14/2021 16:29:12 - INFO - root -   The nums of the train_dataset features is 13915
05/14/2021 16:29:12 - INFO - root -   The nums of the eval_dataset features is 1790


In [6]:
from torch import nn
from transformers import BertModel
from torchcrf import CRF


class DuEEEvent_joint_model(nn.Module):
    def __init__(self, pretrained_model_path, num_tagger_classes,num_role_class):
        super(DuEEEvent_joint_model, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_path)
        self.tagger_classifier = nn.Linear(self.bert.config.hidden_size, num_tagger_classes)
        self.role_classifier = nn.Linear(self.bert.config.hidden_size, num_role_class)
        #self.crf = CRF(num_tags=num_classes, batch_first=True)
        
    def forward(self,
                input_ids=None,
                token_type_ids=None,
                attention_mask=None,
                tagger_labels=None,
                role_labels=None):
        output = self.bert(input_ids,
                           token_type_ids=token_type_ids,
                           attention_mask=attention_mask)
        sequence_output, pooled_output = output[0], output[1]
        tagger_logits = self.tagger_classifier(sequence_output)
        role_logits=self.role_classifier(sequence_output)
#         if labels is not None:
#             loss = self.crf(emissions=logits, tags=labels, mask=attention_mask.to(torch.uint8))
#             return -1 * loss
        return tagger_logits,role_logits

In [7]:
# model
model = DuEEEvent_joint_model(args.model_name_or_path, num_tagger_classes=args.num_tagger_classes,num_role_class=args.num_role_classes)
_=model.to(args.device)

In [10]:
def train(args, train_iter, model):
    logger.info("***** Running train *****")
    # 优化器
    no_decay = ["bias", "LayerNorm.weight"]
    bert_param_optimizer = list(model.bert.named_parameters())
    linear_param_optimizer = list(model.tagger_classifier.named_parameters())
    
    linear_param_optimizer.extend(list(model.role_classifier.named_parameters()))
    
    optimizer_grouped_parameters = [
        {'params': [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay,
         'lr': args.learning_rate},
        {'params': [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0,
         'lr': args.learning_rate},
        {'params': [p for n, p in linear_param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay,
         'lr': args.linear_learning_rate},
        {'params': [p for n, p in linear_param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0,
         'lr': args.linear_learning_rate},
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    # 损失函数
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(args.device)
    batch_loss = 0
    pbar = ProgressBar(n_total=len(train_iter), desc='Training')
    print("****" * 20)
    for step, batch in enumerate(train_iter):
        for key in batch.keys():
            batch[key] = batch[key].to(args.device)
        tagger_logits,role_logits = model(
            input_ids=batch['all_input_ids'],
            attention_mask=batch['all_attention_mask'],
            token_type_ids=batch['all_token_type_ids'],
            tagger_labels=batch['all_tagger_labels'],
            role_labels=batch['all_role_labels']
        )

        tagger_logits = tagger_logits.view(-1, args.num_tagger_classes)
        role_logits = role_logits.view(-1, args.num_role_classes)
        # 正常训练
        loss1 = criterion(tagger_logits, batch["all_tagger_labels"].view(-1))
        loss2= criterion(role_logits, batch["all_role_labels"].view(-1))
        loss=loss1+loss2
        loss.backward()
        #
        batch_loss += loss.item()
        pbar(step,
             {
                 'batch_loss': batch_loss / (step + 1),
             })
        optimizer.step()
        model.zero_grad()

In [14]:
def evaluate(args, eval_iter, model, tagger_metric,role_metric):
    """evaluate"""
    tagger_metric.reset()
    role_metric.reset()
    batch_loss = 0
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(args.device)

    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(eval_iter):
            
            for key in batch.keys():
                batch[key] = batch[key].to(args.device)
            tagger_logits,role_logits = model(
            input_ids=batch['all_input_ids'],
            attention_mask=batch['all_attention_mask'],
            token_type_ids=batch['all_token_type_ids'],
            tagger_labels=batch['all_tagger_labels'],
            role_labels=batch['all_role_labels']
            )
            #loss = criterion(logits.view(-1, args.num_classes),batch["all_labels"].view(-1))
            #batch_loss += loss.item()

            
            tagger_preds = torch.argmax(tagger_logits, axis=-1)
            #preds=torch.tensor(model.crf.decode(logits),dtype=torch.int)
            n_infer, n_label, n_correct = tagger_metric.compute(batch["all_seq_lens"], tagger_preds, batch['all_tagger_labels'])
            tagger_metric.update(n_infer, n_label, n_correct)
            
            role_preds = torch.argmax(role_logits, axis=-1)
            #preds=torch.tensor(model.crf.decode(logits),dtype=torch.int)
            n_infer, n_label, n_correct = role_metric.compute(batch["all_seq_lens"], role_preds, batch['all_role_labels'])
            role_metric.update(n_infer, n_label, n_correct)
            
    precision1, recall1, f1_score1 = tagger_metric.accumulate()
    precision2, recall2, f1_score2 = role_metric.accumulate()

    return precision1, recall1, f1_score1,precision2, recall2, f1_score2

In [16]:
best_f1 = 0
early_stop = 0
for epoch, _ in enumerate(range(int(args.num_train_epochs))):
    model.train()
    train(args, train_iter, model)
    eval_p, eval_r, eval_f1, eval_p2, eval_r2, eval_f12 = evaluate(args, eval_iter, model,tagegr_metric,role_metric)
    logger.info(
        "The trigger F1-score is {} , the role F1-score is {}".format(eval_f1,eval_f12)
    )
    
    sumf1=eval_f1+eval_f12
    if sumf1 > best_f1:
        early_stop = 0
        best_f1 = sumf1
        logger.info("the best trigger eval f1 is {:.4f} , role eval f1 is {:.4f}, saving model !!".format(eval_f1,eval_f12))
        best_model = copy.deepcopy(model.module if hasattr(model, "module") else model)
        torch.save(best_model.state_dict(), args.output_model_path)
    else:
        early_stop += 1
        if early_stop == args.early_stop:
            logger.info("Early stop in {} epoch!".format(epoch))
            break

05/14/2021 16:38:36 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:40:28 - INFO - root -   The trigger F1-score is 0.7793240556660039 , the role F1-score is 0.4533152909336942
05/14/2021 16:40:28 - INFO - root -   the best trigger eval f1 is 0.7793 , role eval f1 is 0.4533, saving model !!
05/14/2021 16:40:29 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:42:22 - INFO - root -   The trigger F1-score is 0.7924596050269299 , the role F1-score is 0.47027285782671135
05/14/2021 16:42:22 - INFO - root -   the best trigger eval f1 is 0.7925 , role eval f1 is 0.4703, saving model !!
05/14/2021 16:42:23 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:44:18 - INFO - root -   The trigger F1-score is 0.782959970620639 , the role F1-score is 0.4859230394544569
05/14/2021 16:44:18 - INFO - root -   the best trigger eval f1 is 0.7830 , role eval f1 is 0.4859, saving model !!
05/14/2021 16:44:19 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:46:13 - INFO - root -   The trigger F1-score is 0.7898433279308482 , the role F1-score is 0.49415975885455915
05/14/2021 16:46:13 - INFO - root -   the best trigger eval f1 is 0.7898 , role eval f1 is 0.4942, saving model !!
05/14/2021 16:46:14 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:48:09 - INFO - root -   The trigger F1-score is 0.7957423380436777 , the role F1-score is 0.5006423559640282
05/14/2021 16:48:09 - INFO - root -   the best trigger eval f1 is 0.7957 , role eval f1 is 0.5006, saving model !!
05/14/2021 16:48:10 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:50:05 - INFO - root -   The trigger F1-score is 0.7921114528677402 , the role F1-score is 0.49222499520061436
05/14/2021 16:50:05 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:51:58 - INFO - root -   The trigger F1-score is 0.788222384784199 , the role F1-score is 0.5000495687518588
05/14/2021 16:51:58 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:53:53 - INFO - root -   The trigger F1-score is 0.7897417242633685 , the role F1-score is 0.5011448481831757
05/14/2021 16:53:53 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:55:49 - INFO - root -   The trigger F1-score is 0.7959633027522935 , the role F1-score is 0.5097256857855362
05/14/2021 16:55:49 - INFO - root -   the best trigger eval f1 is 0.7960 , role eval f1 is 0.5097, saving model !!
05/14/2021 16:55:49 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:57:45 - INFO - root -   The trigger F1-score is 0.7956673398200843 , the role F1-score is 0.5049975739932072
05/14/2021 16:57:45 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 16:59:40 - INFO - root -   The trigger F1-score is 0.7961271465107781 , the role F1-score is 0.5029061457837065
05/14/2021 16:59:40 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:01:36 - INFO - root -   The trigger F1-score is 0.8024327312937708 , the role F1-score is 0.503968253968254
05/14/2021 17:01:36 - INFO - root -   the best trigger eval f1 is 0.8024 , role eval f1 is 0.5040, saving model !!
05/14/2021 17:01:37 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:03:31 - INFO - root -   The trigger F1-score is 0.7930007446016382 , the role F1-score is 0.5005988023952096
05/14/2021 17:03:31 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:05:25 - INFO - root -   The trigger F1-score is 0.8020813975097565 , the role F1-score is 0.5084978878082327
05/14/2021 17:05:25 - INFO - root -   the best trigger eval f1 is 0.8021 , role eval f1 is 0.5085, saving model !!
05/14/2021 17:05:25 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:07:19 - INFO - root -   The trigger F1-score is 0.7957963399166517 , the role F1-score is 0.5096942220507379
05/14/2021 17:07:19 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:09:13 - INFO - root -   The trigger F1-score is 0.7837131116921927 , the role F1-score is 0.5098786480794304
05/14/2021 17:09:13 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:11:07 - INFO - root -   The trigger F1-score is 0.7843210802700675 , the role F1-score is 0.5086022600020361
05/14/2021 17:11:07 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:13:01 - INFO - root -   The trigger F1-score is 0.7943288528816057 , the role F1-score is 0.5080294919705081
05/14/2021 17:13:01 - INFO - root -   ***** Running train *****


********************************************************************************

05/14/2021 17:14:55 - INFO - root -   The trigger F1-score is 0.793943870014771 , the role F1-score is 0.5133176654915785
05/14/2021 17:14:55 - INFO - root -   Early stop in 18 epoch!
