In [1]:
import json
import sys
sys.path.append('../')

import numpy as np
from KAIST_frame_parser.src import dataio, etri
from KAIST_frame_parser.src import targetid
import torch
from torch import nn
from torch.optim import Adam
import glob
import os
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from pytorch_pretrained_bert import BertTokenizer, BertConfig, BertModel
from tqdm import tqdm, trange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
n_gpu = torch.cuda.device_count()

from KAIST_frame_parser.koreanframenet.src import conll2textae
from KAIST_frame_parser.koreanframenet import koreanframenet

from KAIST_frame_parser.src.fn_modeling import BertForJointFrameParsing

from sklearn.metrics import accuracy_score
from seqeval.metrics import f1_score

from pprint import pprint

from datetime import datetime
start_time = datetime.now()

Using TensorFlow backend.


Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.
### Korean FrameNet ###
	# contact: hahmyg@kaist, hahmyg@gmail.com #



In [2]:
MAX_LEN = 256
batch_size = 6

In [3]:
# language = 'en'
# version = 1.7
# version = 1.5

language = 'ko'
version = 1.1
if language == 'en':
    framenet = 'fn'+str(version)
    fn_dir = '/disk_4/resource/fn'+str(version)
    trn_d, dev_d, tst_d = dataio.load_fn_data(fn_dir)
elif language == 'ko':
    framenet = 'kfn'+str(version)
    kfn = koreanframenet.interface(version=version)
    trn_d, dev_d, tst_d = kfn.load_data()
    
try:
    dir_path = os.path.dirname(os.path.abspath( __file__ ))
except:
    dir_path = '.'
    
# save your model to
model_dir = '/disk_4/resource/models/'

if language == 'en':
    print('### loading English FrameNet', str(version), 'data...')
    print('\t# of instances in training data:',len(trn_d))
    print('\t# of instances in dev data:',len(dev_d))
    print('\t# of instances in test data:',len(tst_d))


### loading Korean FrameNet 1.1 data...
	# of instances in training data: 17838
	# of instances in dev data: 2548
	# of instances in test data: 5097


In [4]:
print(trn_d[0])

[['태풍', 'Hugo가', '남긴', '피해들과', '회사', '내', '몇몇', '주요', '부서들의', '저조한', '실적들을', '반영하여,', 'Aetna', 'Life', 'and', 'Casualty', 'Co.의', '3분기', '순이익이', '182.6', '백만', '달러', '또는', '주당', '1.63', '달러로', '22', '%', '하락하였다.'], ['_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '이익.n', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_'], ['_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', 'Earnings_and_losses', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_'], ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Earner', 'I-Earner', 'I-Earner', 'I-Earner', 'I-Earner', 'B-Time', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']]


In [4]:
def data2tgt_data(input_data):
    result = []
    for item in input_data:
        ori_tokens, ori_lus, ori_frames, ori_args = item[0],item[1],item[2],item[3]
        for idx in range(len(ori_lus)):
            lu = ori_lus[idx]
            if lu != '_':
                if idx == 0:
                    begin = idx
                elif ori_lus[idx-1] == '_':
                    begin = idx
                end = idx
        tokens, lus, frames, args = [],[],[],[]
        for idx in range(len(ori_lus)):
            token = ori_tokens[idx]
            lu = ori_lus[idx]
            frame = ori_frames[idx]
            arg = ori_args[idx]
            if idx == begin:
                tokens.append('<tgt>')
                lus.append('_')
                frames.append('_')
                args.append('O')
                
            tokens.append(token)
            lus.append(lu)
            frames.append(frame)
            args.append(arg)
            
            if idx == end:
                tokens.append('</tgt>')
                lus.append('_')
                frames.append('_')
                args.append('O')
        sent = []
        sent.append(tokens)
        sent.append(lus)
        sent.append(frames)
        sent.append(args)
        result.append(sent)
    return result 
    
trn = data2tgt_data(trn_d)
dev = data2tgt_data(dev_d)
tst = data2tgt_data(tst_d)

print('### an example of training data')
print('[')
for i in trn[0]:
    print('\t',i)
print(']')

### an example of training data
[
	 ['Paula_Zahn', ':', 'Questions', 'about', 'the', 'facts', 'or', 'what', 'were', 'presented', 'as', 'facts', 'that', 'led', 'the', 'United', 'States', 'into', 'the', 'war', 'in', 'Iraq', '<tgt>', 'spilled', '</tgt>', 'into', 'open', 'warfare', 'today', 'on', 'the', 'Senate', 'floor', '.']
	 ['_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', 'spill.v', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_']
	 ['_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_', 'Fluidic_motion', '_', '_', '_', '_', '_', '_', '_', '_', '_', '_']
	 ['O', 'O', 'B-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'I-Fluid', 'O', 'O', 'O', 'B-Goal', 'I-Goal', 'I-Goal', 'B-Time', 'B-Place', 'I-Place', 

In [11]:
class for_BERT():
    
    def __init__(self, mode='training', language='ko', version=1.0):
        version = str(version)
        self.mode = mode
        if language == 'en':
            data_path = dir_path+'/koreanframenet/resource/info/fn'+version+'_'
        else:
            data_path = dir_path+'/koreanframenet/resource/info/kfn'+version+'_'
        with open(data_path+'lu2idx.json','r') as f:
            self.lu2idx = json.load(f)
        if version == '1.5':
            fname = dir_path+'/koreanframenet/resource/info/fn1.5_frame2idx.json'
        else:
            fname = dir_path+'/koreanframenet/resource/info/fn1.7_frame2idx.json'
        with open(fname,'r') as f:
            #self.sense2idx = json.load(f)
            self.frame2idx = json.load(f)
        with open(data_path+'lufrmap.json','r') as f:
            #self.lusensemap = json.load(f)
            self.lufrmap = json.load(f)
        with open(dir_path+'/koreanframenet/resource/info/fn1.7_fe2idx.json','r') as f:
            self.arg2idx = json.load(f)
        with open(dir_path+'/koreanframenet/resource/info/fn1.7_frargmap.json','r') as f:
            self.frargmap = json.load(f)
        with open(dir_path+'/koreanframenet/resource/info/fn1.7_bio_fe2idx.json','r') as f:
            self.bio_arg2idx = json.load(f)
        with open(dir_path+'/koreanframenet/resource/info/fn1.7_bio_frargmap.json','r') as f:
            self.bio_frargmap = json.load(f)

        self.idx2frame = dict(zip(self.frame2idx.values(),self.frame2idx.keys()))
        self.idx2lu = dict(zip(self.lu2idx.values(),self.lu2idx.keys()))
        self.idx2arg = dict(zip(self.arg2idx.values(),self.arg2idx.keys()))
        self.idx2bio_arg = dict(zip(self.bio_arg2idx.values(),self.bio_arg2idx.keys()))

        # load pretrained BERT tokenizer
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)
        
        # load BERT tokenizer with untokenizing frames
        never_split_tuple = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")
        added_never_split = []
        added_never_split.append('<tgt>')
        added_never_split.append('</tgt>')
        added_never_split_tuple = tuple(added_never_split)
        never_split_tuple += added_never_split_tuple
        vocab_file_path = dir_path+'/data/bert-multilingual-cased-dict-add-frames'
        self.tokenizer_with_frame = BertTokenizer(vocab_file_path, do_lower_case=False, max_len=256, never_split=never_split_tuple)

    def idx2tag(self, predictions, model='frameid'):
        if model == 'frameid':
            pred_tags = [self.idx2frame[p_i] for p in predictions for p_i in p]
        elif model == 'argclassification':
            pred_tags = [self.idx2arg[p_i] for p in predictions for p_i in p]
        elif model == 'argid':
            pred_tags = [self.idx2bio_arg[p_i] for p in predictions for p_i in p]
        return pred_tags
    
    def get_masks(self, datas, model='frameid'):
        if model == 'frameid':
            mapdata = self.lufrmap
            num_label = len(self.frame2idx)
        elif model == 'argclassification':
            mapdata = self.frargmap
            num_label = len(self.arg2idx)
        elif model == 'argid':
            mapdata = self.bio_frargmap
            num_label = len(self.bio_arg2idx)
        masks = []
        for idx in datas:
            mask = torch.zeros(num_label)
            try:
                candis = mapdata[str(int(idx[0]))]
            except KeyboardInterrupt:
                raise
            except:
                candis = mapdata[int(idx[0])]
            for candi_idx in candis:
                mask[candi_idx] = 1
            masks.append(mask)
        masks = torch.stack(masks)
        return masks    
    
    # bert tokenizer and assign to the first token
    def bert_tokenizer(self, text):
        orig_tokens = text.split(' ')
        bert_tokens = []
        orig_to_tok_map = []
        bert_tokens.append("[CLS]")
        for orig_token in orig_tokens:
            orig_to_tok_map.append(len(bert_tokens))
            bert_tokens.extend(self.tokenizer_with_frame.tokenize(orig_token))
        bert_tokens.append("[SEP]")

        return orig_tokens, bert_tokens, orig_to_tok_map
    
    def convert_to_bert_input_JointFrameParsing(self, input_data):
        tokenized_texts, lus, frames, args = [],[],[],[]
        orig_tok_to_maps = []
        for i in range(len(input_data)):    
            data = input_data[i]
            text = ' '.join(data[0])
            orig_tokens, bert_tokens, orig_to_tok_map = self.bert_tokenizer(text)
            orig_tok_to_maps.append(orig_to_tok_map)
            tokenized_texts.append(bert_tokens)

            ori_lus = data[1]    
            lu_sequence = []
            for i in range(len(bert_tokens)):
                if i in orig_to_tok_map:
                    idx = orig_to_tok_map.index(i)
                    l = ori_lus[idx]
                    lu_sequence.append(l)
                else:
                    lu_sequence.append('_')
            lus.append(lu_sequence)        

            if self.mode == 'training':
                ori_frames, ori_args = data[2], data[3]
                frame_sequence, arg_sequence = [],[]
                for i in range(len(bert_tokens)):
                    if i in orig_to_tok_map:
                        idx = orig_to_tok_map.index(i)
                        fr = ori_frames[idx]
                        frame_sequence.append(fr)
                        ar = ori_args[idx]
                        arg_sequence.append(ar)
                    else:
                        frame_sequence.append('_')
                        arg_sequence.append('X')
                frames.append(frame_sequence)
                args.append(arg_sequence)

        input_ids = pad_sequences([self.tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],
                              maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")
        orig_tok_to_maps = pad_sequences(orig_tok_to_maps, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post", value=-1)
        
        if self.mode =='training':
            arg_ids = pad_sequences([[self.bio_arg2idx.get(ar) for ar in arg] for arg in args],
                                    maxlen=MAX_LEN, value=self.bio_arg2idx["X"], padding="post",
                                    dtype="long", truncating="post")

        lu_seq, frame_seq = [],[]
        for sent_idx in range(len(lus)):
            lu_items = lus[sent_idx]
            lu = []
            for idx in range(len(lu_items)):
                if lu_items[idx] != '_':
                    if len(lu) == 0:
                        lu.append(self.lu2idx[lu_items[idx]])
            lu_seq.append(lu)
            
            if self.mode == 'training':
                frame_items, arg_items = frames[sent_idx], args[sent_idx]
                frame= []
                for idx in range(len(frame_items)):
                    if frame_items[idx] != '_':
                        if len(frame) == 0:
                            frame.append(self.frame2idx[frame_items[idx]])
                frame_seq.append(frame)

        attention_masks = [[float(i>0) for i in ii] for ii in input_ids]    
        data_inputs = torch.tensor(input_ids)
        data_orig_tok_to_maps = torch.tensor(orig_tok_to_maps)
        data_lus = torch.tensor(lu_seq)
        data_masks = torch.tensor(attention_masks)
        
        if self.mode == 'training':
            data_frames = torch.tensor(frame_seq)
            data_args = torch.tensor(arg_ids)
            bert_inputs = TensorDataset(data_inputs, data_orig_tok_to_maps, data_lus, data_frames, data_args, data_masks)
        else:
            bert_inputs = TensorDataset(data_inputs, data_orig_tok_to_maps, data_lus, data_masks)
        return bert_inputs

In [12]:
bert_io = for_BERT(mode='training', language=language, version=version)

In [13]:
trn_data = bert_io.convert_to_bert_input_JointFrameParsing(trn)

In [7]:
def train():
    model_path = model_dir+framenet+'/joint/'
    print('your model would be saved at', model_path)
    
    model = BertForJointFrameParsing.from_pretrained("bert-base-multilingual-cased", 
                                                     num_frames = len(bert_io.frame2idx), 
                                                     num_args = len(bert_io.bio_arg2idx),
                                                     lufrmap=bert_io.lufrmap, frargmap = bert_io.bio_frargmap)
    model.to(device);
    
    trn_data = bert_io.convert_to_bert_input_JointFrameParsing(trn)
    sampler = RandomSampler(trn)
    trn_dataloader = DataLoader(trn_data, sampler=sampler, batch_size=batch_size)
    
    # load optimizer
    FULL_FINETUNING = True
    if FULL_FINETUNING:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
             'weight_decay_rate': 0.0}
        ]
    else:
        param_optimizer = list(model.classifier.named_parameters()) 
        optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]
    optimizer = Adam(optimizer_grouped_parameters, lr=3e-5)
    
    
    # train 
    epochs = 50
    max_grad_norm = 1.0
    num_of_epoch = 0
    accuracy_result = []
    for _ in trange(epochs, desc="Epoch"):
        # TRAIN loop
        model.train()
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(trn_dataloader):
            # add batch to gpu
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_input_orig_tok_to_maps, b_input_lus, b_input_frames, b_input_args, b_input_masks = batch            
            # forward pass
            loss = model(b_input_ids, token_type_ids=None, lus=b_input_lus, frames=b_input_frames, args=b_input_args,
                         attention_mask=b_input_masks)
            # backward pass
            loss.backward()
            # track train loss
            tr_loss += loss.item()
            nb_tr_examples += b_input_ids.size(0)
            nb_tr_steps += 1
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
            # update parameters
            optimizer.step()
            model.zero_grad()
#             break
#         break

        # print train loss per epoch
        print("Train loss: {}".format(tr_loss/nb_tr_steps))
        model_saved_path = model_path+'epoch-'+str(num_of_epoch)+'-joint.pt'        
        torch.save(model, model_saved_path)
        num_of_epoch += 1
    print('...training is done')

In [8]:
train()

In [9]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=2).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

## TODO (0514) 
def logit2label(logit, mask):
    masking = np.multiply(logit, mask)
    masking[masking==0] = np.NINF
    sm = nn.Softmax()
    pred_logits = sm(masking).view(1,-1)
    score, label = pred_logits.max(1)
    score = float(score)
    return label, score
# def logit2label(logit, mask):
#     pred_idxs, pred_logits = [],[]
#     for idx in range(len(mask)):
#         if mask[idx] > 0:
#             pred_idxs.append(idx)
#             pred_logits.append(logit[idx].item())
#     pred_idxs = torch.tensor(pred_idxs)
#     pred_logits = torch.tensor(pred_logits)
#     sm = nn.Softmax()
#     pred_logits = sm(pred_logits).view(1,-1)
#     score, indice = pred_logits.max(1)
#     label = pred_idxs[indice]
#     score = float(score)
#     return label, score
    

def test():
    model_dir
    model_path = model_dir+framenet+'/joint/'
    result_path = model_path+'result/'
    models = glob.glob(model_path+'*.pt')
    results = []
    for m in models:
        print('model:', m)
        model = torch.load(m)
        model.eval()

        tst_data = bert_io.convert_to_bert_input_JointFrameParsing(tst)
        sampler = RandomSampler(tst)
        tst_dataloader = DataLoader(tst_data, sampler=sampler, batch_size=batch_size)

        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        
        pred_frames, true_frames, pred_args, true_args = [],[],[],[]
        for batch in tst_dataloader:
            batch = tuple(t.to(device) for t in batch)
            b_input_ids, b_orig_tok_to_maps, b_lus, b_frames, b_args, b_masks = batch

            with torch.no_grad():
                tmp_eval_loss = model(b_input_ids, token_type_ids=None, 
                                     lus=b_lus, attention_mask=b_masks)
                frame_logits, arg_logits = model(b_input_ids, token_type_ids=None, 
                                lus=b_lus, attention_mask=b_masks)
            frame_logits = frame_logits.detach().cpu().numpy()
            arg_logits = arg_logits.detach().cpu().numpy()
            
            gold_frame_ids = b_frames.to('cpu').numpy()
            gold_arg_ids = b_args.to('cpu').numpy()
            input_ids = b_input_ids.to('cpu').numpy()
            lufr_masks = dataio.get_masks(b_lus, bert_io.lufrmap, num_label=len(bert_io.frame2idx)).to(device)
            
            for b_idx in range(len(frame_logits)):
                input_id = input_ids[b_idx]
                frame_logit = frame_logits[b_idx]
                arg_logit = arg_logits[b_idx]
                lufr_mask = lufr_masks[b_idx]
                orig_tok_to_map = b_orig_tok_to_maps[b_idx]
                
                pred_frame, frame_score = logit2label(frame_logit, lufr_mask)
                frarg_mask = dataio.get_masks([pred_frame], bert_io.bio_frargmap, num_label=len(bert_io.bio_arg2idx)).to(device)[0]

                pred_arg_bert = []
                for logit in arg_logit:
                    label, score = logit2label(logit, frarg_mask)
                    pred_arg_bert.append(int(label))
                 
                #infer
                pred_arg,true_arg = [],[]
                for idx in orig_tok_to_map:
                    if idx != -1:
                        tok_id = int(input_id[idx])
                        if tok_id == 1:
                            pass
                        elif tok_id == 2:
                            pass
                        else:
                            pred_arg.append(pred_arg_bert[idx])
                            true_arg.append(gold_arg_ids[b_idx][idx])
                
                pred_frames.append([int(pred_frame)])
                pred_args.append(pred_arg)
                true_args.append(true_arg)
            true_frames.append(gold_frame_ids)
            
#             break
#         break

        pred_frame_tags = [bert_io.idx2frame[p_i] for p in pred_frames for p_i in p]
        valid_frame_tags = [bert_io.idx2frame[l_ii] for l in true_frames for l_i in l for l_ii in l_i]
        
        pred_arg_tags = [[bert_io.idx2bio_arg[p_i] for p_i in p] for p in pred_args]
        valid_arg_tags = [[bert_io.idx2bio_arg[v_i] for v_i in v] for v in true_args]

        acc = accuracy_score(pred_frame_tags, valid_frame_tags)
        f1 = f1_score(pred_arg_tags, valid_arg_tags)
        print("FrameId Accuracy: {}".format(accuracy_score(pred_frame_tags, valid_frame_tags)))
        print("ArgId F1: {}".format(f1_score(pred_arg_tags, valid_arg_tags)))
        
        result = m+'\tframeid:'+str(acc)+'\targid:'+str(f1)+'\n'
        results.append(result)
        
        epoch = m.split('-')[1]
        fname = model_path+str(epoch)+'-result.txt'
        with open(fname, 'w') as f:
            line = result
            f.write(line)
            line = 'gold'+'\t'+'pred'+'\n'
            f.write(line)
            for r in range(len(pred_frame_tags)):
                line = valid_frame_tags[r] + '\t' + pred_frame_tags[r]+'\n'
                f.write(line)
                line = str(valid_arg_tags[r]) + '\t' + str(pred_arg_tags[r])+'\n'
                f.write(line)
    fname = model_path+'result.txt'
    with open(fname, 'w') as f:
        for r in results:
            f.write(r)

    print('result is written to', fname)


In [10]:
test()

model: /disk_4/resource/models/fn1.7/joint/epoch-0-joint.pt


  app.launch_new_instance()


FrameId Accuracy: 1.0
ArgId F1: 0.4
model: /disk_4/resource/models/fn1.7/joint/epoch-1-joint.pt
FrameId Accuracy: 0.8333333333333334
ArgId F1: 0.09999999999999999
model: /disk_4/resource/models/fn1.7/joint/epoch-2-joint.pt
FrameId Accuracy: 1.0
ArgId F1: 0.27272727272727276
result is written to /disk_4/resource/models/fn1.7/joint/result.txt


In [None]:
end_time = datetime.now()
ptime = end_time - start_time
print('...all process is done during', str(ptime))