In [1]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import time
import os
import math
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
#plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def __save_model(model_name, model, root):
    if not os.path.isdir(root):
        os.mkdir(root)
    p = os.path.join(root, '{}-params.pkl'.format(model_name))
    torch.save(model.state_dict(), p)
    return p

def save_model(models, root='./model'):
    p = {}
    for k, m in models.items():
        p[k] = __save_model(k, m, root)
    return p

def __load_model(model_name, model, root):
    p = os.path.join(root, '{}-params.pkl'.format(model_name))
    if not os.path.isfile(p):
        raise AttributeError(
            "No model parameters file for {}!".format(model_name)
        )
    paras = torch.load(p)
    model.load_state_dict(paras)

def load_model(models, root='./model'):
    for k, m in models.items():
        __load_model(k, m, root)
        
def save_model_by_score(models, loss, root):
    p = os.path.join(root, 'state.pkl')
    state = None
    if os.path.isfile(p):
         state = torch.load(p)
            
    if state is not None and state['loss'] < loss:
        return;
    
    save_model(models, root)
    state = {'loss' : loss}
    torch.save(state, p)

In [4]:
class CharDict:
    def __init__(self):
        self.word2index = {}
        self.index2word = {}
        self.n_words = 0
        
        for i in range(26):
            self.addWord(chr(ord('a') + i))
        
        tokens = ["SOS", "EOS"]
        for t in tokens:
            self.addWord(t)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.index2word[self.n_words] = word
            self.n_words += 1

    def longtensorFromString(self, s):
        s = ["SOS"] + list(s) + ["EOS"]
        return torch.LongTensor([self.word2index[ch] for ch in s])
    
    def stringFromLongtensor(self, l):
        s = ""
        for i in l:
            ch = self.index2word[i.item()]
            if len(ch) > 1:
                continue
            s += ch
        return s

class wordsDataset(Dataset):
    def __init__(self, train=True):
        if train:
            f = './train.txt'
        else:
            f = './test.txt'
        self.datas = np.loadtxt(f, dtype=np.str)
        
        if train:
            self.datas = self.datas.reshape(-1)
        else:
            '''
            sp -> p
            sp -> pg
            sp -> tp
            sp -> tp
            p  -> tp
            sp -> pg
            p  -> sp
            pg -> sp
            pg -> p
            pg -> tp
            '''
            self.targets = np.array([
                [0, 3],
                [0, 2],
                [0, 1],
                [0, 1],
                [3, 1],
                [0, 2],
                [3, 0],
                [2, 0],
                [2, 3],
                [2, 1],
            ])
        
        #self.tenses = ['sp', 'tp', 'pg', 'p']
        self.tenses = [
            'simple-present', 
            'third-person', 
            'present-progressive', 
            'simple-past'
        ]
        self.chardict = CharDict()
        
        self.train = train
    
    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, index):
        if self.train:
            c = index % len(self.tenses)
            return self.chardict.longtensorFromString(self.datas[index]), c
        else:
            i = self.chardict.longtensorFromString(self.datas[index, 0])
            ci = self.targets[index, 0]
            o = self.chardict.longtensorFromString(self.datas[index, 1])
            co = self.targets[index, 1]
            
            return i, ci, o, co

In [5]:
#Encoder
class EncoderRNN(nn.Module):
    def __init__(
        self, word_size, hidden_size, latent_size, 
        num_condition, condition_size
    ):
        super(EncoderRNN, self).__init__()
        self.word_size = word_size
        self.hidden_size = hidden_size
        self.condition_size = condition_size
        self.latent_size = latent_size

        self.condition_embedding = nn.Embedding(num_condition, condition_size)
        self.word_embedding = nn.Embedding(word_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.mean = nn.Linear(hidden_size, latent_size)
        self.logvar = nn.Linear(hidden_size, latent_size)

    def forward(self, inputs, init_hidden, input_condition):
        c = self.condition(input_condition)
        
        # get (1,1,hidden_size)
        hidden = torch.cat((init_hidden, c), dim=2)
        
        # get (seq, 1, hidden_size)
        x = self.word_embedding(inputs).view(-1, 1, self.hidden_size)
        
        # get (seq, 1, hidden_size), (1, 1, hidden_size)
        outputs, hidden = self.gru(x, hidden)
        
        # get (1, 1, hidden_size)
        m = self.mean(hidden)
        logvar = self.logvar(hidden)
        
        z = self.sample_z() * torch.exp(logvar/2) + m
        
        return z

    def initHidden(self):
        return torch.zeros(
            1, 1, self.hidden_size - self.condition_size, 
            device=device
        )
    
    def condition(self, c):
        c = torch.LongTensor([c]).to(device)
        return self.condition_embedding(c).view(1,1,-1)
    
    def sample_z(self):
        return torch.normal(
            torch.FloatTensor([0]*self.latent_size), 
            torch.FloatTensor([1]*self.latent_size)
        ).to(device)

# TODO Teacher forcing

In [6]:
#Decoder
class DecoderRNN(nn.Module):
    def __init__(
        self, word_size, hidden_size, latent_size, condition_size
    ):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.word_size = word_size

        self.latent_to_hidden = nn.Linear(
            latent_size+condition_size, hidden_size
        )
        self.word_embedding = nn.Embedding(word_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, word_size)
        
    def forward(self, inputs, z, c):
        # get (1,1,latent_size + condition_size)
        latent = torch.cat((z, c), dim=2)
        
        # get (1,1,hidden_size)
        hidden = self.latent_to_hidden(latent)
        
        # get (seq, 1, hidden_size)
        x = self.word_embedding(inputs).view(-1, 1, self.hidden_size)
        
        # get (seq, 1, hidden_size), (1, 1, hidden_size)
        outputs, hidden = self.gru(x, hidden)
        
        # get (seq, word_size)
        outputs = self.out(outputs).view(-1, self.word_size)
        
        return outputs

In [7]:
# config

train_dataset = wordsDataset()
test_dataset = wordsDataset(False)

word_size = train_dataset.chardict.n_words
num_condition = len(train_dataset.tenses)
hidden_size = 256
latent_size = 32
condition_size = 8

teacher_forcing_ratio = 0.5
empty_input_ratio = 0.1
KLD_weight = 0.0
LR = 0.05

In [8]:
encoder = EncoderRNN(
    word_size, hidden_size, latent_size, num_condition, condition_size
).to(device)
decoder = DecoderRNN(
    word_size, hidden_size, latent_size, condition_size
).to(device)
encoder, decoder

(EncoderRNN(
   (condition_embedding): Embedding(4, 8)
   (word_embedding): Embedding(28, 256)
   (gru): GRU(256, 256)
   (mean): Linear(in_features=256, out_features=32, bias=True)
   (logvar): Linear(in_features=256, out_features=32, bias=True)
 ), DecoderRNN(
   (latent_to_hidden): Linear(in_features=40, out_features=256, bias=True)
   (word_embedding): Embedding(28, 256)
   (gru): GRU(256, 256)
   (out): Linear(in_features=256, out_features=28, bias=True)
 ))

# TODO KL weights

loss = cross_entropy + (kl w)*KL($q(Z|X, c;\theta') || p(Z|c)$)

In [9]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '{:4d}m {:2d}s'.format(int(m), int(s))


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

def trainEpochs(
    name, encoder, decoder, epoch_size, learning_rate=1e-2,
    show_size=1000,
):
    start = time.time()
    plot_losses = []
    show_loss_total = 0
    plot_loss_total = 0
    
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    
    criterion = nn.CrossEntropyLoss(reduction='sum')
    
    for epoch in range(epoch_size):
        # get data from trian dataset
        for idx in range(len(train_dataset)):   
        #for idx in range(1):
            data = train_dataset[idx]
            inputs, c = data
            
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            
            # input no sos and eos
            z = encoder(inputs[1:-1].to(device), encoder.initHidden(), c)
            
            # input has sos
            outputs = decoder(inputs[:-1].to(device), z, encoder.condition(c))
            
            # target no sos
            loss = criterion(outputs, inputs[1:].to(device))
            #loss = criterion(outputs, inputs[:-1].to(device))
            
            loss.backward()
            
            encoder_optimizer.step()
            decoder_optimizer.step()
            
            loss = loss.item()
            
            show_loss_total += loss
            plot_loss_total += loss
        
        if (epoch + 1)%show_size == 0:
            show_loss_total /= show_size
            print("{} ({} {}%) {:.4f}".format(
                timeSince(start, (epoch+1) / epoch_size),
                epoch+1, (epoch+1)*100/epoch_size, show_loss_total
            ))
            show_loss_total = 0
            
        plot_losses.append(plot_loss_total)
        
        save_model_by_score(
            {'encoder':encoder, 'decoder':decoder}, 
            plot_loss_total, 
            os.path.join('.', name)
        )
        
        plot_loss_total = 0
        
    return plot_losses

In [10]:
%%time
trainEpochs('result', encoder, decoder, 10, show_size=1)

   0m  9s (-    1m 27s) (1 10.0%) 47180.8536
   0m 19s (-    1m 16s) (2 20.0%) 27094.0264
   0m 28s (-    1m  6s) (3 30.0%) 21986.9741
   0m 38s (-    0m 57s) (4 40.0%) 19374.4651
   0m 47s (-    0m 47s) (5 50.0%) 18304.5404
   0m 56s (-    0m 37s) (6 60.0%) 18668.5839
   1m  6s (-    0m 28s) (7 70.0%) 17745.1296
   1m 15s (-    0m 18s) (8 80.0%) 18474.8561
   1m 25s (-    0m  9s) (9 90.0%) 18910.2260
   1m 34s (-    0m  0s) (10 100.0%) 19746.3051
CPU times: user 1min 32s, sys: 1.77 s, total: 1min 34s
Wall time: 1min 34s


[47180.85358595848,
 27094.026449918747,
 21986.97412443161,
 19374.4650554657,
 18304.54042696953,
 18668.583859443665,
 17745.12957930565,
 18474.856108903885,
 18910.226037979126,
 19746.3050968647]

In [11]:
#compute BLEU-4 score
def compute_bleu(output, reference):
    cc = SmoothingFunction()
    return sentence_bleu(
        [reference], output,
        weights=(0.25, 0.25, 0.25, 0.25),smoothing_function=cc.method1
    )