In [1]:
from torchtext import data, datasets
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.init as init
import re
import random
import numpy as np
from recurrent_BatchNorm import recurrent_BatchNorm
from utils import *


In [2]:
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
PAD_TOKEN = 0

In [3]:
class RTE(nn.Module):
    def __init__(self, input_size, EMBEDDING_DIM, HIDDEN_DIM, WBW_ATTN):
        super(RTE, self).__init__()
        self.n_embed = EMBEDDING_DIM
        self.n_dim = HIDDEN_DIM if HIDDEN_DIM % 2 == 0 else HIDDEN_DIM - 1
        self.n_out = 3
        self.embedding = nn.Embedding(input_size, self.n_embed).type(dtype)
        self.WBW_ATTN = WBW_ATTN
    
        self.p_gru = nn.GRU(self.n_embed, self.n_dim, bidirectional=False).type(dtype)
        self.h_gru = nn.GRU(self.n_embed, self.n_dim, bidirectional=False).type(dtype)
        self.out = nn.Linear(self.n_dim, self.n_out).type(dtype)

        
        # Attention Parameters
        self.W_y = nn.Parameter(torch.randn(self.n_dim, self.n_dim).cuda()) if use_cuda else nn.Parameter(torch.randn(self.n_dim, self.n_dim))  # n_dim x n_dim
        self.register_parameter('W_y', self.W_y)
        self.W_h = nn.Parameter(torch.randn(self.n_dim, self.n_dim).cuda()) if use_cuda else nn.Parameter(torch.randn(self.n_dim, self.n_dim))  # n_dim x n_dim
        self.register_parameter('W_h', self.W_h)
        self.W_r = nn.Parameter(torch.randn(self.n_dim, self.n_dim).cuda()) if use_cuda else nn.Parameter(torch.randn(self.n_dim, self.n_dim))  # n_dim x n_dim
        self.register_parameter('W_r', self.W_r)
        self.W_alpha = nn.Parameter(torch.randn(self.n_dim, 1).cuda()) if use_cuda else nn.Parameter(torch.randn(self.n_dim, 1))  # n_dim x 1
        self.register_parameter('W_alpha', self.W_alpha)
        if WBW_ATTN:
            # Since the word by word attention layer is a simple rnn, it suffers from the gradient exploding problem
            # A way to circumvent that is having orthonormal initialization of the weight matrix
            _W_t = np.random.randn(self.n_dim, self.n_dim)
            _W_t_ortho, _ = np.linalg.qr(_W_t)
            self.W_t = nn.Parameter(torch.Tensor(_W_t_ortho).cuda()) if use_cuda else nn.Parameter(torch.Tensor(_W_t_ortho))  # n_dim x n_dim
            self.register_parameter('W_t', self.W_t)
            self.batch_norm_h_r = recurrent_BatchNorm(self.n_dim, 30).type(dtype) # 'MAX_LEN' = 30
            self.batch_norm_r_r = recurrent_BatchNorm(self.n_dim, 30).type(dtype)

        # Final combination Parameters
        self.W_x = nn.Parameter(torch.randn(self.n_dim, self.n_dim).cuda()) if use_cuda else nn.Parameter(torch.randn(self.n_dim, self.n_dim))  # n_dim x n_dim
        self.register_parameter('W_x', self.W_x)
        self.W_p = nn.Parameter(torch.randn(self.n_dim, self.n_dim).cuda()) if use_cuda else nn.Parameter(torch.randn(self.n_dim, self.n_dim))  # n_dim x n_dim
        self.register_parameter('W_p', self.W_p)

    def init_hidden(self, batch_size):
        hidden_p = Variable(torch.zeros(1, batch_size, self.n_dim).type(dtype))
        return hidden_p

    def attn_rnn_init_hidden(self, batch_size):
        r_0 = Variable(torch.zeros(batch_size, self.n_dim).type(dtype))
        return r_0

    def mask_mult(self, o_t, o_tm1, mask_t):
        '''
            o_t : batch x n
            o_tm1 : batch x n
            mask_t : batch x 1
        '''
        # return (mask_t.expand(*o_t.size()) * o_t) + ((1. - mask_t.expand(*o_t.size())) * (o_tm1))
        return (o_t * mask_t) + (o_tm1 * (1. - mask_t))

    def _gru_forward(self, gru, encoded_s, mask_s, h_0):
        '''
        inputs :
            gru : The GRU unit for which the forward pass is to be computed
            encoded_s : T x batch x n_embed
            mask_s : T x batch
            h_0 : 1 x batch x n_dim
        outputs :
            o_s : T x batch x n_dim #outut
            h_n : 1 x batch x n_dim #hidden
        '''
        seq_len = encoded_s.size(0)
        batch_size = encoded_s.size(1)
        o_s = Variable(torch.zeros(seq_len, batch_size, self.n_dim).type(dtype))
        h_tm1 = h_0.squeeze(0)  # batch x n_dim
        o_tm1 = None

        for ix, (x_t, mask_t) in enumerate(zip(encoded_s, mask_s)):
            '''
                x_t : batch x n_embed
                mask_t : batch,
            '''
            o_t, h_t = gru(x_t.unsqueeze(0), h_tm1.unsqueeze(0))  # o_t : 1 x batch x n_dim
                                                                  # h_t : 1 x batch x n_dim
            mask_t = mask_t.unsqueeze(1)  # batch x 1
            h_t = self.mask_mult(h_t[0], h_tm1, mask_t)  # batch x n_dim

            if o_tm1 is not None:
                o_t = self.mask_mult(o_t[0], o_tm1, mask_t)
            o_tm1 = o_t[0] if o_tm1 is None else o_t
            h_tm1 = h_t
            o_s[ix] = o_t

        return o_s, h_t.unsqueeze(0)

    def _attention_forward(self, Y, mask_Y, h, r_tm1=None, index=None):
        '''
        Computes the Attention Weights over Y using h (and r_tm1 if given)
        Returns an attention weighted representation of Y, and the alphas
        inputs:
            Y : T x batch x n_dim
            mask_Y : T x batch
            h : batch x n_dim
            r_tm1 : batch x n_dim
            index : int : The timestep
        params:
            W_y : n_dim x n_dim
            W_h : n_dim x n_dim
            W_r : n_dim x n_dim
            W_alpha : n_dim x 1
        outputs :
            r = batch x n_dim
            alpha : batch x T
        '''
        Y = Y.transpose(1, 0)  # batch x T x n_dim
        mask_Y = mask_Y.transpose(1, 0)  # batch x T

        Wy = torch.bmm(Y, self.W_y.unsqueeze(0).expand(Y.size(0), *self.W_y.size()))  # batch x T x n_dim
        Wh = torch.mm(h, self.W_h)  # batch x n_dim
        if r_tm1 is not None:
            W_r_tm1 = self.batch_norm_r_r(torch.mm(r_tm1, self.W_r), index) if hasattr(self, 'batch_norm_r_r') else torch.mm(r_tm1, self.W_r)
            Wh = self.batch_norm_h_r(Wh, index) if hasattr(self, 'batch_norm_h_r') else Wh
            Wh += W_r_tm1
        M = torch.tanh(Wy + Wh.unsqueeze(1).expand(Wh.size(0), Y.size(1), Wh.size(1)))  # batch x T x n_dim
        alpha = torch.bmm(M, self.W_alpha.unsqueeze(0).expand(Y.size(0), *self.W_alpha.size())).squeeze(-1)  # batch x T
        alpha = alpha + (-1000.0 * (1. - mask_Y))  # To ensure probability mass doesn't fall on non tokens
        alpha = F.softmax(alpha)
        if r_tm1 is not None:
            r = torch.bmm(alpha.unsqueeze(1), Y).squeeze(1) + F.tanh(torch.mm(r_tm1, self.W_t))  # batch x n_dim
        else:
            r = torch.bmm(alpha.unsqueeze(1), Y).squeeze(1)  # batch x n_dim
        return r, alpha

    def _combine_last(self, r, h_t):
        '''
        inputs:
            r : batch x n_dim
            h_t : batch x n_dim (this is the output from the gru unit)
        params :
            W_x : n_dim x n_dim
            W_p : n_dim x n_dim
        out :
            h_star : batch x n_dim
        '''

        W_p_r = torch.mm(r, self.W_p)  # batch x n_dim
        W_x_h = torch.mm(h_t, self.W_x)  # batch x n_dim
        h_star = F.tanh(W_p_r + W_x_h)  # batch x n_dim

        return h_star

    def _attn_rnn_forward(self, o_h, mask_h, r_0, o_p, mask_p):
        '''
        inputs:
            o_h : T x batch x n_dim : The hypothesis
            mask_h : T x batch
            r_0 : batch x n_dim
            o_p : T x batch x n_dim : The premise. Will attend on it at every step
            mask_p : T x batch : the mask for the premise
        params:
            W_t : n_dim x n_dim
        outputs:
            r : batch x n_dim : the last state of the rnn
            alpha_vec : T x batch x T the attn vec at every step
        '''
        seq_len_h = o_h.size(0)
        batch_size = o_h.size(1)
        seq_len_p = o_p.size(0)
        alpha_vec = Variable(torch.zeros(seq_len_h, batch_size, seq_len_p).type(dtype))
        r_tm1 = r_0
        for ix, (h_t, mask_t) in enumerate(zip(o_h, mask_h)):
            '''
                h_t : batch x n_dim
                mask_t : batch,
            '''
            r_t, alpha = self._attention_forward(o_p, mask_p, h_t, r_tm1, ix)   # r_t : batch x n_dim
                                                                                # alpha : batch x T
            alpha_vec[ix] = alpha
            mask_t = mask_t.unsqueeze(1)  # batch x 1
            r_t = self.mask_mult(r_t, r_tm1, mask_t)
            r_tm1 = r_t

        return r_t, alpha_vec

    def forward(self, premise, hypothesis, training=False):
        '''
        inputs:
            premise : batch x T
            hypothesis : batch x T
        outputs :
            pred : batch x num_classes
        '''
        self.train(training)
        batch_size = premise.size(0)

        mask_p = torch.ne(premise, 0).type(dtype)
        mask_h = torch.ne(hypothesis, 0).type(dtype)

        encoded_p = self.embedding(premise)  # batch x T x n_embed
        encoded_p = F.dropout(encoded_p, p=0.1, training=training)

        encoded_h = self.embedding(hypothesis)  # batch x T x n_embed
        encoded_h = F.dropout(encoded_h, p=0.1, training=training)

        encoded_p = encoded_p.transpose(1, 0)  # T x batch x n_embed
        encoded_h = encoded_h.transpose(1, 0)  # T x batch x n_embed

        mask_p = mask_p.transpose(1, 0)  # T x batch
        mask_h = mask_h.transpose(1, 0)  # T x batch

        h_0 = self.init_hidden(batch_size)  # 1 x batch x n_dim
        o_p, h_n = self._gru_forward(self.p_gru, encoded_p, mask_p, h_0)  # o_p : T x batch x n_dim
                                                                          # h_n : 1 x batch x n_dim

        o_h, h_n = self._gru_forward(self.h_gru, encoded_h, mask_h, h_n)  # o_h : T x batch x n_dim
                                                                          # h_n : 1 x batch x n_dim

        if self.WBW_ATTN:
            r_0 = self.attn_rnn_init_hidden(batch_size)  # batch x n_dim
            r, alpha_vec = self._attn_rnn_forward(o_h, mask_h, r_0, o_p, mask_p)  # r : batch x n_dim
                                                                                  # alpha_vec : T x batch x T         
        else:
            r, alpha = self._attention_forward(o_p, mask_p, o_h[-1])  # r : batch x n_dim
                                                                      # alpha : batch x T

        h_star = self._combine_last(r, o_h[-1])  # batch x n_dim
        h_star = self.out(h_star)  # batch x num_classes
        '''
        if self.options['LAST_NON_LINEAR']:
            h_star = F.leaky_relu(h_star)  # Non linear projection
        '''
        pred = F.log_softmax(h_star)
        
        return pred

    def _get_numpy_array_from_variable(self, variable):
        '''
        Converts a torch autograd variable to its corresponding numpy array
        '''
        if use_cuda:
            return variable.cpu().data.numpy()
        else:
            return variable.data.numpy()

In [4]:
inputs = datasets.snli.ParsedTextField(lower=True)
answers = data.Field(sequential=False)

train, dev, test = datasets.SNLI.splits(inputs, answers)

# get input embeddings
inputs.build_vocab(train, vectors='glove.6B.300d')
answers.build_vocab(train)

# global params
global input_size, num_train_steps
vocab_size = len(inputs.vocab)
input_size = vocab_size
num_train_steps = 50000000

train_iter, dev_iter, test_iter = data.BucketIterator.splits((train, dev, test), batch_size=32, device=-1)

downloading snli_1.0.zip
extracting


.vector_cache/glove.6B.zip: 862MB [09:35, 1.50MB/s]                               
100%|██████████| 400000/400000 [01:24<00:00, 4727.75it/s]


In [5]:
model = RTE(input_size, EMBEDDING_DIM = 100, HIDDEN_DIM = 300, WBW_ATTN = True )

In [6]:
def training_loop(model, loss, optimizer, train_iter, dev_iter):
    step = 0
    for i in range(num_train_steps):
        model.train()
        for batch in train_iter:
            premise = batch.premise.transpose(0, 1)
            hypothesis = batch.hypothesis.transpose(0, 1)
            labels = batch.label - 1
            model.zero_grad()
        
            output = model(premise, hypothesis)
            
            lossy = loss(output, labels)
            #print(lossy)
            lossy.backward()
            optimizer.step()

            if step % 10 == 0:
                print( "Step %i; Loss %f; Dev acc %f" %(step, lossy.data[0], evaluate(model, dev_iter)))

            step += 1

In [7]:
def evaluate(model, data_iter):
    model.eval()
    correct = 0
    total = 0
    for batch in data_iter:
        premise = batch.premise.transpose(0,1)
        hypothesis = batch.hypothesis.transpose(0,1)
        labels = (batch.label-1).data
        output = model(premise, hypothesis)
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
    model.train()
    return correct / float(total)

In [8]:
# Loss
loss = nn.CrossEntropyLoss()

# Optimizer
para2 = model.parameters()
optimizer = torch.optim.Adagrad(para2, lr=0.001, weight_decay=5e-5)

# Train the model
best_dev_acc = training_loop(model, loss, optimizer, train_iter, dev_iter)
print(best_dev_acc)



Step 0; Loss 1.088470; Dev acc 0.336212


KeyboardInterrupt: 