In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as Data
from tqdm import tqdm

import numpy as np
import json

USE_CUDA = True

In [2]:
def readfile(data):
    with open(data, "r", encoding="utf-8") as f:
        content = f.read().splitlines()
        
    return content

def schema_load(schema_root):
    raw_dict = "".join(readfile(schema_root))
    dict2json = "".join(raw_dict.split()[2:])

    json_acceptable_string = dict2json.replace("'", "\"")
    schema = json.loads(json_acceptable_string)
    
    return schema

def define_entity(schema):
    tag_type = list(schema['tagging'])
    
    entity_tag = []
    for k in list(schema['entity'].keys()):
        entity_tag.append(schema['entity'][k]['tag'])
        
    TAG = []
    for t in tag_type:
        for e in entity_tag:
            if t!='O':
                TAG.append(t+'-'+e)  
                
    TAG = [UNKOWN_TAG, PAD_TAG] + TAG + ['O']   

    return TAG

def tag2ix(TAG):
    tag_to_ix={t:i for i,t in enumerate(TAG)}
    return tag_to_ix

def define_relation(schema):
    relation_type = list(schema['relation'])
    
    relation_tag = []
    for k in list(schema['relation'].keys()):
        relation_tag.append(schema['relation'][k]['tag'])
    
    relation_tag = [REL_PAD] + [REL_NONE] + relation_tag
        
    return relation_tag

# ==================================================

def get_word_and_label(_content, start_w, end_w):
    word_list = []
    ent_list = []
    rel_list = []
    
    for word_set in _content[start_w:end_w]:
        word_set = word_set.split()
        if len(word_set)==1:
            word_list.append(' ')
            ent_list.append('O')
            rel_list.append(REL_NONE)
        
        else:
            word_list.append(word_set[0])
            ent_list.append(word_set[1])

            try:
                testerror = word_set[2]
            except IndexError:
                rel_list.append(REL_NONE)
            else:
                rel_list.append(word_set[2:])
    
    return word_list, ent_list, rel_list

def split_to_list(content):
    init = 0
    word_list = []
    ent_list = []
    rel_list = []

    for now_token, c in enumerate(content):
        if c=='':
            words, ents, rels = get_word_and_label(content, init, now_token)
            init = now_token+1
            word_list.append(words)
            ent_list.append(ents)
            rel_list.append(rels)
            
    return word_list, ent_list, rel_list

# ==================================================

def word2index(word_list):
    word_to_ix = {"<UNKNOWN>":0, "<PAD>":1}
    for sentence in word_list:
        for word in sentence:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)
                
    return word_to_ix

def dict_inverse(tag_to_ix):
    ix_to_tag = {v: k for k, v in tag_to_ix.items()}
    return ix_to_tag

def index2tag(indexs, ix_to):
    to_tags = [ix_to[i] for i in indexs.cpu().numpy()]
    return to_tags

# ==================================================

def find_max_len(word_list):
    max_len = 0
    for i in range(len(word_list)):
        if max_len<len(word_list[i]):
            max_len=len(word_list[i])
            
    return max_len

# ====== filter the length of sentence more than MAX_LEN =======

def filter_len(word_list):
    reserved_index = []
    for i in range(len(word_list)):
        if len(word_list[i])<MAX_LEN:
            reserved_index.append(i)
            
    return reserved_index


def filter_sentence(reserved_index, word_list, ent_list, rel_list):
    filter_word = list(word_list[i] for i in reserved_index)
    filter_ent = list(ent_list[i] for i in reserved_index)
    filter_rel = list(rel_list[i] for i in reserved_index)
    return filter_word, filter_ent, filter_rel

# ==================================================

def pad_seq(seq, isrel):
    if isrel:
        seq += [REL_NONE for i in range(MAX_LEN-len(seq))]
    else:
        seq += [PAD_TAG for i in range(MAX_LEN-len(seq))]
    return seq

def pad_all(filter_word, filter_ent, filter_rel):
    input_padded = [pad_seq(s, False) for s in filter_word]
    ent_padded = [pad_seq(s, False) for s in filter_ent]
    rel_padded = [pad_seq(s, True) for s in filter_rel]
    
    return input_padded, ent_padded, rel_padded

# ==================================================

def prepare_sequence(seq, to_ix):
    idxs = []
    for w in seq:
        if w not in to_ix:
            idxs.append(to_ix[UNKOWN_TAG])
        else:
            idxs.append(to_ix[w])
    
#     idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

def prepare_all(seqs, to_ix):
    seq_list = []
    for i in range(len(seqs)):
        seq_list.append(prepare_sequence(seqs[i], to_ix))
        
    seq_list = torch.stack(seq_list)
        
    return seq_list



def prepare_rel(rel_padded, to_ix):
    
    rel_ptr = torch.zeros(len(rel_padded), MAX_LEN, MAX_LEN, dtype=torch.long) 
    
    # 對當前的token，去比較之前所有出現過的entity，是否有關係，建成矩陣
    # [B*ML*ML]，第二維ML是當前token，第三維ML是根據當前token對之前出現過的entity紀錄關係，以index紀錄
    for i, rel_seq in enumerate(rel_padded):
        rel_dict = {}
        for j, token_seq in enumerate(rel_seq):
            rel_ptr[i][j][:j+1] = 1
            if token_seq != REL_NONE:
                for k, rel in enumerate(token_seq):

                    # if 是第一次出現，紀錄後面數字(標第幾對)和關係位置(A OR B)
                    # 假如下次出現又是同個關係位置(A)，依然紀錄
                    # 直到下次出現關係位置B，依照之前紀錄的A位置的字，然後在第三維去標關係

                    rel_token = rel.split('-')
                    if rel_token[1] not in rel_dict:
                        rel_dict[rel_token[1]] = {'rel':rel_token[0], 'loc':rel_token[2], 'idx':[j]}

                    elif rel_token[1] in rel_dict and rel_dict[rel_token[1]]['loc']==rel_token[2]:
                        rel_dict[rel_token[1]]['idx'].append(j)

                    else:
                        record_loc = rel_dict[rel_token[1]]['idx']
                        for idxx in record_loc:
                            rel_ptr[i][j][idxx] = to_ix[rel_token[0]]
                            
    return rel_ptr
                


# ==================================================

def dataload(input_var, ent_var, rel_var):
    torch_dataset = Data.TensorDataset(input_var, ent_var, rel_var)

    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               
        num_workers=2,       
        drop_last=True
    )
    
    return loader

# ==================================================
def softmax_entity(entity):
    entity = entity.view(BATCH_SIZE,ent_size).argmax(1)
    return entity

In [3]:
class Attn(nn.Module):
    def __init__(self, attn_input, attn_output, rel_size):
        super(Attn, self).__init__()
        
        self.attn_input = attn_input
        self.attn_output = attn_output
        self.rel_size = rel_size
        
        self.w1 = nn.Linear(self.attn_input, self.attn_output)
        self.w2 = nn.Linear(self.attn_input, self.attn_output)
        self.tanh = nn.Tanh()
        self.v = nn.Linear(self.attn_output, self.rel_size, bias=False)
        
        
    def forward(self, encoder_outputs):
        
        decoder = encoder_outputs[:,-1,:].unsqueeze(1)                       #B*1*(ts+LE) [128,1,8]
        encoder_score = self.w1(encoder_outputs)                             #B*now len*ATTN_OUT
        decoder_score = self.w2(decoder)                                     #B*1*ATTN_OUT
        energy = self.tanh(encoder_score+decoder_score)                      #B*now len*ATTN_OUT            
        
        energy = self.v(energy)                                              #B*now len*rel_size
        
        
        # 針對每個entity做softmax，去顯示他們的關係權重
        # 主要都會是rel_none
        # 對第二維(rel)做softmax
#         p = F.softmax(energy, dim=2)                                         #B*now len*rel_size
        
        return energy
    

In [4]:
class Entity_Typing(nn.Module):
    def __init__(self, vocab_size, ent_tag_to_ix, embedding_dim, hidden_dim1, hidden_dim2, \
                 label_embed_dim, rel_tag_to_ix):
        
        super(Entity_Typing, self).__init__()
        self.embedding_dim = embedding_dim                   #E
        self.hidden_dim1 = hidden_dim1                       #h1
        self.hidden_dim2 = hidden_dim2                       #h2
        self.label_embed_dim = label_embed_dim               #LE
        self.vocab_size = vocab_size                         #vs
        self.ent_to_ix = ent_tag_to_ix
        self.ent_size = len(ent_tag_to_ix)                   #es
        self.rel_to_ix = rel_tag_to_ix
        self.rel_size = len(rel_tag_to_ix)                   #rs           
        
        self.dropout = nn.Dropout(p=0.3)
        self.bn = nn.BatchNorm1d(DENSE_OUT, momentum=0.5, affine=False)
        
        
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        
#         self.bilstm = nn.LSTM(embedding_dim, hidden_dim1 // 2,
#                             num_layers=2, bidirectional=True, batch_first=True, dropout=0.2)        
        self.bilstm = nn.GRU(embedding_dim, hidden_dim1 // 2,
                            num_layers=2, bidirectional=True, batch_first=True, dropout=0.2)
        
        
        self.dense = nn.Linear(hidden_dim1, DENSE_OUT)
        self.top_hidden = nn.LSTMCell(DENSE_OUT+label_embed_dim, hidden_dim2)          
        

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim2, self.ent_size)
        self.softmax = nn.LogSoftmax(dim=1)
        self.label_embed = nn.Linear(self.ent_size, self.label_embed_dim)
        
        self.attn = Attn(ATTN_IN, ATTN_OUT, self.rel_size)
        
        
    def init_hidden1(self):       
        hidden = torch.randn(2*2, BATCH_SIZE, self.hidden_dim1 // 2)    #4*B*(h1/2)
#         hidden = Variable(hidden.data, requires_grad=True)

        return (hidden.cuda(), hidden.cuda())if USE_CUDA else (hidden,hidden)
    
    def init_hidden2(self):       
        hidden = torch.randn(BATCH_SIZE, self.hidden_dim2)              #B*h2
#         hidden = Variable(hidden.data, requires_grad=True)

        return (hidden.cuda(), hidden.cuda())if USE_CUDA else (hidden,hidden)
    
    def init_label_embed(self):
        hidden = torch.zeros(BATCH_SIZE, self.label_embed_dim)          #B*LE
        return hidden.cuda()if USE_CUDA else hidden
    
    def create_entity(self):
        output_tensor = torch.zeros(BATCH_SIZE, MAX_LEN, self.ent_size)  #B*ML*es
        return output_tensor.cuda()if USE_CUDA else output_tensor
    
    def create_rel_matrix(self):
        rel_tensor = torch.zeros(BATCH_SIZE, MAX_LEN, MAX_LEN, self.rel_size)  #B*ML*ML*rs
        return rel_tensor.cuda()if USE_CUDA else rel_tensor
    
    
        
    def forward(self, sentence):
#         self.hidden1 = self.init_hidden1()                      #4*B*(h1/2)
        entity_tensor = self.create_entity()                    #B*ML*es
        rel_tensor = self.create_rel_matrix()                   #B*ML*ML*rs
        
        embeds = self.word_embeds(sentence)                     #B*ML*E,[128, 100, 20]
        
#         bilstm_out, self.hidden1 = self.bilstm(embeds, self.hidden1)
        bilstm_out, hidden1 = self.bilstm(embeds)
        # bilstm_out -> B*ML*h1,[128, 100, 10]
        # self.hidden1 -> ( 4*B*(h1/2), 4*B*(h1/2) )
        
        # bn
        bilstm_out = self.bn(bilstm_out)
        dense_out = self.dense(bilstm_out)                      #B*ML*DENSE_OUT,[128, 100, 100]
        
        
        encoder_sequence_l = [] 

        for length in range(MAX_LEN):
            now_token = dense_out[:,length,:]
            now_token = torch.squeeze(now_token, 1)
            if length==0:
                
#                 fake_hidden=(100)
#                 noise_x = random(100)
                self.hidden2 = self.init_hidden2()
                self.zero_label_embed = self.init_label_embed()
                combine_x = torch.cat((now_token, self.zero_label_embed),1)  #B*(DENSE_OUT+LE),[128, 103]
                
            else:
#                 fake_hidden=h
                self.hidden2 = (h_next, c_next)
                combine_x = torch.cat((now_token, label),1)

            h_next, c_next = self.top_hidden(combine_x, self.hidden2)    #B*h2,[128, 8]           
            to_tags = self.hidden2tag(h_next)                            #B*es,[128, 5]            
            ent_output = self.softmax(to_tags)                               #B*es,[128, 5]             
            label = self.label_embed(ent_output)                             #B*LE,[128, 3]
            
            s_ent_output = softmax_entity(ent_output)
            
            
            # Assignments to Variables are in-place operations.
            # Use that variable in lots of other contexts 
            # and some of the functions require it to not change. 
            to_tags_clone = to_tags.clone()
            label_clone = label.clone()
            
            
#             for i, tag in enumerate(s_ent_output):
#                 if tag==ent_tag_to_ix['O']:
#                     to_tags_clone[i] = torch.FloatTensor([-999999 * self.ent_size])
#                     label_clone[i] = torch.FloatTensor([-999999 * self.ent_size])
                    
            # relation layer
            encoder_sequence_l.append(torch.cat((to_tags,label),1))
            encoder_sequence = torch.stack(encoder_sequence_l).t()     #B*len*(es+LE), [128,1,8]          

            # Calculate attention weights 
            attn_weights = self.attn(encoder_sequence)

        
            entity_tensor[:,length,:] = ent_output
            
            # rel_tensor[:,length, 頭~當前 ,:]
            rel_tensor[:,length,:length+1,:] = attn_weights

        
        
        '''NLLLoss input: Input: (N,C) where C = number of classes'''
        return entity_tensor.view(BATCH_SIZE*MAX_LEN, self.ent_size), \
               rel_tensor.view(BATCH_SIZE*MAX_LEN*MAX_LEN, self.rel_size)

In [5]:
root = '/notebooks/sinica/dataset/'
train_data = root+'facial.train'
dev_data = root+'facial.dev'
test_data = root+'facial.test'

relation_data_old = root+'facial_r.old.train'
# relation_data = root+'facial_r.train'
relation_data = root+'facial_r2.train'
schema_root = root+'schema.txt'
dev_data = root+'facial_r2.dev'


UNKOWN_TAG = "<UNKNOWN>"
PAD_TAG = "<PAD>"
REL_NONE = 'Rel-None'
REL_PAD = 'Rel-Pad'

schema = schema_load(schema_root)
ENT_TAG = define_entity(schema)
REL_TAG = define_relation(schema)
ent_tag_to_ix = tag2ix(ENT_TAG)
'''{'<PAD>': 1,
 '<UNKNOWN>': 0,
 'B-FUNC': 2,
 'B-STAT': 3,
 'I-FUNC': 4,
 'I-STAT': 5,
 'O': 6}'''
rel_tag_to_ix = tag2ix(REL_TAG)
'''{'ApplyTo': 2, 'Rel-None': 1, 'Rel-Pad': 0}'''

# ========hyper-parameter-set==========

ent_size = len(ent_tag_to_ix)
rel_size = len(rel_tag_to_ix)
MAX_LEN = 100
BATCH_SIZE = 18

EMBEDDING_DIM = 20
HIDDEN_DIM1 = 10
HIDDEN_DIM2 = 8
LABEL_EMBED_DIM = ent_size
DENSE_OUT = 100

ATTN_IN = ent_size+LABEL_EMBED_DIM
ATTN_OUT = 6

In [6]:
def preprocess(data):
    content = readfile(data)
    word_list, ent_list, rel_list = split_to_list(content)
    word_to_ix = word2index(word_list)
    reserved_index = filter_len(word_list)
    filter_word, filter_ent, filter_rel = filter_sentence(reserved_index, word_list, ent_list, rel_list)
    input_padded, ent_padded, rel_padded = pad_all(filter_word, filter_ent, filter_rel)
    #================================================
    input_var = prepare_all(input_padded, word_to_ix)
    ent_var = prepare_all(ent_padded, ent_tag_to_ix)
    rel_var = prepare_rel(rel_padded, rel_tag_to_ix)
    #================================================
    vocab_size = len(word_to_ix)
    
    return input_var, ent_var, rel_var, vocab_size, word_to_ix

In [7]:
ix_to_ent_tag = dict_inverse(ent_tag_to_ix)
ix_to_rel_tag = dict_inverse(rel_tag_to_ix)
#===============================================
input_var, ent_var, rel_var, vocab_size, word_to_ix = preprocess(relation_data)

In [8]:
loader = dataload(input_var, ent_var, rel_var)
model = Entity_Typing(vocab_size, ent_tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM1, HIDDEN_DIM2, \
              LABEL_EMBED_DIM, rel_tag_to_ix).cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
criterion_tag = nn.NLLLoss()
criterion_rel = nn.CrossEntropyLoss()

In [9]:
def dev_preprocess(dev_data):
    dev_content = readfile(dev_data)
    word_list, ent_list, rel_list = split_to_list(dev_content)
    reserved_index = filter_len(word_list)
    filter_word, filter_ent, filter_rel = filter_sentence(reserved_index, word_list, ent_list, rel_list)
    input_padded, ent_padded, rel_padded = pad_all(filter_word, filter_ent, filter_rel)
    #================================================
    input_var = prepare_all(input_padded, word_to_ix)
    ent_var = prepare_all(ent_padded, ent_tag_to_ix)
    rel_var = prepare_rel(rel_padded, rel_tag_to_ix)
    
    return input_var, ent_var, rel_var


def dev_dataload(input_var, ent_var, rel_var):
    torch_dataset = Data.TensorDataset(input_var, ent_var, rel_var)

    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               
        num_workers=2,       
        drop_last=True
    )
    
    return loader

In [10]:
input_dev, ent_dev, rel_dev= dev_preprocess(dev_data)
dev_loader = dev_dataload(input_dev, ent_dev, rel_dev)

In [11]:
import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [12]:
len(rel_var)

306

In [13]:
n_iters = 50
print_every = 12
all_losses = []
total_loss = 0 # Reset every plot_every iters
loss = 0
ent_loss = 0

start = time.time()

for epoch in tqdm(range(n_iters)):  
    for step, (batch_x, batch_ent, batch_rel) in enumerate(loader):
        model.train()
        optimizer.zero_grad()
        ent_output, rel_output = model(batch_x.cuda() if USE_CUDA else batch_x)
        
        batch_ent = batch_ent.view(BATCH_SIZE*MAX_LEN)
        batch_rel = batch_rel.view(BATCH_SIZE*MAX_LEN*MAX_LEN)
        
        loss_ent = criterion_tag(ent_output, batch_ent.cuda() if USE_CUDA else batch_ent)
        loss_rel = criterion_rel(rel_output, batch_rel.cuda() if USE_CUDA else batch_rel)
        loss = loss_ent+loss_rel
        
        loss.backward()
#         loss.backward(retain_graph=True)
        optimizer.step()
        
        if step % print_every == 1:
            all_losses.append(loss.cpu())
        #    print('%.4f| epoch: %d| step: %d| %s' % (loss, epoch, step, timeSince(start)))
        
    for step, (batch_x, batch_ent, batch_rel) in enumerate(dev_loader):
        model.eval()
        ent_output, rel_output = model(batch_x.cuda() if USE_CUDA else batch_x)
        ent_loss = criterion_tag(ent_output.cpu(), batch_ent.view(BATCH_SIZE*MAX_LEN)) 
    
    
    print("epoch: %d | ent loss %.4f | rel loss %.4f | total loss %.4f" \
          % (epoch, loss_ent, loss_rel, loss))
    print("         | val ent loss %.4f"
          % (ent_loss))


  2%|▏         | 1/50 [00:07<05:49,  7.14s/it]

epoch: 0 | ent loss 1.1397 | rel loss 0.6035 | total loss 1.7432
         | val ent loss 1.0822


  4%|▍         | 2/50 [00:13<05:31,  6.91s/it]

epoch: 1 | ent loss 0.5929 | rel loss 0.5638 | total loss 1.1567
         | val ent loss 0.5408


  6%|▌         | 3/50 [00:20<05:25,  6.92s/it]

epoch: 2 | ent loss 0.3664 | rel loss 0.5554 | total loss 0.9218
         | val ent loss 0.3068


  8%|▊         | 4/50 [00:27<05:17,  6.89s/it]

epoch: 3 | ent loss 0.2676 | rel loss 0.5534 | total loss 0.8210
         | val ent loss 0.2249


 10%|█         | 5/50 [00:34<05:14,  6.98s/it]

epoch: 4 | ent loss 0.2048 | rel loss 0.5520 | total loss 0.7568
         | val ent loss 0.1613


 12%|█▏        | 6/50 [00:42<05:08,  7.00s/it]

epoch: 5 | ent loss 0.1582 | rel loss 0.5508 | total loss 0.7090
         | val ent loss 0.1508


 14%|█▍        | 7/50 [00:49<05:01,  7.02s/it]

epoch: 6 | ent loss 0.1188 | rel loss 0.5496 | total loss 0.6683
         | val ent loss 0.1369


 16%|█▌        | 8/50 [00:56<04:57,  7.08s/it]

epoch: 7 | ent loss 0.1170 | rel loss 0.5502 | total loss 0.6672
         | val ent loss 0.1108


 18%|█▊        | 9/50 [01:04<04:51,  7.11s/it]

epoch: 8 | ent loss 0.1015 | rel loss 0.5495 | total loss 0.6510
         | val ent loss 0.0924


 20%|██        | 10/50 [01:10<04:41,  7.03s/it]

epoch: 9 | ent loss 0.0857 | rel loss 0.5506 | total loss 0.6363
         | val ent loss 0.0811


 22%|██▏       | 11/50 [01:17<04:34,  7.04s/it]

epoch: 10 | ent loss 0.0825 | rel loss 0.5502 | total loss 0.6328
         | val ent loss 0.0743


 24%|██▍       | 12/50 [01:24<04:27,  7.03s/it]

epoch: 11 | ent loss 0.0714 | rel loss 0.5499 | total loss 0.6213
         | val ent loss 0.0738


 26%|██▌       | 13/50 [01:30<04:18,  7.00s/it]

epoch: 12 | ent loss 0.0478 | rel loss 0.5488 | total loss 0.5966
         | val ent loss 0.0594


 28%|██▊       | 14/50 [01:38<04:12,  7.01s/it]

epoch: 13 | ent loss 0.0395 | rel loss 0.5483 | total loss 0.5878
         | val ent loss 0.0849


 30%|███       | 15/50 [01:45<04:05,  7.03s/it]

epoch: 14 | ent loss 0.0357 | rel loss 0.5486 | total loss 0.5843
         | val ent loss 0.0782


 32%|███▏      | 16/50 [01:50<03:54,  6.91s/it]

epoch: 15 | ent loss 0.0405 | rel loss 0.5489 | total loss 0.5894
         | val ent loss 0.0587


 34%|███▍      | 17/50 [01:57<03:48,  6.91s/it]

epoch: 16 | ent loss 0.0287 | rel loss 0.5473 | total loss 0.5760
         | val ent loss 0.0650


 36%|███▌      | 18/50 [02:04<03:41,  6.91s/it]

epoch: 17 | ent loss 0.0257 | rel loss 0.5476 | total loss 0.5733
         | val ent loss 0.0590


 38%|███▊      | 19/50 [02:10<03:33,  6.89s/it]

epoch: 18 | ent loss 0.0236 | rel loss 0.5486 | total loss 0.5722
         | val ent loss 0.0568


 40%|████      | 20/50 [02:16<03:25,  6.84s/it]

epoch: 19 | ent loss 0.0295 | rel loss 0.5474 | total loss 0.5769
         | val ent loss 0.0639


 42%|████▏     | 21/50 [02:21<03:14,  6.72s/it]

epoch: 20 | ent loss 0.0301 | rel loss 0.5484 | total loss 0.5784
         | val ent loss 0.0447


 44%|████▍     | 22/50 [02:25<03:05,  6.61s/it]

epoch: 21 | ent loss 0.0203 | rel loss 0.5475 | total loss 0.5678
         | val ent loss 0.0650


 46%|████▌     | 23/50 [02:31<02:58,  6.61s/it]

epoch: 22 | ent loss 0.0176 | rel loss 0.5472 | total loss 0.5648
         | val ent loss 0.0609


 48%|████▊     | 24/50 [02:39<02:52,  6.63s/it]

epoch: 23 | ent loss 0.0185 | rel loss 0.5473 | total loss 0.5658
         | val ent loss 0.0393


 50%|█████     | 25/50 [02:46<02:46,  6.66s/it]

epoch: 24 | ent loss 0.0297 | rel loss 0.5471 | total loss 0.5768
         | val ent loss 0.0394


 52%|█████▏    | 26/50 [02:52<02:39,  6.63s/it]

epoch: 25 | ent loss 0.0166 | rel loss 0.5471 | total loss 0.5637
         | val ent loss 0.0702


 54%|█████▍    | 27/50 [02:59<02:32,  6.64s/it]

epoch: 26 | ent loss 0.0119 | rel loss 0.5465 | total loss 0.5584
         | val ent loss 0.0597


 56%|█████▌    | 28/50 [03:05<02:25,  6.63s/it]

epoch: 27 | ent loss 0.0159 | rel loss 0.5467 | total loss 0.5625
         | val ent loss 0.0405


 58%|█████▊    | 29/50 [03:12<02:19,  6.65s/it]

epoch: 28 | ent loss 0.0097 | rel loss 0.5472 | total loss 0.5569
         | val ent loss 0.0392


 60%|██████    | 30/50 [03:19<02:13,  6.65s/it]

epoch: 29 | ent loss 0.0117 | rel loss 0.5463 | total loss 0.5580
         | val ent loss 0.0547


 62%|██████▏   | 31/50 [03:26<02:06,  6.66s/it]

epoch: 30 | ent loss 0.0109 | rel loss 0.5466 | total loss 0.5575
         | val ent loss 0.0686


 64%|██████▍   | 32/50 [03:32<01:59,  6.65s/it]

epoch: 31 | ent loss 0.0104 | rel loss 0.5457 | total loss 0.5561
         | val ent loss 0.0605


 66%|██████▌   | 33/50 [03:39<01:53,  6.67s/it]

epoch: 32 | ent loss 0.0091 | rel loss 0.5454 | total loss 0.5545
         | val ent loss 0.0445


 68%|██████▊   | 34/50 [03:46<01:46,  6.65s/it]

epoch: 33 | ent loss 0.0157 | rel loss 0.5456 | total loss 0.5613
         | val ent loss 0.0516


 70%|███████   | 35/50 [03:53<01:39,  6.66s/it]

epoch: 34 | ent loss 0.0105 | rel loss 0.5454 | total loss 0.5559
         | val ent loss 0.0428


 72%|███████▏  | 36/50 [03:59<01:33,  6.65s/it]

epoch: 35 | ent loss 0.0079 | rel loss 0.5456 | total loss 0.5535
         | val ent loss 0.0329


 74%|███████▍  | 37/50 [04:06<01:26,  6.66s/it]

epoch: 36 | ent loss 0.0130 | rel loss 0.5455 | total loss 0.5585
         | val ent loss 0.0502


 76%|███████▌  | 38/50 [04:13<01:20,  6.68s/it]

epoch: 37 | ent loss 0.0078 | rel loss 0.5454 | total loss 0.5532
         | val ent loss 0.0571


 78%|███████▊  | 39/50 [04:20<01:13,  6.69s/it]

epoch: 38 | ent loss 0.0069 | rel loss 0.5454 | total loss 0.5523
         | val ent loss 0.0497


 80%|████████  | 40/50 [04:28<01:07,  6.71s/it]

epoch: 39 | ent loss 0.0115 | rel loss 0.5454 | total loss 0.5569
         | val ent loss 0.0458


 82%|████████▏ | 41/50 [04:35<01:00,  6.72s/it]

epoch: 40 | ent loss 0.0067 | rel loss 0.5454 | total loss 0.5521
         | val ent loss 0.0539


 84%|████████▍ | 42/50 [04:42<00:53,  6.73s/it]

epoch: 41 | ent loss 0.0086 | rel loss 0.5450 | total loss 0.5536
         | val ent loss 0.0555


 86%|████████▌ | 43/50 [04:49<00:47,  6.73s/it]

epoch: 42 | ent loss 0.0078 | rel loss 0.5453 | total loss 0.5531
         | val ent loss 0.0489


 88%|████████▊ | 44/50 [04:56<00:40,  6.74s/it]

epoch: 43 | ent loss 0.0075 | rel loss 0.5455 | total loss 0.5531
         | val ent loss 0.0457


 90%|█████████ | 45/50 [05:03<00:33,  6.74s/it]

epoch: 44 | ent loss 0.0100 | rel loss 0.5451 | total loss 0.5551
         | val ent loss 0.0476


 92%|█████████▏| 46/50 [05:10<00:26,  6.74s/it]

epoch: 45 | ent loss 0.0055 | rel loss 0.5453 | total loss 0.5508
         | val ent loss 0.0334


 94%|█████████▍| 47/50 [05:17<00:20,  6.75s/it]

epoch: 46 | ent loss 0.0082 | rel loss 0.5449 | total loss 0.5531
         | val ent loss 0.0454


 96%|█████████▌| 48/50 [05:24<00:13,  6.76s/it]

epoch: 47 | ent loss 0.0082 | rel loss 0.5454 | total loss 0.5536
         | val ent loss 0.0497


 98%|█████████▊| 49/50 [05:31<00:06,  6.76s/it]

epoch: 48 | ent loss 0.0044 | rel loss 0.5448 | total loss 0.5493
         | val ent loss 0.0572


100%|██████████| 50/50 [05:38<00:00,  6.77s/it]

epoch: 49 | ent loss 0.0079 | rel loss 0.5450 | total loss 0.5529
         | val ent loss 0.0539





In [14]:
import random
def random_choose(input_var):
    r_choose = []
    for i in range(BATCH_SIZE):
        r_choose.append(random.randint(0,len(input_var)))
    return r_choose
        
def ent_argmax(output):
    output = output.view(BATCH_SIZE,MAX_LEN,ent_size).argmax(2)
    return output

In [15]:
# Check predictions after training
with torch.no_grad():
    r_choose = random_choose(input_var)
    model.eval()
    ent_output, rel_output = model(input_var[r_choose].cuda() if USE_CUDA else input_var)
    
    ent_loss = criterion_tag(ent_output.cpu(), ent_var[r_choose].view(BATCH_SIZE*MAX_LEN))
    ent_output = ent_argmax(ent_output)
    
    rel_loss = criterion_rel(rel_output.cpu(), rel_var[r_choose].view(BATCH_SIZE*MAX_LEN*MAX_LEN))
    
    
#     print('predict :', ent_output[0])
#     print('true :', ent_var[r_choose[0]])
    print()
    print('predict :', index2tag(ent_output[0], ix_to_ent_tag))
    print('true :', index2tag(ent_var[r_choose[0]], ix_to_ent_tag))
    print()
    print('===================================================')
    print()
#     print('predict :', ent_output[1])
#     print('true :', ent_var[r_choose[1]])
    print()
    print('predict :', index2tag(ent_output[1], ix_to_ent_tag))
    print('true :', index2tag(ent_var[r_choose[1]], ix_to_ent_tag))
    
    print()
    print("Entity loss : %.4f" % ent_loss)
    print("Rel loss : %.4f" % rel_loss)


predict : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 'I-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
true : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 'I-STAT', 'I-STAT', 'O', 'O',

In [19]:
with torch.no_grad():
    r_choose = random_choose(input_dev)
    model.eval()
    ent_output, rel_output = model(input_dev[r_choose].cuda() if USE_CUDA else input_dev)
    
    ent_loss = criterion_tag(ent_output.cpu(), ent_dev[r_choose].view(BATCH_SIZE*MAX_LEN))
    ent_output = ent_argmax(ent_output)
    
    print(r_choose[0])
    print()
    print('predict :', index2tag(ent_output[0], ix_to_ent_tag))
    print()
    print('true :', index2tag(ent_dev[r_choose[0]], ix_to_ent_tag))
    print()

    print("Entity loss : %.4f" % ent_loss)
    

12

predict : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'B-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']

true : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'B-STAT', 'I-STAT', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O

In [18]:
with torch.no_grad():
    for step, (batch_x, batch_ent, batch_rel) in enumerate(dev_loader):
        model.eval()
        ent_output, rel_output = model(batch_x.cuda() if USE_CUDA else batch_x)
        
        ent_loss = criterion_tag(ent_output.cpu(), batch_ent.view(BATCH_SIZE*MAX_LEN))
        ent_output = ent_argmax(ent_output)
    
        print()
        print('predict :', index2tag(ent_output[0], ix_to_ent_tag))
        print('true :', index2tag(batch_ent[0], ix_to_ent_tag))
        print()
        
        print("Entity loss : %.4f" % ent_loss)
        
        


predict : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 'I-STAT', 'I-STAT', 'I-STAT', 'I-STAT', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
true : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-STAT', 'I-STAT', 'I-STAT', 'I-STAT', 'I-STAT', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', '

In [None]:
content = readfile(relation_data)
word_list, ent_list, rel_list = split_to_list(content)
word_to_ix = word2index(word_list)
reserved_index = filter_len(word_list)
filter_word, filter_ent, filter_rel = filter_sentence(reserved_index, word_list, ent_list, rel_list)
input_padded, ent_padded, rel_padded = pad_all(filter_word, filter_ent, filter_rel)

In [None]:
rel_padded[0]