In [None]:
import json
import random

import torch
import torch.nn as nn
from torch.optim import AdamW
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score
from torch.optim import Adam

from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm_notebook as tqdm

from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForTokenClassification

from crf_layer import CRFLayer
from multiLabelTokenClassfication import MultiLabelTokenClassification
from stackedModel import MultiLabelStackedClassification

from utils import extract_result_multilabel

import warnings
warnings.filterwarnings(action='ignore',category=UserWarning,module='torch')

In [None]:
def load_dict(dict_path):
    """load_dict"""
    vocab = {}
    for line in open(dict_path, 'r', encoding='utf-8'):
        value, key = line.strip('\n').split('\t')
        vocab[key] = int(value)
    return vocab

In [None]:
enum_role = "环节"
max_seq_len = 512

label_vocab = load_dict(dict_path='./dictionary/trigger_tag.dict')
id2label = {val: key for key, val in label_vocab.items()}

In [None]:
def data_process(dataset):
    """data_process"""

    def label_data(data, start, l, _type):
        """label_data"""
        for i in range(start, start + l):
            suffix = "B-" if i == start else "I-"
            if isinstance(data[i], str):
                data[i] = []
            solt = "{}{}".format(suffix, _type)
            if solt not in data[i]:
                data[i].append(solt)
        return data
    
    def replace_control_chars(str):
        if str == '\u200b' or str == '\ufeff' or str == '\ue601' or str == '\u3000':
            return '[UNK]'
        else:
            return str

    output = []
    for d_json in dataset:
        _id = d_json["id"]
        text_a = [
            "，" if t == " " or t == "\n" or t == "\t" else replace_control_chars(t)
            for t in list(d_json["text"].lower())
        ]
        labels = ["O"] * len(text_a)
        if len(d_json.get("event_list", [])) == 0:
            continue
        for event in d_json.get("event_list", []):
            event_type = event["event_type"]
            start = event["trigger_start_index"]
            trigger = event["trigger"]
            if start >= 0:
                labels = label_data(labels, start, len(trigger), event_type)
        output.append({
            "tokens": text_a, "labels": labels
        })
    return output

In [None]:
def set_seed(seed = 42):
    """Set the seed for generating random numbers on all GPUs.

    It's safe to call this function if CUDA is not available; in that case, it is silently ignored.

    Args:
        seed (int, optional): random numbers on all GPUs. Defaults to 42.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [None]:
# setting device on GPU if available, else CPU
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    n_gpu = torch.cuda.device_count()
    print(torch.cuda.get_device_name(0))
    
    print('Memory Usage:')
    print('Allocated:', torch.cuda.memory_allocated(0)/1024**3, 'GB')
    print('Cached:   ', torch.cuda.memory_reserved(0)/1024**3, 'GB')
    
    print('CUDA Device Count:', n_gpu)
    
set_seed(seed=42)

In [None]:
model_dict = {
    'ernie-base': AutoTokenizer.from_pretrained("nghuyong/ernie-1.0"),
    'roberta-chinese-base': AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext"),
    'roberta-chinese-large': AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large"),
}

In [None]:
class BaiduEventDataset(Dataset):
    
    def __init__(self, dataset_path, label_dict_path):
        self.label_vocab = load_dict(label_dict_path)
        self.label_num = max(self.label_vocab.values()) + 1
        self.examples = []
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = json.loads(f.read())
            preprocess_dataset = data_process(dataset)
            for d_json in tqdm(preprocess_dataset, total=len(preprocess_dataset)):
                tokens = d_json['tokens']
                b_input_ids, b_attention_masks, b_token_type_ids = [], [], []
                for tokenizer in model_dict.values():
                    PADDING = tokenizer.vocab[tokenizer.pad_token]
                    SEP = tokenizer.vocab[tokenizer.sep_token]
                    input_ids = tokenizer(tokens, is_split_into_words=True, add_special_tokens=True, max_length=max_seq_len, truncation=True)['input_ids']
                    tokens_input = input_ids + [PADDING] * (max_seq_len - len(input_ids))
                    attention_masks = self._get_attention_mask(input_ids, max_seq_len)
                    token_type_ids = self._get_token_type_id(input_ids, max_seq_len, sep_token=SEP)
                    b_input_ids.append(tokens_input)
                    b_attention_masks.append(attention_masks)
                    b_token_type_ids.append(token_type_ids)
                example = {
                    "input_ids": b_input_ids, "attention_masks": b_attention_masks, "token_type_ids": b_token_type_ids, "seq_lens": len(tokens)
                }
                if 'labels' in d_json:
                    labels = d_json['labels']
                    labels = labels[:(max_seq_len - 2)]
                    encoded_label = ["O"] + labels + ["O"]
                    encoded_label = self.to_one_hot_vector(encoded_label, max_seq_len - 2 - len(labels))
                    example.update({"encoded_label": encoded_label})
                self.examples.append(example)
                
    def to_one_hot_vector(self, labels, zero_padding_len = 0):
        """Convert seq to one hot."""
        one_hot_vectors = []
        for label in labels:
            one_hot_vector = np.zeros(self.label_num)
            if isinstance(label, str):
                one_hot_vector[self.label_vocab.get(label, 0)] = 1
            elif isinstance(label, list):
                for l in label:
                    one_hot_vector[self.label_vocab.get(l, 0)] = 1
            one_hot_vectors.append(one_hot_vector)
        for _ in range(zero_padding_len):
            one_hot_vector = np.zeros(self.label_num)
            one_hot_vectors.append(one_hot_vector)
        return np.array(one_hot_vectors)

    def _get_attention_mask(self, input_ids, max_seq_len):
        """Mask for padding."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        return [1] * len(input_ids) + [0] * (max_seq_len - len(input_ids))

    def _get_token_type_id(self, input_ids, max_seq_len, sep_token):
        """Segments: 0 for the first sequence, 1 for the second."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        segments = []
        current_segment_id = 0
        for input_id in input_ids:
            segments.append(current_segment_id)
            if input_id == sep_token:
                current_segment_id = 1
        return segments + [0] * (max_seq_len - len(input_ids))        
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, item_idx):
        example = {
            "input_ids": torch.tensor(self.examples[item_idx]["input_ids"]).long(),
            "attention_masks": torch.tensor(self.examples[item_idx]["attention_masks"]),
            "token_type_ids": torch.tensor(self.examples[item_idx]["token_type_ids"]),
            "seq_lens": self.examples[item_idx]["seq_lens"]
        }
        if "encoded_label" in self.examples[item_idx]:
            example.update({"encoded_label": torch.tensor(self.examples[item_idx]["encoded_label"], dtype=torch.float)})
        return example

In [None]:
train_dataset = BaiduEventDataset(dataset_path='./resources/duee_fin_train_preprocess.json', label_dict_path='./dictionary/trigger_tag.dict')
dev_dataset = BaiduEventDataset(dataset_path='./resources/duee_fin_dev_preprocess.json', label_dict_path='./dictionary/trigger_tag.dict')

In [None]:
dev_dataset[0]

In [None]:
models = [
    torch.load(f'./models/DuEE_fin/{model_name}/trigger-multilabel.bin').to(device)
    for model_name in model_dict.keys()
]
stacked_model = MultiLabelStackedClassification(models=models, num_labels=len(label_vocab))

In [None]:
@torch.no_grad()
def evaluate(model, eval_dataloader):
    model.eval()
    step = 0
    eval_acc = 0.0
    eval_f1 = 0.0
    eval_precision = 0.0
    eval_recall = 0.0
    eval_loss = 0.0
    for batch in eval_dataloader:
        loss, logits = model(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_masks'].to(device),
            token_type_ids=batch['token_type_ids'].to(device),
            labels=batch['encoded_label'].to(device)
        )
        
#         if n_gpu > 1:
#             loss = loss.mean()
        
        eval_loss += loss.item()
        pred_Y = (torch.sigmoid(logits).data > 0.5).cpu().numpy()
        true_Y = batch['encoded_label'].cpu().numpy()
        batch_size = true_Y.shape[0]
        batch_precision, batch_recall, batch_f1 = 0.0, 0.0, 0.0
        for text, t_ids, p_ids, seq_len in zip(batch["text"], true_Y, pred_Y, batch['seq_lens']):
            true_label, pred_label = [], []
            for pid in p_ids[1: seq_len - 1]:
                true_indices = np.argwhere(pid).flatten()
                pred_label.append([id2label[true_index] for true_index in true_indices])
            p_ret = extract_result_multilabel(text, pred_label)
            pred_event_types = set([(p["type"], ''.join(p["text"]), p["start"]) for p in p_ret])
            for tid in t_ids[1: seq_len - 1]:
                true_indices = np.argwhere(tid).flatten()
                true_label.append([id2label[true_index] for true_index in true_indices])
            t_ret = extract_result_multilabel(text, true_label)
            true_event_types = set([(t["type"], ''.join(t["text"]), t["start"]) for t in t_ret])
            count_predict = len(list(pred_event_types))
            count_true = len(list(true_event_types))
            count_correct = len(list(pred_event_types & true_event_types))
            p = count_correct / max(1, count_predict)  # precision
            r = count_correct / max(1, count_true)  # recall
            batch_precision += p
            batch_recall += r
            batch_f1 += 2 * r * p / max(1e-9, r + p) # f1 score
        eval_acc += accuracy_score(pred_Y.flatten(), true_Y.flatten())
        eval_precision += batch_precision / batch_size
        eval_recall += batch_recall / batch_size
        eval_f1 += batch_f1 / batch_size
        step += 1
    model.train()
    return eval_loss/step, eval_acc/step, eval_precision/step, eval_recall/step, eval_f1/step

In [None]:
### train model

def train(model, ds_train, ds_dev = None, n_epochs = 100, learning_rate = 1e-2, weight_decay = 0.01, batch_size = 1, eval_per_epoch = 1):
    model = model.to(device)
    
    train_sampler = RandomSampler(ds_train)
    train_dataloader = DataLoader(ds_train, sampler=train_sampler, batch_size=batch_size)
    
    eval_sampler = SequentialSampler(ds_dev)
    eval_dataloader = DataLoader(ds_dev, sampler=eval_sampler, batch_size=batch_size)
    
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]

#     if n_gpu > 1:
#         model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])

    optimizer_grouped_parameters = [{
        "params": model.parameters(),
        "lr": learning_rate, 
        "weight_decay": weight_decay,
        "apply_decay_param_fun": lambda x: x in decay_params
    }]
    optimizer = AdamW(optimizer_grouped_parameters)
#     scheduler = ReduceLROnPlateau(optimizer, "min")
    
    f1 = 0.0
    acc = 0.0
    precision = 0.0
    recall = 0.0
    tr_loss = 0.0
    global_step = 0
    model.train()
    model.zero_grad()
    postfix = {}
    for epoch in range(0, n_epochs):
        eval_flag = False
        train_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{n_epochs}")
        for batch in train_iterator:
            loss, logits = model(
                input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attention_masks'].to(device),
                token_type_ids=batch['token_type_ids'].to(device),
                labels=batch['encoded_label'].to(device)
            )
            
#             if n_gpu > 1:
#                 loss = loss.mean()
            
            loss.backward()
            optimizer.step()
            # scheduler.step(loss)

            tr_loss += loss.item()
            pred_Y = (torch.sigmoid(logits).data > 0.5).cpu().numpy()
            true_Y = batch['encoded_label'].cpu().numpy()
            batch_size = true_Y.shape[0]
            batch_precision, batch_recall, batch_f1 = 0.0, 0.0, 0.0
            for text, t_ids, p_ids, seq_len in zip(batch["text"], true_Y, pred_Y, batch['seq_lens']):
                true_label, pred_label = [], []
                for pid in p_ids[1: seq_len - 1]:
                    true_indices = np.argwhere(pid).flatten()
                    pred_label.append([id2label[true_index] for true_index in true_indices])
                p_ret = extract_result_multilabel(text, pred_label)
                pred_event_types = set([(p["type"], ''.join(p["text"]), p["start"]) for p in p_ret])
                for tid in t_ids[1: seq_len - 1]:
                    true_indices = np.argwhere(tid).flatten()
                    true_label.append([id2label[true_index] for true_index in true_indices])
                t_ret = extract_result_multilabel(text, true_label)
                true_event_types = set([(t["type"], ''.join(t["text"]), t["start"]) for t in t_ret])
                count_predict = len(list(pred_event_types))
                count_true = len(list(true_event_types))
                count_correct = len(list(pred_event_types & true_event_types))
                p = count_correct / max(1, count_predict)  # precision
                r = count_correct / max(1, count_true)  # recall
                batch_precision += p
                batch_recall += r
                batch_f1 += 2 * r * p / max(1e-9, r + p) # f1 score
            acc += accuracy_score(pred_Y.flatten(), true_Y.flatten())
            precision += batch_precision / batch_size
            recall += batch_recall / batch_size
            f1 += batch_f1 / batch_size
            model.zero_grad()

            postfix.update({"Avg loss": f"{tr_loss / (global_step + 1):.5f}", "Avg acc score": f"{acc / (global_step + 1):.5f}", "Avg precision score": f"{precision / (global_step + 1):.5f}", "Avg recall score": f"{recall / (global_step + 1):.5f}", "Avg f1 score": f"{f1 / (global_step + 1):.5f}"})
            if (
                not eval_flag
                and (global_step + 1) % len(train_dataloader) == 0
                and (epoch % eval_per_epoch) == 0
            ):
                if ds_dev is not None:
                    eval_loss, eval_acc, eval_precision, eval_recall, eval_f1 = evaluate(model, eval_dataloader)
                postfix.update({"Avg eval loss": f"{eval_loss:.5f}", "Avg eval acc": f"{eval_acc:.5f}", "Avg eval precision": f"{eval_precision:.5f}", "Avg eval recall": f"{eval_recall:.5f}", "Avg eval f1": f"{eval_f1:.5f}"})
                eval_flag = True
            train_iterator.set_postfix(postfix)
            global_step += 1

In [None]:
train(stacked_model, train_dataset, ds_dev=dev_dataset, n_epochs=30, batch_size=256)

In [None]:
torch.save(stacked_model.state_dict(), './models/DuEE_fin/stacked/stacked_trigger-multilabel.dict')

In [None]:
torch.cuda.empty_cache()