In [49]:
import time
import math
import torch

from torch import nn
import torch.optim as O
import torch.nn.functional as F
from torchtext import data, vocab, datasets

In [107]:
class Parameters():
    def __init__(self):
        # gpu
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # word vectors
        self.embed_size = 50
        self.word_vectors = True
        self.glove_path = '/home/ndg/users/jkurre/mnli/utils/embeddings/glove.6B.50d.txt'
        # model configs
        self.hidden_size = 1024
        self.batch_size = 32
        self.input_size = 76790
        self.output_size = 4
        self.n_layers = 2
        self.n_cells = 4
        self.dropout = 0.5
        # training
        self.epochs = 5
        self.learning_rate = 0.001

params = Parameters()

In [108]:
inputs = data.Field(
    lower=True,
    tokenize='spacy'
)

answers = data.Field(
    sequential=False
)

train, dev, test = datasets.MultiNLI.splits(
    text_field=inputs,
    label_field=answers
    )

inputs.build_vocab(train, dev, test)

if params.word_vectors:
    inputs.vocab.load_vectors(vocab.Vectors(params.glove_path, cache="."))

answers.build_vocab(train)

In [109]:
params.n_embed = len(inputs.vocab)
params.d_out = len(answers.vocab)

print(f"Unique tokens in inputs vocabulary: {params.n_embed}")
print(f"Unique tokens in answers vocabulary: {params.d_out}")

Unique tokens in inputs vocabulary: 76790
Unique tokens in answers vocabulary: 4


In [110]:
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train, dev, test), batch_size=params.batch_size, device=params.device)

In [118]:
class MultiNLIModel(nn.Module):
    def __init__(self, input_size, output_size, embed_size, 
                 hidden_size, dropout, n_layers, n_cells):
        
        super(MultiNLIModel, self).__init__()
        
        self.hidden_size = hidden_size
        self.n_cells = n_cells
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.embed = nn.Embedding(input_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size,
                            num_layers=n_layers, dropout=dropout, 
                            bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, output_size, bias=False)
        self.attention = Attention(hidden_size)
    
    def encode(self, pair_embed, batch_size):
        state_shape = self.n_cells, batch_size, self.hidden_size
        h0 = c0 = pair_embed.new_zeros(state_shape)
        outputs, (ht, ct) = self.lstm(pair_embed, (h0, c0))
        return ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
        
    def forward(self, pair):
        # get batch size
        batch_size = pair.batch_size
        
        # seq_length, batch_size, embed_size
        prem_embed = self.embed(pair.premise)
        hypo_embed = self.embed(pair.hypothesis)
        
        # fix word embeddings
        prem_embed.detach()
        hypo_embed.detach()
        
        # seq_length * 2, batch_size, embed_size
        pair_embed = torch.cat((prem_embed, hypo_embed),0)
        pair_embed = self.encode(pair_embed, batch_size)

        # seq_length * 2, batch_size, output_size
        pair_embed = self.relu(self.fc(pair_embed))
        
        return pair_embed

In [119]:
model = MultiNLIModel(params.input_size, params.output_size, params.embed_size,
                      params.hidden_size, params.dropout, params.n_layers, params.n_cells).to(params.device)

In [121]:
criterion = nn.CrossEntropyLoss()
opt = O.Adam(model.parameters(), lr=params.learning_rate)

for epoch in range(params.epochs):
    train_iterator.init_epoch()
    n_correct, n_total = 0, 0
    for batch_idx, batch in enumerate(train_iterator):
        
        # switch model to training mode, clear gradient accumulators
        model.train();
        opt.zero_grad()

        iterations += 1

        # forward pass
        answer = model(batch)
        break
    break

torch.Size([32, 2048])


In [124]:
answer

tensor([[5.0365e-03, 1.4996e-03, 0.0000e+00, 5.3676e-03],
        [1.5427e-02, 0.0000e+00, 1.9569e-02, 4.4181e-03],
        [1.2586e-02, 0.0000e+00, 0.0000e+00, 4.6843e-03],
        [7.9193e-03, 0.0000e+00, 0.0000e+00, 1.3234e-02],
        [7.1614e-03, 1.2290e-04, 0.0000e+00, 3.2336e-03],
        [1.5495e-02, 0.0000e+00, 0.0000e+00, 5.6884e-03],
        [7.0269e-03, 1.3025e-03, 1.4302e-02, 1.3160e-03],
        [2.5676e-02, 0.0000e+00, 0.0000e+00, 1.3555e-02],
        [6.0713e-03, 0.0000e+00, 0.0000e+00, 1.0785e-02],
        [1.4159e-02, 7.0453e-04, 3.8457e-03, 1.7475e-02],
        [1.4288e-02, 0.0000e+00, 0.0000e+00, 4.1077e-03],
        [1.1782e-02, 0.0000e+00, 1.5258e-03, 4.7654e-03],
        [1.4330e-02, 0.0000e+00, 6.0556e-03, 1.2336e-02],
        [2.0785e-02, 0.0000e+00, 0.0000e+00, 1.4587e-02],
        [1.2617e-02, 0.0000e+00, 0.0000e+00, 1.6060e-02],
        [2.1105e-02, 6.3266e-03, 0.0000e+00, 0.0000e+00],
        [1.9460e-02, 8.2628e-03, 2.9484e-03, 1.8655e-03],
        [1.531