In [1]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import BertModel, BertTokenizer, BertConfig
import random
import functools
from torchcrf import CRF
from sklearn.metrics import f1_score, precision_score, recall_score
import gc
import math
import numpy as np
import os
from pyltp import SentenceSplitter
from langconv import Traditional2Simplified
import re
import collections
import transformer
import copy
import pickle
from utils import measure_event_table_filling

random.seed(666)
torch.manual_seed(666) # cpu
torch.cuda.manual_seed(666) #gpu
np.random.seed(666) #numpy

def strQ2B(ustring):
    """全角转半角"""
    rstring = ""
    for uchar in ustring:
        inside_code=ord(uchar)
        if inside_code == 12288:                              #全角空格直接转换            
            inside_code = 32 
        elif (inside_code >= 65281 and inside_code <= 65374): #全角字符（除空格）根据关系转化
            inside_code -= 65248

        rstring += chr(inside_code)
    return rstring

def sub_list_index(list, sub_list):
    matches = []
    for i in range(len(list)):
        if list[i] == sub_list[0] and list[i: i+len(sub_list)] == sub_list:
            matches.append(i)
    return matches

In [2]:
MAX_TOKENS_LENGTH = 500
MAX_SENT_NUM = 5
DEV_REAIO = 0.05
TEXT_NORM = True

# {'破产清算': {'公司名称', '公司行业', '公告时间', '受理法院', '裁定时间'},
#  '重大安全事故': {'伤亡人数', '公司名称', '公告时间', '其他影响', '损失金额'},
#  '股东减持': {'减持开始日期', '减持的股东', '减持金额'},
#  '股权质押': {'接收方', '质押开始日期', '质押方', '质押结束日期', '质押金额'},
#  '股东增持': {'增持开始日期', '增持的股东', '增持金额'},
#  '股权冻结': {'冻结开始日期', '冻结结束日期', '冻结金额', '被冻结股东'},
#  '高层死亡': {'公司名称', '死亡/失联时间', '死亡年龄', '高层人员', '高层职务'},
#  '重大资产损失': {'公司名称', '公告时间', '其他损失', '损失金额'},
#  '重大对外赔付': {'公司名称', '公告时间', '赔付对象', '赔付金额'}}

EVENT_TYPES = ['破产清算', '重大安全事故', '股东减持', '股权质押', '股东增持', '股权冻结', '高层死亡', '重大资产损失', '重大对外赔付']

EVENT_FIELDS = {
    '破产清算': (['公司名称', '公司行业', '公告时间', '受理法院', '裁定时间'], ['公司', '行业', '时间', '机构', '时间']),
    '重大安全事故': (['伤亡人数', '公司名称', '公告时间', '其他影响', '损失金额'], ['数字', '公司', '时间', '文本短语', '数字']),
    '股东减持': (['减持开始日期', '减持的股东', '减持金额'], ['时间', '公司/人名', '数字和单位']),
    '股权质押': (['接收方', '质押开始日期', '质押方', '质押结束日期', '质押金额'], ['公司/人名', '时间', '公司/人名', '时间', '数字']),
    '股东增持': (['增持开始日期', '增持的股东', '增持金额'], ['时间', '公司/人名', '数字和单位']),
    '股权冻结': (['冻结开始日期', '冻结结束日期', '冻结金额', '被冻结股东'], ['时间', '时间', '数字', '公司/人名']),
    '高层死亡': (['公司名称', '死亡/失联时间', '死亡年龄', '高层人员', '高层职务'], ['公司', '时间', '数字', '人名', '职称']),
    '重大资产损失': (['公司名称', '公告时间', '其他损失', '损失金额'], ['公司', '时间', '文本短语', '数字']),
    '重大对外赔付': (['公司名称', '公告时间', '赔付对象', '赔付金额'], ['公司', '时间', '公司/人名', '数字'])
}
NER_LABEL_LIST = ['O']
NER_LABEL2ID = {'O': 0}
for ee_type, (ee_roles, ee_role_types) in EVENT_FIELDS.items():
    for ee_role, ee_role_type in zip(ee_roles, ee_role_types):
        if 'B-' + ee_role in NER_LABEL_LIST:
            continue
        NER_LABEL_LIST.append('B-' + ee_role)
        NER_LABEL2ID[NER_LABEL_LIST[-1]] = len(NER_LABEL_LIST) - 1
        NER_LABEL_LIST.append('I-' + ee_role)
        NER_LABEL2ID[NER_LABEL_LIST[-1]] = len(NER_LABEL_LIST) - 1

EVENT_TYPE_FIELDS_PAIRS = []
for event_type in EVENT_TYPES:
    fields = EVENT_FIELDS[event_type][0]
    EVENT_TYPE_FIELDS_PAIRS.append((event_type, fields))

In [3]:
bert_config = BertConfig.from_pretrained('bert-base-chinese', cache_dir='./bert_base_chinese')
bert_config.num_hidden_layers = 4
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', cache_dir='./bert_base_chinese')
bert = BertModel.from_pretrained('bert-base-chinese', cache_dir='./bert_base_chinese', config=bert_config)

In [4]:
test_file = open('ccks 4_2 Data/event_element_dev_data.txt', mode='r', encoding='utf-8')
train_file = open('ccks 4_2 Data/event_element_train_data_label.txt', mode='r', encoding='utf-8')
train = []
test = []
for line in train_file.readlines():
    train.append(json.loads(line))
for line in test_file.readlines():
    test.append(json.loads(line))
#some label bug fix
train[1904]['events'][0]['增持的股东'] = '微医集团（浙江）有限公司'
train[3693]['events'][-1]['被冻结股东'] = '上海九川投资（集团）有限公司'
train[3693]['content'] = train[3693]['content'].replace('上海九川投资(集团)有限公司', '上海九川投资（集团）有限公司')

In [5]:
sentence_sum = 0
sentence_length_sum = 0
truncate_span = 0
span_sum = 0
for i, ins in enumerate(train):
    events = ins['events']
    content = (ins['content'])
    if TEXT_NORM:
        content_norm = Traditional2Simplified(strQ2B(content)).lower()
        assert len(content) == len(content_norm)
        
    
    UNK_ID = tokenizer.vocab['[UNK]']
    PAD_ID = tokenizer.vocab['[PAD]']
    ids = []
    for i, char in enumerate(content):
        if TEXT_NORM:
            char = content_norm[i]
        if char in tokenizer.vocab:
            ids.append(tokenizer.vocab[char])
        else:
            ids.append(UNK_ID)
    labels = [0 for _ in ids]
    event_cls = [0 for _ in EVENT_TYPES]
    for event in events:
        event_cls[EVENT_TYPES.index(event['event_type'])] = 1
        for event_role, span in event.items():
            if event_role == 'event_type' or event_role == 'event_id' or not span:
                continue
            span_sum += 1
            find_idx = -0.5
            while find_idx != -1:
                find_idx = content.find(span, int(find_idx + 1))
                if find_idx != -1:
                    assert content[find_idx: find_idx + len(span)] == span
                    labels[find_idx] = NER_LABEL2ID['B-' + event_role]
                    for k in range(1, len(span)):
                        labels[find_idx + k] = NER_LABEL2ID['I-' + event_role]
    
    assert len(ids) == len(content) == len(labels)
    sentences_ids = []
    sentences_labels = []
    
    sentences = []
    raw_sentences = list(filter(lambda x: bool(x), re.split('([^。；]+[。；])', content)))
    curr_pos = 0
    sentence_sum += len(raw_sentences)
    for sentence in raw_sentences:
        sentence_length_sum += len(sentence)
        # print(len(sentence))
        if len(sentence) < MAX_TOKENS_LENGTH:
            sentences.append(sentence)
            curr_pos += len(sentence)
        else:
            while len(sentence) > 0:
                sentences.append(sentence[:MAX_TOKENS_LENGTH])
                curr_pos += len(sentences[-1])
                if curr_pos < len(labels) and labels[curr_pos] != 0 and labels[curr_pos -1] != 0:
                    truncate_span += 1
                    #print(truncate_span / span_sum)
                sentence = sentence[MAX_TOKENS_LENGTH:]
    
    merge_sentences = []
    curr_sentence = ''
    for sentence in sentences:
        if len(sentence) + len(curr_sentence) <= MAX_TOKENS_LENGTH:
            curr_sentence += sentence
        else:
            merge_sentences.append(curr_sentence)
            curr_sentence = sentence
    if curr_sentence:
        merge_sentences.append(curr_sentence)
    
    curr_pos = 0
    ids_list = []
    labels_list = []
    attention_mask = []
    ids_length = []
    if len(merge_sentences) > 3:
        filted_merge_sentences = []
        for sentence in merge_sentences:
            if sum(labels[curr_pos: curr_pos + len(sentence)]) == 0:
                curr_pos += len(sentence)
            else:
                filted_merge_sentences.append(sentence)
                ids_list.append(ids[curr_pos: curr_pos + len(sentence)])
                labels_list.append(labels[curr_pos: curr_pos + len(sentence)])
                attention_mask.append([1 for _ in range(len(sentence))])
                ids_length.append(len(ids_list[-1]))
                if len(ids_list[-1]) < MAX_TOKENS_LENGTH:
                    pad_num = MAX_TOKENS_LENGTH - len(ids_list[-1])
                    ids_list[-1].extend([0 for _ in range(pad_num)])
                    labels_list[-1].extend([-1 for _ in range(pad_num)])
                    attention_mask[-1].extend([0 for _ in range(pad_num)])
                curr_pos += len(sentence)
                assert ids_length[-1] == sum(attention_mask[-1])
        
        assert len(filted_merge_sentences) == len(ids_list) == len(labels_list) == len(attention_mask)
        if len(filted_merge_sentences) > MAX_SENT_NUM: # truncate sentence
            filted_merge_sentences = filted_merge_sentences[:MAX_SENT_NUM]
            ids_list = ids_list[:MAX_SENT_NUM]
            labels_list = labels_list[:MAX_SENT_NUM]
            attention_mask = attention_mask[:MAX_SENT_NUM]
            ids_length = ids_length[:MAX_SENT_NUM]
        ins['merge_sentences'] = filted_merge_sentences
        assert len(filted_merge_sentences) == len(ids_list) == len(labels_list) == len(attention_mask)
    else:
        for sentence in merge_sentences:
            ids_list.append(ids[curr_pos: curr_pos + len(sentence)])
            labels_list.append(labels[curr_pos: curr_pos + len(sentence)])
            attention_mask.append([1 for _ in range(len(sentence))])
            ids_length.append(len(ids_list[-1]))
            if len(ids_list[-1]) < MAX_TOKENS_LENGTH:
                pad_num = MAX_TOKENS_LENGTH - len(ids_list[-1])
                ids_list[-1].extend([0 for _ in range(pad_num)])
                labels_list[-1].extend([-1 for _ in range(pad_num)])
                attention_mask[-1].extend([0 for _ in range(pad_num)])
            curr_pos += len(sentence)
            assert ids_length[-1] == sum(attention_mask[-1])
        ins['merge_sentences'] = merge_sentences
        assert len(merge_sentences) == len(ids_list) == len(labels_list) == len(attention_mask)
    ins['ids_list'] = ids_list
    ins['labels_list'] = labels_list
    ins['attention_mask'] = attention_mask
    ins['ids'] = ids
    ins['labels'] = labels
    ins['event_cls'] = event_cls
    ins['ids_length'] = ids_length
    assert ''.join(merge_sentences) == content
#doc_id: 2047486 for test
# chr(1627)
random.shuffle(train)
dev_num = math.ceil(len(train) * DEV_REAIO)
dev = train[:dev_num]
part_train = train[dev_num:]

In [6]:
def collate_label(labels_list, attention_mask, ids_list):
    span_token_drange_list = []
    for sent_idx, labels in enumerate(labels_list):
        mask = attention_mask[sent_idx]
        ids = ids_list[sent_idx]
        seq_len = len(labels)
        char_s = 0
        while char_s < seq_len:
            if mask[char_s] == 0:
                break
            entity_idx = labels[char_s]
            if entity_idx % 2 == 1:
                char_e = char_s + 1
                while char_e < seq_len and mask[char_e] == 1 and labels[char_e] == entity_idx + 1:
                    char_e += 1

                token_tup = tuple(ids[char_s:char_e])
                drange = (sent_idx, char_s, char_e)

                span_token_drange_list.append((token_tup, drange, entity_idx))

                char_s = char_e
            else:
                char_s += 1
    span_token_drange_list.sort(key=lambda x: x[-1])  # sorted by drange = (sent_idx, char_s, char_e)
    token_tup2dranges = collections.OrderedDict()
    for token_tup, drange, entity_idx in span_token_drange_list:
        # print(tokenizer.decode(token_tup), NER_LABEL_LIST[entity_idx])
        if token_tup not in token_tup2dranges:
            token_tup2dranges[token_tup] = []
        token_tup2dranges[token_tup].append((drange, entity_idx))
    return token_tup2dranges

# for i in range(10):
#     ret = collate_label(train[i]['labels_list'], train[i]['attention_mask'], train[i]['ids_list'])
#     print('------------')

def measure_dee_prediction(doc_decode_res, batch_dev, EVENT_TYPES, EVENT_FIELDS, EVENT_TYPE_FIELDS_PAIRS):
    for i, decode_res in enumerate(doc_decode_res):
        merge_sentences = batch_dev[i]['merge_sentences']
        for event_idx, res in enumerate(decode_res):
            if res is None:
                continue
            for fields in res:
                for j, drange in enumerate(fields):
                    if drange is None:
                        continue
                    sent_idx, char_s, char_e = drange
                    fields[j] = merge_sentences[sent_idx][char_s: char_e]
    
    gt_decode_res = []
    for ins in batch_dev:
        decode_res = [None for _ in EVENT_TYPES]
        events = ins['events']
        for event in events:
            event_type = event['event_type']
            event_idx = EVENT_TYPES.index(event_type)
            if decode_res[event_idx] == None:
                decode_res[event_idx] = []
            event_fields = EVENT_FIELDS[event_type]
            res = []
            for field in event_fields[0]:
                span = None
                if field in event and event[field]:
                    span = event[field]
                res.append(span)
            decode_res[event_idx].append(res)
        gt_decode_res.append(decode_res)
    eval_res = measure_event_table_filling(doc_decode_res, gt_decode_res, EVENT_TYPE_FIELDS_PAIRS, dict_return=True)
    return eval_res
#gg = measure_dee_prediction(copy.deepcopy(res), batch_dev, EVENT_TYPES, EVENT_FIELDS, EVENT_TYPE_FIELDS_PAIRS)

In [7]:
def decode_drange2text(doc_decode_res, batch_test):
    for i, decode_res in enumerate(doc_decode_res):
        merge_sentences = batch_test[i]['merge_sentences']
        for event_idx, res in enumerate(decode_res):
            if res is None:
                continue
            for fields in res:
                for j, drange in enumerate(fields):
                    if drange is None:
                        continue
                    sent_idx, char_s, char_e = drange
                    fields[j] = merge_sentences[sent_idx][char_s: char_e]

def preprocess_test(test, tokenizer):
    for i, ins in enumerate(test):
        print(list(ins.keys()))
        content = ins['content']
        UNK_ID = tokenizer.vocab['[UNK]']
        PAD_ID = tokenizer.vocab['[PAD]']
        ids = []
        for char in content:
            if char in tokenizer.vocab:
                ids.append(tokenizer.vocab[char])
            else:
                ids.append(UNK_ID)

        sentences_ids = []
        sentences = []
        raw_sentences = list(filter(lambda x: bool(x), re.split('([^。；]+[。；])', content)))
        curr_pos = 0
        for sentence in raw_sentences:
            if len(sentence) < MAX_TOKENS_LENGTH:
                sentences.append(sentence)
                curr_pos += len(sentence)
            else:
                while len(sentence) > 0:
                    sentences.append(sentence[:MAX_TOKENS_LENGTH])
                    curr_pos += len(sentences[-1])
                    sentence = sentence[MAX_TOKENS_LENGTH:]
        
        merge_sentences = []
        curr_sentence = ''
        for sentence in sentences:
            if len(sentence) + len(curr_sentence) <= MAX_TOKENS_LENGTH:
                curr_sentence += sentence
            else:
                merge_sentences.append(curr_sentence)
                curr_sentence = sentence
        if curr_sentence:
            merge_sentences.append(curr_sentence)

        curr_pos = 0
        ids_list = []
        attention_mask = []
        ids_length = []
        print(len(merge_sentences))
        if len(merge_sentences) > 3:
            filted_merge_sentences = []
            for sentence in merge_sentences:
                if sum(labels[curr_pos: curr_pos + len(sentence)]) == 0:
                    curr_pos += len(sentence)
                else:
                    filted_merge_sentences.append(sentence)
                    ids_list.append(ids[curr_pos: curr_pos + len(sentence)])
                    attention_mask.append([1 for _ in range(len(sentence))])
                    ids_length.append(len(ids_list[-1]))
                    if len(ids_list[-1]) < MAX_TOKENS_LENGTH:
                        pad_num = MAX_TOKENS_LENGTH - len(ids_list[-1])
                        ids_list[-1].extend([0 for _ in range(pad_num)])
                        attention_mask[-1].extend([0 for _ in range(pad_num)])
                    curr_pos += len(sentence)
                    assert ids_length[-1] == sum(attention_mask[-1])

            assert len(filted_merge_sentences) == len(ids_list) == len(attention_mask)
            if len(filted_merge_sentences) > MAX_SENT_NUM: # truncate sentence
                filted_merge_sentences = filted_merge_sentences[:MAX_SENT_NUM]
                ids_list = ids_list[:MAX_SENT_NUM]
                attention_mask = attention_mask[:MAX_SENT_NUM]
                ids_length = ids_length[:MAX_SENT_NUM]
            ins['merge_sentences'] = filted_merge_sentences
            assert len(filted_merge_sentences) == len(ids_list) == len(attention_mask)
        else:
            for sentence in merge_sentences:
                ids_list.append(ids[curr_pos: curr_pos + len(sentence)])
                attention_mask.append([1 for _ in range(len(sentence))])
                ids_length.append(len(ids_list[-1]))
                if len(ids_list[-1]) < MAX_TOKENS_LENGTH:
                    pad_num = MAX_TOKENS_LENGTH - len(ids_list[-1])
                    ids_list[-1].extend([0 for _ in range(pad_num)])
                    attention_mask[-1].extend([0 for _ in range(pad_num)])
                curr_pos += len(sentence)
                assert ids_length[-1] == sum(attention_mask[-1])
            ins['merge_sentences'] = merge_sentences
            assert len(merge_sentences) == len(ids_list) == len(attention_mask)
        ins['ids_list'] = ids_list
        ins['attention_mask'] = attention_mask
        ins['ids'] = ids
        ins['ids_length'] = ids_length
        assert ''.join(merge_sentences) == content  

def eval_save_test(origin_test, tokenizer, save_file_name):
    TEST_DOC_BATCH_SIZE = 2
    TEST_BATCH_NUM = math.ceil(len(origin_test) / TEST_DOC_BATCH_SIZE)
    TEST_SAVE_DIR = '%s/test_res' % OUTPUT_DIR
    if not os.path.exists(TEST_SAVE_DIR):
        os.mkdir(TEST_SAVE_DIR)
    save_file_name = os.path.join(TEST_SAVE_DIR, save_file_name)
    
    test = copy.deepcopy(origin_test)
    preprocess_test(test, tokenizer)
    model.eval()
    total_decode_res = []
    with tqdm(total=TEST_BATCH_NUM) as pbar:
        for batch_num in range(TEST_BATCH_NUM):
            batch_beg = batch_num * TEST_DOC_BATCH_SIZE
            batch_end = (batch_num + 1) * TEST_DOC_BATCH_SIZE
            batch_test = test[batch_beg: batch_end]

            _, doc_decode_res = model(batch_test, train_flag=False, use_gold=False)
            total_decode_res.extend(doc_decode_res)
            pbar.update()
    decode_drange2text(total_decode_res, test)
    test_copy = copy.deepcopy(origin_test)
    for ins, decode_res, ins_copy in zip(test, total_decode_res, test_copy):
        mult_ans = []
        for event_idx, mult_res in enumerate(decode_res):
            if mult_res is None or len(mult_res) < 1:
                continue
            event_type = EVENT_TYPES[event_idx]
            event_fields = EVENT_FIELDS[event_type][0]
            for res in mult_res:
                ans = {'event_type': event_type}
                for field_idx, span in enumerate(res):
                    if span is None:
                        continue
                    ans[event_fields[field_idx]] = span
                mult_ans.append(ans)
        ins_copy['events'] = mult_ans
    # json.save(test_copy, open(save_file_name, mode='w', encoding='utf-8'))
    with_content = open(save_file_name + '-with_content', mode='w', encoding='utf-8')
    with open(save_file_name, mode='w', encoding='utf-8') as f:
        for ins in test_copy:
            write_obj = {}
            write_obj['doc_id'] = ins['doc_id']
            write_obj['events'] = ins['events']
            f.write(json.dumps(write_obj, ensure_ascii=False) + '\n')
            with_content.write(json.dumps(ins, ensure_ascii=False) + '\n')
    return test_copy

In [8]:
class EventTable(nn.Module):
    def __init__(self, event_type, field_types, hidden_size):
        super(EventTable, self).__init__()

        self.event_type = event_type
        self.field_types = field_types
        self.num_fields = len(field_types)
        self.hidden_size = hidden_size

        self.event_cls = nn.Linear(hidden_size, 2)  # 0: NA, 1: trigger this event
        self.field_cls_list = nn.ModuleList(
            # 0: NA, 1: trigger this field
            [nn.Linear(hidden_size, 2) for _ in range(self.num_fields)]
        )

        # used to aggregate sentence and span embedding
        self.event_query = nn.Parameter(torch.Tensor(1, self.hidden_size))
        # used for fields that do not contain any valid span
        # self.none_span_emb = nn.Parameter(torch.Tensor(1, self.hidden_size))
        # used for aggregating history filled span info
        self.field_queries = nn.ParameterList(
            [nn.Parameter(torch.Tensor(1, self.hidden_size)) for _ in range(self.num_fields)]
        )

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.hidden_size)
        self.event_query.data.uniform_(-stdv, stdv)
        # self.none_span_emb.data.uniform_(-stdv, stdv)
        for fq in self.field_queries:
            fq.data.uniform_(-stdv, stdv)
    
    def forward(self, sent_context_emb=None, batch_span_emb=None, field_idx=None):
        assert (sent_context_emb is None) ^ (batch_span_emb is None)

        if sent_context_emb is not None:  # [num_spans+num_sents, hidden_size]
            # doc_emb.size = [1, hidden_size]
            doc_emb, _ = transformer.attention(self.event_query, sent_context_emb, sent_context_emb)
            doc_pred_logits = self.event_cls(doc_emb)
            doc_pred_logp = F.log_softmax(doc_pred_logits, dim=-1)

            return doc_pred_logp

        if batch_span_emb is not None:
            assert field_idx is not None
            # span_context_emb: [batch_size, hidden_size] or [hidden_size]
            if batch_span_emb.dim() == 1:
                batch_span_emb = batch_span_emb.unsqueeze(0)
            span_pred_logits = self.field_cls_list[field_idx](batch_span_emb)
            span_pred_logp = F.log_softmax(span_pred_logits, dim=-1)

            return span_pred_logp
        
class SentencePosEncoder(nn.Module):
    def __init__(self, hidden_size, max_sent_num=100, dropout=0.1):
        super(SentencePosEncoder, self).__init__()

        self.embedding = nn.Embedding(max_sent_num, hidden_size)
        self.layer_norm = transformer.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, batch_elem_emb, sent_pos_ids=None):
        if sent_pos_ids is None:
            num_elem = batch_elem_emb.size(-2)
            sent_pos_ids = torch.arange(
                num_elem, dtype=torch.long, device=batch_elem_emb.device, requires_grad=False
            )
        elif not isinstance(sent_pos_ids, torch.Tensor):
            sent_pos_ids = torch.tensor(
                sent_pos_ids, dtype=torch.long, device=batch_elem_emb.device, requires_grad=False
            )

        batch_pos_emb = self.embedding(sent_pos_ids)
        out = batch_elem_emb + batch_pos_emb
        out = self.dropout(self.layer_norm(out))
        return out

In [9]:
class DocEE(nn.Module):
    def __init__(self, config, basic_encoder):
        super().__init__()
        self.config = config
        self.init_eval_obj()
        
        self.basic_encoder = basic_encoder
        if config['use_crf']:
            self.crf = CRF(num_tags=len(NER_LABEL_LIST), batch_first=True)
        self.seq_labeler = nn.Sequential(
            nn.Linear(self.basic_encoder.config.hidden_size, self.basic_encoder.config.hidden_size // 2),
            nn.ReLU(),
            nn.Linear(self.basic_encoder.config.hidden_size // 2, len(NER_LABEL_LIST))
        )
        self.event_tables = nn.ModuleList([
            EventTable(event_type, self.config['EVENT_FIELDS'][event_type], self.basic_encoder.config.hidden_size)
            for event_type in self.config['EVENT_TYPES']
        ])
        self.sent_pos_encoder = SentencePosEncoder(
            basic_encoder.config.hidden_size, max_sent_num=config['MAX_SENT_NUM'], dropout=basic_encoder.config.hidden_dropout_prob
        )
    def init_eval_obj(self):
        self.eval_obj = {'event_type_pred': [], 'event_type_gt': [], 'ner_pred': [], 'ner_gt': [] }
    
    def forward(self, batch_train, train_flag=True, use_gold=True, ee_method='GreedyDec'):
        ner_loss, doc_ids_emb, doc_ner_pred, doc_sent_emb = self.do_ner(batch_train, train_flag, use_gold)
        if ee_method == 'GreedyDec':
            decode_loss, decode_res = self.greedy_dec(doc_sent_emb, doc_ner_pred, batch_train, train_flag)
            total_loss = ner_loss + decode_loss
            return total_loss, decode_res
        else:
            raise Exception('Not support event method: ' + ee_method)
    
    def pooling(self, emb, emb_length, attention_mask, sent_pos_ids):
        # emb.shape should be (batch_size, sent_length, hidden_size)
        pooling_type = self.config['POOLING']
        if pooling_type == 'max':
            attention_mask = attention_mask.unsqueeze(dim=-1)
            emb.masked_fill(attention_mask == 0, -float('inf'))
            pooling_emb = torch.max(emb, dim=1)[0]
        else:
            raise Exception('Not support pooling method: ' + pooling_type)
        
        pooling_emb = self.sent_pos_encoder(pooling_emb, sent_pos_ids)
        return pooling_emb
    
    def greedy_dec(self, doc_sent_emb, doc_ner_pred, batch_train, train_flag=True):
        doc_decode_res = []
        event_cls_loss = 0
        for sent_emb, ner_pred, ins in zip(doc_sent_emb, doc_ner_pred, batch_train):
            event_type_score = []
            for event_table in self.event_tables:
                 event_type_score.append(event_table(sent_context_emb=sent_emb))
            event_type_score = torch.cat(event_type_score, dim=0)
            
            if train_flag:
                event_type_label = torch.tensor(ins['event_cls'], device=event_type_score.device)
                event_cls_loss += F.nll_loss(event_type_score, event_type_label)
            else:
                span2drange = collate_label(ner_pred, ins['attention_mask'], ins['ids_list'])
                event_type_pred = torch.argmax(event_type_score, dim=-1).tolist()
                self.eval_obj['event_type_pred'].append(event_type_pred)
                if ins.get('event_cls') is not None:
                    self.eval_obj['event_type_gt'].append(ins['event_cls'])
                # event_type_pred = ins['event_cls']
                
                EVENT_TYPES = self.config['EVENT_TYPES']
                EVENT_FIELDS = self.config['EVENT_FIELDS']
                NER_LABEL_LIST = self.config['NER_LABEL_LIST']
                label2drange = {}
                for span, dranges in span2drange.items():
                    for drange, label_idx in dranges:
                        label = NER_LABEL_LIST[label_idx]
                        #assert label.startswith('B-')
                        label = label[2:]
                        if label not in label2drange:
                            label2drange[label] = []
                        label2drange[label].append(drange)
                decode_res = []
                for event_idx, pred in enumerate(event_type_pred):
                    if pred == 0:
                        decode_res.append(None)
                        continue
                    event_type = EVENT_TYPES[event_idx]
                    fields = EVENT_FIELDS[event_type][0]
                    field_res = []
                    for field in fields:
                        if field not in label2drange:
                            field_res.append(None)
                            continue
                        field_res.append(label2drange[field][0]) # greedy
                    decode_res.append([field_res])
                doc_decode_res.append(decode_res)
        event_cls_loss /= len(doc_sent_emb) #mean for batch
        return event_cls_loss, doc_decode_res
        
    
    def do_ner(self, batch_train, train_flag=True, use_gold=False):
        device = self.basic_encoder.device
        input_ids = []
        ner_label = []
        attention_mask = []
        ids_length = []
        sent_pos_ids = []
        doc_beg_list = [0]
        for ins in batch_train:
            if train_flag:
                ner_label.extend(ins['labels_list'])
            input_ids.extend(ins['ids_list'])
            attention_mask.extend(ins['attention_mask'])
            ids_length.extend(ins['ids_length'])
            sent_pos_ids.extend(list(range(len(ins['ids_list']))))
            doc_beg_list.append(len(input_ids))
        
        input_ids = torch.tensor(input_ids, device=device, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, device=device, dtype=torch.float)
        batch_emb = self.basic_encoder(input_ids, attention_mask=attention_mask)[0]
        ner_score = self.seq_labeler(batch_emb)
        pooling_emb = self.pooling(batch_emb, ids_length, attention_mask, sent_pos_ids)
        ner_loss = 0
        if train_flag:
            self.eval_obj['ner_gt'].append(ner_label)
            ner_label = torch.tensor(ner_label, device=device, dtype=torch.long)
            if self.config['use_crf']:
                attention_mask = attention_mask.to(torch.uint8)
                ner_loss = -self.crf(ner_score, ner_label, mask=attention_mask, reduction='mean')
            else:
                ner_loss = F.cross_entropy(ner_score.view(-1, len(self.config['NER_LABEL_LIST'])), ner_label.view(-1), ignore_index=-1)
        if use_gold:
            ner_pred = ner_label
        else:
            if self.config['use_crf']:
                ner_pred = self.crf.decode(ner_score, mask=attention_mask.to(dtype=torch.uint8))
            else:
                ner_pred = torch.argmax(ner_score, dim=-1).tolist()
        self.eval_obj['ner_pred'].append(ner_pred)
        doc_ids_emb = []
        doc_ner_pred = []
        doc_sent_emb = []
        for i in range(len(doc_beg_list) - 1):
            doc_beg, doc_end = doc_beg_list[i], doc_beg_list[i + 1]
            doc_ids_emb.append(batch_emb[doc_beg: doc_end])
            doc_ner_pred.append(ner_pred[doc_beg: doc_end])
            doc_sent_emb.append(pooling_emb[doc_beg: doc_end])
        return ner_loss, doc_ids_emb, doc_ner_pred, doc_sent_emb
            
config = {
    'use_crf': True,
    'SENT_BATCH_SIZE': 4,
    'POOLING': 'max',
    'EVENT_FIELDS': EVENT_FIELDS,
    'EVENT_TYPES': EVENT_TYPES,
    'NER_LABEL_LIST': NER_LABEL_LIST,
    'NER_LABEL2ID': NER_LABEL2ID,
    'MAX_SENT_NUM': MAX_SENT_NUM
}
model = DocEE(config, bert)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
print('fin')

fin


In [10]:
EPOCH = 30
DOC_BATCH_SIZE = 2
EVAL_DOC_BATCH_SIZE = 2

# dev = part_train
BATCH_NUM = math.ceil(len(part_train) / DOC_BATCH_SIZE)
EVAL_BATCH_NUM = math.ceil(len(dev) / EVAL_DOC_BATCH_SIZE)

OUTPUT_DIR = 'output'
MODEL_SAVE_DIR = '%s/save_model' % OUTPUT_DIR
EVAL_SAVE_DIR = '%s/save_eval' % OUTPUT_DIR
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)
if not os.path.exists(MODEL_SAVE_DIR):
    os.mkdir(MODEL_SAVE_DIR)
if not os.path.exists(EVAL_SAVE_DIR):
    os.mkdir(EVAL_SAVE_DIR)

EVAL_JSON_FILE = os.path.join(EVAL_SAVE_DIR, 'eval-%d.json')
EVAL_OBJ_FILE = os.path.join(EVAL_SAVE_DIR, 'eval-obj-%d.pkl')
TEST_FILE = 'test-%d.txt'

for epoch in range(EPOCH):
    print('%d-----------------' % epoch)
#     print_loss = 0
#     random.shuffle(part_train)
#     model.train()
#     with tqdm(total=BATCH_NUM) as pbar:
#         for batch_num in range(BATCH_NUM):
#             batch_beg = batch_num * DOC_BATCH_SIZE
#             batch_end = (batch_num + 1) * DOC_BATCH_SIZE
#             batch_train = part_train[batch_beg: batch_end]
            
#             #optimizer.zero_grad()
#             loss, _ = model(batch_train, train_flag=True, use_gold=True)
#             loss.backward()
#             optimizer.step()
#             model.zero_grad()
#             print_loss += loss.cpu().detach().numpy()
#             pbar.set_description('total_loss: %f' % (print_loss / (batch_num + 1)))
#             pbar.update()
    model.eval()
    model.init_eval_obj()
    total_decode_res = []
    with tqdm(total=EVAL_BATCH_NUM) as pbar:
        for batch_num in range(EVAL_BATCH_NUM):
            batch_beg = batch_num * EVAL_DOC_BATCH_SIZE
            batch_end = (batch_num + 1) * EVAL_DOC_BATCH_SIZE
            batch_dev = dev[batch_beg: batch_end]

            _, doc_decode_res = model(batch_dev, train_flag=False, use_gold=False)
            total_decode_res.extend(doc_decode_res)
            pbar.update()
    eval_json = measure_dee_prediction(total_decode_res, dev, EVENT_TYPES, EVENT_FIELDS, EVENT_TYPE_FIELDS_PAIRS)
    json.dump(eval_json, open(EVAL_JSON_FILE % epoch, mode='w', encoding='utf-8'))
    pickle.dump(model.eval_obj, open(EVAL_OBJ_FILE % epoch, mode='wb'))
    eval_save_test(test, tokenizer, TEST_FILE % epoch)
    print(eval_json[-1]['MicroF1'])

  1%|          | 1/99 [00:00<00:11,  8.70it/s]

0-----------------


100%|██████████| 99/99 [00:12<00:00,  7.86it/s]
 18%|█▊        | 65/370 [00:07<00:34,  8.81it/s]


KeyboardInterrupt: 

In [14]:
sum(labels)

2892

In [None]:
%%javascript
Jupyter.notebook.session.delete();