In [1]:
from torchtext import data, datasets
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import re
import random

In [66]:
inputs = datasets.snli.ParsedTextField(lower=True)
answers = data.Field(sequential=False)

train, dev, test = datasets.SNLI.splits(inputs, answers)

inputs.build_vocab(train, dev, test)
answers.build_vocab(train)

train_iter, dev_iter, test_iter = data.BucketIterator.splits(
            (train, dev, test), batch_size=32, device=-1)

In [316]:
# A Multi-Layer Perceptron (MLP)
class DecomposableAttention(nn.Module): # inheriting from nn.Module!
    
    def __init__(self, input_size, embedding_dim, hidden_dim, num_labels):
        super(DecomposableAttention, self).__init__()
        
        # Define the parameters that you will need.  
        # You need an embedding matrix, parameters for affine mappings and ReLus
        # Pay attention to dimensions!
        self.embedding_dim = embedding_dim 
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(p=0.2)
        self.embed = nn.Embedding(input_size, embedding_dim, padding_idx=0)
        self.linear_1 = nn.Linear(embedding_dim, hidden_dim)
        self.linear_f = nn.Linear(hidden_dim, hidden_dim)
        self.linear_g = nn.Linear(2 * hidden_dim, hidden_dim)
        self.linear_final = nn.Linear(hidden_dim, num_labels)
        
    def forward(self, prem, hypo):
        # Prepare  
        prem_emb = self.embed(prem)
        hypo_emb = self.embed(hypo)
        prem_emb = self.linear_1(prem_emb)
        hypo_emb = self.linear_1(hypo_emb)
        
        # Attend
        #print('Premise size', prem_emb.size())
        #print('Hypo size', hypo_emb.size())
        prem_ff = self.dropout(prem_emb)
        prem_ff = F.relu(self.linear_f(prem_ff))
        prem_ff = self.dropout(prem_ff)
        prem_ff = F.relu(self.linear_f(prem_ff))
        
        hypo_ff = self.dropout(hypo_emb)
        hypo_ff = F.relu(self.linear_f(hypo_ff))
        hypo_ff = self.dropout(hypo_ff)
        hypo_ff = F.relu(self.linear_f(hypo_ff))

        e_ij = torch.bmm(prem_ff, torch.transpose(hypo_ff, 1, 2))
        #print('eij', e_ij.size())
        beta_ij = F.softmax(e_ij)
        #print('beta_ij', beta_ij.size())
        beta_i = torch.bmm(beta_ij, hypo_emb)

        e_ji = torch.transpose(e_ij, 1, 2)
        alpha_ji = F.softmax(e_ji)
        alpha_j = torch.bmm(alpha_ji, prem_emb)
        
        aligned_1 = torch.cat((prem_emb, beta_i), 2)       
        aligned_2 = torch.cat((hypo_emb, alpha_j), 2)
        #print(aligned_2.size())
        
        # Compare
        aligned_1 = self.dropout(aligned_1)
        aligned_1 = F.relu(self.linear_g(aligned_1))
        aligned_1 = self.dropout(aligned_1)
        aligned_1 = F.relu(self.linear_f(aligned_1))
        
        aligned_2 = self.dropout(aligned_2)
        #print(aligned_2.size())
        aligned_2 = F.relu(self.linear_g(aligned_2))
        aligned_2 = self.dropout(aligned_2)
        aligned_2 = F.relu(self.linear_f(aligned_2))
        #print(aligned_2.size())
        
        # Aggregate
        v_1 = torch.sum(aligned_1, 1)
        v_2 = torch.sum(aligned_2, 1)
        #print(v_1.size())
        v_concat = torch.cat((v_1, v_2), 1)
        
        v_concat = self.dropout(v_concat)
        v_concat = F.relu(self.linear_g(v_concat))
        v_concat = self.dropout(v_concat)
        v_concat = F.relu(self.linear_f(v_concat))
        #print(v_concat.size())
        
        out = F.log_softmax(self.linear_final(v_concat))
        
        return out

In [317]:
def training_loop(model, loss, optimizer, train_iter, dev_iter):
    step = 0
    for i in range(num_train_steps):
        model.train()
        for batch in train_iter:
            premise = batch.premise.transpose(0,1)
            hypothesis = batch.hypothesis.transpose(0,1)
            labels = batch.label-1
            model.zero_grad()
            output = model(premise, hypothesis)
            lossy = loss(output, labels)
            #print(lossy)
            lossy.backward()
            optimizer.step()

            if step % 10 == 0:
                print( "Step %i; Loss %f; Dev acc %f" 
                %(step, lossy.data[0], evaluate(model, dev_iter)))

            step += 1

In [318]:
def evaluate(model, data_iter):
    model.eval()
    correct = 0
    total = 0
    for batch in data_iter:
        premise = batch.premise.transpose(0,1)
        hypothesis = batch.hypothesis.transpose(0,1)
        labels = (batch.label-1).data
        output = model(premise, hypothesis)
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
    model.train()
    return correct / float(total)

In [319]:
vocab_size = len(inputs.vocab)
input_size = vocab_size
num_labels = 3
hidden_dim = 50
embedding_dim = 300
batch_size = 32
learning_rate = 0.004
num_train_steps = 1000

In [320]:
model = DecomposableAttention(input_size, embedding_dim, hidden_dim, num_labels)
    
# Loss and Optimizer
loss = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
training_loop(model, loss, optimizer, train_iter, dev_iter)

  """


Step 0; Loss 1.256658; Dev acc 0.336212
Step 10; Loss 1.111215; Dev acc 0.337736
Step 20; Loss 1.103053; Dev acc 0.333164
Step 30; Loss 1.096885; Dev acc 0.337939
Step 40; Loss 1.100307; Dev acc 0.352571
Step 50; Loss 1.118882; Dev acc 0.390165
Step 60; Loss 1.121069; Dev acc 0.404389
Step 70; Loss 1.052480; Dev acc 0.425422
Step 80; Loss 1.058879; Dev acc 0.437309
Step 90; Loss 1.093908; Dev acc 0.420037
Step 100; Loss 0.968255; Dev acc 0.416988
Step 110; Loss 0.916794; Dev acc 0.425726
Step 120; Loss 1.050667; Dev acc 0.414245
Step 130; Loss 1.171308; Dev acc 0.438021
Step 140; Loss 1.102476; Dev acc 0.440154
Step 150; Loss 1.041205; Dev acc 0.379191
Step 160; Loss 1.112971; Dev acc 0.365170
Step 170; Loss 1.091776; Dev acc 0.431721
Step 180; Loss 1.102773; Dev acc 0.447978
Step 190; Loss 1.056771; Dev acc 0.452550
Step 200; Loss 1.068482; Dev acc 0.439342
Step 210; Loss 1.140976; Dev acc 0.420341
Step 220; Loss 1.021827; Dev acc 0.442085
Step 230; Loss 1.037763; Dev acc 0.433550
Ste

KeyboardInterrupt: 