In [1]:
import json
import random

import torch
import torch.nn as nn
from torch.optim import AdamW
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score
from torch.optim import Adam

from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm_notebook as tqdm

from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForSequenceClassification

from crf_layer import CRFLayer
from stackedModel import MultiLabelStackedClassification

import warnings
warnings.filterwarnings(action='ignore',category=UserWarning,module='torch')

In [2]:
def load_dict(dict_path):
    """load_dict"""
    vocab = {}
    for line in open(dict_path, 'r', encoding='utf-8'):
        value, key = line.strip('\n').split('\t')
        vocab[key] = int(value)
    return vocab

In [3]:
enum_role = "环节"
max_seq_len = 512

label_vocab = load_dict(dict_path='./dictionary/enum_tag.dict')

In [4]:
def enum_data_process(dataset):
    """enum_data_process"""
    output = []
    for d_json in dataset:
        text = d_json["text"].lower().replace("\t", " ")
#         if len(d_json.get("event_list", [])) == 0:
#             continue
        labels = []
        for event in d_json.get("event_list", []):
            if event["event_type"] != "公司上市":
                continue
            for argument in event["arguments"]:
                role_type = argument["role"]
                if role_type == enum_role and argument["argument"] not in labels:
                    labels.append(argument["argument"])
        if labels:
            output.append({
               "id": d_json['id'], "text": text, "labels": labels
            })
    return output

In [5]:
def set_seed(seed = 42):
    """Set the seed for generating random numbers on all GPUs.

    It's safe to call this function if CUDA is not available; in that case, it is silently ignored.

    Args:
        seed (int, optional): random numbers on all GPUs. Defaults to 42.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [6]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    n_gpu = torch.cuda.device_count()
    print(torch.cuda.get_device_name(0))
    
    print('Memory Usage:')
    print('Allocated:', torch.cuda.memory_allocated(0)/1024**3, 'GB')
    print('Cached:   ', torch.cuda.memory_reserved(0)/1024**3, 'GB')
    
    print('CUDA Device Count:', n_gpu)
    
set_seed(seed=42)

Using device: cuda

Tesla V100-PCIE-32GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
CUDA Device Count: 3


In [7]:
model_dict = {
    'ernie-base': AutoTokenizer.from_pretrained("nghuyong/ernie-1.0"),
    'roberta-chinese-base': AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext"),
    'roberta-chinese-large': AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large"),
}

In [8]:
class BaiduEnermDataset(Dataset):
    
    def __init__(self, dataset_path, label_dict_path):
        self.label_vocab = load_dict(label_dict_path)
        self.label_num = max(self.label_vocab.values()) + 1
        self.examples = []
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = json.loads(f.read())
            preprocess_dataset = enum_data_process(dataset)
            for d_json in tqdm(preprocess_dataset, total=len(preprocess_dataset)):
                text = d_json['text']
                b_input_ids, b_attention_masks, b_token_type_ids = [], [], []
                for tokenizer in model_dict.values():
                    PADDING = tokenizer.vocab[tokenizer.pad_token]
                    SEP = tokenizer.vocab[tokenizer.sep_token]
                    input_ids = tokenizer(text, is_split_into_words=False, add_special_tokens=True, max_length=max_seq_len, truncation=True)['input_ids']
                    tokens_input = input_ids + [PADDING] * (max_seq_len - len(input_ids))
                    attention_masks = self._get_attention_mask(input_ids, max_seq_len)
                    token_type_ids = self._get_token_type_id(input_ids, max_seq_len, sep_token=SEP)
                    b_input_ids.append(tokens_input)
                    b_attention_masks.append(attention_masks)
                    b_token_type_ids.append(token_type_ids)
                example = {
                    "input_ids": b_input_ids, "attention_masks": b_attention_masks, "token_type_ids": b_token_type_ids
                }
                if 'labels' in d_json:
                    labels = self.to_one_hot_vector(d_json['labels'])
                    example.update({"encoded_label": labels})
                self.examples.append(example)
                
    def to_one_hot_vector(self, labels):
        """Convert seq to one hot."""
        one_hot_vector = np.zeros(self.label_num)
        for label in labels:
            one_hot_vector[self.label_vocab[label]] = 1
        return one_hot_vector
                
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, item_idx):
        example = {
            "input_ids": torch.tensor(self.examples[item_idx]["input_ids"]).long(),
            "attention_masks": torch.tensor(self.examples[item_idx]["attention_masks"]),
            "token_type_ids": torch.tensor(self.examples[item_idx]["token_type_ids"]),
        }
        if "encoded_label" in self.examples[item_idx]:
            example.update({"encoded_label": torch.tensor(self.examples[item_idx]["encoded_label"], dtype=torch.float)})
        return example

    def _get_attention_mask(self, input_ids, max_seq_len):
        """Mask for padding."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        return [1] * len(input_ids) + [0] * (max_seq_len - len(input_ids))

    def _get_token_type_id(self, input_ids, max_seq_len, sep_token):
        """Segments: 0 for the first sequence, 1 for the second."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        segments = []
        current_segment_id = 0
        for input_id in input_ids:
            segments.append(current_segment_id)
            if input_id == sep_token:
                current_segment_id = 1
        return segments + [0] * (max_seq_len - len(input_ids)) 

In [9]:
train_dataset = BaiduEnermDataset(dataset_path='./resources/duee_fin_train_preprocess.json', label_dict_path='./dictionary/enum_tag.dict')
dev_dataset = BaiduEnermDataset(dataset_path='./resources/duee_fin_dev_preprocess.json', label_dict_path='./dictionary/enum_tag.dict')

HBox(children=(IntProgress(value=0, max=426), HTML(value='')))




HBox(children=(IntProgress(value=0, max=66), HTML(value='')))




In [10]:
train_dataset[0]

{'input_ids': tensor([[   1,  252,  560,  ...,    0,    0,    0],
         [ 101, 1266,  776,  ...,    0,    0,    0],
         [ 101, 1266,  776,  ...,    0,    0,    0]]),
 'attention_masks': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'encoded_label': tensor([0., 0., 1., 0.])}

In [11]:
models = [
    torch.load(f'./models/DuEE_fin/{model_name}/enum.bin').to(device)
    for model_name in model_dict.keys()
]
stacked_model = MultiLabelStackedClassification(models=models, num_labels=len(label_vocab))

In [12]:
@torch.no_grad()
def evaluate(model, eval_dataloader):
    model.eval()
    step = 0
    eval_acc = 0.0
    eval_precision = 0.0
    eval_recall = 0.0
    eval_f1 = 0.0
    eval_loss = 0.0
    for batch in eval_dataloader:
        loss, logits = model(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_masks'].to(device),
            token_type_ids=batch['token_type_ids'].to(device),
            labels=batch['encoded_label'].to(device)
        )
        
        if n_gpu > 1:
            loss = loss.mean()

        eval_loss += loss.item()
        pred_Y = (torch.sigmoid(logits).data > 0.5).cpu().numpy().flatten()
        true_Y = batch['encoded_label'].cpu().numpy().flatten()
        eval_acc += accuracy_score(pred_Y, true_Y)
        eval_precision += precision_score(pred_Y, true_Y, average="macro", zero_division=1)
        eval_recall += recall_score(pred_Y, true_Y, average="macro", zero_division=1)
        eval_f1 += f1_score(pred_Y, true_Y, average="macro")
        step += 1
    model.train()
    return eval_loss/step, eval_acc/step, eval_precision/step, eval_recall/step, eval_f1/step

In [13]:
### train model

def train(model, ds_train, ds_dev = None, n_epochs = 100, learning_rate = 1e-2, weight_decay = 0.01, batch_size = 1, eval_per_epoch = 1):
    model = model.to(device)
    
    train_sampler = RandomSampler(ds_train)
    train_dataloader = DataLoader(ds_train, sampler=train_sampler, batch_size=batch_size)
    
    eval_sampler = SequentialSampler(ds_dev)
    eval_dataloader = DataLoader(ds_dev, sampler=eval_sampler, batch_size=batch_size)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=[0, 1])

    optimizer_grouped_parameters = [{"params": model.parameters(), "lr": learning_rate, "weight_decay": weight_decay}]
    optimizer = AdamW(optimizer_grouped_parameters)
#     scheduler = ReduceLROnPlateau(optimizer, "min")
    
    f1 = 0.0
    acc = 0.0
    precision = 0.0
    recall = 0.0
    tr_loss = 0.0
    global_step = 0
    model.train()
    model.zero_grad()
    postfix = {}
    for epoch in range(0, n_epochs):
        eval_flag = False
        train_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{n_epochs}")
        for batch in train_iterator:
            loss, logits = model(
                input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attention_masks'].to(device),
                token_type_ids=batch['token_type_ids'].to(device),
                labels=batch['encoded_label'].to(device)
            )
            
            if n_gpu > 1:
                loss = loss.mean()
            
            loss.backward()
            optimizer.step()
#             scheduler.step(loss)

            tr_loss += loss.item()
            pred_Y = (torch.sigmoid(logits).data > 0.5).cpu().numpy().flatten()
            true_Y = batch['encoded_label'].cpu().numpy().flatten()
            acc += accuracy_score(pred_Y, true_Y)
            precision += precision_score(pred_Y, true_Y, average="macro", zero_division=1)
            recall += recall_score(pred_Y, true_Y, average="macro", zero_division=1)
            f1 += f1_score(pred_Y, true_Y, average='macro', zero_division=1)
            model.zero_grad()

            postfix.update({"Avg loss": f"{tr_loss / (global_step + 1):.2f}", "Avg acc score": f"{acc / (global_step + 1):.2f}", "Avg precision score": f"{precision / (global_step + 1):.2f}", "Avg recall score": f"{recall / (global_step + 1):.2f}", "Avg f1 score": f"{f1 / (global_step + 1):.2f}"})
            if (
                not eval_flag
                and (global_step + 1) % len(train_dataloader) == 0
                and (epoch % eval_per_epoch) == 0
            ):
                if ds_dev is not None:
                    eval_loss, eval_acc, eval_precision, eval_recall, eval_f1 = evaluate(model, eval_dataloader)
                postfix.update({"Avg eval loss": f"{eval_loss:.2f}", "Avg eval acc": f"{eval_acc:.2f}", "Avg eval precision": f"{eval_precision:.2f}", "Avg eval recall": f"{eval_recall:.2f}", "Avg eval f1": f"{eval_f1:.2f}"})
                eval_flag = True
            train_iterator.set_postfix(postfix)
            global_step += 1

In [14]:
train(stacked_model, train_dataset, ds_dev=dev_dataset, n_epochs=100, batch_size=256*2)

HBox(children=(IntProgress(value=0, description='Epoch 1/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 2/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 3/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 4/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 5/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 6/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 7/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 8/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 9/100', max=1, style=ProgressStyle(description_width='i…




HBox(children=(IntProgress(value=0, description='Epoch 10/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 11/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 12/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 13/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 14/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 15/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 16/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 17/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 18/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 19/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 20/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 21/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 22/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 23/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 24/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 25/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 26/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 27/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 28/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 29/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 30/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 31/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 32/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 33/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 34/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 35/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 36/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 37/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 38/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 39/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 40/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 41/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 42/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 43/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 44/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 45/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 46/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 47/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 48/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 49/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 50/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 51/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 52/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 53/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 54/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 55/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 56/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 57/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 58/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 59/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 60/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 61/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 62/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 63/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 64/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 65/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 66/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 67/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 68/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 69/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 70/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 71/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 72/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 73/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 74/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 75/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 76/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 77/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 78/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 79/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 80/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 81/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 82/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 83/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 84/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 85/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 86/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 87/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 88/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 89/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 90/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 91/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 92/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 93/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 94/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 95/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 96/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 97/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 98/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 99/100', max=1, style=ProgressStyle(description_width='…




HBox(children=(IntProgress(value=0, description='Epoch 100/100', max=1, style=ProgressStyle(description_width=…




In [15]:
torch.save(stacked_model.state_dict(), './models/DuEE_fin/stacked/stacked_enum.dict')

In [16]:
torch.cuda.empty_cache()