In [2]:
import pandas as pd
import numpy as np
import json
import time
import math
import re
import os
import torch
from transformers import  AdamW
from transformers import BertTokenizer, BertConfig, BertModel, BertPreTrainedModel, BertForTokenClassification
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import get_linear_schedule_with_warmup
import ast

from ckiptagger import WS, POS, NER

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
train_df = pd.read_csv('EuSun_train.csv')
val_df = pd.read_csv('EuSun_valid.csv')
print(len(train_df),len(val_df))
train_df.head(2)

3893 977


Unnamed: 0.1,Unnamed: 0,news_ID,hyperlink,content,name,full_content
0,0,1,https://news.cnyes.com/news/id/4352432,0理財基金量化交易追求絕對報酬有效對抗牛熊市鉅亨網記者鄭心芸2019/07/05 22:35...,[],近年來投資市場波動越來越明顯，追求低波動、絕對報酬的量化交易備受注目。專家表示，採用量化交易...
1,1,2,https://udn.com/news/story/120775/4112519,10月13日晚間發生Uber Eats黃姓外送人員職災死亡案件 ### 省略內文 ### 北...,[],\r\r\n\r\r\n\r\r\n10月13日晚間發生Uber Eats黃姓外送人員職災死...


In [4]:
train_df['full_content'][3715]

'\r\r\n\r\r\n檢調偵辦國安局人員走私菸品案，陸續約談多名國安局人員，本週預計將會展開大規模動作，將遭民眾舉發的華航前資深副總經理羅雅美、前空品處副總邱彰信以及稱高層不知情的董事長兼總經理謝世謙列為貪汙罪被告，並盡速約談到案說明，揪出私菸案幕後「藏鏡人」。\xa0私菸案爆發後，檢調已陸續約談吳宗憲、張恒嘉等多名國安局人員，據中國時報報導，由於有民眾告發，羅雅美、邱彰信2人被列為貪汙罪他字案被告，最快本週就會約談到案，揪出私菸案幕後「藏鏡人」。\xa0另外，日前召開記者會代表華航說羅、邱2人對私菸案不知情的謝世謙，同樣因遭民眾告發涉案，也遭北檢簽分為他字案被告，預計將會是下波約談對象之一。最HOT話題在這！想跟上時事，快點我加入TVBS新聞LINE好友！\r\r\n'

In [5]:
train_df['name'][3715]

"['羅雅美', '邱彰信', '謝世謙', '吳宗憲', '張恒嘉']"

In [6]:
#ckip
# from ckiptagger import data_utils
# data_utils.download_data_gdown("./")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ws = WS("./data", disable_cuda=False)
pos = POS("./data", disable_cuda=False)
ner = NER("./data", disable_cuda=False)
# 18 entity types: https://github.com/ckiplab/ckiptagger/wiki/Entity-Types 
# 0 for others, 1 for mlp names
entity_dict = {'GPE':2,'PERSON':3,'DATE':4,'ORG':5,'CARDINAL':6,
'NORP':7,'LOC':8,'TIME':9,'FAC':10,'MONEY':11,'ORDINAL':12,'EVENT':13,
'WORK_OF_ART':14,'QUANTITY':15,'PERCENT':16,'LANGUAGE':17,'PRODUCT':18,'LAW':19}

In [7]:
# text = train_df['full_content'][3715]
# ws_results = ws([text])
# pos_results = pos(ws_results)
# ner_results = ner(ws_results, pos_results)
# for n in ner_results[0]:
#     print(n)

In [8]:
class NewsDataset(Dataset):
    def __init__(self, mode, tokenizer, data):
        assert mode in ["train", "test"]  
        self.mode = mode
        self.data = data
        self.len = len(self.data)
        self.tokenizer = tokenizer  

    def __len__(self):
        return self.len
    
    def clean_text(self,text):
        text = text.replace('\r','')
        text = text.replace('\n','')
        return text
    
    def __getitem__(self, idx):
        content = self.data['full_content'][idx]
        text = self.clean_text(content)
        inputs = self.tokenizer.encode_plus(text=text, max_length=512, return_tensors='pt', 
                                            pad_to_max_length = True, 
                                            return_token_type_ids = True,
                                            return_attention_mask=True)
        
        input_ids = inputs['input_ids'].squeeze(0)
        segments_tensor = inputs['token_type_ids'].squeeze(0)
        masks_tensor = inputs['attention_mask'].squeeze(0)
        
        if self.mode == 'train':
            name_true = self.data['name'][idx] #list of 
            name_true = ast.literal_eval(name_true)
            labels = [0]*512
            
            #ckip
            ws_results = ws([text])
            pos_results = pos(ws_results)
            ner_results = ner(ws_results, pos_results)
            for n in ner_results[0]:
                label_start = int(n[0])
                label_end = int(n[1])
                if label_end <= 510:
                    if n[2] == 'PERSON' and n[3] in name_true:
                        label_num = 1
                    else:
                        label_num = entity_dict[n[2]]
                    labels[label_start+1:label_end+1] = [label_num]* len(n[3])
            return input_ids, torch.tensor(labels)
            
#             if not name_true:
#                 return input_ids, torch.tensor(labels)
#             else:
#                 #print(idx,name_true)
#                 for name in name_true:
#                     index = [i.start() for i in re.finditer(name, text)]
#                     for i in range(510-3):
#                         if i in index:
#                             labels[i+1:i+len(name)+1] = [1]*len(name)
#                             i = i+len(name)
#                 return input_ids, torch.tensor(labels)
        else:
            newsid = self.data['news_ID'][idx]
            return newsid, input_ids


In [9]:
class BertForTokenClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = 20 ###
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss()

            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                #print(input_ids.shape,logits.shape,sequence_output.shape)
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), scores, (hidden_states), (attentions)


In [10]:
def asMinutes(s): #s = time.time()-start_time
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def trainIter(PRETRAINED_MODEL_NAME, trainloader, epochs, LR):
    config = BertConfig.from_pretrained(PRETRAINED_MODEL_NAME)
    model = BertForTokenClassification.from_pretrained(PRETRAINED_MODEL_NAME)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.train().to(device)
    min_dev_loss, best_epoch = 1000000, 0
    start_time = time.time()
    optimizer = AdamW(model.parameters(),lr = LR)
    
    for epoch in range(epochs):
        total_steps = len(trainloader) * epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps = total_steps)
        
        steps, train_loss = 0, 0
        for data in trainloader:
            input_ids, labels = [t.to(device) for t in data] 
            
            outputs = model(input_ids, labels=labels)
            loss = outputs[0]
            loss.backward()
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            
            train_loss += loss.item()
            steps += 1
            print(f'Epoch : {epoch+1}/{epochs}, setps:{steps}, time: {asMinutes(time.time()-start_time)}, Training Loss : {train_loss/steps}',  end = '\r')
        print('\n===========================================================')
        
        torch.save(model.state_dict(), f'bert_{epoch}') 

In [11]:
BS = 6
PRETRAINED_MODEL_NAME = "bert-base-chinese"
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, do_lower_case=True)
train = NewsDataset('train',tokenizer, train_df)
trainloader = torch.utils.data.DataLoader(train, batch_size=BS,drop_last = True,shuffle=True)

In [10]:
# for data in trainloader:
#     input_ids, labels = [t for t in data] 
#     #print(input_ids)
#     for l in labels:
#         print(l)
#     #print(torch.nonzero(labels[0]))
#     break

In [None]:
trainIter(PRETRAINED_MODEL_NAME, trainloader,epochs=10, LR=5e-5)

Epoch : 1/10 ,setps:257,time: 70m 23s, Training Loss : 0.8921828425348037

In [120]:
def evaluation(tokenizer, test_loader,model):
    prediction = pd.DataFrame(columns=['news_ID','name_pred'])
    with torch.no_grad():
        ids, pred = [], []
        for data in test_loader: 
            newsid, input_ids = [t.to(device) for t in data] 
            outputs = model(input_ids)
            ans_pred = torch.sigmoid(outputs[0])
            _, ans_pred = ans_pred.max(-1)
            #print(ans_pred)
            ans_pred = list(ans_pred[0])
            ans_index = [i for i,x in enumerate(ans_pred) if x == 1]
            #print(ans_index)
            text = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]
            text = list(text[0])
            ans_tokens = [text[x] for i,x in enumerate(ans_index)]
            #print(ans_tokens)
            names, answer = [], []
            if ans_tokens != []:
                ws_results = ws([text])
                pos_results = pos(ws_results)
                ner_results = ner(ws_results, pos_results)
                for n in ner_results[0]:
                    if n[2] == 'PERSON':
                        names.append(n[3])
            #print(newsid,names)
            for t in ans_tokens:
                for n in names:
                    if t in n and n not in answer and len(n) > 2:
                        if len(n) < 4:
                            answer.append(n)
            ids.append(newsid)
            pred.append(answer)
            print(f'news id : {newsid.item()}, names : {answer}', end = '\r')
        prediction['news_ID'] = ids
        prediction['name_pred'] = pred
    return prediction
        
def get_score(name_preds, true_preds):
    total_score=0
    no_names = 0
    for i in range(len(name_preds)):
        name_pred = name_preds[i]
        true_pred = true_preds[i]
        if name_pred == [] and true_pred == []:
            total_score += 1
            no_names += 1
        elif name_pred != [] and true_pred != []:
            #print(name_pred,true_pred, set(name_pred) & set(true_pred))
            intersection = list(set(name_pred) & set(true_pred))
            recall = len(intersection)/len(true_pred)
            precision = len(intersection)/len(name_pred)
            if recall != 0 and precision != 0:
                f1 = 2/(1/recall+1/precision) 
            else:
                f1 = 0
            total_score += f1
    return no_names, total_score

In [116]:
config = BertConfig.from_pretrained(PRETRAINED_MODEL_NAME)
model = BertForTokenClassification(config)       
model = model.cuda()
checkpoint = torch.load('bert_3')
model.load_state_dict(checkpoint)
model = model.eval()
BS=1
val = NewsDataset('test',tokenizer, val_df)
valloader = torch.utils.data.DataLoader(val, batch_size=BS,shuffle=False)
prediction = evaluation(tokenizer, valloader,model)

news id : 4953, names : ['許玉秀', '王隆昌'] '廖麗櫻'] '吳清吉', '蔡清華', '連定安', '鍾葦怡']

In [121]:
true_pred = [ast.literal_eval(i) for i in val_df['name']]
name_pred = prediction['name_pred'].to_list()
no_names, total_score = get_score(name_pred,true_pred)
print('num of empty names:', no_names)
print('total score:', total_score)

num of empty names: 907
total score: 947.8582251082247


(33, 35, 'DATE', '本週')
(63, 66, 'PERSON', '羅雅美')
(73, 76, 'PERSON', '邱彰信')
(236, 237, 'PERSON', '邱')
(68, 71, 'ORG', '空品處')
(173, 176, 'PERSON', '羅雅美')
(10, 13, 'ORG', '國安局')
(158, 162, 'WORK_OF_ART', '中國時報')
(196, 198, 'DATE', '本週')
(222, 224, 'DATE', '日前')
(234, 235, 'PERSON', '羅')
(237, 238, 'CARDINAL', '2')
(177, 180, 'PERSON', '邱彰信')
(231, 233, 'ORG', '華航')
(92, 95, 'PERSON', '謝世謙')
(151, 154, 'ORG', '國安局')
(309, 313, 'ORG', 'TVBS')
(145, 148, 'PERSON', '張恒嘉')
(247, 250, 'PERSON', '謝世謙')
(54, 56, 'ORG', '華航')
(141, 144, 'PERSON', '吳宗憲')
(27, 30, 'ORG', '國安局')
(180, 181, 'CARDINAL', '2')
