In [1]:
import copy
import json
import random

import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.optim import Adam

from collections import Counter

from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm_notebook as tqdm

from transformers import AutoTokenizer, AutoConfig, AutoModel

from crf_layer import CRFLayer
from multiLabelTokenClassfication import MultiLabelTokenClassification
from multiLabelSequenceClassfication import MultiLabelSequenceClassification

from utils import read_by_lines, extract_result_multilabel

In [2]:
device_Id = "1"
folder_name = "ernie-base"
tokenizer_model = "nghuyong/ernie-1.0"

shema_path = './dictionary/event_schema.json'
enerm_dict_path = './dictionary/enum_tag.dict'
trigger_dict_path = './dictionary/trigger_tag.dict'
role_dict_path = './dictionary/role_tag.dict'

enerm_model_path = f'./models/DuEE_fin/{folder_name}/enum.bin'
tigger_model_path = f'./models/DuEE_fin/{folder_name}/trigger-multilabel.bin'
role_model_path = f'./models/DuEE_fin/{folder_name}/role-multilabel-trick1.bin'

duee_fin_dev_path = './resources/duee_fin_dev.json'
duee_fin_dev_preprocess_path = './resources/duee_fin_dev_preprocess.json'

enum_role = "环节"
enum_event_type = "公司上市"
max_seq_len = 512

In [3]:
def load_dict(dict_path):
    """load_dict"""
    vocab = {}
    for line in open(dict_path, 'r', encoding='utf-8'):
        value, key = line.strip('\n').split('\t')
        vocab[key] = int(value)
    return vocab

In [4]:
label_enum_vocab = load_dict(dict_path=enerm_dict_path)
id2enumlabel = {val: key for key, val in label_enum_vocab.items()}
label_trigger_vocab = load_dict(dict_path=trigger_dict_path)
id2triggerlabel = {val: key for key, val in label_trigger_vocab.items()}
label_role_vocab = load_dict(dict_path=role_dict_path)
id2rolelabel = {val: key for key, val in label_role_vocab.items()}

In [5]:
def enum_data_process(dataset):
    """enum_data_process"""
    output = []
    for d_json in dataset:
        text = d_json["text"].lower().replace("\t", " ")
        labels = []
        for event in d_json.get("event_list", []):
            if event["event_type"] != "公司上市":
                continue
            for argument in event["arguments"]:
                role_type = argument["role"]
                if role_type == enum_role and argument["argument"] not in labels:
                    labels.append(argument["argument"])
        if labels:
            output.append({
                "id": d_json["id"],
                "sent_id": d_json["sent_id"],
                "text": text,
                "labels": labels
            })
    return output

In [6]:
def trigger_data_process(dataset):
    """data_process"""
    
    def label_data(data, start, l, _type):
        """label_data"""
        for i in range(start, start + l):
            suffix = "B-" if i == start else "I-"
            if isinstance(data[i], str):
                data[i] = []
            solt = "{}{}".format(suffix, _type)
            if solt not in data[i]:
                data[i].append(solt)
        return data
    
    def replace_control_chars(str):
        if str == '\u200b' or str == '\ufeff' or str == '\ue601' or str == '\u3000':
            return '[UNK]'
        else:
            return str

    output = []
    for d_json in dataset:
        _id = d_json["id"]
        text_a = [
            "，" if t == " " or t == "\n" or t == "\t" else replace_control_chars(t)
            for t in list(d_json["text"].lower())
        ]
        labels = ["O"] * len(text_a)
        for event in d_json.get("event_list", []):
            event_type = event["event_type"]
            start = event["trigger_start_index"]
            trigger = event["trigger"]
            if start >= 0:
                labels = label_data(labels, start, len(trigger), event_type)
        output.append({
            "id": d_json["id"],
            "sent_id": d_json["sent_id"],
            "text": d_json["text"],
            "tokens": text_a,
            "labels": labels
        })
    return output

In [7]:
def role_data_process(dataset):
    """data_process"""
    
    def label_data(data, start, l, _type):
        """label_data"""
        for i in range(start, start + l):
            suffix = "B-" if i == start else "I-"
            if isinstance(data[i], str):
                data[i] = []
            solt = "{}{}".format(suffix, _type)
            if solt not in data[i]:
                data[i].append(solt)
        return data
    
    def replace_control_chars(str):
        if str == '\u200b' or str == '\ufeff' or str == '\ue601' or str == '\u3000':
            return '[UNK]'
        else:
            return str

    output = []
    for d_json in dataset:
        _id = d_json["id"]
        text_a = [
            "，" if t == " " or t == "\n" or t == "\t" else replace_control_chars(t)
            for t in list(d_json["text"].lower())
        ]
        if len(d_json.get("event_list", [])) == 0:
            continue
         ### combine same event type
        event_type_mapping = {}
        for event in d_json.get("event_list", []):
            event_type = event['event_type']
            trigger = event['trigger']
            type_tuple = (event_type, trigger)
            if type_tuple not in event_type_mapping:
                event_type_mapping[type_tuple] = []
            for argument in event["arguments"]:
                if argument not in event_type_mapping[type_tuple]:
                    event_type_mapping[type_tuple].append(argument)

        for type_tuple, arguments in event_type_mapping.items():
            event_type = type_tuple[0]
            trigger = type_tuple[1]
            trigger_text = event_type+f"({trigger})："
            labels = ["O"] * len(text_a)
            for arg in arguments:
                role_type = arg["role"]
                if role_type == enum_role:
                    continue
                argument = arg["argument"]
                start = arg["argument_start_index"]
                labels = label_data(labels, start, len(argument), role_type)
            if d_json["id"] == "10be7f956da35f15fa4a9ad2a4556960":
                print(labels)
            text_a_trigger = [
                "，" if t == " " or t == "\n" or t == "\t" else t
                for t in list(trigger_text.lower())
            ]
            trigger_label = ["O"] * len(text_a_trigger)
            output.append({
                "id": d_json["id"],
                "sent_id": d_json["sent_id"],
                "text": trigger_text + d_json["text"],
                "event_type": event_type,
                "trigger": trigger,
                "tokens": text_a_trigger + text_a,
                "labels": trigger_label+labels
            })
    return output

In [8]:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)

In [9]:
PADDING = tokenizer.vocab[tokenizer.pad_token]
SEP = tokenizer.vocab[tokenizer.sep_token]

In [10]:
class BaiduEnermDataset(Dataset):
    
    def __init__(self, dataset_path, label_dict_path):
        self.label_vocab = load_dict(label_dict_path)
        self.label_num = max(self.label_vocab.values()) + 1
        self.examples = []
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = json.loads(f.read())
            preprocess_dataset = enum_data_process(dataset)
            for d_json in preprocess_dataset:
                text = d_json['text']
                input_ids = tokenizer(text, is_split_into_words=False, add_special_tokens=True, max_length=max_seq_len, truncation=True)['input_ids']
                tokens_input = input_ids + [PADDING] * (max_seq_len - len(input_ids))
                attention_masks = self._get_attention_mask(input_ids, max_seq_len)
                token_type_ids = self._get_token_type_id(input_ids, max_seq_len)
                example = {
                    "input_ids": tokens_input,
                    "attention_masks": attention_masks,
                    "token_type_ids": token_type_ids
                }
                example.update(d_json)
                if 'labels' in d_json:
                    labels = self.to_one_hot_vector(d_json['labels'])
                    example.update({"encoded_label": labels})
                self.examples.append(example)
                
    def to_one_hot_vector(self, labels):
        """Convert seq to one hot."""
        one_hot_vector = np.zeros(self.label_num)
        for label in labels:
            one_hot_vector[self.label_vocab[label]] = 1
        return one_hot_vector
                
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, item_idx):
        example = {
            "id": self.examples[item_idx]["id"],
            "sent_id": self.examples[item_idx]["sent_id"],
            "text": self.examples[item_idx]["text"],
            "input_ids": torch.tensor(self.examples[item_idx]["input_ids"]).long(),
            "attention_masks": torch.tensor(self.examples[item_idx]["attention_masks"]),
            "token_type_ids": torch.tensor(self.examples[item_idx]["token_type_ids"])
        }
        if "encoded_label" in self.examples[item_idx]:
            example.update({"encoded_label": torch.tensor(self.examples[item_idx]["encoded_label"], dtype=torch.float)})
        return example

    def _get_attention_mask(self, input_ids, max_seq_len):
        """Mask for padding."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        return [1] * len(input_ids) + [0] * (max_seq_len - len(input_ids))

    def _get_token_type_id(self, input_ids, max_seq_len):
        """Segments: 0 for the first sequence, 1 for the second."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        segments = []
        current_segment_id = 0
        for input_id in input_ids:
            segments.append(current_segment_id)
            if input_id == SEP:
                current_segment_id = 1
        return segments + [0] * (max_seq_len - len(input_ids)) 

In [11]:
class BaiduTriggerDataset(Dataset):
    
    def __init__(self, dataset_path, label_dict_path, ignore_index=-100):
        self.label_vocab = load_dict(label_dict_path)
        self.label_num = max(self.label_vocab.values()) + 1
        self.examples = []
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = json.loads(f.read())
            preprocess_dataset = trigger_data_process(dataset)
            for d_json in preprocess_dataset:
                tokens = d_json['tokens']
                input_ids = tokenizer(tokens, is_split_into_words=True, add_special_tokens=True, max_length=max_seq_len, truncation=True)['input_ids']
                tokens_input = input_ids + [PADDING] * (max_seq_len - len(input_ids))
                attention_masks = self._get_attention_mask(input_ids, max_seq_len)
                token_type_ids = self._get_token_type_id(input_ids, max_seq_len)
                example = {
                    "input_ids": tokens_input,
                    "attention_masks": attention_masks,
                    "token_type_ids": token_type_ids,
                    "seq_lens": len(tokens)
                }
                example.update(d_json)
                if 'labels' in d_json:
                    labels = d_json['labels']
                    labels = labels[:(max_seq_len - 2)]
                    encoded_label = ["O"] + labels + ["O"]
                    encoded_label = self.to_one_hot_vector(encoded_label, max_seq_len - 2 - len(labels))
                    example.update({"encoded_label": encoded_label})
                self.examples.append(example)
                
    def to_one_hot_vector(self, labels, zero_padding_len = 0):
        """Convert seq to one hot."""
        one_hot_vectors = []
        for label in labels:
            one_hot_vector = np.zeros(self.label_num)
            if isinstance(label, str):
                one_hot_vector[self.label_vocab.get(label, 0)] = 1
            elif isinstance(label, list):
                for l in label:
                    one_hot_vector[self.label_vocab.get(l, 0)] = 1
            one_hot_vectors.append(one_hot_vector)
        for _ in range(zero_padding_len):
            one_hot_vector = np.zeros(self.label_num)
            one_hot_vectors.append(one_hot_vector)
        return np.array(one_hot_vectors)

    def _get_attention_mask(self, input_ids, max_seq_len):
        """Mask for padding."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        return [1] * len(input_ids) + [0] * (max_seq_len - len(input_ids))

    def _get_token_type_id(self, input_ids, max_seq_len):
        """Segments: 0 for the first sequence, 1 for the second."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        segments = []
        current_segment_id = 0
        for input_id in input_ids:
            segments.append(current_segment_id)
            if input_id == SEP:
                current_segment_id = 1
        return segments + [0] * (max_seq_len - len(input_ids))        
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, item_idx):
        example = {
            "id": self.examples[item_idx]["id"],
            "sent_id": self.examples[item_idx]["sent_id"],
            "text": self.examples[item_idx]["text"],
            "input_ids": torch.tensor(self.examples[item_idx]["input_ids"]).long(),
            "attention_masks": torch.tensor(self.examples[item_idx]["attention_masks"]),
            "token_type_ids": torch.tensor(self.examples[item_idx]["token_type_ids"]),
            "seq_lens": self.examples[item_idx]["seq_lens"]
        }
        if "encoded_label" in self.examples[item_idx]:
            example.update({"encoded_label": torch.tensor(self.examples[item_idx]["encoded_label"], dtype=torch.float)})
        return example

In [12]:
class BaiduRoleDataset(Dataset):
    
    def __init__(self, dataset_path, label_dict_path, ignore_index=-100):
        self.label_vocab = load_dict(label_dict_path)
        self.label_num = max(self.label_vocab.values()) + 1
        self.examples = []
        with open(dataset_path, 'r', encoding='utf-8') as f:
            dataset = json.loads(f.read())
            preprocess_dataset = role_data_process(dataset)
            for d_json in preprocess_dataset:
                tokens = d_json['tokens']
                input_ids = tokenizer(tokens, is_split_into_words=True, add_special_tokens=True, max_length=max_seq_len, truncation=True)['input_ids']
                tokens_input = input_ids + [PADDING] * (max_seq_len - len(input_ids))
                attention_masks = self._get_attention_mask(input_ids, max_seq_len)
                token_type_ids = self._get_token_type_id(input_ids, max_seq_len)
                example = {
                    "input_ids": tokens_input,
                    "attention_masks": attention_masks,
                    "token_type_ids": token_type_ids,
                    "seq_lens": len(tokens)
                }
                example.update(d_json)
                if 'labels' in d_json:
                    labels = d_json['labels']
                    labels = labels[:(max_seq_len - 2)]
                    encoded_label = ["O"] + labels + ["O"]
                    encoded_label = self.to_one_hot_vector(encoded_label, max_seq_len - 2 - len(labels))
                    example.update({"encoded_label": encoded_label})
                self.examples.append(example)
                
    def to_one_hot_vector(self, labels, zero_padding_len = 0):
        """Convert seq to one hot."""
        one_hot_vectors = []
        for label in labels:
            one_hot_vector = np.zeros(self.label_num)
            if isinstance(label, str):
                one_hot_vector[self.label_vocab.get(label, 0)] = 1
            elif isinstance(label, list):
                for l in label:
                    one_hot_vector[self.label_vocab.get(l, 0)] = 1
            one_hot_vectors.append(one_hot_vector)
        for _ in range(zero_padding_len):
            one_hot_vector = np.zeros(self.label_num)
            one_hot_vectors.append(one_hot_vector)
        return np.array(one_hot_vectors)

    def _get_attention_mask(self, input_ids, max_seq_len):
        """Mask for padding."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        return [1] * len(input_ids) + [0] * (max_seq_len - len(input_ids))

    def _get_token_type_id(self, input_ids, max_seq_len):
        """Segments: 0 for the first sequence, 1 for the second."""
        if len(input_ids) > max_seq_len:
            raise IndexError("Token length more than max seq length!")
        segments = []
        current_segment_id = 0
        for input_id in input_ids:
            segments.append(current_segment_id)
            if input_id == SEP:
                current_segment_id = 1
        return segments + [0] * (max_seq_len - len(input_ids))        
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, item_idx):
        example = {
            "id": self.examples[item_idx]["id"],
            "sent_id": self.examples[item_idx]["sent_id"],
            "text": self.examples[item_idx]["text"],
            "event_type": self.examples[item_idx]["event_type"],
            "trigger": self.examples[item_idx]["trigger"],
            "input_ids": torch.tensor(self.examples[item_idx]["input_ids"]).long(),
            "attention_masks": torch.tensor(self.examples[item_idx]["attention_masks"]),
            "token_type_ids": torch.tensor(self.examples[item_idx]["token_type_ids"]),
            "seq_lens": self.examples[item_idx]["seq_lens"]
        }
        if "encoded_label" in self.examples[item_idx]:
            example.update({"encoded_label": torch.tensor(self.examples[item_idx]["encoded_label"], dtype=torch.float)})
        return example

In [13]:
dev_enerm_dataset = BaiduEnermDataset(dataset_path=duee_fin_dev_preprocess_path, label_dict_path=enerm_dict_path)
dev_trigger_dataset = BaiduTriggerDataset(dataset_path=duee_fin_dev_preprocess_path, label_dict_path=trigger_dict_path)
dev_role_dataset = BaiduRoleDataset(dataset_path=duee_fin_dev_preprocess_path, label_dict_path=role_dict_path)

[['B-约谈机构'], ['I-约谈机构'], ['I-约谈机构'], ['I-约谈机构'], ['I-约谈机构'], 'O', 'O', ['B-公司名称'], ['I-公司名称'], 'O', ['B-公司名称'], ['I-公司名称'], ['I-公司名称'], 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ['B-被约谈时间'], ['I-被约谈时间'], 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', ['B-公司名称'], ['I-公司名称'], 'O', ['B-公司名称'], ['I-公司名称'], 'O', ['B-公司名称'], ['I-公司名称'], 'O', ['B-公司名称'], ['I-公司名称'], 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O

In [14]:
len(dev_enerm_dataset)

66

In [15]:
len(dev_trigger_dataset)

1946

In [16]:
len(dev_role_dataset)

1803

In [17]:
def set_seed(seed = 42):
    """Set the seed for generating random numbers on all GPUs.

    It's safe to call this function if CUDA is not available; in that case, it is silently ignored.

    Args:
        seed (int, optional): random numbers on all GPUs. Defaults to 42.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [18]:
# setting device on GPU if available, else CPU
device = torch.device(f'cuda:{device_Id}' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    n_gpu = torch.cuda.device_count()
    print(torch.cuda.get_device_name(0))
    
    print('Memory Usage:')
    print('Allocated:', torch.cuda.memory_allocated(0)/1024**3, 'GB')
    print('Cached:   ', torch.cuda.memory_reserved(0)/1024**3, 'GB')
    
    print('CUDA Device Count:', n_gpu)
    
set_seed(seed=42)

Using device: cuda:1

Tesla V100-PCIE-32GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
CUDA Device Count: 3


In [19]:
@torch.no_grad()
def test_enerm(model, test_dataloader):
    from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score
    
    model.eval()
    step = 0
    eval_acc = 0.0
    eval_f1 = 0.0
    eval_precision = 0.0
    eval_recall = 0.0
    results = []
    test_iterator = tqdm(test_dataloader)
    for batch in test_iterator:
        _, logits = model(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_masks'].to(device),
            token_type_ids=batch['token_type_ids'].to(device),
            labels=batch['encoded_label'].to(device)
        )

        probs = torch.sigmoid(logits).data.cpu()
        probs_ids = (probs > 0.5).numpy()
        true_label = batch.get("encoded_label", None).cpu().numpy()
        probs = probs.numpy()
        pred_Y = probs_ids.flatten()
        true_Y = true_label.flatten()
        eval_acc += accuracy_score(pred_Y, true_Y)
        eval_precision += precision_score(pred_Y, true_Y, average="macro", zero_division=1)
        eval_recall += recall_score(pred_Y, true_Y, average="macro", zero_division=1)
        eval_f1 += f1_score(pred_Y, true_Y, average="macro")
        for id_, sent_id, text, label_probs, p_id in zip(batch['id'], batch['sent_id'], batch['text'], probs.tolist(), probs_ids.tolist()):
            true_indices = np.argwhere(p_id).flatten()
            labels = [id2enumlabel[true_index] for true_index in true_indices]
            results.append({"id": id_, "sent_id": sent_id, "text": text, "pred":{"probs": label_probs, "label": labels}})
        step += 1
    print({"Avg eval acc": f"{eval_acc/step:.2f}", "Avg eval precision": f"{eval_precision/step:.2f}", "Avg eval recall": f"{eval_recall/step:.2f}", "Avg eval f1": f"{eval_f1/step:.2f}"})
    return results

In [20]:
enum_model = torch.load(enerm_model_path).to(device)

test_enerm_sampler = SequentialSampler(dev_enerm_dataset)
test_enerm_dataloader = DataLoader(dev_enerm_dataset, sampler=test_enerm_sampler, batch_size = 512)
    
sentences_enum_data = test_enerm(enum_model, test_enerm_dataloader)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


{'Avg eval acc': '1.00', 'Avg eval precision': '1.00', 'Avg eval recall': '1.00', 'Avg eval f1': '1.00'}


In [31]:
@torch.no_grad()
def test_trigger(model, test_dataloader):
    from sklearn.metrics import accuracy_score
    
    model.eval()
    step = 0
    eval_acc = 0.0
    eval_f1 = 0.0
    eval_precision = 0.0
    eval_recall = 0.0
    results = []
    test_iterator = tqdm(test_dataloader)
    for batch in test_iterator:
        _, logits = model(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_masks'].to(device),
            token_type_ids=batch['token_type_ids'].to(device)
        )

        probs = torch.sigmoid(logits).data.cpu()
        probs_ids = (probs > 0.5).numpy()
        probs = probs.numpy()
        true_Y = batch['encoded_label'].cpu().numpy()
        batch_size = true_Y.shape[0]
        batch_precision, batch_recall, batch_f1 = 0.0, 0.0, 0.0
        for text, t_ids, p_ids, seq_len in zip(batch["text"], true_Y, probs_ids, batch['seq_lens']):
            true_label, pred_label = [], []
            for pid in p_ids[1: seq_len - 1]:
                true_indices = np.argwhere(pid).flatten()
                pred_label.append([id2triggerlabel[true_index] for true_index in true_indices])
            p_ret = extract_result_multilabel(text, pred_label)
            pred_event_types = set([(p["type"], ''.join(p["text"]), p["start"]) for p in p_ret])
            for tid in t_ids[1: seq_len - 1]:
                true_indices = np.argwhere(tid).flatten()
                true_label.append([id2triggerlabel[true_index] for true_index in true_indices])
            t_ret = extract_result_multilabel(text, true_label)
            true_event_types = set([(t["type"], ''.join(t["text"]), t["start"]) for t in t_ret])
            count_predict = len(list(pred_event_types))
            count_true = len(list(true_event_types))
            count_correct = len(list(pred_event_types & true_event_types))
            p = count_correct / max(1, count_predict)  # precision
            r = count_correct / max(1, count_true)  # recall
            batch_precision += p
            batch_recall += r
            batch_f1 += 2 * r * p / max(1e-9, r + p) # f1 score
        eval_acc += accuracy_score(probs_ids.flatten(), true_Y.flatten())
        eval_precision += batch_precision / batch_size
        eval_recall += batch_recall / batch_size
        eval_f1 += batch_f1 / batch_size
        for id_, sent_id, text, p_list, p_ids, seq_len in zip(batch['id'], batch['sent_id'], batch['text'], probs.tolist(), probs_ids.tolist(), batch['seq_lens']):
            prob_multi, label_multi = [], []
            for index, pid in enumerate(p_ids[1: seq_len - 1]):
                true_indices = np.argwhere(pid).flatten()
                prob_multi.append(p_list[index])
                label_multi.append([id2triggerlabel[true_index] for true_index in true_indices])
            results.append({"id": id_, "sent_id":sent_id, "text": text, "pred": {"probs": prob_multi, "labels": label_multi}})
        step += 1
    print({"Avg eval acc": f"{eval_acc/step:.2f}", "Avg eval precision": f"{eval_precision/step:.2f}", "Avg eval recall": f"{eval_recall/step:.2f}", "Avg eval f1": f"{eval_f1/step:.2f}"})
    return results

In [32]:
tigger_model = torch.load(tigger_model_path).to(device)

test_trigger_sampler = SequentialSampler(dev_trigger_dataset)
test_trigger_dataloader = DataLoader(dev_trigger_dataset, sampler=test_trigger_sampler, batch_size = 512)
    
sentences_tigger_data = test_trigger(tigger_model, test_trigger_dataloader)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))


{'Avg eval acc': '1.00', 'Avg eval precision': '0.62', 'Avg eval recall': '0.65', 'Avg eval f1': '0.63'}


In [37]:
@torch.no_grad()
def test_role(model, test_dataloader):
    from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score

    model.eval()
    step = 0
    eval_acc = 0.0
    eval_f1 = 0.0
    eval_precision = 0.0
    eval_recall = 0.0
    results = []
    test_iterator = tqdm(test_dataloader)
    for batch in test_iterator:
        loss, logits = model(
            input_ids=batch['input_ids'].to(device),
            attention_mask=batch['attention_masks'].to(device),
            token_type_ids=batch['token_type_ids'].to(device)
        )

        probs = torch.sigmoid(logits).data.cpu()
        probs_ids = (probs > 0.3).numpy()
        probs = probs.numpy()
        true_Y = batch['encoded_label'].cpu().numpy()
        batch_size = true_Y.shape[0]
        batch_precision, batch_recall, batch_f1 = 0.0, 0.0, 0.0
        for text, t_ids, p_ids, seq_len in zip(batch["text"], true_Y, probs_ids, batch['seq_lens']):
            true_label, pred_label = [], []
            for pid in p_ids[1: seq_len - 1]:
                true_indices = np.argwhere(pid).flatten()
                pred_label.append([id2rolelabel[true_index] for true_index in true_indices])
            p_ret = extract_result_multilabel(text, pred_label)
            pred_event_types = set([(p["type"], ''.join(p["text"]), p["start"]) for p in p_ret])
            for tid in t_ids[1: seq_len - 1]:
                true_indices = np.argwhere(tid).flatten()
                true_label.append([id2rolelabel[true_index] for true_index in true_indices])
            t_ret = extract_result_multilabel(text, true_label)
            true_event_types = set([(t["type"], ''.join(t["text"]), t["start"]) for t in t_ret])
            count_predict = len(list(pred_event_types))
            count_true = len(list(true_event_types))
            count_correct = len(list(pred_event_types & true_event_types))
            p = count_correct / max(1, count_predict)  # precision
            r = count_correct / max(1, count_true)  # recall
            batch_precision += p
            batch_recall += r
            batch_f1 += 2 * r * p / max(1e-9, r + p) # f1 score
        eval_acc += accuracy_score(probs_ids.flatten(), true_Y.flatten())
        eval_precision += batch_precision / batch_size
        eval_recall += batch_recall / batch_size
        eval_f1 += batch_f1 / batch_size
        for id_, sent_id, text, input_ids, event_type, trigger, p_list, p_ids, seq_len in zip(batch['id'], batch['sent_id'], batch['text'], batch['input_ids'], batch['event_type'], batch['trigger'], probs.tolist(), probs_ids.tolist(), batch['seq_lens']):
            prob_multi, label_multi = [], []
            for index, pid in enumerate(p_ids[1: seq_len - 1]):
                true_indices = np.argwhere(pid).flatten()
                prob_multi.append(p_list[index])
                label_multi.append([id2rolelabel[true_index] for true_index in true_indices])
            results.append({"id": id_, "sent_id":sent_id, "event_type": event_type, "trigger": trigger, "text": text, "tokens": tokenizer.convert_ids_to_tokens(input_ids), "pred": {"probs": prob_multi, "labels": label_multi}})
        step += 1
    print({"Avg eval acc": f"{eval_acc/step:.2f}", "Avg eval precision": f"{eval_precision/step:.2f}", "Avg eval recall": f"{eval_recall/step:.2f}", "Avg eval f1": f"{eval_f1/step:.2f}"})
    return results

In [38]:
role_model = torch.load(role_model_path).to(device)

test_role_sampler = SequentialSampler(dev_role_dataset)
test_role_dataloader = DataLoader(dev_role_dataset, sampler=test_role_sampler, batch_size = 512)
    
sentences_role_data = test_role(role_model, test_role_dataloader)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))


{'Avg eval acc': '1.00', 'Avg eval precision': '0.67', 'Avg eval recall': '0.74', 'Avg eval f1': '0.69'}


In [39]:
def event_normalization(doc):
    """event_merge"""
    for event in doc.get("event_list", []):
        argument_list = []
        argument_set = set()
        for arg in event["arguments"]:
            arg_str = "{}-{}".format(arg["role"], arg["argument"])
            if arg_str not in argument_set:
                argument_list.append(arg)
            argument_set.add(arg_str)
        event["arguments"] = argument_list

    event_list = sorted(
        doc.get("event_list", []),
        key=lambda x: len(x["arguments"]),
        reverse=True)
    new_event_list = []
    for event in event_list:
        event_type = event["event_type"]
        event_argument_set = set()
        for arg in event["arguments"]:
            event_argument_set.add("{}-{}".format(arg["role"], arg["argument"]))
        flag = True
        for new_event in new_event_list:
            if event_type != new_event["event_type"]:
                continue
            new_event_argument_set = set()
            for arg in new_event["arguments"]:
                new_event_argument_set.add("{}-{}".format(arg["role"], arg[
                    "argument"]))
            if len(event_argument_set & new_event_argument_set) == len(
                    new_event_argument_set):
                flag = False
        if flag:
            new_event_list.append(event)
    doc["event_list"] = new_event_list
    return doc

def predict_data_process(trigger_data, role_data, enum_data, schema_file):
    """predict_data_process"""
    pred_ret = []
    schema_data = read_by_lines(schema_file)
    print("trigger predict {} load.".format(len(trigger_data)))
    print("role predict {} load".format(len(role_data)))
    print("enum predict {} load".format(len(enum_data)))
    print("schema {} load from {}".format(len(schema_data), schema_file))

    schema, sent_role_mapping, sent_enum_mapping = {}, {}, {}
    for s in schema_data:
        d_json = json.loads(s)
        schema[d_json["event_type"]] = [r["role"] for r in d_json["role_list"]]

    # role depends on id and sent_id 
    for d_json in role_data:
        r_ret = extract_result_multilabel(d_json["text"], d_json["pred"]["labels"])
        role_ret = {}
        for r in r_ret:
            role_type = r["type"]
            if role_type not in role_ret:
                role_ret[role_type] = []
            role_ret[role_type].append("".join(r["text"]))
        _id = "{}\t{}\t{}\t{}".format(d_json["id"], d_json["sent_id"], d_json["event_type"], d_json["trigger"])
        if _id not in sent_role_mapping:
            sent_role_mapping[_id] = role_ret
        else:
            for role_type, vals in role_ret.items():
                if role_type in sent_role_mapping[_id]:
                    sent_role_mapping[_id][role_type].extend(vals)
                else:
                    sent_role_mapping[_id][role_type] = vals
            
    # process the enum_role data
    for d_json in enum_data:
        _id = "{}\t{}".format(d_json["id"], d_json["sent_id"])
        labels = d_json["pred"]["label"]
        sent_enum_mapping[_id] = labels

    # process trigger data
    for d_json in trigger_data:
        t_ret = extract_result_multilabel(d_json["text"], d_json["pred"]["labels"])
        pred_event_types = list(set([(t["type"], ''.join(t["text"])) for t in t_ret]))
        event_list = []
        _id = "{}\t{}".format(d_json["id"], d_json["sent_id"])
        for pred_event_type in pred_event_types:
            event_type = pred_event_type[0]
            trigger = pred_event_type[1]
            role_id = _id + "\t{}\t{}".format(event_type, trigger)
            role_list = schema[event_type]
            arguments = []
            for role_type, ags in sent_role_mapping.get(role_id, {}).items():
                if role_type not in role_list:
                    continue
                for arg in ags:
                    out = {"role": role_type, "argument": arg}
                    if out not in arguments:
                        arguments.append(out)
            # 特殊处理环节
            if arguments and event_type == enum_event_type:
                for label in sent_enum_mapping.get(_id, []):
                    arguments.append({
                        "role": enum_role,
                        "argument": label
                    })
            if arguments:
                event = {
                    "event_type": event_type,
                    "arguments": arguments,
                    "text": d_json["text"]
                }
                event_list.append(event)
        pred_ret.append({
            "id": d_json["id"],
            "sent_id": d_json["sent_id"],
            "text": d_json["text"],
            "event_list": event_list
        })
    doc_pred = {}
    for d in pred_ret:
        if d["id"] not in doc_pred:
            doc_pred[d["id"]] = {"id": d["id"], "event_list": []}
        doc_pred[d["id"]]["event_list"].extend(d["event_list"])

    # unfiy the all prediction results and save them
    doc_pred = [
        event_normalization(r)
        for r in doc_pred.values()
    ]
    print("submit data {} save".format(len(doc_pred)))
    return doc_pred

In [40]:
doc_pred = predict_data_process(sentences_tigger_data, sentences_role_data, sentences_enum_data, shema_path)

trigger predict 1946 load.
role predict 1803 load
enum predict 66 load
schema 13 load from ./dictionary/event_schema.json
submit data 1174 save


In [28]:
# true_sent_enum_data, true_sent_tigger_data, true_sent_role_data  = [], [], []
# with open(duee_fin_dev_preprocess_path, 'r', encoding='utf-8') as f:
#     dataset = json.loads(f.read())
#     preprocess_enum_dataset = enum_data_process(dataset)
#     preprocess_trigger_dataset = trigger_data_process(dataset)
#     preprocess_role_dataset = role_data_process(dataset)
#     for d_json in preprocess_enum_dataset:
#         true_sent_enum_data.append({"id": d_json['id'], "sent_id":d_json['sent_id'], "text": d_json['text'], "pred": {"label": d_json['label']}})
#     for d_json in preprocess_trigger_dataset:
#         labels = d_json['labels']
#         labels = labels[:(max_seq_len - 2)]
#         encoded_label = [[l] if isinstance(l , str) else l for l in labels]
#         true_sent_tigger_data.append({"id": d_json['id'], "sent_id":d_json['sent_id'], "text": d_json['text'], "pred": {"labels": encoded_label}})
#     for d_json in preprocess_role_dataset:
#         labels = d_json['labels']
#         labels = labels[:(max_seq_len - 2)]
#         encoded_label = [[l] if isinstance(l , str) else l for l in labels]
#         true_sent_role_data.append({"id": d_json['id'], "sent_id":d_json['sent_id'], "event_type": d_json['event_type'], "trigger": d_json['trigger'], "text": d_json['text'], "pred": {"labels": encoded_label}})

In [29]:
true_data_list = []
with open(duee_fin_dev_path, 'r', encoding='utf-8') as f:
    for line in f:
        json_data = json.loads(line)
        true_data_list.append(json_data)
pred_mapping_dict = {}
for doc in doc_pred:
    pred_mapping_dict[doc['id']] = doc

In [30]:
count = 0
for true_data in true_data_list:
    id_ = true_data['id']
    true_data_set = set([(argument['role'], argument['argument']) for true_event in true_data.get('event_list', []) for argument in true_event.get('arguments')])
    pred_data = pred_mapping_dict[id_]
    pred_data_set = set([(argument['role'], argument['argument']) for pred_event in pred_data.get('event_list', []) for argument in pred_event.get('arguments')])
    if len(list(true_data_set)) != len(list(pred_data_set)):
        print(id_)
        print(list(true_data_set))
        print(list(pred_data_set))
        count += 1

10be7f956da35f15fa4a9ad2a4556960
[('公司名称', '美团'), ('公司名称', '饿了么'), ('约谈机构', '北京市监局'), ('公司名称', '京东'), ('公司名称', '快手'), ('被约谈时间', '近日'), ('公司名称', '抖音'), ('公司名称', '微店')]
[('公司名称', '美团'), ('公司名称', '饿'), ('公司名称', '美'), ('公司名称', '饿了么'), ('约谈机构', '北京市监局'), ('公司名称', '京东'), ('约谈机构', '北京市市场监管局'), ('公司名称', '快手'), ('被约谈时间', '近日'), ('公司名称', '抖音'), ('公司名称', '微店')]
be11a1377c0d97bd5f5da146990597fa
[('事件时间', '7月22日'), ('被投资方', '天宜上佳'), ('投资方', '启赋安泰基金'), ('环节', '筹备上市'), ('事件时间', '2018年6月'), ('披露时间', '近日'), ('上市公司', '天宜上佳')]
[('事件时间', '7月22日'), ('事件时间', '近日'), ('环节', '筹备上市'), ('上市公司', '天宜上佳')]
364cf72ae0d2d5b94dcf6d815716b189
[('质押物', '股份'), ('质押股票/股份数量', '854.68万'), ('质押物所属公司', '翔鹭钨业'), ('质押方', '众达投资')]
[('质权方', '海通证券'), ('质押方', '潮州市众达投资有限公司'), ('质押物所属公司', '翔鹭钨业'), ('质押物', '股份'), ('质押股票/股份数量', '854.68万'), ('披露时间', '5月20日晚间')]
5259811fce42c3a4833e326fb8f847a9
[('中标公司', '华润'), ('披露日期', '近日'), ('中标标的', '医用耗材集中配送服务遴选供应商项目'), ('中标公司', '国药控股'), ('中标金额', '3.25亿元'), ('招标方', '河南科技大学第二附属医院'), ('中标公司', '国药器械'), 

[('环节', '正式上市'), ('破产公司', '誉衡药业'), ('事件时间', '2010年'), ('上市公司', '誉衡药业')]
[('破产公司', '誉衡药业'), ('披露时间', '最'), ('破产时间', '2000年')]
90457ba7aa92ca9fd56321dd72b93eb9
[('质权方', '海通证券的公司'), ('净亏损', '5.4亿元'), ('质押物', '有限售条件股份'), ('披露时间', '8月6日晚间'), ('事件时间', '近日'), ('财报周期', '上半年'), ('质押股票/股份数量', '1.416亿'), ('质押物占持股比', '99.81%'), ('披露时间', '7月30日'), ('质押方', '邢加兴'), ('质押物', '股份'), ('事件时间', '截至目前'), ('质押物所属公司', '拉夏贝尔'), ('披露时间', '8月6日'), ('质押物占总股比', '25.85%'), ('公司名称', '拉夏贝尔')]
[('事件时间', '2017年11月起'), ('事件时间', '从017年11月起'), ('事件时间', '近日'), ('事件时间', '8月6日'), ('质权方', '海通证券'), ('质押股票/股份数量', '1.416亿'), ('披露时间', '8'), ('质押物占持股比', '99.81%'), ('质押方', '邢加兴'), ('质押物', '股份'), ('事件时间', '8月'), ('质押物所属公司', '拉夏贝尔'), ('披露时间', '8月6日'), ('质押物占总股比', '25.85%'), ('事件时间', '截至目前')]
79ed2396795118056dbb8460050ce4c1
[('质权方', '兴业证券'), ('质押物占总股比', '14.84%'), ('质押物占总股比', '1.78%'), ('事件时间', '近日'), ('事件时间', '截至公告日'), ('质押股票/股份数量', '1.5亿'), ('质押物', '股份'), ('披露时间', '10月9日'), ('质押物占持股比', '69.30%'), ('质押股票/股份数量', '12.5194亿'), ('质押物所属公

[('财报周期', '上年同期'), ('净亏损', '992.4万元'), ('净亏损', '1,411.1万元'), ('亏损变化', '扩大'), ('财报周期', '截至今年6月底止中期')]
1d15b00cf74941f59a80dbbc59bb4479
[('交易完成时间', '8月31日'), ('交易金额', '420.4万元'), ('股票简称', '湘油泵'), ('披露时间', '9月1日'), ('每股交易价格', '37.84元'), ('交易股票/股份数量', '11.11万'), ('减持方', '袁春华'), ('减持部分占总股本比例', '0.11%')]
[('交易完成时间', '8月31日'), ('交易金额', '420.4万元'), ('股票简称', '湘油泵'), ('披露时间', '9月1日'), ('股票简称', '湖南机油泵'), ('减持方', '袁春华'), ('交易股票/股份数量', '11.11万'), ('每股交易价格', '37.84元/股'), ('减持部分占总股本比例', '0.11%')]
5d2a4299fcabb0627b8a8a1cd445f19d
[('质押物', '股'), ('质押股票/股份数量', '3000万'), ('披露时间', '6月22日'), ('质押物所属公司', '盛洋科技'), ('质押方', '叶利明'), ('质押股票/股份数量', '1500万'), ('质押方', '盛洋电器'), ('披露时间', '2020年6月22日')]
[('质押股票/股份数量', '3000万'), ('披露时间', '6月22日'), ('事件时间', '2020年6月22日'), ('质押方', '叶利明'), ('质押物所属公司', '盛洋科技'), ('质押股票/股份数量', '1500万'), ('质押物所属公司', '盛洋电器')]
cf393c2b24f7b4b7466ba80733d81ee3
[('披露日期', '8月23日'), ('中标公司', '光大国际'), ('中标标的', '甘肃凉州区静脉产业园PPP项目'), ('中标日期', '8月23日'), ('中标金额', '73934.74万元')]
[('招标方', '甘肃武威市凉州区'), ('中标标

In [31]:
count

643

In [1]:
for d_json in sentences_role_data:
    if d_json['id'] == '28488b2dd838cfdf73bba1670c5efd0c':
        r_ret = extract_result_multilabel(d_json["text"], d_json["pred"]["labels"])
        print(d_json["text"])
#         print(d_json["pred"]["labels"][d_json["text"].find("8-10美元"):d_json["text"].find("8-10美元")+1])
#         print(d_json["pred"]["probs"][d_json["text"].find("8-10美元")+1:d_json["text"].find("8-10美元")+2][0])
        print(d_json["pred"]["labels"])
        print(r_ret)

NameError: name 'sentences_role_data' is not defined

In [33]:
def evaluate_mergedata(predict_doc, true_merge_dataset_path):
    true_data_list = []
    with open(true_merge_dataset_path, 'r', encoding='utf-8') as f:
        for line in f:
            json_data = json.loads(line)
            true_data_list.append(json_data)
    predict_mapping_dict = {}
    for doc in predict_doc:
        predict_mapping_dict[doc['id']] = doc
    count_predict = 0
    count_true = 0
    count_correct = 0
    for true_data in true_data_list:
        if true_data['id'] not in predict_mapping_dict:
            if 'event_list' in true_data:
                print('error: ', true_data)
        else:
            predict_doc = predict_mapping_dict[true_data['id']]
            pred_data_set = set([(pred_event['event_type'], argument['role'], argument['argument']) for pred_event in predict_doc.get('event_list', []) for argument in pred_event.get('arguments')])
            true_data_set = set([(true_event['event_type'], argument['role'], argument['argument']) for true_event in true_data.get('event_list', []) for argument in true_event.get('arguments')])
            count_predict += len(list(pred_data_set))
            count_true += len(list(true_data_set))
            count_correct += len(list(pred_data_set & true_data_set))
    p = count_correct / max(1, count_predict)  # precision
    r = count_correct / max(1, count_true)  # recall
    f1 = 2 * r * p / max(1e-9, r + p) # f1 score
    s = count_true  # support

    print("{:>10}{:>10}{:>10}{:>10}\n".format("precision", "recall", "f1-score", "support"))
    formatter = "{:>10.3f}{:>10.3f}{:>10.3f}{:>10d}".format
    print(formatter(p, r, f1, s))
    print("")

In [34]:
evaluate_mergedata(doc_pred, duee_fin_dev_path)

 precision    recall  f1-score   support

     0.698     0.734     0.716      7061

