In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as Data
from tqdm import tqdm

import numpy as np
import json

USE_CUDA = True

In [2]:
def readfile(data):
    with open(data, "r", encoding="utf-8") as f:
        content = f.read().splitlines()
        
    return content

def schema_load(schema_root):
    raw_dict = "".join(readfile(schema_root))
    dict2json = "".join(raw_dict.split()[2:])

    json_acceptable_string = dict2json.replace("'", "\"")
    schema = json.loads(json_acceptable_string)
    
    return schema

def define_entity(schema):
    tag_type = list(schema['tagging'])
    
    entity_tag = []
    for k in list(schema['entity'].keys()):
        entity_tag.append(schema['entity'][k]['tag'])
        
    TAG = []
    for t in tag_type:
        for e in entity_tag:
            if t!='O':
                TAG.append(t+'-'+e)  
                
    TAG = [UNKOWN_TAG, PAD_TAG] + TAG + ['O']   

    return TAG

def tag2ix(TAG):
    tag_to_ix={t:i for i,t in enumerate(TAG)}
    return tag_to_ix

def define_relation(schema):
    relation_type = list(schema['relation'])
    
    relation_tag = []
    for k in list(schema['relation'].keys()):
        relation_tag.append(schema['relation'][k]['tag'])
    
    relation_tag = [REL_PAD] + [REL_NONE] + relation_tag
        
    return relation_tag

# ==================================================

def get_word_and_label(_content, start_w, end_w):
    word_list = []
    ent_list = []
    rel_list = []
    
    for word_set in _content[start_w:end_w]:
        word_set = word_set.split()
        if len(word_set)==1:
            word_list.append(' ')
            ent_list.append('O')
            rel_list.append(REL_NONE)
        
        else:
            word_list.append(word_set[0])
            ent_list.append(word_set[1])

            try:
                testerror = word_set[2]
            except IndexError:
                rel_list.append(REL_NONE)
            else:
                rel_list.append(word_set[2:])
    
    return word_list, ent_list, rel_list

def split_to_list(content):
    init = 0
    word_list = []
    ent_list = []
    rel_list = []

    for now_token, c in enumerate(content):
        if c=='':
            words, ents, rels = get_word_and_label(content, init, now_token)
            init = now_token+1
            word_list.append(words)
            ent_list.append(ents)
            rel_list.append(rels)
            
    return word_list, ent_list, rel_list

# ==================================================

def word2index(word_list):
    word_to_ix = {"<UNKNOWN>":0, "<PAD>":1}
    for sentence in word_list:
        for word in sentence:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)
                
    return word_to_ix

def dict_inverse(tag_to_ix):
    ix_to_tag = {v: k for k, v in tag_to_ix.items()}
    return ix_to_tag

def index2tag(indexs, ix_to):
    to_tags = [ix_to[i] for i in indexs.cpu().numpy()]
    return to_tags

# ==================================================

def find_max_len(word_list):
    max_len = 0
    for i in range(len(word_list)):
        if max_len<len(word_list[i]):
            max_len=len(word_list[i])
            
    return max_len

# ====== filter the length of sentence more than MAX_LEN =======

def filter_len(word_list):
    reserved_index = []
    for i in range(len(word_list)):
        if len(word_list[i])<MAX_LEN:
            reserved_index.append(i)
            
    return reserved_index


def filter_sentence(reserved_index, word_list, ent_list, rel_list):
    filter_word = list(word_list[i] for i in reserved_index)
    filter_ent = list(ent_list[i] for i in reserved_index)
    filter_rel = list(rel_list[i] for i in reserved_index)
    return filter_word, filter_ent, filter_rel

# ==================================================

def pad_seq(seq, isrel):
    if isrel:
        seq += [REL_NONE for i in range(MAX_LEN-len(seq))]
    else:
        seq += [PAD_TAG for i in range(MAX_LEN-len(seq))]
    return seq

def pad_all(filter_word, filter_ent, filter_rel):
    input_padded = [pad_seq(s, False) for s in filter_word]
    ent_padded = [pad_seq(s, False) for s in filter_ent]
    rel_padded = [pad_seq(s, True) for s in filter_rel]
    
    return input_padded, ent_padded, rel_padded

# ==================================================

def prepare_sequence(seq, to_ix):
    idxs = []
    for w in seq:
        if w not in to_ix:
            idxs.append(to_ix[UNKOWN_TAG])
        else:
            idxs.append(to_ix[w])
    
#     idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

def prepare_all(seqs, to_ix):
    seq_list = []
    for i in range(len(seqs)):
        seq_list.append(prepare_sequence(seqs[i], to_ix))
        
    seq_list = torch.stack(seq_list)
        
    return seq_list



def prepare_rel(rel_padded, to_ix):
    
    rel_ptr = torch.zeros(len(rel_padded), MAX_LEN, MAX_LEN, dtype=torch.long) 
    
    # 對當前的token，去比較之前所有出現過的entity，是否有關係，建成矩陣
    # [B*ML*ML]，第二維ML是當前token，第三維ML是根據當前token對之前出現過的entity紀錄關係，以index紀錄
    for i, rel_seq in enumerate(rel_padded):
        rel_dict = {}
        for j, token_seq in enumerate(rel_seq):
            rel_ptr[i][j][:j+1] = 1
            if token_seq != REL_NONE:
                for k, rel in enumerate(token_seq):

                    # if 是第一次出現，紀錄後面數字(標第幾對)和關係位置(A OR B)
                    # 假如下次出現又是同個關係位置(A)，依然紀錄
                    # 直到下次出現關係位置B，依照之前紀錄的A位置的字，然後在第三維去標關係

                    rel_token = rel.split('-')
                    if rel_token[1] not in rel_dict:
                        rel_dict[rel_token[1]] = {'rel':rel_token[0], 'loc':rel_token[2], 'idx':[j]}

                    elif rel_token[1] in rel_dict and rel_dict[rel_token[1]]['loc']==rel_token[2]:
                        rel_dict[rel_token[1]]['idx'].append(j)

                    else:
                        record_loc = rel_dict[rel_token[1]]['idx']
                        for idxx in record_loc:
                            rel_ptr[i][j][idxx] = to_ix[rel_token[0]]
                            
    return rel_ptr
                


# ==================================================

def dataload(input_var, ent_var, rel_var):
    torch_dataset = Data.TensorDataset(input_var, ent_var, rel_var)

    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               
        num_workers=2,       
        drop_last=True
    )
    
    return loader

# ==================================================
def softmax_entity(entity):
    entity = entity.view(BATCH_SIZE,ent_size).argmax(1)
    return entity

In [3]:
class Attn(nn.Module):
    def __init__(self, attn_input, attn_output, rel_size):
        super(Attn, self).__init__()
        
        self.attn_input = attn_input
        self.attn_output = attn_output
        self.rel_size = rel_size
        
        self.w1 = nn.Linear(self.attn_input, self.attn_output)
        self.w2 = nn.Linear(self.attn_input, self.attn_output)
        self.tanh = nn.Tanh()
        self.v = nn.Linear(self.attn_output, self.rel_size, bias=False)
        self.softmax = nn.LogSoftmax(dim=2)
        
        
    def forward(self, encoder_outputs):
        
        decoder = encoder_outputs[:,-1,:].unsqueeze(1)                       #B*1*(ts+LE) [128,1,8]
        encoder_score = self.w1(encoder_outputs)                             #B*now len*ATTN_OUT
        decoder_score = self.w2(decoder)                                     #B*1*ATTN_OUT
        energy = self.tanh(encoder_score+decoder_score)                      #B*now len*ATTN_OUT            
        
        energy = self.v(energy)                                              #B*now len*rel_size
        
        
        # 針對每個entity做softmax，去顯示他們的關係權重
        # 主要都會是rel_none
        # 對第二維(rel)做softmax
        p = self.softmax(energy)                                         #B*now len*rel_size
        
        return p
    

In [4]:
class Entity_Typing(nn.Module):
    def __init__(self, vocab_size, ent_tag_to_ix, embedding_dim, hidden_dim1, hidden_dim2, \
                 label_embed_dim, rel_tag_to_ix):
        
        super(Entity_Typing, self).__init__()
        self.embedding_dim = embedding_dim                   #E
        self.hidden_dim1 = hidden_dim1                       #h1
        self.hidden_dim2 = hidden_dim2                       #h2
        self.label_embed_dim = label_embed_dim               #LE
        self.vocab_size = vocab_size                         #vs
        self.ent_to_ix = ent_tag_to_ix
        self.ent_size = len(ent_tag_to_ix)                   #es
        self.rel_to_ix = rel_tag_to_ix
        self.rel_size = len(rel_tag_to_ix)                   #rs           
        
        self.dropout = nn.Dropout(p=0.3)
        self.bn = nn.BatchNorm1d(DENSE_OUT, momentum=0.5, affine=False)
        
        
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        
#         self.bilstm = nn.LSTM(embedding_dim, hidden_dim1 // 2,
#                             num_layers=2, bidirectional=True, batch_first=True, dropout=0.2)        
        self.bilstm = nn.GRU(embedding_dim, hidden_dim1 // 2,
                            num_layers=2, bidirectional=True, batch_first=True, dropout=0.2)
        
        
        self.dense = nn.Linear(hidden_dim1, DENSE_OUT)
        self.top_hidden = nn.LSTMCell(DENSE_OUT+label_embed_dim, hidden_dim2)          
        

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim2, self.ent_size)
        self.softmax = nn.LogSoftmax(dim=1)
        self.label_embed = nn.Linear(self.ent_size, self.label_embed_dim)
        
        self.attn = Attn(ATTN_IN, ATTN_OUT, self.rel_size)
        
        
    def init_hidden1(self):       
        hidden = torch.randn(2*2, BATCH_SIZE, self.hidden_dim1 // 2)    #4*B*(h1/2)
#         hidden = Variable(hidden.data, requires_grad=True)

        return (hidden.cuda(), hidden.cuda())if USE_CUDA else (hidden,hidden)
    
    def init_hidden2(self):       
        hidden = torch.randn(BATCH_SIZE, self.hidden_dim2)              #B*h2
#         hidden = Variable(hidden.data, requires_grad=True)

        return (hidden.cuda(), hidden.cuda())if USE_CUDA else (hidden,hidden)
    
    def init_label_embed(self):
        hidden = torch.zeros(BATCH_SIZE, self.label_embed_dim)          #B*LE
        return hidden.cuda()if USE_CUDA else hidden
    
    def create_entity(self):
        output_tensor = torch.zeros(BATCH_SIZE, MAX_LEN, self.ent_size)  #B*ML*es
        return output_tensor.cuda()if USE_CUDA else output_tensor
    
    def create_rel_matrix(self):
        rel_tensor = torch.zeros(BATCH_SIZE, MAX_LEN, MAX_LEN, self.rel_size)  #B*ML*ML*rs
        return rel_tensor.cuda()if USE_CUDA else rel_tensor
    
    
        
    def forward(self, sentence):
#         self.hidden1 = self.init_hidden1()                      #4*B*(h1/2)
        entity_tensor = self.create_entity()                    #B*ML*es
        rel_tensor = self.create_rel_matrix()                   #B*ML*ML*rs
        
        embeds = self.word_embeds(sentence)                     #B*ML*E,[128, 100, 20]
        
#         bilstm_out, self.hidden1 = self.bilstm(embeds, self.hidden1)
        bilstm_out, hidden1 = self.bilstm(embeds)
        # bilstm_out -> B*ML*h1,[128, 100, 10]
        # self.hidden1 -> ( 4*B*(h1/2), 4*B*(h1/2) )
        
        # bn
        bilstm_out = self.bn(bilstm_out)
        dense_out = self.dense(bilstm_out)                      #B*ML*DENSE_OUT,[128, 100, 100]
        
        
        encoder_sequence_l = [] 

        for length in range(MAX_LEN):
            now_token = dense_out[:,length,:]
            now_token = torch.squeeze(now_token, 1)
            if length==0:
                
#                 fake_hidden=(100)
#                 noise_x = random(100)
                self.hidden2 = self.init_hidden2()
                self.zero_label_embed = self.init_label_embed()
                combine_x = torch.cat((now_token, self.zero_label_embed),1)  #B*(DENSE_OUT+LE),[128, 103]
                
            else:
#                 fake_hidden=h
                self.hidden2 = (h_next, c_next)
                combine_x = torch.cat((now_token, label),1)

            h_next, c_next = self.top_hidden(combine_x, self.hidden2)    #B*h2,[128, 8]           
            to_tags = self.hidden2tag(h_next)                            #B*es,[128, 5]            
            ent_output = self.softmax(to_tags)                               #B*es,[128, 5]             
            label = self.label_embed(ent_output)                             #B*LE,[128, 3]
            
            s_ent_output = softmax_entity(ent_output)
            
            
            # Assignments to Variables are in-place operations.
            # Use that variable in lots of other contexts 
            # and some of the functions require it to not change. 
            to_tags_clone = to_tags.clone()
            label_clone = label.clone()
            
            
#             for i, tag in enumerate(s_ent_output):
#                 if tag==ent_tag_to_ix['O']:
#                     to_tags_clone[i] = torch.FloatTensor([-999999 * self.ent_size])
#                     label_clone[i] = torch.FloatTensor([-999999 * self.ent_size])
                    
            # relation layer
            encoder_sequence_l.append(torch.cat((to_tags,label),1))          
            encoder_sequence = torch.stack(encoder_sequence_l).t()     #B*len*(es+LE), [128,1,8]          

            # Calculate attention weights 
            attn_weights = self.attn(encoder_sequence)

        
            entity_tensor[:,length,:] = ent_output
            
            # rel_tensor[:,length, 頭~當前 ,:]
            rel_tensor[:,length,:length+1,:] = attn_weights

        
        
        '''NLLLoss input: Input: (N,C) where C = number of classes'''
        return entity_tensor.view(BATCH_SIZE*MAX_LEN, self.ent_size), \
               rel_tensor.view(BATCH_SIZE*MAX_LEN*MAX_LEN, self.rel_size)

In [5]:
root = '/notebooks/sinica/dataset/'
train_data = root+'facial.train'
dev_data = root+'facial.dev'
test_data = root+'facial.test'

relation_data_old = root+'facial_r.old.train'
# relation_data = root+'facial_r.train'
relation_data = root+'facial_r2.train'
schema_root = root+'schema.txt'
dev_data = root+'facial_r2.dev'


UNKOWN_TAG = "<UNKNOWN>"
PAD_TAG = "<PAD>"
REL_NONE = 'Rel-None'
REL_PAD = 'Rel-Pad'
rule = ('FUNC', 'ApplyTo', 'STAT')

schema = schema_load(schema_root)
ENT_TAG = define_entity(schema)
REL_TAG = define_relation(schema)
ent_tag_to_ix = tag2ix(ENT_TAG)
'''{'<PAD>': 1,
 '<UNKNOWN>': 0,
 'B-FUNC': 2,
 'B-STAT': 3,
 'I-FUNC': 4,
 'I-STAT': 5,
 'O': 6}'''
rel_tag_to_ix = tag2ix(REL_TAG)
'''{'ApplyTo': 2, 'Rel-None': 1, 'Rel-Pad': 0}'''

# ========hyper-parameter-set==========

ent_size = len(ent_tag_to_ix)
rel_size = len(rel_tag_to_ix)
MAX_LEN = 100
BATCH_SIZE = 18

EMBEDDING_DIM = 20
HIDDEN_DIM1 = 10
HIDDEN_DIM2 = 8
LABEL_EMBED_DIM = ent_size
DENSE_OUT = 100

ATTN_IN = ent_size+LABEL_EMBED_DIM
ATTN_OUT = 6

In [6]:
def preprocess(data):
    content = readfile(data)
    word_list, ent_list, rel_list = split_to_list(content)
    word_to_ix = word2index(word_list)
    reserved_index = filter_len(word_list)
    filter_word, filter_ent, filter_rel = filter_sentence(reserved_index, word_list, ent_list, rel_list)
    input_padded, ent_padded, rel_padded = pad_all(filter_word, filter_ent, filter_rel)
    #================================================
    input_var = prepare_all(input_padded, word_to_ix)
    ent_var = prepare_all(ent_padded, ent_tag_to_ix)
    rel_var = prepare_rel(rel_padded, rel_tag_to_ix)
    #================================================
    vocab_size = len(word_to_ix)
    
    return input_var, ent_var, rel_var, vocab_size, word_to_ix

def dev_preprocess(dev_data):
    dev_content = readfile(dev_data)
    word_list, ent_list, rel_list = split_to_list(dev_content)
    reserved_index = filter_len(word_list)
    filter_word, filter_ent, filter_rel = filter_sentence(reserved_index, word_list, ent_list, rel_list)
    input_padded, ent_padded, rel_padded = pad_all(filter_word, filter_ent, filter_rel)
    #================================================
    input_var = prepare_all(input_padded, word_to_ix)
    ent_var = prepare_all(ent_padded, ent_tag_to_ix)
    rel_var = prepare_rel(rel_padded, rel_tag_to_ix)
    
    return input_var, ent_var, rel_var

In [7]:
ix_to_ent_tag = dict_inverse(ent_tag_to_ix)
ix_to_rel_tag = dict_inverse(rel_tag_to_ix)
#===============================================
input_var, ent_var, rel_var, vocab_size, word_to_ix = preprocess(relation_data)
loader = dataload(input_var, ent_var, rel_var)

input_dev, ent_dev, rel_dev= dev_preprocess(dev_data)
dev_loader = dataload(input_dev, ent_dev, rel_dev)

In [8]:
model = Entity_Typing(vocab_size, ent_tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM1, HIDDEN_DIM2, \
              LABEL_EMBED_DIM, rel_tag_to_ix).cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
criterion_tag = nn.NLLLoss()
# criterion_rel = nn.CrossEntropyLoss()
criterion_rel = nn.NLLLoss()

In [10]:
import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [11]:
len(rel_var)

306

In [12]:
n_iters = 50
print_every = 12
all_losses = []
total_loss = 0 # Reset every plot_every iters
loss = 0
ent_loss = 0

start = time.time()

for epoch in tqdm(range(n_iters)):  
    for step, (batch_x, batch_ent, batch_rel) in enumerate(loader):
        model.train()
        optimizer.zero_grad()
        ent_output, rel_output = model(batch_x.cuda() if USE_CUDA else batch_x)
        
        batch_ent = batch_ent.view(BATCH_SIZE*MAX_LEN)
        batch_rel = batch_rel.view(BATCH_SIZE*MAX_LEN*MAX_LEN)
        
        loss_ent = criterion_tag(ent_output, batch_ent.cuda() if USE_CUDA else batch_ent)
        loss_rel = criterion_rel(rel_output, batch_rel.cuda() if USE_CUDA else batch_rel)
        loss = loss_ent+loss_rel
        
        loss.backward()
#         loss.backward(retain_graph=True)
        optimizer.step()
        
        if step % print_every == 1:
            all_losses.append(loss.cpu())
        #    print('%.4f| epoch: %d| step: %d| %s' % (loss, epoch, step, timeSince(start)))
        
    for step, (batch_x, batch_ent, batch_rel) in enumerate(dev_loader):
        model.eval()
        ent_output, rel_output = model(batch_x.cuda() if USE_CUDA else batch_x)
        val_loss_ent = criterion_tag(ent_output.cpu(), batch_ent.view(BATCH_SIZE*MAX_LEN)) 
        val_loss_rel = criterion_rel(rel_output.cpu(), batch_rel.view(BATCH_SIZE*MAX_LEN*MAX_LEN))
    
    
    print("epoch: %d | ent loss %.4f | rel loss %.4f | total loss %.4f" \
          % (epoch, loss_ent, loss_rel, loss))
    print("      %s  | val ent loss %.4f | val rel loss %.4f"
          % (" "*len(str(epoch)), val_loss_ent, val_loss_rel))


  2%|▏         | 1/50 [00:06<05:36,  6.87s/it]

epoch: 0 | ent loss 0.5988 | rel loss 0.0403 | total loss 0.6391
         | val ent loss 0.5435 | val rel loss 0.0364


  4%|▍         | 2/50 [00:14<05:39,  7.07s/it]

epoch: 1 | ent loss 0.3573 | rel loss 0.0153 | total loss 0.3725
         | val ent loss 0.3037 | val rel loss 0.0135


  6%|▌         | 3/50 [00:21<05:30,  7.04s/it]

epoch: 2 | ent loss 0.2297 | rel loss 0.0087 | total loss 0.2383
         | val ent loss 0.2190 | val rel loss 0.0083


  8%|▊         | 4/50 [00:28<05:28,  7.13s/it]

epoch: 3 | ent loss 0.1758 | rel loss 0.0078 | total loss 0.1836
         | val ent loss 0.1714 | val rel loss 0.0073


 10%|█         | 5/50 [00:35<05:22,  7.16s/it]

epoch: 4 | ent loss 0.1737 | rel loss 0.0067 | total loss 0.1804
         | val ent loss 0.1655 | val rel loss 0.0059


 12%|█▏        | 6/50 [00:43<05:17,  7.21s/it]

epoch: 5 | ent loss 0.1485 | rel loss 0.0062 | total loss 0.1547
         | val ent loss 0.1507 | val rel loss 0.0055


 14%|█▍        | 7/50 [00:50<05:12,  7.27s/it]

epoch: 6 | ent loss 0.1332 | rel loss 0.0058 | total loss 0.1390
         | val ent loss 0.1266 | val rel loss 0.0048


 16%|█▌        | 8/50 [00:58<05:06,  7.29s/it]

epoch: 7 | ent loss 0.1374 | rel loss 0.0064 | total loss 0.1438
         | val ent loss 0.1153 | val rel loss 0.0037


 18%|█▊        | 9/50 [01:05<04:59,  7.31s/it]

epoch: 8 | ent loss 0.1046 | rel loss 0.0055 | total loss 0.1101
         | val ent loss 0.1175 | val rel loss 0.0037


 20%|██        | 10/50 [01:13<04:52,  7.32s/it]

epoch: 9 | ent loss 0.0991 | rel loss 0.0047 | total loss 0.1038
         | val ent loss 0.1187 | val rel loss 0.0031


 22%|██▏       | 11/50 [01:20<04:46,  7.34s/it]

epoch: 10 | ent loss 0.1056 | rel loss 0.0058 | total loss 0.1114
          | val ent loss 0.1110 | val rel loss 0.0036


 24%|██▍       | 12/50 [01:28<04:39,  7.35s/it]

epoch: 11 | ent loss 0.0683 | rel loss 0.0039 | total loss 0.0722
          | val ent loss 0.0834 | val rel loss 0.0033


 26%|██▌       | 13/50 [01:35<04:32,  7.38s/it]

epoch: 12 | ent loss 0.0578 | rel loss 0.0037 | total loss 0.0614
          | val ent loss 0.0781 | val rel loss 0.0030


 28%|██▊       | 14/50 [01:42<04:24,  7.35s/it]

epoch: 13 | ent loss 0.0680 | rel loss 0.0037 | total loss 0.0717
          | val ent loss 0.0663 | val rel loss 0.0030


 30%|███       | 15/50 [01:50<04:17,  7.35s/it]

epoch: 14 | ent loss 0.0568 | rel loss 0.0038 | total loss 0.0605
          | val ent loss 0.0716 | val rel loss 0.0027


 32%|███▏      | 16/50 [01:57<04:10,  7.36s/it]

epoch: 15 | ent loss 0.0360 | rel loss 0.0031 | total loss 0.0392
          | val ent loss 0.0535 | val rel loss 0.0026


 34%|███▍      | 17/50 [02:05<04:03,  7.38s/it]

epoch: 16 | ent loss 0.0324 | rel loss 0.0034 | total loss 0.0358
          | val ent loss 0.0605 | val rel loss 0.0026


 36%|███▌      | 18/50 [02:12<03:56,  7.39s/it]

epoch: 17 | ent loss 0.0422 | rel loss 0.0032 | total loss 0.0454
          | val ent loss 0.0673 | val rel loss 0.0021


 38%|███▊      | 19/50 [02:20<03:49,  7.39s/it]

epoch: 18 | ent loss 0.0367 | rel loss 0.0032 | total loss 0.0399
          | val ent loss 0.0643 | val rel loss 0.0020


 40%|████      | 20/50 [02:27<03:41,  7.40s/it]

epoch: 19 | ent loss 0.0210 | rel loss 0.0028 | total loss 0.0238
          | val ent loss 0.0510 | val rel loss 0.0025


 42%|████▏     | 21/50 [02:35<03:34,  7.40s/it]

epoch: 20 | ent loss 0.0317 | rel loss 0.0027 | total loss 0.0343
          | val ent loss 0.0505 | val rel loss 0.0018


 44%|████▍     | 22/50 [02:42<03:26,  7.39s/it]

epoch: 21 | ent loss 0.0392 | rel loss 0.0030 | total loss 0.0422
          | val ent loss 0.0572 | val rel loss 0.0020


 46%|████▌     | 23/50 [02:50<03:19,  7.39s/it]

epoch: 22 | ent loss 0.0229 | rel loss 0.0019 | total loss 0.0248
          | val ent loss 0.0574 | val rel loss 0.0017


 48%|████▊     | 24/50 [02:57<03:12,  7.40s/it]

epoch: 23 | ent loss 0.0245 | rel loss 0.0024 | total loss 0.0269
          | val ent loss 0.0523 | val rel loss 0.0020


 50%|█████     | 25/50 [03:05<03:05,  7.41s/it]

epoch: 24 | ent loss 0.0230 | rel loss 0.0021 | total loss 0.0250
          | val ent loss 0.0384 | val rel loss 0.0017


 52%|█████▏    | 26/50 [03:12<02:57,  7.41s/it]

epoch: 25 | ent loss 0.0296 | rel loss 0.0024 | total loss 0.0319
          | val ent loss 0.0478 | val rel loss 0.0013


 54%|█████▍    | 27/50 [03:20<02:50,  7.41s/it]

epoch: 26 | ent loss 0.0270 | rel loss 0.0023 | total loss 0.0294
          | val ent loss 0.0731 | val rel loss 0.0020


 56%|█████▌    | 28/50 [03:27<02:42,  7.41s/it]

epoch: 27 | ent loss 0.0159 | rel loss 0.0025 | total loss 0.0184
          | val ent loss 0.0565 | val rel loss 0.0014


 58%|█████▊    | 29/50 [03:34<02:35,  7.41s/it]

epoch: 28 | ent loss 0.0193 | rel loss 0.0022 | total loss 0.0215
          | val ent loss 0.0547 | val rel loss 0.0018


 60%|██████    | 30/50 [03:42<02:28,  7.42s/it]

epoch: 29 | ent loss 0.0151 | rel loss 0.0020 | total loss 0.0170
          | val ent loss 0.0634 | val rel loss 0.0015


 62%|██████▏   | 31/50 [03:49<02:20,  7.40s/it]

epoch: 30 | ent loss 0.0289 | rel loss 0.0020 | total loss 0.0309
          | val ent loss 0.0640 | val rel loss 0.0016


 64%|██████▍   | 32/50 [03:56<02:13,  7.40s/it]

epoch: 31 | ent loss 0.0121 | rel loss 0.0016 | total loss 0.0137
          | val ent loss 0.0478 | val rel loss 0.0017


 66%|██████▌   | 33/50 [04:04<02:05,  7.41s/it]

epoch: 32 | ent loss 0.0094 | rel loss 0.0020 | total loss 0.0114
          | val ent loss 0.0349 | val rel loss 0.0016


 68%|██████▊   | 34/50 [04:12<01:58,  7.41s/it]

epoch: 33 | ent loss 0.0145 | rel loss 0.0018 | total loss 0.0163
          | val ent loss 0.0493 | val rel loss 0.0015


 70%|███████   | 35/50 [04:19<01:51,  7.41s/it]

epoch: 34 | ent loss 0.0067 | rel loss 0.0014 | total loss 0.0081
          | val ent loss 0.0362 | val rel loss 0.0014


 72%|███████▏  | 36/50 [04:27<01:43,  7.42s/it]

epoch: 35 | ent loss 0.0058 | rel loss 0.0016 | total loss 0.0074
          | val ent loss 0.0449 | val rel loss 0.0016


 74%|███████▍  | 37/50 [04:34<01:36,  7.41s/it]

epoch: 36 | ent loss 0.0126 | rel loss 0.0015 | total loss 0.0141
          | val ent loss 0.0463 | val rel loss 0.0012


 76%|███████▌  | 38/50 [04:39<01:28,  7.34s/it]

epoch: 37 | ent loss 0.0117 | rel loss 0.0017 | total loss 0.0133
          | val ent loss 0.0644 | val rel loss 0.0014


 78%|███████▊  | 39/50 [04:46<01:20,  7.34s/it]

epoch: 38 | ent loss 0.0093 | rel loss 0.0017 | total loss 0.0110
          | val ent loss 0.0639 | val rel loss 0.0016


 80%|████████  | 40/50 [04:53<01:13,  7.34s/it]

epoch: 39 | ent loss 0.0116 | rel loss 0.0016 | total loss 0.0132
          | val ent loss 0.0454 | val rel loss 0.0011


 82%|████████▏ | 41/50 [05:01<01:06,  7.35s/it]

epoch: 40 | ent loss 0.0089 | rel loss 0.0016 | total loss 0.0105
          | val ent loss 0.0699 | val rel loss 0.0015


 84%|████████▍ | 42/50 [05:08<00:58,  7.35s/it]

epoch: 41 | ent loss 0.0050 | rel loss 0.0016 | total loss 0.0065
          | val ent loss 0.0584 | val rel loss 0.0013


 86%|████████▌ | 43/50 [05:16<00:51,  7.36s/it]

epoch: 42 | ent loss 0.0074 | rel loss 0.0011 | total loss 0.0086
          | val ent loss 0.0560 | val rel loss 0.0012


 88%|████████▊ | 44/50 [05:24<00:44,  7.37s/it]

epoch: 43 | ent loss 0.0095 | rel loss 0.0017 | total loss 0.0112
          | val ent loss 0.0537 | val rel loss 0.0014


 90%|█████████ | 45/50 [05:31<00:36,  7.37s/it]

epoch: 44 | ent loss 0.0048 | rel loss 0.0016 | total loss 0.0064
          | val ent loss 0.0410 | val rel loss 0.0012


 92%|█████████▏| 46/50 [05:39<00:29,  7.37s/it]

epoch: 45 | ent loss 0.0057 | rel loss 0.0013 | total loss 0.0070
          | val ent loss 0.0517 | val rel loss 0.0012


 94%|█████████▍| 47/50 [05:46<00:22,  7.38s/it]

epoch: 46 | ent loss 0.0102 | rel loss 0.0013 | total loss 0.0114
          | val ent loss 0.0563 | val rel loss 0.0013


 96%|█████████▌| 48/50 [05:54<00:14,  7.39s/it]

epoch: 47 | ent loss 0.0049 | rel loss 0.0011 | total loss 0.0060
          | val ent loss 0.0551 | val rel loss 0.0012


 98%|█████████▊| 49/50 [06:02<00:07,  7.39s/it]

epoch: 48 | ent loss 0.0032 | rel loss 0.0011 | total loss 0.0043
          | val ent loss 0.0586 | val rel loss 0.0011


100%|██████████| 50/50 [06:09<00:00,  7.40s/it]

epoch: 49 | ent loss 0.0060 | rel loss 0.0013 | total loss 0.0073
          | val ent loss 0.0668 | val rel loss 0.0012





In [20]:
import random
def random_choose(input_var):
    r_choose = []
    for i in range(BATCH_SIZE):
        r_choose.append(random.randint(0,len(input_var)))
    return r_choose
        
def ent_argmax(output):
    output = output.view(BATCH_SIZE,MAX_LEN,ent_size).argmax(2)
    return output

def rel_argmax(output):
    output = output.view(BATCH_SIZE,MAX_LEN,MAX_LEN,rel_size).argmax(3)
    return output

In [15]:
# Check predictions after training
with torch.no_grad():
    r_choose = random_choose(input_var)
    model.eval()
    ent_output, rel_output = model(input_var[r_choose].cuda() if USE_CUDA else input_var)
    
    ent_loss = criterion_tag(ent_output.cpu(), ent_var[r_choose].view(BATCH_SIZE*MAX_LEN))
    ent_output = ent_argmax(ent_output)
    
    rel_loss = criterion_rel(rel_output.cpu(), rel_var[r_choose].view(BATCH_SIZE*MAX_LEN*MAX_LEN))
    
    
#     print('predict :', ent_output[0])
#     print('true :', ent_var[r_choose[0]])
    print()
    print('predict :', index2tag(ent_output[0], ix_to_ent_tag))
    print('true :', index2tag(ent_var[r_choose[0]], ix_to_ent_tag))
    print()
    print('===================================================')
    print()
    print()
    print('predict :', index2tag(ent_output[1], ix_to_ent_tag))
    print('true :', index2tag(ent_var[r_choose[1]], ix_to_ent_tag))
    
    print()
    print("Entity loss : %.4f" % ent_loss)
    print("Relation loss : %.4f" % rel_loss)


predict : ['B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 'I-STAT', 'I-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
true : ['B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 

In [34]:
with torch.no_grad():
    r_choose = random_choose(input_dev)
    model.eval()
    ent_output, rel_output = model(input_dev[r_choose].cuda() if USE_CUDA else input_dev)
    
    ent_loss = criterion_tag(ent_output.cpu(), ent_dev[r_choose].view(BATCH_SIZE*MAX_LEN))
    ent_output = ent_argmax(ent_output)
    
    rel_loss = criterion_rel(rel_output.cpu(), rel_var[r_choose].view(BATCH_SIZE*MAX_LEN*MAX_LEN))
    
    print(r_choose[0])
    print()
    print('predict :', index2tag(ent_output[0], ix_to_ent_tag))
    print()
    print('true :', index2tag(ent_dev[r_choose[0]], ix_to_ent_tag))
    print()

    print("Entity loss : %.4f" % ent_loss)
    print("Relation loss : %.4f" % rel_loss)
    

12

predict : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'B-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']

true : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'B-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O

In [66]:
with torch.no_grad():
    for step, (batch_x, batch_ent, batch_rel) in enumerate(dev_loader):
        model.eval()
        ent_output, rel_output = model(batch_x.cuda() if USE_CUDA else batch_x)
        
        ent_loss = criterion_tag(ent_output.cpu(), batch_ent.view(BATCH_SIZE*MAX_LEN))
        ent_output = ent_argmax(ent_output)
        
        rel_loss = criterion_rel(rel_output.cpu(), rel_var[r_choose].view(BATCH_SIZE*MAX_LEN*MAX_LEN))
        rel_output = rel_argmax(rel_output)
    
        print()
        print('predict :', index2tag(ent_output[0], ix_to_ent_tag))
        print('true :', index2tag(batch_ent[0], ix_to_ent_tag))
        print()
        
        print("Entity loss : %.4f" % ent_loss)
        print("Relation loss : %.4f" % rel_loss)
        


predict : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 'I-STAT', 'I-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
true : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 'I-STAT', 'I-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'

In [144]:
ent_output[3]9~14

tensor([ 6,  6,  3,  5,  6,  6,  6,  6,  6,  2,  4,  4,  4,  4,
         4,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1], device='cuda:0')

In [231]:
rel_output[3][14]

tensor([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0], device='cuda:0')

In [222]:
rule = ('FUNC', 'ApplyTo', 'STAT')
def decode_output(ent_output, rel_output):
    r_list = []
    r_dict = {}
    pred_ent = index2tag(ent_output, ix_to_ent_tag)
    
    e_loc = 0
    for loc, e in enumerate(pred_ent):
        if e[0]=='B':
            e_loc = loc
            r_dict[loc] = {
                '_2ndtag':e[2:],
            }
            r_list.append([])
            
        
        elif e[0]=='I':
            r_dict[e_loc]['end'] = loc
            r_list.append([])
            
        else:
            r_list.append("")
    
    
    IsB = False
    IsNext = False
    num_reocrd = -1
    now_loc = 0
    end_loc = 0
    tag = ""
    preAorB = ""
    nowAorB = ""
    pre_complete_rel = ""
    now_complete_rel = ""
    
    for now in range(len(rel_output)):
        for loc, rel in enumerate(rel_output[now][:now+1]):
            rel = rel.cpu().numpy()

            if rel!=rel_tag_to_ix[REL_NONE] and IsB==False and IsNext==False:

                IsB = True
                IsNext = True
                tag = ix_to_rel_tag[int(rel)]
                num_reocrd+=1
                
                now_loc = loc
                end_loc = r_dict[now_loc]['end']
                
                second_tag = r_dict[now_loc]['_2ndtag']
                preAorB = check_loc(second_tag)
                nowAorB = 'B' if preAorB=='A' else 'A'
                
                pre_complete_rel = tag+"-"+str(num_reocrd)+"-"+preAorB
                now_complete_rel = tag+"-"+str(num_reocrd)+"-"+nowAorB
                
                for token in range(now_loc, end_loc+1):
                    r_list[token].append(pre_complete_rel)
                
                r_list[now].append(now_complete_rel)

                
            elif rel!=rel_tag_to_ix[REL_NONE] and IsB:
                if loc<=end_loc:
                    pass
                else:
                    IsB = False
            
            elif rel!=rel_tag_to_ix[REL_NONE] and IsNext:
                r_list[now] = r_list[now-1]
                
            else:
                IsB = False

                
                
    return r_list
                

                
def check_loc(second_tag):
    if second_tag in rule[0]:
        return 'A'
    elif second_tag in rule[2]:
        return 'B'

In [223]:
decode_output(ent_output[3], rel_output[3])

['', '', [], [], '', '', '', '', '', [], [], [], [], [], [], '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '']


['',
 '',
 ['ApplyTo-0-A'],
 ['ApplyTo-0-A'],
 '',
 '',
 '',
 '',
 '',
 ['ApplyTo-0-B'],
 ['ApplyTo-0-B'],
 ['ApplyTo-0-B'],
 ['ApplyTo-0-B'],
 ['ApplyTo-0-B'],
 [],
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '',
 '']

In [90]:
ent_output[0]

tensor([ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  2,  4,  4,  4,
         6,  6,  6,  6,  6,  6,  6,  3,  5,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1], device='cuda:0')