In [1]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import optim
import numpy as np

import random
import os

seed = 10

In [2]:
TEST_FILE = "data/test.txt"
TRAIN_FILE = "data/train.txt"
F_VOCAB_FILE = "data/vocab.f.txt"
Q_VOCAB_FILE = "data/vocab.q.txt"
CHECKPOINT_FILE = "out/Attention/WithGradientClippingAndLearningRateDecay/checkpoints/model_seq2seq_attention.pt"
LOSS_DIR = "out/Attention/WithGradientClippingAndLearningRateDecay/losses"
METRICS_DIR = "out/Attention/WithGradientClippingAndLearningRateDecay/metrics"
ATTENTION_ALIGNMENTS_DIR = "out/Attention/WithGradientClippingAndLearningRateDecay/attention_alignments"

if not os.path.exists(os.path.dirname(CHECKPOINT_FILE)):
    os.makedirs(os.path.dirname(CHECKPOINT_FILE))
if not os.path.exists(LOSS_DIR):
    os.makedirs(LOSS_DIR)
if not os.path.exists(METRICS_DIR):
    os.makedirs(METRICS_DIR)
if not os.path.exists(ATTENTION_ALIGNMENTS_DIR):
    os.makedirs(ATTENTION_ALIGNMENTS_DIR)

In [3]:
class Options:
    rnn_size = 50
    init_weight = 0.08
    decay_rate = 0.985
    learning_rate = 0.01
    plot_every = 10
    print_every = 200
    grad_clip = 5
    dropout = 0
    dropoutrec = 0
    learning_rate_decay =  0.985
    learning_rate_decay_after = 5

In [4]:
class Language:
    def __init__(self, question_vocab, form_vocab):
        self.w2i = Language.__read_vocab(question_vocab)
        self.lf2i = Language.__read_vocab(form_vocab)
        self.i2lf = {self.lf2i[i] : i for i in self.lf2i}
        
    @staticmethod
    def __read_vocab(filename):
        t2i = {"<s>": 0, "</s>":1, "UNK": 2}
        with open(filename) as target:
            for line in target:
                token = line.strip().split()[0]
                if token not in t2i:
                    t2i[token] = len(t2i)
        return t2i

In [5]:
class Entry:
    def __init__(self, sentence, form, language):
        self.sentence = sentence
        self.form = form
        self.sentence_tensor = Entry.__create_index_tensor(sentence, language.w2i)
        self.form_tensor = Entry.__create_index_tensor(form, language.lf2i)
        self.predicted_form = None
    
    @staticmethod
    def __create_index_tensor(sequence, dictionary):
        tensor = torch.zeros((1, len(sequence) + 2), dtype=torch.long)
        tensor[0][0] = dictionary["<s>"]
        for idx, token in enumerate(sequence):
            token_index = dictionary[token] if token in dictionary else dictionary["UNK"]
            tensor[0][idx+1] = token_index
        tensor[0][-1] = dictionary["</s>"]
        return tensor

In [6]:
class Data:
    def __init__(self, train_file, test_file, question_vocab, form_vocab, shuffle_train=False):
        self.language = Language(question_vocab, form_vocab)
        self.train_entries = self.__read_data(train_file, self.language)
        self.test_entries = self.__read_data(test_file, self.language)
        if shuffle_train:
            random.shuffle(self.train_entries)
            
    @staticmethod
    def __read_data(file, language):
        entries = []
        with open(file) as target:
            for line in target:
                sentence, lf = line.strip().split("\t")
                sentence = sentence.split()
                lf = lf.split()
                entries.append(Entry(sentence, lf, language))
        return entries

In [7]:
class LSTM(nn.Module):
    def __init__(self, opt):
        super(LSTM, self).__init__()
        self.opt = opt
        self.i2h = nn.Linear(opt.rnn_size, 4 * opt.rnn_size)
        self.h2h = nn.Linear(opt.rnn_size, 4 * opt.rnn_size)
        if opt.dropoutrec > 0:
            self.dropout = nn.Dropout(opt.dropoutrec)
            
    def forward(self, x, prev_c, prev_h):
        gates = self.i2h(x) + self.h2h(prev_h)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)
        if self.opt.dropoutrec > 0:
            cellgate = self.dropout(cellgate)
        cy = (forgetgate * prev_c) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)  # n_b x hidden_dim
        return cy, hy


In [8]:
class RNN(nn.Module):
    def __init__(self, opt, input_size):
        super(RNN, self).__init__()
        self.opt = opt
        self.hidden_size = opt.rnn_size
        self.embedding = nn.Embedding(input_size, self.hidden_size)
        self.lstm = LSTM(self.opt)
        if opt.dropout > 0:
            self.dropout = nn.Dropout(opt.dropout)
        self.__initParameters()

    def __initParameters(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                init.uniform_(param, -self.opt.init_weight, self.opt.init_weight)
                
    def forward(self, input_src, prev_c, prev_h):
        src_emb = self.embedding(input_src) # batch_size x src_length x emb_size
        if self.opt.dropout > 0:
            src_emb = self.dropout(src_emb)
        prev_cy, prev_hy = self.lstm(src_emb, prev_c, prev_h)
        return prev_cy, prev_hy

In [9]:
class Attention(nn.Module):
    def __init__(self, opt, output_size):
        super(Attention, self).__init__()
        self.opt = opt
        self.hidden_size = opt.rnn_size

        self.linear_att = nn.Linear(2*self.hidden_size, self.hidden_size)
        self.linear_out = nn.Linear(self.hidden_size, output_size)
        if opt.dropout > 0:
            self.dropout = nn.Dropout(opt.dropout)

        self.softmax = nn.Softmax(dim=1)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.__initParameters()
    
    def __initParameters(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                init.uniform_(param, -self.opt.init_weight, self.opt.init_weight)

    def forward(self, enc_s_top, dec_s_top):
        dot = torch.bmm(enc_s_top, dec_s_top.unsqueeze(2))
        attention = self.softmax(dot.squeeze(2)).unsqueeze(2)
        enc_attention = torch.bmm(enc_s_top.permute(0,2,1), attention)
        hid = torch.tanh(self.linear_att(torch.cat((enc_attention.squeeze(2),dec_s_top), 1)))
        h2y_in = hid
        if self.opt.dropout > 0:
            h2y_in = self.dropout(h2y_in)
        h2y = self.linear_out(h2y_in)
        pred = self.logsoftmax(h2y)
        return pred, attention

In [24]:
class Model:
    def __init__(self):
        self.opt = None
        self.encoder = None
        self.decoder = None
        self.attention = None
        self.optimizers = {}
        self.criterion = None
        
    def create(self, opt, data):
        self.opt = opt
        self.encoder = RNN(self.opt, len(data.language.w2i))
        self.decoder = RNN(self.opt, len(data.language.lf2i))
        self.attention = Attention(self.opt, len(data.language.lf2i))
        self.optimizers["encoder_optimizer"] = optim.RMSprop(self.encoder.parameters(), lr=self.opt.learning_rate, alpha=self.opt.decay_rate)
        self.optimizers["decoder_optimizer"] = optim.RMSprop(self.decoder.parameters(), lr=self.opt.learning_rate, alpha=self.opt.decay_rate)
        self.optimizers["attention_optimizer"] = optim.RMSprop(self.attention.parameters(), lr=self.opt.learning_rate, alpha=self.opt.decay_rate)
        self.criterion = nn.NLLLoss(ignore_index=0)
 
    def train(self):
        self.encoder.train()
        self.decoder.train()
        self.attention.train()
        
    def eval(self):
        self.encoder.eval()
        self.decoder.eval()
        self.attention.eval()
    
    def step(self):
        for optimizer in self.optimizers:
            self.optimizers[optimizer].step()
            
    def zero_grad(self):
        for optimizer in self.optimizers:
            self.optimizers[optimizer].zero_grad()
            
    def rate_decay(self):
        for optimizer in self.optimizers:
            for param_group in optimizers[optimizer].param_groups:
                param_group['lr'] = param_group['lr'] * self.opt.learning_rate_decay
        
    def grad_clip(self):
        torch.nn.utils.clip_grad_value_(self.encoder.parameters(), self.opt.grad_clip)
        torch.nn.utils.clip_grad_value_(self.decoder.parameters(), self.opt.grad_clip)
        torch.nn.utils.clip_grad_value_(self.attention.parameters(), self.opt.grad_clip)
            
    def save(self, filename):
        checkpoint = {}
        checkpoint["opt"] = self.opt
        checkpoint["encoder"] = self.encoder.state_dict()
        checkpoint["decoder"] = self.decoder.state_dict()
        checkpoint["attention"] = self.attention.state_dict()
        for o in self.optimizers:
             checkpoint[o] = self.optimizers[o].state_dict()
        torch.save(checkpoint, filename)
    
    def load(self, filename, data):
        checkpoint = torch.load(filename)
    
        opt = checkpoint["opt"]
        self.create(opt, data)
        
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])
        self.attention.load_state_dict(checkpoint["attention"])
        for o in self.optimizers:
            self.optimizers[o].load_state_dict(checkpoint[o])

In [11]:
def train(model, sentence, form):
    model.zero_grad()

    outputs = torch.zeros((1, sentence.size(1), model.opt.rnn_size), dtype=torch.float, requires_grad=False)
    c = torch.zeros((1, model.opt.rnn_size), dtype=torch.float, requires_grad=True)
    h = torch.zeros((1, model.opt.rnn_size), dtype=torch.float, requires_grad=True)
    for i in range(sentence.size(1)):
        c, h = model.encoder(sentence[:, i], c, h)
        outputs[:,i, :] = h
        
    loss = 0
    for i in range(form.size(1)-1):
        c, h = model.decoder(form[:, i], c, h)
        pred, attention_weights = model.attention(outputs, h)
        loss += model.criterion(pred, form[:, i+1])
        
    loss.backward()
    if model.opt.grad_clip != -1:
        model.grad_clip()
    model.step()
    return loss

In [12]:
def predict(model, sentence, lf2i):
    c = torch.zeros((1, model.opt.rnn_size), dtype=torch.float, requires_grad=True)
    h = torch.zeros((1, model.opt.rnn_size), dtype=torch.float, requires_grad=True)
    outputs = torch.zeros((1, sentence.size(1), model.opt.rnn_size), dtype=torch.float, requires_grad=True)

    for i in range(sentence.size(1)):
        c, h = model.encoder(sentence[:, i], c, h)
        outputs[:,i, :] = h
            
    prev = torch.tensor([lf2i['<s>']], dtype=torch.long)
    predicted_form = []
    counter = 0
    max_length = 100
    decoder_attentions = torch.zeros(max_length, sentence.size(1))
    while True:
        c, h = model.decoder(prev, c, h)
        pred, attention_weights = model.attention(outputs, h)
        decoder_attentions[counter] = attention_weights.view(-1)
        form_id = pred.argmax().item()
        prev = torch.tensor([form_id], dtype=torch.long)
        counter += 1
        if form_id == lf2i["</s>"] or counter >= 100:
            break
        predicted_form.append(form_id)
    return predicted_form, decoder_attentions[:counter]

In [13]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker


def showPlot(points, fig_name, extra_info):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.title(extra_info) 
    plt.plot(points)
    plt.savefig("{}.png".format(fig_name))
    plt.close()

def showAttention(fig_name, input_sentence, output_words, attentions):
    # Set up figure with colorbar
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions, cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    input_sentence = ['']  + ['<s>'] + input_sentence+ ['</s>']
    output_words = ['']  +output_words + ["</s>"]
    ax.set_xticklabels(input_sentence, rotation=90)
    ax.set_yticklabels( output_words)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    plt.savefig("{}.png".format(fig_name))
    plt.close()

def evaluateAndShowAttention(fig_name, input_sentence, input_tensor, preds, attentions, i2lf):
    output_words = [i2lf[i] for i in preds]
    showAttention(fig_name, input_sentence, output_words, attentions.numpy())

In [21]:
def train_and_test(epoch_num):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    data = Data(TRAIN_FILE, TEST_FILE, Q_VOCAB_FILE, F_VOCAB_FILE, shuffle_train=True)
    model = Model()
    model.create(Options(), data)

    losses = []
    max_acc = 0
    maxAccEpochId = 0
    accuracies = []
    chunkend_losses = []
    for epoch in range(epoch_num):
        print("---Epoch {}---\n".format(epoch+1))
        
        print("Training...")
        model.train()
        plot_data = []
        for index, entry in enumerate(data.train_entries):
            loss = train(model,
                         entry.sentence_tensor,
                         entry.form_tensor)
            if index != 0:
                if index % model.opt.plot_every == 0:     
                    plot_data.append(np.mean(losses[epoch*len(data.train_entries)+index-model.opt.plot_every:]))
                    chunkend_losses.append(np.mean(plot_data))
                if index % model.opt.print_every == 0:
                    print("Index {} Loss {}".format(index,
                                                    np.mean(losses[epoch*len(data.train_entries)+index-model.opt.print_every:])))
            losses.append(loss.item())

        if model.opt.learning_rate_decay < 1:
            if epoch >= model.opt.learning_rate_decay_after:
                model.rate_decay()

        print("Predicting..")
        model.eval()
        correct = 0.0
        with torch.no_grad():
            for index, entry in enumerate(data.test_entries):
                prediction, attention_weights = predict(model, entry.sentence_tensor, data.language.lf2i)
                entry.prediction = [data.language.i2lf[p] for p in prediction]
                if len(entry.form) == len(entry.prediction):
                    same = True
                    for g, p in zip(entry.form, entry.prediction):
                        if g != p:
                            same = False
                    if same:
                        correct += 1
  
        accuracy = 100*(correct/len(data.test_entries))
        accuracies.append(accuracy)
        if accuracy > max_acc:
            max_acc = accuracy
            maxAccEpochId = epoch
            model.save(CHECKPOINT_FILE)

        print("Accuracy: {} Max Accuracy {}".format(accuracy, max_acc))
        file_name = "{}/epoch.{}".format(LOSS_DIR, epoch)
        extra = "Mean Loss {0:.2f}".format(np.mean(plot_data))
        showPlot(plot_data, file_name, extra)

    file_name = "{}/{}".format(METRICS_DIR, "accuracies")
    extra = "Maximum Accuracy {0:.2f} at epoch {1}".format(np.max(accuracies), maxAccEpochId)
    showPlot(accuracies, file_name, extra)
    file_name = "{}/{}".format(LOSS_DIR, "all_losses") 
    extra = "Mean Loss {0:.2f}".format(np.mean(chunkend_losses))
    showPlot(chunkend_losses, file_name, extra)


In [15]:
def show_predictions(model, data):
    model.eval()
    correct = 0.0
    with torch.no_grad():
        for index, entry in enumerate(data.test_entries):
            prediction, attention_weights = predict(model, entry.sentence_tensor, data.language.lf2i)
            entry.prediction = [data.language.i2lf[p] for p in prediction]
            showAttention("{}/entry.{}".format(ATTENTION_ALIGNMENTS_DIR, index+1), entry.sentence, entry.prediction, attention_weights.numpy())
            if len(entry.form) == len(entry.prediction):
                same = True
                for g, p in zip(entry.form, entry.prediction):
                    if g != p:
                        same = False
                if same:
                    correct += 1

    accuracy = 100*(correct/len(data.test_entries))
    print("Accuracy : {}".format(accuracy))

In [22]:
train_and_test(epoch_num=100)

---Epoch 1---

Training...
Index 200 Loss 24.758871879577637
Index 400 Loss 14.563745485544205
Predicting..
Accuracy: 8.214285714285714 Max Accuracy 8.214285714285714
---Epoch 2---

Training...
Index 200 Loss 10.910113530158997
Index 400 Loss 8.931237440109253
Predicting..
Accuracy: 27.142857142857142 Max Accuracy 27.142857142857142


<matplotlib.figure.Figure at 0x1094c1630>

<matplotlib.figure.Figure at 0x11565a8d0>

<matplotlib.figure.Figure at 0x115a62828>

<matplotlib.figure.Figure at 0x11a15f0b8>

In [25]:
data = Data(TRAIN_FILE, TEST_FILE, Q_VOCAB_FILE, F_VOCAB_FILE)
model = Model()
model.load(CHECKPOINT_FILE, data)

In [26]:
show_predictions(model, data)

Accuracy : 27.142857142857142
