In [1]:
from torchtext import data, datasets
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import re
import random
import torch.optim as optim

In [9]:
inputs = datasets.snli.ParsedTextField(lower=True)
answers = data.Field(sequential=False)

train, dev, test = datasets.SNLI.splits(inputs,answers)

inputs.build_vocab(train, dev, test,vectors='glove.6B.300d')

answers.build_vocab(train)

train_iter, dev_iter, test_iter = data.BucketIterator.splits(
            (train, dev, test), batch_size=64, device=-1)

.vector_cache/glove.6B.zip: 862MB [04:16, 3.36MB/s]                               
100%|██████████| 400000/400000 [01:05<00:00, 6144.97it/s]


In [10]:
len(inputs.vocab)

57324

In [11]:
batch = next(iter(train_iter))

In [12]:
example = train[0]

In [13]:
train[3].hypothesis

['they', 'are', 'smiling', 'at', 'their', 'parents']

In [14]:
example.hypothesis

['a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition.']

In [15]:
batch = next(iter(train_iter))
print(batch.premise)
print(batch.hypothesis)
print(batch.label)

Variable containing:
     1      1      1  ...       1      1      1
   442   1530   1480  ...     978   5153    419
     2     34     31  ...       3      2     19
        ...            ⋱           ...         
    11    111      4  ...      93   2217      5
     6     14     12  ...     694    342      6
     2     13      2  ...      13      2      2
[torch.LongTensor of size 17x64]

Variable containing:

Columns 0 to 10 
     1      1      1      1      1      1      1      1      1      1      1
     1      1      1      1      1      1      1      1      1      1      1
   327    658    388   1801   4202   2417  13900     62    230   2357    690
     2     91      3      2     41      2   7238      3    367      2      2
    22   1824     16     22      9    129     22   2441    819      7     16
     5      9    455     14   1715      5      5      5      4      5      5
    29     42     12      9     13    536     12     12     76     39     46
     2     13      3     45    

In [9]:
# A Multi-Layer Perceptron (MLP)
class MLPClassifier(nn.Module): # inheriting from nn.Module!
    
    def __init__(self, input_size, embedding_dim, hidden_dim, num_labels):
        super(MLPClassifier, self).__init__()
        
        # Define the parameters that you will need.  
        # You need an embedding matrix, parameters for affine mappings and ReLus
        # Pay attention to dimensions!
        
        self.embed = nn.Embedding(input_size, embedding_dim, padding_idx=0)
        self.dropout = nn.Dropout(p=0.5)
            
        self.linear_1 = nn.Linear(2*embedding_dim, hidden_dim) 
        self.linear_2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear_3 = nn.Linear(hidden_dim, num_labels)
        self.init_weights()
        
    def forward(self, prem, hypo):
        # Pass the input through your layers in order
        emb_prem = self.embed(prem).mean(1)
        emb_hypo = self.embed(hypo).mean(1)
        emb_concat = torch.cat([emb_prem, emb_hypo],1)
        out = self.dropout(emb_concat)
        out = F.relu(self.linear_1(out))
        out = F.relu(self.linear_2(out))
        out = self.dropout(self.linear_3(out))
        return F.log_softmax(out)

    def init_weights(self):
        initrange = 0.1
        lin_layers = [self.linear_1, self.linear_2]
        em_layer = [self.embed]
     
        for layer in lin_layers+em_layer:
            layer.weight.data.uniform_(-initrange, initrange)
            if layer in lin_layers:
                layer.bias.data.fill_(0)

In [25]:
'''
baseline model for Stanford natural language inference
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class encoder(nn.Module):

    def __init__(self, num_embeddings, embedding_size, hidden_size, para_init):
        super(encoder, self).__init__()

        self.num_embeddings = num_embeddings
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.para_init = para_init

        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_size)
        self.input_linear = nn.Linear(
            self.embedding_size, self.hidden_size, bias=False)  # linear transformation
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, self.para_init)
                # m.bias.data.uniform_(-0.01, 0.01)

    def forward(self, sent1, sent2):
        '''
               sent: batch_size x length (Long tensor)
        '''
        batch_size = sent1.size(0)
        sent1 = self.embedding(sent1)
        sent2 = self.embedding(sent2)

        sent1 = sent1.view(-1, self.embedding_size)
        sent2 = sent2.view(-1, self.embedding_size)

        sent1_linear = self.input_linear(sent1).view(
            batch_size, -1, self.hidden_size)
        sent2_linear = self.input_linear(sent2).view(
            batch_size, -1, self.hidden_size)

        return sent1_linear, sent2_linear

class atten(nn.Module):
    '''
        intra sentence attention
    '''

    def __init__(self, hidden_size, label_size, para_init):
        super(atten, self).__init__()

        self.hidden_size = hidden_size
        self.label_size = label_size
        self.para_init = para_init

        self.mlp_f = self._mlp_layers(self.hidden_size, self.hidden_size)
        self.mlp_g = self._mlp_layers(2 * self.hidden_size, self.hidden_size)
        self.mlp_h = self._mlp_layers(2 * self.hidden_size, self.hidden_size)

        self.final_linear = nn.Linear(
            self.hidden_size, self.label_size, bias=True)

        self.log_prob = nn.LogSoftmax()

        '''initialize parameters'''
        for m in self.modules():
            # print m
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(0, self.para_init)
                m.bias.data.normal_(0, self.para_init)

    def _mlp_layers(self, input_dim, output_dim):
        mlp_layers = []
        mlp_layers.append(nn.Dropout(p=0.2))
        mlp_layers.append(nn.Linear(
            input_dim, output_dim, bias=True))
        mlp_layers.append(nn.ReLU())
        mlp_layers.append(nn.Dropout(p=0.2))
        mlp_layers.append(nn.Linear(
            output_dim, output_dim, bias=True))
        mlp_layers.append(nn.ReLU())        
        return nn.Sequential(*mlp_layers)   # * used to unpack list

    def forward(self, sent1_linear, sent2_linear):
        '''
            sent_linear: batch_size x length x hidden_size
        '''
        len1 = sent1_linear.size(1)
        len2 = sent2_linear.size(1)

        '''attend'''

        f1 = self.mlp_f(sent1_linear.view(-1, self.hidden_size))
        f2 = self.mlp_f(sent2_linear.view(-1, self.hidden_size))

        f1 = f1.view(-1, len1, self.hidden_size)
        # batch_size x len1 x hidden_size
        f2 = f2.view(-1, len2, self.hidden_size)
        # batch_size x len2 x hidden_size

        score1 = torch.bmm(f1, torch.transpose(f2, 1, 2))
        # e_{ij} batch_size x len1 x len2
        prob1 = F.softmax(score1.view(-1, len2)).view(-1, len1, len2)
        # batch_size x len1 x len2

        score2 = torch.transpose(score1.contiguous(), 1, 2)
        score2 = score2.contiguous()
        # e_{ji} batch_size x len2 x len1
        prob2 = F.softmax(score2.view(-1, len1)).view(-1, len2, len1)
        # batch_size x len2 x len1

        sent1_combine = torch.cat(
            (sent1_linear, torch.bmm(prob1, sent2_linear)), 2)
        # batch_size x len1 x (hidden_size x 2)
        sent2_combine = torch.cat(
            (sent2_linear, torch.bmm(prob2, sent1_linear)), 2)
        # batch_size x len2 x (hidden_size x 2)

        '''sum'''
        g1 = self.mlp_g(sent1_combine.view(-1, 2 * self.hidden_size))
        g2 = self.mlp_g(sent2_combine.view(-1, 2 * self.hidden_size))
        g1 = g1.view(-1, len1, self.hidden_size)
        # batch_size x len1 x hidden_size
        g2 = g2.view(-1, len2, self.hidden_size)
        # batch_size x len2 x hidden_size

        sent1_output = torch.sum(g1, 1)  # batch_size x 1 x hidden_size
        sent1_output = torch.squeeze(sent1_output, 1)
        sent2_output = torch.sum(g2, 1)  # batch_size x 1 x hidden_size
        sent2_output = torch.squeeze(sent2_output, 1)

        input_combine = torch.cat((sent1_output, sent2_output), 1)
        # batch_size x (2 * hidden_size)
        h = self.mlp_h(input_combine)
        # batch_size * hidden_size

        # if sample_id == 15:
        #     print '-2 layer'
        #     print h.data[:, 100:150]

        h = self.final_linear(h)

        # print 'final layer'
        # print h.data

        log_prob = self.log_prob(h)

        return log_prob

In [26]:
def training_loop(model, loss, optimizer, train_iter, dev_iter):
    step = 0
    encoder_model = model[0]
    atten_model = model[1]
    input_optimizer = optimizer[0]
    inter_atten_optimizer = optimizer[1]
    for i in range(num_train_steps):
        encoder_model.train()
        atten_model.train()
        for batch in train_iter:
            premise = batch.premise.transpose(0,1)
            hypothesis = batch.hypothesis.transpose(0,1)
            labels = batch.label-1
            input_optimizer.zero_grad()
            inter_atten_optimizer.zero_grad()
            premise_lin, hypo_lin = encoder_model(premise, hypothesis)
            output = atten_model(premise_lin, hypo_lin)
            lossy = loss(output, labels)
            #print(lossy)
            lossy.backward()
            input_optimizer.step()
            inter_atten_optimizer.step()

            if step % 10 == 0:
                print( "Step %i; Loss %f; Dev acc %f" 
                %(step, lossy.data[0], evaluate(model, dev_iter)))

            step += 1

In [27]:
def evaluate(model, data_iter):
    model[0].eval()
    model[1].eval()
    correct = 0
    total = 0
    for batch in data_iter:
        premise = batch.premise.transpose(0,1)
        hypothesis = batch.hypothesis.transpose(0,1)
        labels = (batch.label-1).data
        premise_lin, hypo_lin = model[0](premise, hypothesis)
        output = model[1](premise_lin, hypo_lin)
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
    model[0].train()
    model[1].train()
    return correct / float(total)

In [49]:
vocab_size = len(inputs.vocab)
input_size = vocab_size
num_labels = 3
hidden_dim = 200
embedding_dim = 300
batch_size = 64
learning_rate = 0.004
num_train_steps = 50000000
para_init =0.01
weight_decay = 5e-5
optimizer_init = 0.1

In [45]:
#glove_home = './glove_6B/'
words_to_load = vocab_size

import numpy as np

# pre-trained word vectors
with open('glove.6B.300d.txt') as f:
    word_vecs = np.zeros((words_to_load, embedding_dim)) #dim: (50000, 50)
    words = {}
    idx2words = {}
    ordered_words = []
    for i, line in enumerate(f):
        if i >= words_to_load: 
            break
        s = line.split()
        
        word_vecs[i, :] = np.asarray(s[1:])
        words[s[0]] = i
        idx2words[i] = s[0]
        ordered_words.append(s[0])
word_vecs = torch.from_numpy(word_vecs)

In [51]:
input_encoder = encoder(word_vecs.size(0), embedding_dim, hidden_dim,para_init)
input_encoder.embedding.weight.data.copy_(word_vecs)
input_encoder.embedding.weight.requires_grad = False
inter_atten = atten(hidden_dim, num_labels, para_init)

In [52]:
para1 = filter(lambda p: p.requires_grad, input_encoder.parameters())
para2 = inter_atten.parameters()

In [None]:
loss = nn.CrossEntropyLoss()  
input_optimizer = optim.Adagrad(para1, lr=learning_rate)
inter_atten_optimizer = optim.Adagrad(para2, lr=learning_rate)
training_loop([input_encoder,inter_atten], loss, [input_optimizer,inter_atten_optimizer], train_iter, dev_iter)



Step 0; Loss 1.097516; Dev acc 0.338244
Step 10; Loss 1.107534; Dev acc 0.338244
Step 20; Loss 1.105402; Dev acc 0.338244
Step 30; Loss 1.094640; Dev acc 0.338244
Step 40; Loss 1.102262; Dev acc 0.338244
Step 50; Loss 1.095276; Dev acc 0.338244
Step 60; Loss 1.101535; Dev acc 0.338244
Step 70; Loss 1.097106; Dev acc 0.349827
Step 80; Loss 1.097952; Dev acc 0.338854
Step 90; Loss 1.093613; Dev acc 0.347287
Step 100; Loss 1.102106; Dev acc 0.342105
Step 110; Loss 1.103661; Dev acc 0.402459
Step 120; Loss 1.083205; Dev acc 0.352367
Step 130; Loss 1.075743; Dev acc 0.420748
Step 140; Loss 1.073250; Dev acc 0.391079
Step 150; Loss 1.031930; Dev acc 0.350437
Step 160; Loss 1.078828; Dev acc 0.422374
Step 170; Loss 1.016546; Dev acc 0.415261
Step 180; Loss 1.093939; Dev acc 0.426844
Step 190; Loss 1.032353; Dev acc 0.428368
Step 200; Loss 1.087542; Dev acc 0.401748
Step 210; Loss 1.051501; Dev acc 0.423593
Step 220; Loss 1.111732; Dev acc 0.441170
Step 230; Loss 1.089261; Dev acc 0.452347
Ste

Step 1940; Loss 1.020494; Dev acc 0.475513
Step 1950; Loss 1.000687; Dev acc 0.477139
Step 1960; Loss 1.057447; Dev acc 0.477240
Step 1970; Loss 0.983266; Dev acc 0.470738
Step 1980; Loss 1.033582; Dev acc 0.466877
Step 1990; Loss 1.028756; Dev acc 0.477139
Step 2000; Loss 1.006563; Dev acc 0.475818
Step 2010; Loss 1.078742; Dev acc 0.480390
Step 2020; Loss 1.054736; Dev acc 0.479984
Step 2030; Loss 0.943384; Dev acc 0.481305
Step 2040; Loss 1.039329; Dev acc 0.480797
Step 2050; Loss 1.103858; Dev acc 0.480085
Step 2060; Loss 1.029112; Dev acc 0.481101
Step 2070; Loss 1.003457; Dev acc 0.478968
Step 2080; Loss 1.038387; Dev acc 0.479882
Step 2090; Loss 1.047266; Dev acc 0.482219
Step 2100; Loss 1.038920; Dev acc 0.482117
Step 2110; Loss 1.011539; Dev acc 0.483032
Step 2120; Loss 1.030318; Dev acc 0.480898
Step 2130; Loss 1.048042; Dev acc 0.482422
Step 2140; Loss 1.082131; Dev acc 0.475208
Step 2150; Loss 1.001114; Dev acc 0.478561
Step 2160; Loss 1.001402; Dev acc 0.477444
Step 2170; 

Step 3850; Loss 1.050406; Dev acc 0.492786
Step 3860; Loss 0.931196; Dev acc 0.492684
Step 3870; Loss 1.070361; Dev acc 0.496342
Step 3880; Loss 0.984535; Dev acc 0.494818
Step 3890; Loss 0.985229; Dev acc 0.493904
Step 3900; Loss 1.072940; Dev acc 0.496850
Step 3910; Loss 1.055887; Dev acc 0.495936
Step 3920; Loss 1.017643; Dev acc 0.492888
Step 3930; Loss 1.037676; Dev acc 0.487096
Step 3940; Loss 1.048766; Dev acc 0.496139
Step 3950; Loss 0.873986; Dev acc 0.490856
Step 3960; Loss 1.008761; Dev acc 0.490551
Step 3970; Loss 0.865696; Dev acc 0.491059
Step 3980; Loss 1.085487; Dev acc 0.494412
Step 3990; Loss 1.124855; Dev acc 0.491872
Step 4000; Loss 1.017875; Dev acc 0.491567
Step 4010; Loss 1.061775; Dev acc 0.493192
Step 4020; Loss 1.043727; Dev acc 0.492176
Step 4030; Loss 0.975189; Dev acc 0.494717
Step 4040; Loss 1.006584; Dev acc 0.494310
Step 4050; Loss 1.037454; Dev acc 0.495529
Step 4060; Loss 0.945582; Dev acc 0.493802
Step 4070; Loss 1.033705; Dev acc 0.495733
Step 4080; 

Step 5760; Loss 1.074105; Dev acc 0.498273
Step 5770; Loss 1.160983; Dev acc 0.501016
Step 5780; Loss 1.003788; Dev acc 0.499086
Step 5790; Loss 0.903169; Dev acc 0.500305
Step 5800; Loss 0.962369; Dev acc 0.499898
Step 5810; Loss 0.962080; Dev acc 0.499390
Step 5820; Loss 0.987918; Dev acc 0.497561
Step 5830; Loss 0.928381; Dev acc 0.500610
Step 5840; Loss 0.939286; Dev acc 0.501524
Step 5850; Loss 0.915123; Dev acc 0.499390
Step 5860; Loss 1.124410; Dev acc 0.502134
Step 5870; Loss 1.062809; Dev acc 0.500914
Step 5880; Loss 0.935979; Dev acc 0.498679
Step 5890; Loss 1.058529; Dev acc 0.500203
Step 5900; Loss 0.965678; Dev acc 0.500000
Step 5910; Loss 0.947947; Dev acc 0.495529
Step 5920; Loss 1.011692; Dev acc 0.498578
Step 5930; Loss 0.980662; Dev acc 0.495834
Step 5940; Loss 1.064087; Dev acc 0.497968
Step 5950; Loss 0.990043; Dev acc 0.499492
Step 5960; Loss 1.069820; Dev acc 0.499898
Step 5970; Loss 0.986999; Dev acc 0.499797
Step 5980; Loss 1.053754; Dev acc 0.498374
Step 5990; 

Step 7670; Loss 0.963273; Dev acc 0.506096
Step 7680; Loss 0.998765; Dev acc 0.503658
Step 7690; Loss 1.072097; Dev acc 0.507925
Step 7700; Loss 0.992789; Dev acc 0.504775
Step 7710; Loss 1.013128; Dev acc 0.504979
Step 7720; Loss 1.108023; Dev acc 0.506503
Step 7730; Loss 1.021104; Dev acc 0.501422
Step 7740; Loss 1.006647; Dev acc 0.501829
Step 7750; Loss 1.039706; Dev acc 0.504471
Step 7760; Loss 1.001991; Dev acc 0.509144
Step 7770; Loss 1.034575; Dev acc 0.505588
Step 7780; Loss 0.933293; Dev acc 0.505588
Step 7790; Loss 1.002883; Dev acc 0.508332
Step 7800; Loss 0.948351; Dev acc 0.504267
Step 7810; Loss 1.090950; Dev acc 0.505487
Step 7820; Loss 1.025625; Dev acc 0.505283
Step 7830; Loss 1.091346; Dev acc 0.506096
Step 7840; Loss 1.007271; Dev acc 0.504674
Step 7850; Loss 0.953918; Dev acc 0.501931
Step 7860; Loss 1.091390; Dev acc 0.505385
Step 7870; Loss 1.052401; Dev acc 0.506198
Step 7880; Loss 1.010531; Dev acc 0.505385
Step 7890; Loss 0.975974; Dev acc 0.507519
Step 7900; 

Step 9580; Loss 1.055220; Dev acc 0.508535
Step 9590; Loss 1.087037; Dev acc 0.502743
Step 9600; Loss 0.912459; Dev acc 0.508840
Step 9610; Loss 1.068263; Dev acc 0.506096
Step 9620; Loss 1.023793; Dev acc 0.511583
Step 9630; Loss 0.914247; Dev acc 0.507620
Step 9640; Loss 0.787379; Dev acc 0.506503
Step 9650; Loss 0.894310; Dev acc 0.502845
Step 9660; Loss 0.933603; Dev acc 0.506604
Step 9670; Loss 0.986758; Dev acc 0.507011
Step 9680; Loss 0.995458; Dev acc 0.509246
Step 9690; Loss 0.977483; Dev acc 0.507722
Step 9700; Loss 1.020639; Dev acc 0.504471
Step 9710; Loss 0.879995; Dev acc 0.507112
Step 9720; Loss 1.042867; Dev acc 0.505792
Step 9730; Loss 1.001780; Dev acc 0.506706
Step 9740; Loss 0.752722; Dev acc 0.505080
Step 9750; Loss 1.070056; Dev acc 0.503150
Step 9760; Loss 0.981172; Dev acc 0.504979
Step 9770; Loss 1.009093; Dev acc 0.509754
Step 9780; Loss 1.006658; Dev acc 0.504775
Step 9790; Loss 1.097699; Dev acc 0.509246
Step 9800; Loss 0.954635; Dev acc 0.508941
Step 9810; 

In [21]:
model = MLPClassifier(input_size, embedding_dim, hidden_dim, num_labels)
    
# Loss and Optimizer
loss = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
training_loop(model, loss, optimizer, train_iter, dev_iter)

torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz



torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz

torch.Size([64, 12])
torch.Size([64, 12])
torch.Size([64, 12])
torch.Size([64, 14])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 13])
torch.Size([64, 16])
torch.Size([64, 13])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 11])
torch.Size([64, 13])
torch.Size([64, 12])
torch.Size([64, 14])
torch.Size([64, 12])
torch.Size([64, 15])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 10])
torch.Size([64, 16])
torch.Size([64, 10])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 14])
torch.Size([64, 14])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 20])
torch.Size([64, 19])
torch.Size([64, 25])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 17])
torch.Size([64, 

torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz

torch.Size([64, 14])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 6])
torch.Size([64, 16])
torch.Size([64, 6])
torch.Size([64, 16])
torch.Size([64, 6])
torch.Size([64, 16])
torch.Size([64, 7])
torch.Size([64, 15])
torch.Size([64, 8])
torch.Size([64, 16])
torch.Size([64, 8])
torch.Size([64, 16])
torch.Size([64, 7])
torch.Size([64, 16])
torch.Size([64, 9])
torch.Size([64, 9])
torch.Size([64, 9])
torch.Size([64, 9])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 11])
torch.Size([64, 12])
torch.Size([64, 11])
torch.Size([64, 10])
torch.Size([64, 11])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 11])
torch.Size([64, 11])
torch.Size([64, 12])
torch.Size([64, 12])
torch.Size([64, 12])
torch

torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 19])
torch.Size([64, 8])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 20])
torch.Size([64, 7])
torch.Size([64, 19])
torch.Size([64, 8])
torch.Size([64, 20])
torch.Size([64, 8])
torch.Size([64, 21])
torch.Size([64, 8])
torch.Size([64, 24])
torch.Size([64, 5])
torch.Size([64, 22])
torch.Size([64, 6])
torch.Size([64, 22])
torch.Size([64, 7])
torch.Size([64, 21])
torch.Size([64, 8])
torch.Size([64, 22])
torch.Size([64, 8])
torch.Size([64, 23])
torch.Size([64, 8])
torch.Size([64, 24])
torch.Size([64, 7])
torch.Size([64, 24])
torch.Size([64, 8])
torch.Size([64, 24])
torch.Size([64, 9])
torch.Size([64, 17])
torch.Size([64, 10])
torch.Size([64, 18])
torch.Size([64, 10])
torch.Size([64, 18])
torch.Size([64, 12])
torch.Size([64, 18])
torch.Size([64, 12])
torch.Size([64, 19])
torch.Size([64, 12])
torch.Size(

torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz

torch.Size([64, 16])
torch.Size([64, 10])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 14])
torch.Size([64, 14])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 20])
torch.Size([64, 19])
torch.Size([64, 25])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 17])
torch.Size([64, 8])
torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 19])
torch.Size([64, 8])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 20])
torch.Size([64, 7])
torch.Size([64, 19])
torch.Size([64, 8])
torch.Size([64, 20])
torch.Size([64, 8])
torch.Size([64, 21])
torch.Size([64, 8])
torch.Size([64, 24])
torch.Size([64, 5])
torch.Size([64, 22])
torch.Size([64, 6])
torch.Size([64, 22])
torch.Size([64, 7])
torch.Siz

torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz

torch.Size([64, 11])
torch.Size([64, 12])
torch.Size([64, 11])
torch.Size([64, 10])
torch.Size([64, 11])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 10])
torch.Size([64, 12])
torch.Size([64, 11])
torch.Size([64, 11])
torch.Size([64, 12])
torch.Size([64, 12])
torch.Size([64, 12])
torch.Size([64, 12])
torch.Size([64, 14])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 13])
torch.Size([64, 16])
torch.Size([64, 13])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 11])
torch.Size([64, 13])
torch.Size([64, 12])
torch.Size([64, 14])
torch.Size([64, 12])
torch.Size([64, 15])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 10])
torch.Size([64, 16])
torch.Size([64, 10])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 14])
torch.Size([64, 14])
torch.Size([6

torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz

torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 14])
torch.Size([64, 14])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 20])
torch.Size([64, 19])
torch.Size([64, 25])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 17])
torch.Size([64, 8])
torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 19])
torch.Size([64, 8])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 20])
torch.Size([64, 7])
torch.Size([64, 19])
torch.Size([64, 8])
torch.Size([64, 20])
torch.Size([64, 8])
torch.Size([64, 21])
torch.Size([64, 8])
torch.Size([64, 24])
torch.Size([64, 5])
torch.Size([64, 22])
torch.Size([64, 6])
torch.Size([64, 22])
torch.Size([64, 7])
torch.Size([64, 21])
torch.Size([64, 8])
torch.Size

torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz

torch.Size([64, 13])
torch.Size([64, 16])
torch.Size([64, 13])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 10])
torch.Size([64, 14])
torch.Size([64, 11])
torch.Size([64, 13])
torch.Size([64, 12])
torch.Size([64, 14])
torch.Size([64, 12])
torch.Size([64, 15])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 10])
torch.Size([64, 16])
torch.Size([64, 10])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 16])
torch.Size([64, 14])
torch.Size([64, 14])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 16])
torch.Size([64, 20])
torch.Size([64, 19])
torch.Size([64, 25])
torch.Size([64, 20])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 6])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 17])
torch.Size([64, 8])
torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 18])
torch.Size([64, 7])
torch.Size([64, 18])
torch.Size([64, 8])
torch.Size([64, 19])
torch.Size([64, 8])


torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 6])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 7])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 10])
torch.Size([64, 8])
torch.Size([64, 12])
torch.Size([64, 8])
torch.Size([64, 15])
torch.Size([64, 11])
torch.Size([64, 16])
torch.Size([64, 12])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 5])
torch.Size([64, 9])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Size([64, 6])
torch.Size([64, 10])
torch.Siz

KeyboardInterrupt: 