In [1]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import optim
import numpy as np

import random
import os

seed = 10

random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [2]:
TEST_FILE = "data/test.txt"
TRAIN_FILE = "data/train.txt"
WHOLE_FILE = "data/whole.txt"
F_VOCAB_FILE = "data/vocab.f.txt"
Q_VOCAB_FILE = "data/vocab.q.txt"

In [3]:
class Options:
    def __init__(self):
        self.rnn_size = 50
        self.init_weight = 0.08
        self.decay_rate = 0.985
        self.learning_rate = 0.01
        self.plot_every = 10
        self.print_every = 50
        self.grad_clip = 5
        self.dropout = 0
        self.dropoutrec = 0
        self.learning_rate_decay =  0.985
        self.learning_rate_decay_after = 5
        
        
opt = Options()

In [4]:
class LSTM(nn.Module):
    def __init__(self, opt):
        super(LSTM, self).__init__()
        self.opt = opt
        self.i2h = nn.Linear(opt.rnn_size, 4 * opt.rnn_size)
        self.h2h = nn.Linear(opt.rnn_size, 4 * opt.rnn_size)
        if opt.dropoutrec > 0:
            self.dropout = nn.Dropout(opt.dropoutrec)
            
    def forward(self, x, prev_c, prev_h):
        gates = self.i2h(x) + self.h2h(prev_h)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)
        if self.opt.dropoutrec > 0:
            cellgate = self.dropout(cellgate)
        cy = (forgetgate * prev_c) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)  # n_b x hidden_dim
        return cy, hy


In [5]:
class RNN(nn.Module):
    def __init__(self, opt, input_size):
        super(RNN, self).__init__()
        self.opt = opt
        self.hidden_size = opt.rnn_size
        self.embedding = nn.Embedding(input_size, self.hidden_size)
        self.lstm = LSTM(self.opt)
        if opt.dropout > 0:
            self.dropout = nn.Dropout(opt.dropout)
        self.__initParameters()

    def __initParameters(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                init.uniform_(param, -opt.init_weight, opt.init_weight)
                
    def forward(self, input_src, prev_c, prev_h):
        src_emb = self.embedding(input_src) # batch_size x src_length x emb_size
        if self.opt.dropout > 0:
            src_emb = self.dropout(src_emb)
        prev_cy, prev_hy = self.lstm(src_emb, prev_c, prev_h)
        return prev_cy, prev_hy

In [6]:
class Attention(nn.Module):
    def __init__(self, opt, output_size):
        super(Attention, self).__init__()
        self.opt = opt
        self.hidden_size = opt.rnn_size

        self.linear_att = nn.Linear(2*self.hidden_size, self.hidden_size)
        self.linear_out = nn.Linear(self.hidden_size, output_size)
        if opt.dropout > 0:
            self.dropout = nn.Dropout(opt.dropout)

        self.softmax = nn.Softmax(dim=1)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.__initParameters()
    
    def __initParameters(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                init.uniform_(param, -opt.init_weight, opt.init_weight)

    def forward(self, enc_s_top, dec_s_top):
        dot = torch.bmm(enc_s_top, dec_s_top.unsqueeze(2))
        attention = self.softmax(dot.squeeze(2)).unsqueeze(2)
        enc_attention = torch.bmm(enc_s_top.permute(0,2,1), attention)
        hid = torch.tanh(self.linear_att(torch.cat((enc_attention.squeeze(2),dec_s_top), 1)))
        h2y_in = hid
        if self.opt.dropout > 0:
            h2y_in = self.dropout(h2y_in)
        h2y = self.linear_out(h2y_in)
        pred = self.logsoftmax(h2y)
        return pred

In [7]:
def read_data(fh):
    for line in fh:
        sentence, lf = line.strip().split("\t")
        sentence = sentence.split()
        lf = lf.split()
        yield sentence, lf

def read_vocab(filename):
    t2i = {"<s>": 0, "</s>":1, "UNK": 2}
    with open(filename) as target:
        for line in target:
            token = line.strip().split()[0]
            if token not in t2i:
                t2i[token] = len(t2i)
    return t2i

def is_equal(gold, predictions):
    total_correct = 0.0
    if len(gold) == len(predictions):
        equal = True
        for g, p in zip(gold, predictions):
            if g != p:
                equal = False
        return equal
    return False

In [8]:
def preprare_data(file_name):
    shuffledData = None
    with open(TRAIN_FILE, 'r') as train:
        shuffledData = list(read_data(train))
        random.shuffle(shuffledData)
    sentence_index_tensors = []
    form_index_tensors = []
    for sentence in shuffledData:
        text_tensor = torch.zeros((1, len(sentence[0]) + 2), dtype=torch.long)
        text_tensor[0][0] = w2i["<s>"]
        for idx, word in enumerate(sentence[0]):
            word_index = w2i[word] if word in w2i else w2i["UNK"]
            text_tensor[0][idx+1] = word_index
        text_tensor[0][-1] = w2i["</s>"]
        sentence_index_tensors.append(text_tensor)
        form_tensor = torch.zeros((1, len(sentence[1]) + 2), dtype=torch.long)
        form_tensor[0][0] = lf2i["<s>"]
        for idx, form in enumerate(sentence[1]):
            form_index = lf2i[form] if form in lf2i else lf2i["UNK"]
            form_tensor[0][idx+1] = form_index
        form_tensor[0][-1] = lf2i["</s>"]
        form_index_tensors.append(form_tensor)
    return shuffledData, sentence_index_tensors, form_index_tensors

In [9]:
def train(opt, criterion, encoder_optimizer, decoder_optimizer, attention_optimizer, encoder, decoder, attention, s1, f1):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    attention_optimizer.zero_grad()

    outputs = torch.zeros((1, s1.size(1), opt.rnn_size), dtype=torch.float, requires_grad=False)
    c = torch.zeros((1, opt.rnn_size), dtype=torch.float, requires_grad=True)
    h = torch.zeros((1, opt.rnn_size), dtype=torch.float, requires_grad=True)
    for i in range(s1.size(1)):
        c, h = encoder(s1[:, i], c, h)
        outputs[:,i, :] = h
        
    loss = 0
    for i in range(f1.size(1)-1):
        c, h = decoder(f1[:, i], c, h)
        pred = attention(outputs, h)
        loss += criterion(pred, f1[:, i+1])
        
    loss.backward()
    if opt.grad_clip != -1:
        torch.nn.utils.clip_grad_value_(encoder.parameters(),opt.grad_clip)
        torch.nn.utils.clip_grad_value_(decoder.parameters(),opt.grad_clip)
        torch.nn.utils.clip_grad_value_(attention.parameters(),opt.grad_clip)

    encoder_optimizer.step()
    decoder_optimizer.step()
    attention_optimizer.step()
    return loss

In [10]:
def predict(opt, s1, lf2i, encoder, decoder, attention):
    c = torch.zeros((1, opt.rnn_size), dtype=torch.float, requires_grad=True)
    h = torch.zeros((1, opt.rnn_size), dtype=torch.float, requires_grad=True)
    outputs = torch.zeros((1, s1.size(1), opt.rnn_size), dtype=torch.float, requires_grad=True)

    for i in range(s1.size(1)):
        c, h = encoder(s1[:, i], c, h)
        outputs[:,i, :] = h
            
    prev = torch.tensor([lf2i['<s>']], dtype=torch.long)
    predicted_form = []
    counter = 0
    while True:
        counter += 1
        c, h = decoder(prev, c, h)
        pred = attention(outputs, h)
        form_id = pred.argmax().item()
        prev = torch.tensor([form_id], dtype=torch.long)
        if form_id == lf2i["</s>"] or counter >= 100:
            break
        predicted_form.append(form_id)
    return predicted_form

In [11]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker


def showPlot(points, fig_name, extra_info):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.title(extra_info) 
    plt.plot(points)
    plt.savefig("{}.png".format(fig_name))
    plt.close('all')

In [12]:
w2i = read_vocab(Q_VOCAB_FILE)
lf2i = read_vocab(F_VOCAB_FILE)
i2lf = {lf2i[i] : i for i in lf2i}

In [13]:
def train_and_test(epoch_num, directory):
    train_data, sentence_index_tensors_train, form_index_tensors_train = preprare_data(TRAIN_FILE)
    test_data, sentence_index_tensors_test, form_index_tensors_test = preprare_data(TEST_FILE)
    
    encoder = RNN(opt, len(w2i))
    decoder = RNN(opt, len(lf2i))
    attention = Attention(opt, len(lf2i))
    optim_state = {"learningRate" : opt.learning_rate, "alpha" :  opt.decay_rate}
    encoder_optimizer = optim.RMSprop(encoder.parameters(),  lr=optim_state["learningRate"], alpha=optim_state["alpha"])
    decoder_optimizer = optim.RMSprop(decoder.parameters(),  lr=optim_state["learningRate"], alpha=optim_state["alpha"])
    attention_optimizer = optim.RMSprop(attention.parameters(),  lr=optim_state["learningRate"], alpha=optim_state["alpha"])

    criterion = nn.NLLLoss(ignore_index=0)

    losses = []
    max_acc = 0
    maxAccEpochId = 0
    accuracies = []
    chunkend_losses = []
    for epoch in range(epoch_num):
        print("---Epoch {}---\n".format(epoch+1))
        print("Training...")
        encoder.train()
        decoder.train()
        attention.train()
        plot_data = []
        for index, (sentence, form) in enumerate(zip(sentence_index_tensors_train, form_index_tensors_train)):
            loss = train(opt, criterion, encoder_optimizer, decoder_optimizer, attention_optimizer, encoder, decoder, attention, sentence, form)
            if index != 0:
                if index % opt.plot_every == 0:     
                    plot_data.append(np.mean(losses[epoch*len(train_data)+index-opt.plot_every:]))
                    chunkend_losses.append(np.mean(plot_data))
                if index % opt.print_every == 0:
                    print("Index {} Loss {}".format(index, np.mean(losses[epoch*len(train_data)+index-opt.print_every:])))
            losses.append(loss.item())

        if opt.learning_rate_decay < 1:
            if epoch >= opt.learning_rate_decay_after:
                decay_factor = opt.learning_rate_decay
                optim_state["learningRate"] = optim_state["learningRate"] * decay_factor #decay it
                for param_group in encoder_optimizer.param_groups:
                    param_group['lr'] = optim_state["learningRate"]
                for param_group in decoder_optimizer.param_groups:
                    param_group['lr'] = optim_state["learningRate"]
                for param_group in attention_optimizer.param_groups:
                    param_group['lr'] = optim_state["learningRate"]

        print("Predicting..")
        encoder.eval()
        decoder.eval()
        attention.eval()
        correct = 0.0
        with torch.no_grad():
            for index, (sentence, form) in enumerate(zip(sentence_index_tensors_test, form_index_tensors_test)):
                prediction = predict(opt, sentence, lf2i, encoder, decoder, attention)
                prediction = [i2lf[p] for p in prediction]
                #print(test_data[index][1])
                #print(prediction)
                if len(test_data[index][1]) == len(prediction):
                    same = True
                    for g, p in zip(test_data[index][1], prediction):
                        if g != p:
                            same = False
                    if same:
                        correct += 1

        accuracy = 100*(correct/len(test_data))
        accuracies.append(accuracy)
        if accuracy > max_acc:
            max_acc = accuracy
            maxAccEpochId = epoch

        print("Accuracy: {} Max Accuracy {}".format(accuracy, max_acc))

        if not os.path.exists(directory):
            os.makedirs(directory)

        file_name = "{}/epoch.{}".format(directory, epoch)
        extra = "Mean Loss {0:.2f}".format(np.mean(plot_data))
        showPlot(plot_data, file_name, extra)

    file_name = "{}/{}".format(directory, "accuracies")
    extra = "Maximum Accuracy {0:.2f} at epoch {1}".format(np.max(accuracies), maxAccEpochId)
    showPlot(accuracies, file_name, extra)
    file_name = "{}/{}".format(directory, "all_losses") 
    extra = "Mean Loss {0:.2f}".format(np.mean(chunkend_losses))
    showPlot(chunkend_losses, file_name, extra)


In [14]:
#train_and_test(100, "out/Attention/BaseExperiment.required_grad_false")

In [15]:
train_and_test(100, "out/Attention/WithGradientClippingAndLearningRateDecay")

---Epoch 1---

Training...
Index 50 Loss 41.97954932212829
Index 100 Loss 20.482985944747924
Index 150 Loss 19.88434991836548
Index 200 Loss 16.688602333068847
Index 250 Loss 11.353065056800842
Index 300 Loss 14.88828067779541
Index 350 Loss 14.386179599761963
Index 400 Loss 17.627456607818605
Index 450 Loss 13.326656484603882
Index 500 Loss 14.242035932540894
Index 550 Loss 9.719332494735717
Predicting..
Accuracy: 7.166666666666667 Max Accuracy 7.166666666666667
---Epoch 2---

Training...
Index 50 Loss 13.416100749969482
Index 100 Loss 9.839723510742187
Index 150 Loss 10.965814399719239
Index 200 Loss 9.418815460205078
Index 250 Loss 6.330221309661865
Index 300 Loss 9.345489149093629
Index 350 Loss 8.623284606933593
Index 400 Loss 11.425954694747924
Index 450 Loss 8.220655422210694
Index 500 Loss 9.958322334289551
Index 550 Loss 6.825116271972656
Predicting..
Accuracy: 31.5 Max Accuracy 31.5
---Epoch 3---

Training...
Index 50 Loss 10.03083477973938
Index 100 Loss 6.705715613365173
In

Predicting..
Accuracy: 78.83333333333333 Max Accuracy 83.83333333333334
---Epoch 19---

Training...
Index 50 Loss 2.1990102195739745
Index 100 Loss 1.4025133323669434
Index 150 Loss 1.3063645267486572
Index 200 Loss 0.6221094608306885
Index 250 Loss 0.6473284530639648
Index 300 Loss 1.5741555404663086
Index 350 Loss 1.071119375228882
Index 400 Loss 1.456754903793335
Index 450 Loss 1.0423159313201904
Index 500 Loss 1.6220253658294679
Index 550 Loss 0.6622550868988037
Predicting..
Accuracy: 79.16666666666666 Max Accuracy 83.83333333333334
---Epoch 20---

Training...
Index 50 Loss 1.572121057510376
Index 100 Loss 1.3988638496398926
Index 150 Loss 1.1362424182891846
Index 200 Loss 0.9564263534545898
Index 250 Loss 0.28587471961975097
Index 300 Loss 0.94455979347229
Index 350 Loss 0.8909911632537841
Index 400 Loss 1.2073979187011719
Index 450 Loss 1.0104517650604248
Index 500 Loss 1.4400315380096436
Index 550 Loss 0.9718067169189453
Predicting..
Accuracy: 80.5 Max Accuracy 83.83333333333334

Index 300 Loss 0.35222079515457155
Index 350 Loss 0.0894132137298584
Index 400 Loss 0.5227732276916504
Index 450 Loss 0.4157846403121948
Index 500 Loss 0.6021114826202393
Index 550 Loss 0.219708890914917
Predicting..
Accuracy: 94.16666666666667 Max Accuracy 94.16666666666667
---Epoch 37---

Training...
Index 50 Loss 0.6854583263397217
Index 100 Loss 0.46393558502197263
Index 150 Loss 0.6090233802795411
Index 200 Loss 0.5767447853088379
Index 250 Loss 0.10915672302246093
Index 300 Loss 0.21147939682006836
Index 350 Loss 0.1313382625579834
Index 400 Loss 0.5821598243713378
Index 450 Loss 0.33011043548583985
Index 500 Loss 0.7743584537506103
Index 550 Loss 0.3640983867645264
Predicting..
Accuracy: 92.83333333333333 Max Accuracy 94.16666666666667
---Epoch 38---

Training...
Index 50 Loss 1.3162111568450927
Index 100 Loss 0.7639776992797852
Index 150 Loss 0.6088557243347168
Index 200 Loss 0.3641281032562256
Index 250 Loss 0.017347488403320312
Index 300 Loss 0.30410518646240237
Index 350 Los

Index 550 Loss 0.6380233287811279
Predicting..
Accuracy: 97.5 Max Accuracy 97.5
---Epoch 54---

Training...
Index 50 Loss 0.24075490951538087
Index 100 Loss 0.10400729179382324
Index 150 Loss 0.291354808807373
Index 200 Loss 0.2694266891479492
Index 250 Loss 0.06679290771484375
Index 300 Loss 0.08801857948303222
Index 350 Loss 0.1389940357208252
Index 400 Loss 0.09743263244628907
Index 450 Loss 0.07249117136001587
Index 500 Loss 0.1531997060775757
Index 550 Loss 0.11463861703872681
Predicting..
Accuracy: 97.83333333333334 Max Accuracy 97.83333333333334
---Epoch 55---

Training...
Index 50 Loss 0.13407788276672364
Index 100 Loss 0.17035475730895996
Index 150 Loss 0.6904112625122071
Index 200 Loss 0.056035614013671874
Index 250 Loss 0.019073867797851564
Index 300 Loss 0.15106025218963623
Index 350 Loss 0.038874607086181644
Index 400 Loss 0.08158164501190185
Index 450 Loss 0.16020237922668457
Index 500 Loss 0.2753929805755615
Index 550 Loss 0.07304740905761718
Predicting..
Accuracy: 95.0 

Index 100 Loss 0.014113454818725587
Index 150 Loss 0.10284672737121582
Index 200 Loss 0.06603322982788086
Index 250 Loss 0.005804805755615234
Index 300 Loss 0.0071156883239746095
Index 350 Loss 0.014759678840637207
Index 400 Loss 0.006152191162109375
Index 450 Loss 0.038515617847442625
Index 500 Loss 0.2814496421813965
Index 550 Loss 0.043446893692016604
Predicting..
Accuracy: 99.0 Max Accuracy 99.16666666666667
---Epoch 72---

Training...
Index 50 Loss 0.22654109001159667
Index 100 Loss 0.018632688522338868
Index 150 Loss 0.08878922462463379
Index 200 Loss 0.014737939834594727
Index 250 Loss 0.006326980590820312
Index 300 Loss 0.009355564117431641
Index 350 Loss 0.002128705978393555
Index 400 Loss 0.010801048278808593
Index 450 Loss 0.07330493927001953
Index 500 Loss 0.11248868942260742
Index 550 Loss 0.15038013458251953
Predicting..
Accuracy: 97.83333333333334 Max Accuracy 99.16666666666667
---Epoch 73---

Training...
Index 50 Loss 0.2273752498626709
Index 100 Loss 0.0565216636657714

Index 50 Loss 0.03648658752441406
Index 100 Loss 0.02005523681640625
Index 150 Loss 0.004096012115478515
Index 200 Loss 0.0002939414978027344
Index 250 Loss 0.00011381149291992188
Index 300 Loss 0.0003295421600341797
Index 350 Loss 0.00023969650268554688
Index 400 Loss 0.00035937309265136717
Index 450 Loss 0.013740291595458984
Index 500 Loss 0.1297918128967285
Index 550 Loss 0.0005522537231445313
Predicting..
Accuracy: 99.66666666666667 Max Accuracy 99.83333333333333
---Epoch 89---

Training...
Index 50 Loss 0.0009208393096923829
Index 100 Loss 0.00028606414794921873
Index 150 Loss 0.14298168182373047
Index 200 Loss 0.0003307723999023437
Index 250 Loss 6.959915161132813e-05
Index 300 Loss 0.0001576995849609375
Index 350 Loss 0.00010250091552734375
Index 400 Loss 0.0001356983184814453
Index 450 Loss 0.012747316360473633
Index 500 Loss 0.13955438613891602
Index 550 Loss 9.207725524902343e-05
Predicting..
Accuracy: 99.83333333333333 Max Accuracy 99.83333333333333
---Epoch 90---

Training.