In [1]:
import json
from src import prototypical_batch_sampler
from collections import Counter, OrderedDict
from kaiser.src import dataio
from kaiser.src import utils
import torch
from torch import nn
from torch.optim import Adam

import sys
import random

sys.path.insert(0,'../')
sys.path.insert(0,'../../')

from kaiser.src.modeling import BertForJointShallowSemanticParsing, FrameBERT
from pprint import pprint

from kaiser.src.prototypical_loss import prototypical_loss as loss_fn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
if device != "cpu":
    torch.cuda.set_device(0)

### Korean FrameNet ###
	# contact: hahmyg@kaist, hahmyg@gmail.com #



Using TensorFlow backend.


In [2]:
# 실행시간 측정 함수
import time

_start_time = time.time()

def tic():
    global _start_time 
    _start_time = time.time()

def tac():
    t_sec = round(time.time() - _start_time)
    (t_min, t_sec) = divmod(t_sec,60)
    (t_hour,t_min) = divmod(t_min,60)
    
    result = '{}hour:{}min:{}sec'.format(t_hour,t_min,t_sec)
    return result

In [3]:
bert_io = utils.for_BERT(mode='train', language='multi')

used dictionary:
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lu2idx.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lufrmap.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_bio_frargmap.json


In [4]:
tic()
trn, dev, tst = dataio.load_data(srl='framenet', language='en', exem=False)
trn = random.sample(trn, k=500)
tst = random.sample(tst, k=100)
trn_data = bert_io.convert_to_bert_input_JointShallowSemanticParsing(trn)
tst_data = bert_io.convert_to_bert_input_JointShallowSemanticParsing(tst)
tac()

# of instances in trn: 19391
# of instances in dev: 2272
# of instances in tst: 6714
data example: [['Greece', 'wildfires', 'force', 'thousands', 'to', '<tgt>', 'evacuate', '</tgt>'], ['_', '_', '_', '_', '_', '_', 'evacuate.v', '_'], ['_', '_', '_', '_', '_', '_', 'Escaping', '_'], ['O', 'O', 'O', 'B-Escapee', 'O', 'X', 'O', 'X']]


'0hour:0min:3sec'

In [5]:
with open('./koreanframenet/resource/info/fn1.7_frame2idx.json', 'r') as f:
    frame2idx = json.load(f)
with open('./koreanframenet/resource/info/fn1.7_frame_definitions.json', 'r') as f:
    frame2definition = json.load(f)

def_data, def_y = bert_io.convert_to_bert_input_label_definition(frame2definition, frame2idx)

In [6]:
def get_y(data):
    y = []
    for instance in data:
        frame = False
        for i in instance[2]:
            if i != '_':
                frame = i
                break
        frameidx = frame2idx[frame]
        y.append(frameidx)
    return tuple(y)

with open('./koreanframenet/resource/info/fn1.7_frame2idx.json', 'r') as f:
    frame2idx = json.load(f)


all_y = dict(Counter(get_y(trn)))
target_frames = []
for i in all_y:
    count = all_y[i]
    if count >= 2:
        target_frames.append(i)

In [7]:
d = prototypical_batch_sampler.PrototypicalBatchSampler(classes_per_it_tr=4, classes_per_it_val=2,
                                                        num_support_tr=2, num_support_val=2, 
                                                        target_frames=target_frames, 
                                                        def_data=def_data, def_y=def_y)

used dictionary:
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lu2idx.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lufrmap.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_bio_frargmap.json


In [8]:
trn_y = d.get_y(trn)
# tst_y = d.get_y(tst)

In [9]:
trn_batch = d.gen_batch(trn_data, trn_y, mode='train')
# tst_batch = d.gen_batch(tst, tst_y, mode='test')

In [10]:
frameBERT_dir = '/disk/data/models/frameBERT/frameBERT_en'

frameBERT = FrameBERT.from_pretrained(frameBERT_dir,
                                      num_senses = len(bert_io.sense2idx), 
                                      num_args = len(bert_io.bio_arg2idx),
                                      lufrmap=bert_io.lufrmap, 
                                      frargmap = bert_io.bio_frargmap)

frameBERT.to(device)
frameBERT.eval()
print('')




In [11]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768)
        )
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layers(x)
        return x
    
mlp_model = MLP()
mlp_model.to(device)
mlp_model.train()
print('')




In [12]:
for episode in trn_batch:
    support_embs = []
    query_embs = []
    
    support_y, query_y = [],[]
    
    for class_indice in episode:
        support_examples, query_examples = class_indice

        query_inputs, _, query_token_type_ids, query_masks = query_examples[0][0]
        query_inputs = query_inputs.view(1,len(query_inputs)).to(device)
        query_token_type_ids = query_token_type_ids.view(1,len(query_token_type_ids)).to(device)
        query_masks = query_masks.view(1,len(query_masks)).to(device)
        
        query_frame = query_examples[0][1]
        query_y.append(query_frame)
    
        with torch.no_grad():
            _, query_emb = frameBERT(query_inputs, 
                                  token_type_ids=query_token_type_ids, 
                                  attention_mask=query_masks)
            query_emb = query_emb.view(-1)
        query_embs.append(query_emb)
        
        support_inputs, support_token_type_ids, support_masks = [],[],[]
        for i in range(len(support_examples)):
            support_input, _, _, _, _, support_token_type_id, support_mask = support_examples[i][0]
            support_inputs.append(support_input)
            support_token_type_ids.append(support_token_type_id)
            support_masks.append(support_mask)
            
            support_frame = support_examples[i][1]
            support_y.append(support_frame)
            
            
            
        support_inputs = torch.stack(support_inputs).to(device)
        support_token_type_ids = torch.stack(support_token_type_ids).to(device)
        support_masks = torch.stack(support_masks).to(device)
        
        with torch.no_grad():
            _, support_emb = frameBERT(support_inputs, 
                                  token_type_ids=support_token_type_ids, 
                                  attention_mask=support_masks)
        support_embs.append(support_emb)
        
    support_embs = torch.stack(support_embs)
    support_embs = support_embs.view(-1, 768)
    query_embs = torch.stack(query_embs)
    
    support_embs = mlp_model(support_embs)
    query_embs = mlp_model(query_embs)
    
    support_y = tuple(support_y)
    query_y = tuple(query_y)
    
    n_support = 2
    loss_val, acc_val = loss_fn(support_embs, query_embs, support_y, query_y, n_support)
    
    print(loss_val)
    print(acc_val)
            
#         s, q = [],[]
#         for i in range(len(support_examples)):
#             support_example = support_examples[i]
            
#             print('1')
#             print(len(support_example[0]))
#             print(support_example[0][0])
#             print(support_example[1])
#             print('end')
# #             break
        
#         print('support_examples')
#         print(len(support_examples))
#         print(support_examples)
        
#         print('query_examples')
#         print(len(query_examples))
#         print(query_examples)
    break

classes
tensor([ 524,  505, 1121, 1142])
support_idxs
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])

prototypes
tensor([[-0.2642,  0.1758, -0.1155,  ...,  0.1528,  0.0091,  0.1088],
        [-0.3860,  0.0462, -0.1096,  ...,  0.0209,  0.3764,  0.0566],
        [-0.4127,  0.1546, -0.0004,  ..., -0.0424,  0.0542, -0.2226],
        [-0.3157,  0.1858,  0.1979,  ...,  0.1914, -0.0928, -0.0042]],
       device='cuda:0', grad_fn=<CopyBackwards>)
query_embs
tensor([[-0.6290,  0.0325,  0.0662,  ...,  0.1483,  0.0024, -0.1080],
        [-0.6205,  0.0480, -0.0288,  ...,  0.1163,  0.0429,  0.0696],
        [-0.6698,  0.0900,  0.0536,  ...,  0.1522, -0.0614, -0.0286],
        [-0.4019,  0.0774,  0.1149,  ...,  0.1429, -0.0161, -0.1366]],
       device='cuda:0', grad_fn=<AddmmBackward>)
tensor(3.0586, device='cuda:0', grad_fn=<NegBackward>)
tensor(0.2500, device='cuda:0')


In [13]:
# z = torch.rand(4, 2,768)
# print(z)
# print(z.size())

# k = z.view(-1, 768)
# print(k)
# print(k.size())

In [14]:
# z = dict(Counter(d.trn_y))
# print(len(z))

# r = []
# for i in z:
#     count = z[i]
#     if count >=5:
#         r.append(i)
# with open('./data/target_frames.json','w') as f:
#     json.dump(r, f, ensure_ascii=False, indent=4)

In [15]:
# a = (0,1,2)

# for i in range(len(a)):
#     print(i, a[i])

In [16]:
# de = d.def_data[0]
# print(de)

In [17]:
# s = d.trn[2]
# print(s)

In [18]:
# for i in s[1]:
#     if i != '_':
#         frame = i
#         break
# print(frame)

In [19]:
# for f in d.frame2idx:
#     idx = d.frame2idx[f]
#     defi = d.frameidx2definition(idx)
#     print(defi)
#     break

In [20]:
# for i in range(100):
#     print(i)