In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as Data
from tqdm import tqdm

import numpy as np

In [2]:
USE_CUDA = True

In [3]:
root = '/notebooks/sinica/dataset/'
train_data = root+'facial.train'
dev_data = root+'facial.dev'
test_data = root+'facial.test'

START_TAG = "<START>"
STOP_TAG = "<STOP>"
PAD_TAG = "<PAD>"
tag_to_ix = {START_TAG: 0, STOP_TAG: 1, PAD_TAG:2, "B-Func": 3, "I-Func": 4, "O": 5}

tagset_size = len(tag_to_ix)
MAX_LEN = 100
BATCH_SIZE = 128

EMBEDDING_DIM = 20
HIDDEN_DIM1 = 10
HIDDEN_DIM2 = 8
LABEL_EMBED_DIM = 3
DENSE_OUT = 100

In [4]:
def readfile(data):
    with open(data, "r", encoding="utf-8") as f:
        content = f.read().splitlines()
        
    return content

def get_word_and_label(_content, start_w, end_w):
    word_list = []
    tag_list = []
    for word_set in _content[start_w:end_w]:
        word_list.append(word_set[0])
        tag_list.append(word_set[2:])
    
    return word_list, tag_list

def split_to_list(content):
    init = 0
    word_list = []
    tag_list = []

    for i, c in enumerate(content):
        if c=='':
            words, tags = get_word_and_label(content, init, i)
            init = i+1
            word_list.append(words)
            tag_list.append(tags)
            
    return word_list, tag_list
    
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

def prepare_all(seqs, to_ix):
    seq_list = []
    for i in range(len(seqs)):
        seq_list.append(prepare_sequence(seqs[i], to_ix))
        
    seq_list = torch.stack(seq_list)
        
    return seq_list

def word2index(word_list):
    word_to_ix = {"<START>":0, "<STOP>":1, "<PAD>":2}
    for sentence in word_list:
        for word in sentence:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)
                
    return word_to_ix

def find_max_len(word_list):
    max_len = 0
    for i in range(len(word_list)):
        if max_len<len(word_list[i]):
            max_len=len(word_list[i])
            
    return max_len

def filter_len(word_list):
    reserved_index = []
    for i in range(len(word_list)):
        if len(word_list[i])<MAX_LEN:
            reserved_index.append(i)
            
    return reserved_index

def filter_sentence(reserved_index, word_list, tag_list):
    filter_word = list(word_list[i] for i in reserved_index)
    filter_tag = list(tag_list[i] for i in reserved_index)
    return filter_word, filter_tag

def pad_seq(seq):
    seq += [PAD_TAG for i in range(MAX_LEN-len(seq))]
    return seq

def pad_all(filter_word, filter_tag):
    input_padded = [pad_seq(s) for s in filter_word]
    target_padded = [pad_seq(s) for s in filter_tag]
    
    return input_padded, target_padded

#======================================
def dataload(input_var, target_var):
    torch_dataset = Data.TensorDataset(input_var, target_var)

    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               
        num_workers=2,       
        drop_last=True
    )
    
    return loader

In [5]:
class Entity_Typing(nn.Module):
    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim1, hidden_dim2, \
                 label_embed_dim):
        super(Entity_Typing, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim1 = hidden_dim1
        self.hidden_dim2 = hidden_dim2
        self.label_embed_dim = label_embed_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)
        
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.bilstm = nn.LSTM(embedding_dim, hidden_dim1 // 2,
                            num_layers=2, bidirectional=True, batch_first=True)
        
        self.dense = nn.Linear(hidden_dim1, DENSE_OUT)
        
        self.lstm = nn.LSTM(DENSE_OUT+label_embed_dim, hidden_dim2, batch_first=True)

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim2, self.tagset_size)
        self.softmax = nn.LogSoftmax(dim=2)
        self.label_embed = nn.Linear(self.tagset_size, self.label_embed_dim)
        
#         self.hidden1 = self.init_hidden1()
#         self.hidden2 = self.init_hidden2()
#         self.to_label_embed = self.init_label_embed()
        
        
    def init_hidden1(self):       
        hidden = torch.randn(2*2, BATCH_SIZE, self.hidden_dim1 // 2)   
#         hidden = Variable(hidden.data, requires_grad=True)

        return (hidden.cuda(), hidden.cuda())if USE_CUDA else (hidden,hidden)
    
    def init_hidden2(self):       
        hidden = torch.randn(1, BATCH_SIZE, self.hidden_dim2)        
#         hidden = Variable(hidden.data, requires_grad=True)

        return (hidden.cuda(), hidden.cuda())if USE_CUDA else (hidden,hidden)
    
    def init_label_embed(self):
        hidden = torch.randn(BATCH_SIZE, MAX_LEN, self.label_embed_dim)
        return hidden.cuda()if USE_CUDA else hidden
        
    def forward(self, sentence):
        self.hidden1 = self.init_hidden1()
        self.hidden2 = self.init_hidden2()
        self.to_label_embed = self.init_label_embed()
        
        embeds = self.word_embeds(sentence)
        bilstm_out, self.hidden1 = self.bilstm(embeds, self.hidden1)
        dense_out = self.dense(bilstm_out)
#         combine_lstm = torch.cat((dense_out, to_label_embed), 2)
        combine_lstm = torch.cat((dense_out, self.to_label_embed), 2)  
        lstm_out, self.hidden2 = self.lstm(combine_lstm, self.hidden2)  
        to_tags = self.hidden2tag(lstm_out)
        output = self.softmax(to_tags)
        self.to_label_embed = self.label_embed(output)
        
        '''NLLLoss input: Input: (N,C) where C = number of classes'''
        return output.view(BATCH_SIZE*MAX_LEN, self.tagset_size)

In [6]:
content = readfile(train_data)
word_list, tag_list = split_to_list(content)
word_to_ix = word2index(word_list)
reserved_index = filter_len(word_list)
filter_word, filter_tag = filter_sentence(reserved_index, word_list, tag_list)
input_padded, target_padded = pad_all(filter_word, filter_tag)
#================================================
input_var = prepare_all(input_padded, word_to_ix)
target_var = prepare_all(target_padded, tag_to_ix)
#================================================
vocab_size = len(word_to_ix)

In [7]:
loader = dataload(input_var, target_var)
model = Entity_Typing(vocab_size, tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM1, HIDDEN_DIM2, \
              LABEL_EMBED_DIM).cuda()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
criterion = nn.NLLLoss()

In [8]:
import time
import math

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [9]:
n_iters = 10
print_every = 10
all_losses = []
total_loss = 0 # Reset every plot_every iters

start = time.time()

for epoch in tqdm(range(100)):  
    for step, (batch_x, batch_y) in enumerate(loader):
        optimizer.zero_grad()
        output = model(batch_x.cuda() if USE_CUDA else batch_x)
        batch_y = batch_y.view(BATCH_SIZE*MAX_LEN)
        loss = criterion(output, batch_y.cuda() if USE_CUDA else batch_y)
        loss.backward()
#         loss.backward(retain_graph=True)
        optimizer.step()
        
        #if step % print_every == 0:
        #    print('%.4f| epoch: %d| step: %d| %s' % (loss, epoch, step, timeSince(start)))
    print("epoch: %d | loss %.4f" % (epoch,loss))

  1%|          | 1/100 [00:03<05:04,  3.08s/it]

epoch: 0 | loss 0.0998


  2%|▏         | 2/100 [00:06<05:05,  3.12s/it]

epoch: 1 | loss 0.0719


  3%|▎         | 3/100 [00:09<05:01,  3.11s/it]

epoch: 2 | loss 0.0420


  4%|▍         | 4/100 [00:12<04:54,  3.07s/it]

epoch: 3 | loss 0.0294


  5%|▌         | 5/100 [00:15<04:49,  3.05s/it]

epoch: 4 | loss 0.0220


  6%|▌         | 6/100 [00:18<04:45,  3.04s/it]

epoch: 5 | loss 0.0210


  7%|▋         | 7/100 [00:21<04:41,  3.02s/it]

epoch: 6 | loss 0.0172


  8%|▊         | 8/100 [00:24<04:36,  3.01s/it]

epoch: 7 | loss 0.0175


  9%|▉         | 9/100 [00:26<04:31,  2.98s/it]

epoch: 8 | loss 0.0195


 10%|█         | 10/100 [00:29<04:26,  2.96s/it]

epoch: 9 | loss 0.0175


 11%|█         | 11/100 [00:32<04:20,  2.93s/it]

epoch: 10 | loss 0.0155


 12%|█▏        | 12/100 [00:34<04:15,  2.91s/it]

epoch: 11 | loss 0.0159


 13%|█▎        | 13/100 [00:37<04:10,  2.88s/it]

epoch: 12 | loss 0.0154


 14%|█▍        | 14/100 [00:40<04:09,  2.90s/it]

epoch: 13 | loss 0.0152


 15%|█▌        | 15/100 [00:43<04:05,  2.89s/it]

epoch: 14 | loss 0.0171


 16%|█▌        | 16/100 [00:46<04:03,  2.90s/it]

epoch: 15 | loss 0.0151


 17%|█▋        | 17/100 [00:49<04:00,  2.90s/it]

epoch: 16 | loss 0.0147


 18%|█▊        | 18/100 [00:52<03:58,  2.91s/it]

epoch: 17 | loss 0.0138


 19%|█▉        | 19/100 [00:55<03:56,  2.92s/it]

epoch: 18 | loss 0.0149


 20%|██        | 20/100 [00:58<03:53,  2.92s/it]

epoch: 19 | loss 0.0159


 21%|██        | 21/100 [01:01<03:50,  2.92s/it]

epoch: 20 | loss 0.0157


 22%|██▏       | 22/100 [01:04<03:48,  2.92s/it]

epoch: 21 | loss 0.0167


 23%|██▎       | 23/100 [01:07<03:45,  2.93s/it]

epoch: 22 | loss 0.0149


 24%|██▍       | 24/100 [01:10<03:42,  2.93s/it]

epoch: 23 | loss 0.0147


 25%|██▌       | 25/100 [01:13<03:40,  2.94s/it]

epoch: 24 | loss 0.0160


 26%|██▌       | 26/100 [01:16<03:37,  2.94s/it]

epoch: 25 | loss 0.0148


 27%|██▋       | 27/100 [01:19<03:34,  2.94s/it]

epoch: 26 | loss 0.0146


 28%|██▊       | 28/100 [01:21<03:30,  2.92s/it]

epoch: 27 | loss 0.0170


 29%|██▉       | 29/100 [01:24<03:27,  2.92s/it]

epoch: 28 | loss 0.0139


 30%|███       | 30/100 [01:27<03:24,  2.92s/it]

epoch: 29 | loss 0.0166


 31%|███       | 31/100 [01:30<03:21,  2.93s/it]

epoch: 30 | loss 0.0157


 32%|███▏      | 32/100 [01:33<03:18,  2.93s/it]

epoch: 31 | loss 0.0127


 33%|███▎      | 33/100 [01:36<03:16,  2.93s/it]

epoch: 32 | loss 0.0155


 34%|███▍      | 34/100 [01:39<03:13,  2.93s/it]

epoch: 33 | loss 0.0138


 35%|███▌      | 35/100 [01:42<03:10,  2.93s/it]

epoch: 34 | loss 0.0148


 36%|███▌      | 36/100 [01:45<03:07,  2.93s/it]

epoch: 35 | loss 0.0148


 37%|███▋      | 37/100 [01:48<03:05,  2.94s/it]

epoch: 36 | loss 0.0132


 38%|███▊      | 38/100 [01:51<03:01,  2.93s/it]

epoch: 37 | loss 0.0152


 39%|███▉      | 39/100 [01:54<02:58,  2.93s/it]

epoch: 38 | loss 0.0144


 40%|████      | 40/100 [01:57<02:55,  2.93s/it]

epoch: 39 | loss 0.0148


 41%|████      | 41/100 [01:59<02:52,  2.92s/it]

epoch: 40 | loss 0.0149


 42%|████▏     | 42/100 [02:01<02:48,  2.90s/it]

epoch: 41 | loss 0.0144


 43%|████▎     | 43/100 [02:04<02:45,  2.90s/it]

epoch: 42 | loss 0.0160


 44%|████▍     | 44/100 [02:07<02:42,  2.90s/it]

epoch: 43 | loss 0.0147


 45%|████▌     | 45/100 [02:10<02:39,  2.90s/it]

epoch: 44 | loss 0.0134


 46%|████▌     | 46/100 [02:13<02:36,  2.90s/it]

epoch: 45 | loss 0.0150


 47%|████▋     | 47/100 [02:16<02:33,  2.90s/it]

epoch: 46 | loss 0.0141


 48%|████▊     | 48/100 [02:19<02:30,  2.90s/it]

epoch: 47 | loss 0.0151


 49%|████▉     | 49/100 [02:22<02:27,  2.90s/it]

epoch: 48 | loss 0.0143


 50%|█████     | 50/100 [02:24<02:24,  2.90s/it]

epoch: 49 | loss 0.0134


 51%|█████     | 51/100 [02:27<02:22,  2.90s/it]

epoch: 50 | loss 0.0143


 52%|█████▏    | 52/100 [02:30<02:19,  2.90s/it]

epoch: 51 | loss 0.0126


 53%|█████▎    | 53/100 [02:33<02:16,  2.90s/it]

epoch: 52 | loss 0.0132


 54%|█████▍    | 54/100 [02:36<02:13,  2.90s/it]

epoch: 53 | loss 0.0154


 55%|█████▌    | 55/100 [02:38<02:09,  2.89s/it]

epoch: 54 | loss 0.0160


 56%|█████▌    | 56/100 [02:42<02:07,  2.89s/it]

epoch: 55 | loss 0.0152


 57%|█████▋    | 57/100 [02:45<02:04,  2.90s/it]

epoch: 56 | loss 0.0143


 58%|█████▊    | 58/100 [02:47<02:01,  2.89s/it]

epoch: 57 | loss 0.0145


 59%|█████▉    | 59/100 [02:50<01:58,  2.89s/it]

epoch: 58 | loss 0.0140


 60%|██████    | 60/100 [02:53<01:55,  2.89s/it]

epoch: 59 | loss 0.0151


 61%|██████    | 61/100 [02:55<01:52,  2.88s/it]

epoch: 60 | loss 0.0131


 62%|██████▏   | 62/100 [02:57<01:48,  2.86s/it]

epoch: 61 | loss 0.0162


 63%|██████▎   | 63/100 [02:59<01:45,  2.85s/it]

epoch: 62 | loss 0.0129


 64%|██████▍   | 64/100 [03:02<01:42,  2.85s/it]

epoch: 63 | loss 0.0130


 65%|██████▌   | 65/100 [03:05<01:39,  2.85s/it]

epoch: 64 | loss 0.0152


 66%|██████▌   | 66/100 [03:07<01:36,  2.85s/it]

epoch: 65 | loss 0.0141


 67%|██████▋   | 67/100 [03:10<01:33,  2.85s/it]

epoch: 66 | loss 0.0149


 68%|██████▊   | 68/100 [03:13<01:31,  2.85s/it]

epoch: 67 | loss 0.0139


 69%|██████▉   | 69/100 [03:16<01:28,  2.85s/it]

epoch: 68 | loss 0.0142


 70%|███████   | 70/100 [03:19<01:25,  2.85s/it]

epoch: 69 | loss 0.0166


 71%|███████   | 71/100 [03:21<01:22,  2.84s/it]

epoch: 70 | loss 0.0154


 72%|███████▏  | 72/100 [03:23<01:19,  2.82s/it]

epoch: 71 | loss 0.0147


 73%|███████▎  | 73/100 [03:25<01:15,  2.81s/it]

epoch: 72 | loss 0.0131


 74%|███████▍  | 74/100 [03:27<01:12,  2.80s/it]

epoch: 73 | loss 0.0151


 75%|███████▌  | 75/100 [03:28<01:09,  2.79s/it]

epoch: 74 | loss 0.0134


 76%|███████▌  | 76/100 [03:31<01:06,  2.78s/it]

epoch: 75 | loss 0.0154


 77%|███████▋  | 77/100 [03:33<01:03,  2.77s/it]

epoch: 76 | loss 0.0156


 78%|███████▊  | 78/100 [03:36<01:00,  2.77s/it]

epoch: 77 | loss 0.0154


 79%|███████▉  | 79/100 [03:39<00:58,  2.77s/it]

epoch: 78 | loss 0.0143


 80%|████████  | 80/100 [03:40<00:55,  2.76s/it]

epoch: 79 | loss 0.0134


 81%|████████  | 81/100 [03:42<00:52,  2.75s/it]

epoch: 80 | loss 0.0154


 82%|████████▏ | 82/100 [03:44<00:49,  2.74s/it]

epoch: 81 | loss 0.0133


 83%|████████▎ | 83/100 [03:47<00:46,  2.74s/it]

epoch: 82 | loss 0.0151


 84%|████████▍ | 84/100 [03:50<00:43,  2.74s/it]

epoch: 83 | loss 0.0134


 85%|████████▌ | 85/100 [03:53<00:41,  2.75s/it]

epoch: 84 | loss 0.0135


 86%|████████▌ | 86/100 [03:56<00:38,  2.75s/it]

epoch: 85 | loss 0.0141


 87%|████████▋ | 87/100 [03:59<00:35,  2.75s/it]

epoch: 86 | loss 0.0143


 88%|████████▊ | 88/100 [04:02<00:33,  2.76s/it]

epoch: 87 | loss 0.0141


 89%|████████▉ | 89/100 [04:05<00:30,  2.76s/it]

epoch: 88 | loss 0.0135


 90%|█████████ | 90/100 [04:08<00:27,  2.76s/it]

epoch: 89 | loss 0.0140


 91%|█████████ | 91/100 [04:11<00:24,  2.76s/it]

epoch: 90 | loss 0.0126


 92%|█████████▏| 92/100 [04:14<00:22,  2.76s/it]

epoch: 91 | loss 0.0141


 93%|█████████▎| 93/100 [04:17<00:19,  2.76s/it]

epoch: 92 | loss 0.0132


 94%|█████████▍| 94/100 [04:20<00:16,  2.77s/it]

epoch: 93 | loss 0.0159


 95%|█████████▌| 95/100 [04:23<00:13,  2.77s/it]

epoch: 94 | loss 0.0134


 96%|█████████▌| 96/100 [04:25<00:11,  2.77s/it]

epoch: 95 | loss 0.0132


 97%|█████████▋| 97/100 [04:28<00:08,  2.77s/it]

epoch: 96 | loss 0.0148


 98%|█████████▊| 98/100 [04:31<00:05,  2.77s/it]

epoch: 97 | loss 0.0149


 99%|█████████▉| 99/100 [04:33<00:02,  2.76s/it]

epoch: 98 | loss 0.0158


100%|██████████| 100/100 [04:35<00:00,  2.75s/it]

epoch: 99 | loss 0.0138





In [15]:
test_content = readfile(test_data)
word_list_test, tag_list_test = split_to_list(test_content)

In [51]:
# for one input
def easy_pad(easy_sent, easy_tar):
    easy_sent += [PAD_TAG for i in range(MAX_LEN-len(easy_sent))]
    easy_tar += [PAD_TAG for i in range(MAX_LEN-len(easy_tar))]
    
    return easy_sent, easy_tar

def easy_test(_input):
    _input = torch.unsqueeze(_input, 0).expand(128,100)
    return _input

def easy_output(output):
    output = output.view(128,100,6)[0].argmax(1)
    return output

In [52]:
easy_sent, easy_tar = easy_pad(word_list_test[3],tag_list_test[3])

In [53]:
input_test = prepare_sequence(easy_sent, word_to_ix)
target_test = prepare_sequence(easy_tar, tag_to_ix)

_input = easy_test(input_test)

In [55]:
# Check predictions after training
with torch.no_grad():
    output = model(_input.cuda() if USE_CUDA else _input)
    output = easy_output(output)
    
    print('predict :', output)
    print('true :', target_test)
    

predict : tensor([ 5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  3,  4,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2], device='cuda:0')
true : tensor([ 5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  3,  4,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2])


In [None]:
# for batch input