In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

import numpy as np

In [2]:
USE_CUDA = True

In [3]:
root = '/notebooks/sinica/dataset/'
train_data = root+'facial.train'
dev_data = root+'facial.dev'
test_data = root+'facial.test'

START_TAG = "<START>"
STOP_TAG = "<STOP>"
PAD_TAG = "<PAD>"
tag_to_ix = {START_TAG: 0, STOP_TAG: 1, PAD_TAG:2, "B-Func": 3, "I-Func": 4, "O": 5}

tagset_size = len(tag_to_ix)
MAX_LEN = 100
BATCH_SIZE = 128

EMBEDDING_DIM = 20
HIDDEN_DIM = 10

In [4]:
def readfile(data):
    with open(data, "r", encoding="utf-8") as f:
        content = f.read().splitlines()
        
    return content

def get_word_and_label(_content, start_w, end_w):
    word_list = []
    tag_list = []
    for word_set in _content[start_w:end_w]:
        word_list.append(word_set[0])
        tag_list.append(word_set[2:])
    
    return word_list, tag_list

def split_to_list(content):
    init = 0
    word_list = []
    tag_list = []

    for i, c in enumerate(content):
        if c=='':
            words, tags = get_word_and_label(content, init, i)
            init = i+1
            word_list.append(words)
            tag_list.append(tags)
            
    return word_list, tag_list
    
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

def prepare_all(seqs, to_ix):
    seq_list = []
    for i in range(len(seqs)):
        seq_list.append(prepare_sequence(seqs[i], to_ix))
        
    seq_list = torch.stack(seq_list)
    
#     if USE_CUDA:
#         seq_list = seq_list.cuda()
        
    return seq_list

def word2index(word_list):
    word_to_ix = {"<START>":0, "<STOP>":1, "<PAD>":2}
    for sentence in word_list:
        for word in sentence:
            if word not in word_to_ix:
                word_to_ix[word] = len(word_to_ix)
                
    return word_to_ix

def find_max_len(word_list):
    max_len = 0
    for i in range(len(word_list)):
        if max_len<len(word_list[i]):
            max_len=len(word_list[i])
            
    return max_len

def filter_len(word_list):
    reserved_index = []
    for i in range(len(word_list)):
        if len(word_list[i])<MAX_LEN:
            reserved_index.append(i)
            
    return reserved_index

def filter_sentence(reserved_index, word_list, tag_list):
    filter_word = list(word_list[i] for i in reserved_index)
    filter_tag = list(tag_list[i] for i in reserved_index)
    return filter_word, filter_tag

def pad_seq(seq):
    seq += [PAD_TAG for i in range(MAX_LEN-len(seq))]
    return seq

def pad_all(filter_word, filter_tag):
    input_padded = [pad_seq(s) for s in filter_word]
    target_padded = [pad_seq(s) for s in filter_tag]
    
    return input_padded, target_padded

#======================================
def dataload(input_var, target_var):
    torch_dataset = Data.TensorDataset(input_var, target_var)

    loader = Data.DataLoader(
        dataset=torch_dataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               
        num_workers=2,              
    )
    
    return loader

In [5]:
content = readfile(train_data)

In [6]:
word_list, tag_list = split_to_list(content)

In [7]:
word_to_ix = word2index(word_list)

In [8]:
prepare_sequence(word_list[0], word_to_ix)

tensor([  3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
         15,  16,  17,  18,  19,   3,   4,   5,   6,  20,   3,  21,
         22,  23,  24,  25,  26,  27,  28,   7,  29])

In [9]:
max_len = find_max_len(word_list)

In [10]:
reserved_index = filter_len(word_list)

In [11]:
filter_word, filter_tag = filter_sentence(reserved_index, word_list, tag_list)

In [12]:
input_padded, target_padded = pad_all(filter_word, filter_tag)

In [13]:
input_var = prepare_all(input_padded, word_to_ix)
target_var = prepare_all(target_padded, tag_to_ix)

In [14]:
loader = dataload(input_var, target_var)

In [15]:
for epoch in range(10):  
    for step, (batch_x, batch_y) in enumerate(loader):
        print(batch_x)
        print(batch_x.size())
        break

tensor([[  161,   162,   165,  ...,     2,     2,     2],
        [   15,   114,    21,  ...,     2,     2,     2],
        [  203,   165,   559,  ...,     2,     2,     2],
        ...,
        [  301,     9,    22,  ...,     2,     2,     2],
        [  301,     9,   633,  ...,     2,     2,     2],
        [    5,    35,   203,  ...,     2,     2,     2]])
torch.Size([128, 100])
tensor([[  116,   116,    15,  ...,     2,     2,     2],
        [    5,    35,   203,  ...,     2,     2,     2],
        [  727,    62,    49,  ...,     2,     2,     2],
        ...,
        [   13,    14,    45,  ...,     2,     2,     2],
        [  161,    56,    23,  ...,     2,     2,     2],
        [   12,   259,     4,  ...,     2,     2,     2]])
torch.Size([128, 100])
tensor([[  406,   232,   373,  ...,     2,     2,     2],
        [  281,   161,   162,  ...,     2,     2,     2],
        [   43,   135,   282,  ...,     2,     2,     2],
        ...,
        [  244,   259,   165,  ...,     2, 