In [32]:
import json
from src import prototypical_batch_sampler
from collections import Counter, OrderedDict
from kaiser.src import dataio
from kaiser.src import utils
import torch

import sys
import random

sys.path.insert(0,'../')
sys.path.insert(0,'../../')

from kaiser.src.modeling import BertForJointShallowSemanticParsing, FrameBERT
from pprint import pprint

In [2]:
# 실행시간 측정 함수
import time

_start_time = time.time()

def tic():
    global _start_time 
    _start_time = time.time()

def tac():
    t_sec = round(time.time() - _start_time)
    (t_min, t_sec) = divmod(t_sec,60)
    (t_hour,t_min) = divmod(t_min,60)
    
    result = '{}hour:{}min:{}sec'.format(t_hour,t_min,t_sec)
    return result

In [3]:
bert_io = utils.for_BERT(mode='train', language='multi')

used dictionary:
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lu2idx.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lufrmap.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_bio_frargmap.json


In [4]:
tic()
trn, dev, tst = dataio.load_data(srl='framenet', language='en', exem=False)
trn = random.sample(trn, k=500)
tst = random.sample(tst, k=100)
trn_data = bert_io.convert_to_bert_input_JointShallowSemanticParsing(trn)
tst_data = bert_io.convert_to_bert_input_JointShallowSemanticParsing(tst)
tac()

# of instances in trn: 19391
# of instances in dev: 2272
# of instances in tst: 6714
data example: [['Greece', 'wildfires', 'force', 'thousands', 'to', '<tgt>', 'evacuate', '</tgt>'], ['_', '_', '_', '_', '_', '_', 'evacuate.v', '_'], ['_', '_', '_', '_', '_', '_', 'Escaping', '_'], ['O', 'O', 'O', 'B-Escapee', 'O', 'X', 'O', 'X']]


'0hour:0min:3sec'

In [5]:
with open('./koreanframenet/resource/info/fn1.7_frame2idx.json', 'r') as f:
    frame2idx = json.load(f)
with open('./koreanframenet/resource/info/fn1.7_frame_definitions.json', 'r') as f:
    frame2definition = json.load(f)

def_data, def_y = bert_io.convert_to_bert_input_label_definition(frame2definition, frame2idx)

In [6]:
def get_y(data):
    y = []
    for instance in data:
        frame = False
        for i in instance[2]:
            if i != '_':
                frame = i
                break
        frameidx = frame2idx[frame]
        y.append(frameidx)
    return tuple(y)

with open('./koreanframenet/resource/info/fn1.7_frame2idx.json', 'r') as f:
    frame2idx = json.load(f)


all_y = dict(Counter(get_y(trn)))
target_frames = []
for i in all_y:
    count = all_y[i]
    if count >= 2:
        target_frames.append(i)

In [65]:
d = prototypical_batch_sampler.PrototypicalBatchSampler(classes_per_it_tr=4, classes_per_it_val=2,
                                                        num_support_tr=2, num_support_val=2, 
                                                        target_frames=target_frames, 
                                                        def_data=def_data, def_y=def_y)

used dictionary:
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lu2idx.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_lufrmap.json
	 /disk/kaiser/kaiser/src/../koreanframenet/resource/info/mul_bio_frargmap.json


In [58]:
trn_y = d.get_y(trn)
# tst_y = d.get_y(tst)

In [66]:
trn_batch = d.gen_batch(trn_data, trn_y, mode='train')
# tst_batch = d.gen_batch(tst, tst_y, mode='test')

In [30]:
frameBERT_dir = '/disk/data/models/frameBERT/frameBERT_en'

frameBERT = FrameBERT.from_pretrained(frameBERT_dir,
                                      num_senses = len(bert_io.sense2idx), 
                                      num_args = len(bert_io.bio_arg2idx),
                                      lufrmap=bert_io.lufrmap, 
                                      frargmap = bert_io.bio_frargmap)

frameBERT.eval()

FrameBERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

In [94]:
for episode in trn_batch:
    support_embs = []
    query_embs = []
    
    query_frame = []
    support_frame = []
    
    for class_indice in episode:
        support_examples, query_examples = class_indice

        query_inputs, _, query_token_type_ids, query_masks = query_examples[0][0]
        query_inputs = query_inputs.view(1,len(query_inputs))
        query_token_type_ids = query_token_type_ids.view(1,len(query_token_type_ids))
        query_masks = query_masks.view(1,len(query_masks))
        
        query_frame.append(query_examples[0][1])
    
        with torch.no_grad():
            _, query_emb = frameBERT(query_inputs, 
                                  token_type_ids=query_token_type_ids, 
                                  attention_mask=query_masks)
            query_emb = query_emb.view(-1)
        query_embs.append(query_emb)
        
        support_inputs, support_token_type_ids, support_masks = [],[],[]
        for i in range(len(support_examples)):
            support_input, _, _, _, _, support_token_type_id, support_mask = support_examples[i][0]
            support_inputs.append(support_input)
            support_token_type_ids.append(support_token_type_id)
            support_masks.append(support_mask)
            
        support_inputs = torch.stack(support_inputs)
        support_token_type_ids = torch.stack(support_token_type_ids)
        support_masks = torch.stack(support_masks)
        
        with torch.no_grad():
            _, support_emb = frameBERT(support_inputs, 
                                  token_type_ids=support_token_type_ids, 
                                  attention_mask=support_masks)
        support_embs.append(support_emb)
        
    support_embs = torch.stack(support_embs)
    support_embs = support_embs.view(-1, 768)
    query_embs = torch.stack(query_embs)
    
    print(query_frame)
            
#         s, q = [],[]
#         for i in range(len(support_examples)):
#             support_example = support_examples[i]
            
#             print('1')
#             print(len(support_example[0]))
#             print(support_example[0][0])
#             print(support_example[1])
#             print('end')
# #             break
        
#         print('support_examples')
#         print(len(support_examples))
#         print(support_examples)
        
#         print('query_examples')
#         print(len(query_examples))
#         print(query_examples)
    break

[662, 625, 1164, 787]


In [90]:
z = torch.rand(4, 2,768)
print(z)
print(z.size())

k = z.view(-1, 768)
print(k)
print(k.size())

tensor([[[0.6711, 0.6467, 0.0012,  ..., 0.2920, 0.0047, 0.0473],
         [0.9582, 0.0875, 0.3748,  ..., 0.6220, 0.4297, 0.8575]],

        [[0.3772, 0.6406, 0.7527,  ..., 0.7505, 0.7192, 0.1826],
         [0.6418, 0.9444, 0.9254,  ..., 0.9130, 0.4182, 0.3453]],

        [[0.7614, 0.5117, 0.4431,  ..., 0.1247, 0.3713, 0.5361],
         [0.3848, 0.5658, 0.7259,  ..., 0.6360, 0.4423, 0.4318]],

        [[0.2004, 0.9904, 0.9005,  ..., 0.3045, 0.3373, 0.1694],
         [0.5485, 0.8492, 0.4913,  ..., 0.9918, 0.3595, 0.0851]]])
torch.Size([4, 2, 768])
tensor([[0.6711, 0.6467, 0.0012,  ..., 0.2920, 0.0047, 0.0473],
        [0.9582, 0.0875, 0.3748,  ..., 0.6220, 0.4297, 0.8575],
        [0.3772, 0.6406, 0.7527,  ..., 0.7505, 0.7192, 0.1826],
        ...,
        [0.3848, 0.5658, 0.7259,  ..., 0.6360, 0.4423, 0.4318],
        [0.2004, 0.9904, 0.9005,  ..., 0.3045, 0.3373, 0.1694],
        [0.5485, 0.8492, 0.4913,  ..., 0.9918, 0.3595, 0.0851]])
torch.Size([8, 768])


In [12]:
# z = dict(Counter(d.trn_y))
# print(len(z))

# r = []
# for i in z:
#     count = z[i]
#     if count >=5:
#         r.append(i)
# with open('./data/target_frames.json','w') as f:
#     json.dump(r, f, ensure_ascii=False, indent=4)

In [13]:
a = (0,1,2)

for i in range(len(a)):
    print(i, a[i])

0 0
1 1
2 2


In [14]:
de = d.def_data[0]
print(de)

(tensor([  101, 10313, 36065, 24516, 17155,   169, 47684, 46767, 65314, 13135,
        10271, 10192, 20165, 12381, 10455, 12608, 10345, 10108, 10105, 16626,
        21849, 10146, 10464,   112,   187, 18381,   119,   112, 80489, 32296,
        10485, 13000, 10111, 54941, 10336, 10135,   169, 10680, 15790, 60067,
        10165, 19369,   119,   112,   112, 11982, 10921, 13221, 10261, 12153,
        10105, 18444, 10106, 10105, 23602, 78362,   112,   112, 15595, 13904,
        10115, 10426, 10108,   169, 18048, 10124, 14289, 10114, 10347,   169,
        32342, 22564, 10106, 11299, 49441, 10107,   119,   112, 11723, 10301,
        10379, 39063, 28088, 14336, 72762, 11031, 36839,   131,   112, 11149,
        12153, 10485, 12898, 30360, 17155,   119,   112,   102,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,

In [15]:
s = d.trn[2]
print(s)

AttributeError: 'PrototypicalBatchSampler' object has no attribute 'trn'

In [None]:
for i in s[1]:
    if i != '_':
        frame = i
        break
print(frame)

In [None]:
for f in d.frame2idx:
    idx = d.frame2idx[f]
    defi = d.frameidx2definition(idx)
    print(defi)
    break

In [None]:
for i in range(100):
    print(i)