In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from collections import Counter
import pandas as pd
import pickle as pkl
import random
import pdb
from tqdm import tqdm_notebook


In [70]:
random.seed(186)

BATCH_SIZE = 10
MAX_SENTENCE_LENGTH = 1000
EMBED_SIZE = 300
# save index 0 for unk and 1 for pad
PAD_IDX = 0
UNK_IDX = 1
PAD = '<pad>'
UNK = '<unk>'

label_to_id = {'contradiction' : 0, 'entailment' : 1, 'neutral' : 2}

def load_fasttext():
    fasttext_home = './'
    words_to_load = 50000

    loaded_embeddings = np.zeros((words_to_load + 2, EMBED_SIZE)) #+2 to account for pad and unk tokens
    words = {}
    idx2words = {}
    with open(fasttext_home + 'wiki-news-300d-1M.vec') as f:
        loaded_embeddings[PAD_IDX, :] = np.zeros((1, EMBED_SIZE))
        loaded_embeddings[UNK_IDX, :] = np.zeros((1, EMBED_SIZE))
        words[PAD] = PAD_IDX
        words[UNK] = UNK_IDX
        for i, line in enumerate(f):
            if i >= words_to_load: 
                break
            s = line.split()
            idx = i + 2 #+2 to account for PAD and UNK tokens
            loaded_embeddings[idx, :] = np.asarray(s[1:])
            words[s[0]] = idx 
            idx2words[idx] = s[0]
    return loaded_embeddings, words, idx2words

def load_snli_data():
    snli_train = pd.read_csv('./snli_train.tsv', names=['sentence1', 'sentence2', 'label'], skiprows=1, sep='\t|\n', engine='python')
    snli_val = pd.read_csv('./snli_val.tsv', names=['sentence1', 'sentence2', 'label'], skiprows=1, sep='\t|\n', engine='python')
    return snli_train[['sentence1', 'sentence2']], [label_to_id[x] for x in snli_train['label']], snli_val[['sentence1', 'sentence2']], [label_to_id[x] for x in snli_val['label']]


In [10]:
loaded_embeddings, words, idx2words = load_fasttext()

In [11]:
snli_train_sentences, snli_train_labels, snli_val_sentences, snli_val_labels = load_snli_data()
def map_sentence_to_idxs(sentence):
    return [words[tkn] if tkn in words else UNK_IDX for tkn in sentence]
snli_train_sentences_idxs = snli_train_sentences.applymap(lambda x: map_sentence_to_idxs(x))
snli_val_sentences_idxs = snli_val_sentences.applymap(lambda x: map_sentence_to_idxs(x))

In [38]:
class SNLIDataset(Dataset):    
    def __init__(self, data_list, target_list):
        """
        @param data_list: two sentence lists 
        @param target_list: list of targets 

        """
        self.x1 = data_list['sentence1']
        self.x2 = data_list['sentence2']
        self.y = target_list
        assert (len(self.x1) == len(self.x2) == len(self.y))

    def __len__(self):
        return len(self.y)
        
    def __getitem__(self, key):        
        sent1_idx = self.x1[key][:MAX_SENTENCE_LENGTH]
        sent2_idx = self.x2[key][:MAX_SENTENCE_LENGTH]
        label = self.y[key]
        return [sent1_idx, sent2_idx, len(sent1_idx), len(sent2_idx), label]

def SNLI_collate_function(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all 
    data have the same length
    """
    sentence1_list = []
    sentence2_list = []
    label_list = []
    length_list_1 = []
    length_list_2 = []
    # padding
    for datum in batch:
        length_list_1.append(datum[2])
        length_list_2.append(datum[3])
        label_list.append(datum[4])
        padded_vec_1 = np.pad(np.array(datum[0]), 
                                pad_width=((0,MAX_SENTENCE_LENGTH-datum[2])), 
                                mode="constant", constant_values=0)
        sentence1_list.append(padded_vec_1)
        padded_vec_2 = np.pad(np.array(datum[1]), 
                                pad_width=((0,MAX_SENTENCE_LENGTH-datum[3])), 
                                mode="constant", constant_values=0)
        sentence2_list.append(padded_vec_2)
    return [torch.from_numpy(np.array(sentence1_list)), torch.from_numpy(np.array(sentence2_list)), 
#             torch.cuda.LongTensor(length_list_1), torch.cuda.LongTensor(length_list_2), 
            torch.cuda.LongTensor(label_list)]
#             torch.LongTensor(label_list)]

In [39]:
snli_train_dataset = SNLIDataset(snli_train_sentences_idxs, snli_train_labels)
snli_train_loader = torch.utils.data.DataLoader(dataset=snli_train_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=SNLI_collate_function,
                                           shuffle=True)
snli_val_dataset = SNLIDataset(snli_val_sentences_idxs, snli_val_labels)
snli_val_loader = torch.utils.data.DataLoader(dataset=snli_val_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=SNLI_collate_function,
                                           shuffle=True)


In [136]:
class BidirectionalGRU(nn.Module):
    def __init__(self, hidden_size, num_layers):
        # BidirectionalGRU Accepts the following hyperparams:
        # hidden_size: Hidden Size of layer in the GRU
        # num_layers: number of layers in the GRU
        super(BidirectionalGRU, self).__init__()

        self.num_layers, self.hidden_size = num_layers, hidden_size
        self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(loaded_embeddings), freeze=False)
        self.gru = nn.GRU(EMBED_SIZE, hidden_size, num_layers, batch_first=True, bidirectional=True)

    def init_hidden(self, batch_size):
        hidden = torch.randn(self.num_layers*2, batch_size, self.hidden_size)
        return hidden

    def forward(self, x):
        batch_size, seq_len = x.size()
        self.hidden = self.init_hidden(batch_size)

        embed = self.embedding(x)
        #update embedding if token is unk
        m = (x == 1).type(torch.cuda.FloatTensor)
        m = m.unsqueeze(2).repeat(1, 1, EMBED_SIZE)
        embed = m * embed + (1-m) * embed.clone().detach()
        gru_out, hidden = self.gru(embed.cuda(), self.hidden.cuda())
        return gru_out, hidden[0, :, :] + hidden[1, :, :]


In [133]:
def test_gru_model(loader, gru_model):
    correct = 0
    total = 0
    gru_model.eval()
    for sent1_batch, sent2_batch, labels_batch in loader:
        _, hidden_1 = gru_model(sent1_batch.cuda())
        _, hidden_2 = gru_model(sent2_batch.cuda())
        encoded_output = torch.cat([hidden_1, hidden_2], dim=1).cuda()
        outputs = []
        for output in encoded_output:
            outputs.append(fully_connected(output.cuda()))
        outputs = torch.stack(outputs).cuda()
        predicted = F.softmax(outputs)
        predicted = outputs.max(1, keepdim=True)[1]
        total += labels_batch.size(0)
        correct += predicted.eq(labels_batch.view_as(predicted)).sum().item()
    return (100 * correct / total)


del gru_model
torch.cuda.empty_cache()
HIDDEN_SIZE = 200
gru_model = BidirectionalGRU(hidden_size=HIDDEN_SIZE, num_layers=1)
gru_model = gru_model.cuda()

num_epochs = 10 # number epoch to train
criterion = torch.nn.CrossEntropyLoss()
learning_rate = 3e-4
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gru_model.parameters()), lr=learning_rate)
fully_connected = nn.Sequential(nn.Linear(HIDDEN_SIZE*2, HIDDEN_SIZE), nn.ReLU(inplace=True), nn.Linear(HIDDEN_SIZE, 3)).cuda()


for epoch in range(1):
    for i, (sent1, sent2, labels) in enumerate(snli_train_loader):
        gru_model.train()
        optimizer.zero_grad()
        # Forward pass
        _, hidden_1 = gru_model(sent1.cuda())
        _, hidden_2 = gru_model(sent2.cuda())
        encoded_output = torch.cat([hidden_1, hidden_2], dim=1).cuda()
        outputs = []
        for output in encoded_output:
            outputs.append(fully_connected(output.cuda()))
        outputs = torch.stack(outputs).cuda()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if (i > 0 and i % 100 == 0):
            val_acc = test_gru_model(snli_val_loader, gru_model)
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                       epoch+1, num_epochs, i+1, len(snli_train_loader), val_acc))

tensor([ 0.1073,  0.0089,  0.0006,  0.0055, -0.0646, -0.0600,  0.0450, -0.0133,
        -0.0357,  0.0430, -0.0356, -0.0032,  0.0073, -0.0001,  0.0258, -0.0166,
         0.0075,  0.0686,  0.0392,  0.0753,  0.0115, -0.0087,  0.0421,  0.0265,
        -0.0601,  0.2420,  0.0199, -0.0739, -0.0031, -0.0263, -0.0062,  0.0168,
        -0.0357, -0.0249,  0.0190, -0.0184, -0.0537,  0.1420,  0.0600,  0.0226,
        -0.0038, -0.0675, -0.0036, -0.0080,  0.0570,  0.0208,  0.0223, -0.0256,
        -0.0153,  0.0022, -0.0482,  0.0131, -0.6016, -0.0088,  0.0106,  0.0229,
         0.0336,  0.0071,  0.0887,  0.0237, -0.0290, -0.0405, -0.0125,  0.0147,
         0.0475,  0.0647,  0.0474,  0.0199,  0.0408,  0.0322,  0.0036,  0.0350,
        -0.0723, -0.0305,  0.0184, -0.0026,  0.0240, -0.0160, -0.0308,  0.0434,
         0.0147, -0.0457, -0.0267, -0.1703, -0.0099,  0.0417,  0.0235, -0.0260,
        -0.1519, -0.0116, -0.0306, -0.0413,  0.0330,  0.0723,  0.0365, -0.0001,
         0.0042,  0.0346,  0.0277, -0.03

KeyboardInterrupt: 

In [68]:
print(o)

Epoch: [1/10], Step: [101/1563], Validation Acc: 35.6
Epoch: [1/10], Step: [201/1563], Validation Acc: 35.8
Epoch: [1/10], Step: [301/1563], Validation Acc: 37.3
Epoch: [1/10], Step: [401/1563], Validation Acc: 39.2
Epoch: [1/10], Step: [501/1563], Validation Acc: 37.6
Epoch: [1/10], Step: [601/1563], Validation Acc: 41.3
Epoch: [1/10], Step: [701/1563], Validation Acc: 41.2
Epoch: [1/10], Step: [801/1563], Validation Acc: 41.8
Epoch: [1/10], Step: [901/1563], Validation Acc: 42.3
Epoch: [1/10], Step: [1001/1563], Validation Acc: 44.6
Epoch: [1/10], Step: [1101/1563], Validation Acc: 43.6
Epoch: [1/10], Step: [1201/1563], Validation Acc: 45.5
Epoch: [1/10], Step: [1301/1563], Validation Acc: 45.6
Epoch: [1/10], Step: [1401/1563], Validation Acc: 43.5
Epoch: [1/10], Step: [1501/1563], Validation Acc: 44.7
Epoch: [2/10], Step: [101/1563], Validation Acc: 48.0
Epoch: [2/10], Step: [201/1563], Validation Acc: 46.7
Epoch: [2/10], Step: [301/1563], Validation Acc: 45.8
Epoch: [2/10], Step: [