Run this notebook with default settings to train the InferSent NLI model using the original trained InferSent encoder. The encoder and NLI model will be saved as "../savedir/model.pickle.encoder.pkl" and "../savedir/model.pickle", respectively.

In [None]:
# import stuff
%load_ext autoreload
%autoreload 2
%matplotlib inline

import inspect
import os
from random import randint
import time

import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
import re
from nltk.tokenize import word_tokenize

## Load model

In [2]:
# Load model
from models import InferSent
model_version = 1
MODEL_PATH = "../encoder/infersent%s.pkl" % model_version
params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
                'pool_type': 'max', 'dpout_model': 0.0, 'version': model_version}
model = InferSent(params_model)
model.load_state_dict(torch.load(MODEL_PATH))

<All keys matched successfully>

In [3]:
# Keep it on CPU or put it on GPU
use_cuda = torch.cuda.is_available()
#or force not to use cuda
#use_cuda = False
model = model.cuda() if use_cuda else model

In [4]:
# If infersent1 -> use GloVe embeddings. If infersent2 -> use InferSent embeddings.
W2V_PATH = '../GloVe/glove.840B.300d.txt' if model_version == 1 else '../fastText/crawl-300d-2M.vec'
model.set_w2v_path(W2V_PATH)

In [17]:
tmp1 = pd.read_csv('../dataset/esnli_train_1.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
tmp2 = pd.read_csv('../dataset/esnli_train_2.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
train = pd.concat([tmp1, tmp2], ignore_index=True)
train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 549367 entries, 0 to 549366
Data columns (total 3 columns):
 #   Column      Non-Null Count   Dtype 
---  ------      --------------   ----- 
 0   gold_label  549367 non-null  object
 1   Sentence1   549367 non-null  object
 2   Sentence2   549361 non-null  object
dtypes: object(3)
memory usage: 12.6+ MB


In [18]:
valid = pd.read_csv('../dataset/esnli_dev.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
valid.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9842 entries, 0 to 9841
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   gold_label  9842 non-null   object
 1   Sentence1   9842 non-null   object
 2   Sentence2   9842 non-null   object
dtypes: object(3)
memory usage: 230.8+ KB


In [19]:
test = pd.read_csv('../dataset/esnli_test.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
test.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9824 entries, 0 to 9823
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   gold_label  9824 non-null   object
 1   Sentence1   9824 non-null   object
 2   Sentence2   9824 non-null   object
dtypes: object(3)
memory usage: 230.4+ KB


In [20]:
#map label to int
label_to_int = {'entailment': 0, 'neutral': 1, 'contradiction': 2}

In [21]:
#add label int
train['label'] = train['gold_label'].apply(lambda x: label_to_int[x])
valid['label'] = valid['gold_label'].apply(lambda x: label_to_int[x])
test['label'] = test['gold_label'].apply(lambda x: label_to_int[x])

In [22]:
def build_vocab(sentences, glove_path):
    word_dict = get_word_dict(sentences)
    word_vec = get_glove(word_dict, glove_path)
    print('Vocab size : {0}'.format(len(word_vec)))
    return word_vec

def get_glove(word_dict, glove_path):
    # create word_vec with glove vectors
    word_vec = {}
    with open(glove_path, encoding='utf8') as f:
        for line in f:
            #print(line)
            #break
            word, vec = line.split(' ', 1)
            if word in word_dict:
                word_vec[word] = np.array(list(map(float, vec.split())))
    print('Found {0}(/{1}) words with glove vectors'.format(
                len(word_vec), len(word_dict)))
    return word_vec

def get_word_dict(sentences):
    # create vocab of words
    word_dict = {}
    for sent in sentences:
        for word in word_tokenize(str(sent)):
            if word not in word_dict:
                word_dict[word] = ''
            #print(word)
    word_dict['<s>'] = ''
    word_dict['</s>'] = ''
    word_dict['<p>'] = ''
    return word_dict

In [23]:
glove_path = '../GloVe/glove.840B.300d.txt'

In [24]:
#converts DataFrames to dict
train = train.to_dict(orient='list')
valid = valid.to_dict(orient='list')
test = test.to_dict(orient='list')

In [26]:
word_vec = build_vocab(train['Sentence1'] + train['Sentence2'] +
                       valid['Sentence1'] + valid['Sentence2'] +
                       test['Sentence1'] + test['Sentence2'], glove_path)

Found 38909(/43393) words with glove vectors
Vocab size : 38909


In [27]:
class NLINet(nn.Module):
    def __init__(self, config):
        super(NLINet, self).__init__()

        # classifier
        self.nonlinear_fc = config['nonlinear_fc']
        self.fc_dim = config['fc_dim']
        self.n_classes = config['n_classes']
        self.enc_lstm_dim = config['enc_lstm_dim']
        self.encoder_type = config['encoder_type']
        self.dpout_fc = config['dpout_fc']

        self.encoder = model  #eval(self.encoder_type)(config)
        self.inputdim = 4*2*self.enc_lstm_dim
        self.inputdim = 4*self.inputdim if self.encoder_type in \
                        ["ConvNetEncoder", "InnerAttentionMILAEncoder"] else self.inputdim
        self.inputdim = self.inputdim/2 if self.encoder_type == "LSTMEncoder" \
                                        else self.inputdim
        if self.nonlinear_fc:
            self.classifier = nn.Sequential(
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.inputdim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.n_classes),
                )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(self.inputdim, self.fc_dim),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Linear(self.fc_dim, self.n_classes)
                )

    def forward(self, s1, s2):
        # s1 : (s1, s1_len)
        u = self.encoder(s1)
        v = self.encoder(s2)

        features = torch.cat((u, v, torch.abs(u-v), u*v), 1)
        output = self.classifier(features)
        return output

    def encode(self, s1):
        emb = self.encoder(s1)
        return emb


In [28]:
for split in ['Sentence1', 'Sentence2']:
    for data_type in ['train', 'valid', 'test']:
        eval(data_type)[split] = np.array([['<s>'] + \
            [word for word in word_tokenize(str(sent)) if word in word_vec] + \
            ['</s>'] for sent in eval(data_type)[split]])

  eval(data_type)[split] = np.array([['<s>'] + \


In [29]:
train['label'] = np.array(train['label'])

In [30]:
parser = argparse.ArgumentParser(description='NLI training')
# paths
parser.add_argument("--nlipath", type=str, default='dataset/SNLI/', help="NLI data path (SNLI or MultiNLI)")
parser.add_argument("--outputdir", type=str, default='../savedir/', help="Output directory")
parser.add_argument("--outputmodelname", type=str, default='model.pickle')
parser.add_argument("--word_emb_path", type=str, default="../dataset/GloVe/glove.840B.300d.txt", help="word embedding file path")

# training
parser.add_argument("--n_epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=64)  #64)
parser.add_argument("--dpout_model", type=float, default=0., help="encoder dropout")
parser.add_argument("--dpout_fc", type=float, default=0., help="classifier dropout")
parser.add_argument("--nonlinear_fc", type=float, default=1, help="use nonlinearity in fc")
parser.add_argument("--optimizer", type=str, default="sgd,lr=0.1", help="adam or sgd,lr=0.1")
parser.add_argument("--lrshrink", type=float, default=5, help="shrink factor for sgd")
parser.add_argument("--decay", type=float, default=0.99, help="lr decay")
parser.add_argument("--minlr", type=float, default=1e-5, help="minimum lr")
parser.add_argument("--max_norm", type=float, default=5., help="max norm (grad clipping)")

# model
parser.add_argument("--encoder_type", type=str, default='InferSentV1', help="see list of encoders")
parser.add_argument("--enc_lstm_dim", type=int, default=2048, help="encoder nhid dimension")
parser.add_argument("--n_enc_layers", type=int, default=1, help="encoder num layers")
parser.add_argument("--fc_dim", type=int, default=512, help="nhid of fc layers")
parser.add_argument("--n_classes", type=int, default=3, help="entailment/neutral/contradiction")
parser.add_argument("--pool_type", type=str, default='max', help="max or mean")

# gpu
parser.add_argument("--gpu_id", type=int, default=3, help="GPU ID")
parser.add_argument("--seed", type=int, default=1234, help="seed")

# data
parser.add_argument("--word_emb_dim", type=int, default=300, help="word embedding dimension")

params, _ = parser.parse_known_args()
config_nli_model = {
    'n_words'        :  len(word_vec)          ,
    'word_emb_dim'   :  params.word_emb_dim   ,
    'enc_lstm_dim'   :  params.enc_lstm_dim   ,
    'n_enc_layers'   :  params.n_enc_layers   ,
    'dpout_model'    :  params.dpout_model    ,
    'dpout_fc'       :  params.dpout_fc       ,
    'fc_dim'         :  params.fc_dim         ,
    'bsize'          :  params.batch_size     ,
    'n_classes'      :  params.n_classes      ,
    'pool_type'      :  params.pool_type      ,
    'nonlinear_fc'   :  params.nonlinear_fc   ,
    'encoder_type'   :  params.encoder_type   ,
    'use_cuda'       :  True                  ,

}
nli_net = NLINet(config_nli_model)
print(nli_net)

NLINet(
  (encoder): InferSent(
    (enc_lstm): LSTM(300, 2048, bidirectional=True)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.0, inplace=False)
    (1): Linear(in_features=16384, out_features=512, bias=True)
    (2): Tanh()
    (3): Dropout(p=0.0, inplace=False)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): Tanh()
    (6): Dropout(p=0.0, inplace=False)
    (7): Linear(in_features=512, out_features=3, bias=True)
  )
)


In [31]:
"""
SEED
"""
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)

In [32]:
def get_optimizer(s):
    """
    Parse optimizer parameters.
    Input should be of the form:
        - "sgd,lr=0.01"
        - "adagrad,lr=0.1,lr_decay=0.05"
    """
    if "," in s:
        method = s[:s.find(',')]
        optim_params = {}
        for x in s[s.find(',') + 1:].split(','):
            split = x.split('=')
            assert len(split) == 2
            assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
            optim_params[split[0]] = float(split[1])
    else:
        method = s
        optim_params = {}

    if method == 'adadelta':
        optim_fn = optim.Adadelta
    elif method == 'adagrad':
        optim_fn = optim.Adagrad
    elif method == 'adam':
        optim_fn = optim.Adam
    elif method == 'adamax':
        optim_fn = optim.Adamax
    elif method == 'asgd':
        optim_fn = optim.ASGD
    elif method == 'rmsprop':
        optim_fn = optim.RMSprop
    elif method == 'rprop':
        optim_fn = optim.Rprop
    elif method == 'sgd':
        optim_fn = optim.SGD
        assert 'lr' in optim_params
    else:
        raise Exception('Unknown optimization method: "%s"' % method)

    # check that we give good parameters to the optimizer
    #expected_args = inspect.getargspec(optim_fn.__init__)[0]
    #assert expected_args[:2] == ['self', 'params']
    #if not all(k in expected_args[2:] for k in optim_params.keys()):
    #    raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
    #        str(expected_args[2:]), str(optim_params.keys())))

    return optim_fn, optim_params

In [33]:
# loss
weight = torch.FloatTensor(params.n_classes).fill_(1)
loss_fn = nn.CrossEntropyLoss(weight=weight)
loss_fn.size_average = False

# optimizer
optim_fn, optim_params = get_optimizer(params.optimizer)
optimizer = optim_fn(nli_net.parameters(), **optim_params)

# cuda by default
nli_net.cuda()
loss_fn.cuda()

CrossEntropyLoss()

In [34]:
"""
TRAIN
"""
val_acc_best = -1e10
adam_stop = False
stop_training = False
lr = optim_params['lr'] if 'sgd' in params.optimizer else None

In [35]:
def trainepoch(epoch):
    print('\nTRAINING : Epoch ' + str(epoch))
    nli_net.train()
    all_costs = []
    logs = []
    words_count = 0

    last_time = time.time()
    correct = 0.
    # shuffle the data
    permutation = np.random.permutation(len(train['Sentence1']))

    s1 = train['Sentence1'][permutation]
    s2 = train['Sentence2'][permutation]
    target = train['label'][permutation]


    optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * params.decay if epoch>1\
        and 'sgd' in params.optimizer else optimizer.param_groups[0]['lr']
    print('Learning rate : {0}'.format(optimizer.param_groups[0]['lr']))

    for stidx in range(0, len(s1), params.batch_size):
        # prepare batch
        s1_batch, s1_len = get_batch(s1[stidx:stidx + params.batch_size],
                                     word_vec, params.word_emb_dim)
        s2_batch, s2_len = get_batch(s2[stidx:stidx + params.batch_size],
                                     word_vec, params.word_emb_dim)
        s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable(s2_batch.cuda())
        tgt_batch = Variable(torch.LongTensor(target[stidx:stidx + params.batch_size])).cuda()
        k = s1_batch.size(1)  # actual batch size
        
        # model forward
        output = nli_net((s1_batch, s1_len), (s2_batch, s2_len))

        pred = output.data.max(1)[1]
        correct += pred.long().eq(tgt_batch.data.long()).cpu().sum()
        assert len(pred) == len(s1[stidx:stidx + params.batch_size])

        # loss
        loss = loss_fn(output, tgt_batch)
        #print(type(loss))
        all_costs.append(loss.item())  #.data[0])
        words_count += (s1_batch.nelement() + s2_batch.nelement()) / params.word_emb_dim

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient clipping (off by default)
        shrink_factor = 1
        total_norm = 0

        for p in nli_net.parameters():
            if p.requires_grad:
                p.grad.data.div_(k)  # divide by the actual batch size
                total_norm += p.grad.data.norm() ** 2
        total_norm = np.sqrt(total_norm.cpu())

        if total_norm > params.max_norm:
            shrink_factor = params.max_norm / total_norm
        current_lr = optimizer.param_groups[0]['lr'] # current lr (no external "lr", for adam)
        optimizer.param_groups[0]['lr'] = current_lr * shrink_factor # just for update

        # optimizer step
        optimizer.step()
        optimizer.param_groups[0]['lr'] = current_lr

        if len(all_costs) == 100:
            logs.append('{0} ; loss {1} ; sentence/s {2} ; words/s {3} ; accuracy train : {4}'.format(
                            stidx, round(np.mean(all_costs), 2),
                            int(len(all_costs) * params.batch_size / (time.time() - last_time)),
                            int(words_count * 1.0 / (time.time() - last_time)),
                            100.*correct/(stidx+k)))
            print(logs[-1])
            last_time = time.time()
            words_count = 0
            all_costs = []
    train_acc = 100 * correct/len(s1)  #round(100 * correct/len(s1), 2)
    print('results : epoch {0} ; mean accuracy train : {1}'
          .format(epoch, train_acc))
    return train_acc

In [36]:
def get_batch(batch, word_vec, emb_dim=300):
    # sent in batch in decreasing order of lengths (bsize, max_len, word_dim)
    lengths = np.array([len(x) for x in batch])
    max_len = np.max(lengths)
    embed = np.zeros((max_len, len(batch), emb_dim))

    for i in range(len(batch)):
        for j in range(len(batch[i])):
            embed[j, i, :] = word_vec[batch[i][j]]

    return torch.from_numpy(embed).float(), lengths

In [37]:
def evaluate(epoch, eval_type='valid', final_eval=False):
    nli_net.eval()
    correct = 0.
    global val_acc_best, lr, stop_training, adam_stop

    if eval_type == 'valid':
        print('\nVALIDATION : Epoch {0}'.format(epoch))

    s1 = valid['Sentence1'] if eval_type == 'valid' else test['Sentence1']
    s2 = valid['Sentence2'] if eval_type == 'valid' else test['Sentence2']
    target = valid['label'] if eval_type == 'valid' else test['label']

    for i in range(0, len(s1), params.batch_size):
        # prepare batch
        s1_batch, s1_len = get_batch(s1[i:i + params.batch_size], word_vec, params.word_emb_dim)
        s2_batch, s2_len = get_batch(s2[i:i + params.batch_size], word_vec, params.word_emb_dim)
        s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable(s2_batch.cuda())
        tgt_batch = Variable(torch.LongTensor(target[i:i + params.batch_size])).cuda()

        # model forward
        output = nli_net((s1_batch, s1_len), (s2_batch, s2_len))

        pred = output.data.max(1)[1]
        correct += pred.long().eq(tgt_batch.data.long()).cpu().sum()

        
    # save model
    eval_acc = 100 * correct/len(s1)  #round(100 * correct / len(s1), 2)
    if final_eval:
        print('finalgrep : accuracy {0} : {1}'.format(eval_type, eval_acc))
    else:
        print('togrep : results : epoch {0} ; mean accuracy {1} :\
              {2}'.format(epoch, eval_type, eval_acc))

    if eval_type == 'valid' and epoch <= params.n_epochs:
        if eval_acc > val_acc_best:
            print('saving model at epoch {0}'.format(epoch))
            if not os.path.exists(params.outputdir):
                os.makedirs(params.outputdir)
            torch.save(nli_net.state_dict(), os.path.join(params.outputdir,
                       params.outputmodelname))
            val_acc_best = eval_acc
        else:
            if 'sgd' in params.optimizer:
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] / params.lrshrink
                print('Shrinking lr by : {0}. New lr = {1}'
                      .format(params.lrshrink,
                              optimizer.param_groups[0]['lr']))
                if optimizer.param_groups[0]['lr'] < params.minlr:
                    stop_training = True
            if 'adam' in params.optimizer:
                # early stopping (at 2nd decrease in accuracy)
                stop_training = adam_stop
                adam_stop = True
    return eval_acc


In [38]:
"""
Train model on Natural Language Inference task
"""
epoch = 1

while not stop_training and epoch <= params.n_epochs:
    train_acc = trainepoch(epoch)
    eval_acc = evaluate(epoch, 'valid')
    epoch += 1


TRAINING : Epoch 1
Learning rate : 0.1
6336 ; loss 1.1 ; sentence/s 264 ; words/s 15283 ; accuracy train : 33.078125
12736 ; loss 1.1 ; sentence/s 282 ; words/s 16184 ; accuracy train : 33.875
19136 ; loss 1.1 ; sentence/s 277 ; words/s 16420 ; accuracy train : 34.30729293823242
25536 ; loss 1.1 ; sentence/s 282 ; words/s 16326 ; accuracy train : 35.2265625
31936 ; loss 1.1 ; sentence/s 279 ; words/s 16201 ; accuracy train : 35.22187423706055
38336 ; loss 1.1 ; sentence/s 280 ; words/s 16479 ; accuracy train : 34.953125
44736 ; loss 1.09 ; sentence/s 281 ; words/s 16371 ; accuracy train : 34.96651840209961
51136 ; loss 1.09 ; sentence/s 281 ; words/s 16535 ; accuracy train : 34.916015625
57536 ; loss 1.09 ; sentence/s 281 ; words/s 16055 ; accuracy train : 34.82465362548828
63936 ; loss 1.09 ; sentence/s 276 ; words/s 16046 ; accuracy train : 34.80937576293945
70336 ; loss 1.09 ; sentence/s 281 ; words/s 16079 ; accuracy train : 34.8636360168457
76736 ; loss 1.09 ; sentence/s 284 ; wo

44736 ; loss 0.85 ; sentence/s 255 ; words/s 14700 ; accuracy train : 68.72544860839844
51136 ; loss 0.85 ; sentence/s 249 ; words/s 14755 ; accuracy train : 68.669921875
57536 ; loss 0.84 ; sentence/s 252 ; words/s 14916 ; accuracy train : 68.70138549804688
63936 ; loss 0.83 ; sentence/s 254 ; words/s 14633 ; accuracy train : 68.84375
70336 ; loss 0.83 ; sentence/s 239 ; words/s 13997 ; accuracy train : 68.92613983154297
76736 ; loss 0.82 ; sentence/s 228 ; words/s 13310 ; accuracy train : 68.91666412353516
83136 ; loss 0.81 ; sentence/s 249 ; words/s 14241 ; accuracy train : 68.96394348144531
89536 ; loss 0.81 ; sentence/s 252 ; words/s 14693 ; accuracy train : 68.99665069580078
95936 ; loss 0.8 ; sentence/s 254 ; words/s 14944 ; accuracy train : 69.08125305175781
102336 ; loss 0.8 ; sentence/s 253 ; words/s 14756 ; accuracy train : 69.0380859375
108736 ; loss 0.8 ; sentence/s 251 ; words/s 14772 ; accuracy train : 69.03584289550781
115136 ; loss 0.79 ; sentence/s 244 ; words/s 14148

83136 ; loss 0.56 ; sentence/s 288 ; words/s 16281 ; accuracy train : 77.13461303710938
89536 ; loss 0.56 ; sentence/s 282 ; words/s 16560 ; accuracy train : 77.21540069580078
95936 ; loss 0.56 ; sentence/s 287 ; words/s 16475 ; accuracy train : 77.27916717529297
102336 ; loss 0.57 ; sentence/s 287 ; words/s 16801 ; accuracy train : 77.279296875
108736 ; loss 0.57 ; sentence/s 288 ; words/s 16908 ; accuracy train : 77.30422973632812
115136 ; loss 0.56 ; sentence/s 286 ; words/s 16617 ; accuracy train : 77.33246612548828
121536 ; loss 0.55 ; sentence/s 291 ; words/s 17219 ; accuracy train : 77.36759948730469
127936 ; loss 0.55 ; sentence/s 294 ; words/s 16938 ; accuracy train : 77.3984375
134336 ; loss 0.56 ; sentence/s 282 ; words/s 16246 ; accuracy train : 77.42112731933594
140736 ; loss 0.56 ; sentence/s 284 ; words/s 16496 ; accuracy train : 77.46235656738281
147136 ; loss 0.56 ; sentence/s 245 ; words/s 14092 ; accuracy train : 77.48301696777344
153536 ; loss 0.56 ; sentence/s 225 

127936 ; loss 0.45 ; sentence/s 283 ; words/s 16357 ; accuracy train : 82.1343765258789
134336 ; loss 0.46 ; sentence/s 284 ; words/s 16864 ; accuracy train : 82.1547622680664
140736 ; loss 0.46 ; sentence/s 292 ; words/s 16871 ; accuracy train : 82.15979766845703
147136 ; loss 0.45 ; sentence/s 295 ; words/s 16872 ; accuracy train : 82.1793441772461
153536 ; loss 0.44 ; sentence/s 296 ; words/s 16923 ; accuracy train : 82.22721099853516
159936 ; loss 0.45 ; sentence/s 282 ; words/s 16445 ; accuracy train : 82.26249694824219
166336 ; loss 0.46 ; sentence/s 294 ; words/s 16811 ; accuracy train : 82.26382446289062
172736 ; loss 0.44 ; sentence/s 299 ; words/s 16711 ; accuracy train : 82.30497741699219
179136 ; loss 0.43 ; sentence/s 288 ; words/s 17014 ; accuracy train : 82.35491180419922
185536 ; loss 0.45 ; sentence/s 286 ; words/s 16425 ; accuracy train : 82.36476135253906
191936 ; loss 0.46 ; sentence/s 290 ; words/s 16826 ; accuracy train : 82.37239837646484
198336 ; loss 0.43 ; sen

166336 ; loss 0.4 ; sentence/s 285 ; words/s 16489 ; accuracy train : 84.94711303710938
172736 ; loss 0.4 ; sentence/s 284 ; words/s 16584 ; accuracy train : 84.94213104248047
179136 ; loss 0.4 ; sentence/s 288 ; words/s 17081 ; accuracy train : 84.95368194580078
185536 ; loss 0.39 ; sentence/s 290 ; words/s 16983 ; accuracy train : 84.95743560791016
191936 ; loss 0.39 ; sentence/s 284 ; words/s 16634 ; accuracy train : 84.9671859741211
198336 ; loss 0.41 ; sentence/s 283 ; words/s 16933 ; accuracy train : 84.97681427001953
204736 ; loss 0.39 ; sentence/s 291 ; words/s 17101 ; accuracy train : 84.9853515625
211136 ; loss 0.4 ; sentence/s 294 ; words/s 17136 ; accuracy train : 84.97774505615234
217536 ; loss 0.39 ; sentence/s 291 ; words/s 17135 ; accuracy train : 84.97840118408203
223936 ; loss 0.39 ; sentence/s 292 ; words/s 16888 ; accuracy train : 85.01026916503906
230336 ; loss 0.39 ; sentence/s 291 ; words/s 16920 ; accuracy train : 85.01519012451172
236736 ; loss 0.37 ; sentence/

204736 ; loss 0.35 ; sentence/s 285 ; words/s 16476 ; accuracy train : 86.33056640625
211136 ; loss 0.36 ; sentence/s 282 ; words/s 16557 ; accuracy train : 86.3508529663086
217536 ; loss 0.36 ; sentence/s 295 ; words/s 16874 ; accuracy train : 86.34099578857422
223936 ; loss 0.37 ; sentence/s 279 ; words/s 15908 ; accuracy train : 86.328125
230336 ; loss 0.36 ; sentence/s 296 ; words/s 16773 ; accuracy train : 86.33506774902344
236736 ; loss 0.35 ; sentence/s 277 ; words/s 16670 ; accuracy train : 86.3492431640625
243136 ; loss 0.37 ; sentence/s 289 ; words/s 16760 ; accuracy train : 86.33676147460938
249536 ; loss 0.35 ; sentence/s 278 ; words/s 15949 ; accuracy train : 86.35416412353516
255936 ; loss 0.36 ; sentence/s 285 ; words/s 16339 ; accuracy train : 86.3628921508789
262336 ; loss 0.35 ; sentence/s 287 ; words/s 16165 ; accuracy train : 86.38300323486328
268736 ; loss 0.36 ; sentence/s 282 ; words/s 16466 ; accuracy train : 86.38876342773438
275136 ; loss 0.35 ; sentence/s 288

249536 ; loss 0.35 ; sentence/s 278 ; words/s 16098 ; accuracy train : 87.390625
255936 ; loss 0.34 ; sentence/s 298 ; words/s 17355 ; accuracy train : 87.38008117675781
262336 ; loss 0.35 ; sentence/s 279 ; words/s 16146 ; accuracy train : 87.37461853027344
268736 ; loss 0.35 ; sentence/s 282 ; words/s 16562 ; accuracy train : 87.36421203613281
275136 ; loss 0.35 ; sentence/s 287 ; words/s 16700 ; accuracy train : 87.35865020751953
281536 ; loss 0.34 ; sentence/s 290 ; words/s 16629 ; accuracy train : 87.35440063476562
287936 ; loss 0.35 ; sentence/s 288 ; words/s 16661 ; accuracy train : 87.33715057373047
294336 ; loss 0.34 ; sentence/s 296 ; words/s 17248 ; accuracy train : 87.33899688720703
300736 ; loss 0.34 ; sentence/s 301 ; words/s 17664 ; accuracy train : 87.34906768798828
307136 ; loss 0.34 ; sentence/s 289 ; words/s 16516 ; accuracy train : 87.34896087646484
313536 ; loss 0.35 ; sentence/s 287 ; words/s 16867 ; accuracy train : 87.34629821777344
319936 ; loss 0.35 ; sentence

287936 ; loss 0.34 ; sentence/s 283 ; words/s 16310 ; accuracy train : 87.6295166015625
294336 ; loss 0.34 ; sentence/s 259 ; words/s 15541 ; accuracy train : 87.63451385498047
300736 ; loss 0.35 ; sentence/s 269 ; words/s 15479 ; accuracy train : 87.61835479736328
307136 ; loss 0.34 ; sentence/s 267 ; words/s 15587 ; accuracy train : 87.62174224853516
313536 ; loss 0.34 ; sentence/s 275 ; words/s 15800 ; accuracy train : 87.60969543457031
319936 ; loss 0.34 ; sentence/s 266 ; words/s 15862 ; accuracy train : 87.60437774658203
326336 ; loss 0.34 ; sentence/s 266 ; words/s 15780 ; accuracy train : 87.59803771972656
332736 ; loss 0.34 ; sentence/s 269 ; words/s 16442 ; accuracy train : 87.59134674072266
339136 ; loss 0.35 ; sentence/s 276 ; words/s 16095 ; accuracy train : 87.58785247802734
345536 ; loss 0.35 ; sentence/s 277 ; words/s 16015 ; accuracy train : 87.57002258300781
351936 ; loss 0.35 ; sentence/s 299 ; words/s 17284 ; accuracy train : 87.56647491455078
358336 ; loss 0.33 ; s

326336 ; loss 0.33 ; sentence/s 274 ; words/s 15755 ; accuracy train : 87.74264526367188
332736 ; loss 0.34 ; sentence/s 301 ; words/s 17819 ; accuracy train : 87.74188995361328
339136 ; loss 0.35 ; sentence/s 293 ; words/s 17047 ; accuracy train : 87.72994995117188
345536 ; loss 0.33 ; sentence/s 284 ; words/s 16849 ; accuracy train : 87.73321533203125
351936 ; loss 0.33 ; sentence/s 288 ; words/s 16680 ; accuracy train : 87.73892211914062
358336 ; loss 0.35 ; sentence/s 296 ; words/s 17462 ; accuracy train : 87.71791076660156
364736 ; loss 0.34 ; sentence/s 304 ; words/s 17359 ; accuracy train : 87.71189880371094
371136 ; loss 0.32 ; sentence/s 276 ; words/s 15956 ; accuracy train : 87.72144317626953
377536 ; loss 0.33 ; sentence/s 266 ; words/s 15780 ; accuracy train : 87.7190170288086
383936 ; loss 0.32 ; sentence/s 285 ; words/s 17047 ; accuracy train : 87.73124694824219
390336 ; loss 0.34 ; sentence/s 283 ; words/s 16283 ; accuracy train : 87.72003173828125
396736 ; loss 0.34 ; s

364736 ; loss 0.33 ; sentence/s 302 ; words/s 17507 ; accuracy train : 87.73272705078125
371136 ; loss 0.35 ; sentence/s 294 ; words/s 17339 ; accuracy train : 87.7184829711914
377536 ; loss 0.33 ; sentence/s 300 ; words/s 17600 ; accuracy train : 87.72854614257812
383936 ; loss 0.33 ; sentence/s 280 ; words/s 16072 ; accuracy train : 87.73645782470703
390336 ; loss 0.33 ; sentence/s 295 ; words/s 17315 ; accuracy train : 87.74411010742188
396736 ; loss 0.32 ; sentence/s 297 ; words/s 16890 ; accuracy train : 87.7429428100586
403136 ; loss 0.33 ; sentence/s 284 ; words/s 16412 ; accuracy train : 87.73983001708984
409536 ; loss 0.34 ; sentence/s 295 ; words/s 16489 ; accuracy train : 87.739990234375
415936 ; loss 0.34 ; sentence/s 243 ; words/s 14344 ; accuracy train : 87.73966217041016
422336 ; loss 0.32 ; sentence/s 280 ; words/s 16240 ; accuracy train : 87.74384307861328
428736 ; loss 0.33 ; sentence/s 292 ; words/s 17164 ; accuracy train : 87.75233459472656
435136 ; loss 0.33 ; sent

403136 ; loss 0.34 ; sentence/s 298 ; words/s 17293 ; accuracy train : 87.76984405517578
409536 ; loss 0.34 ; sentence/s 302 ; words/s 17191 ; accuracy train : 87.76220703125
415936 ; loss 0.33 ; sentence/s 299 ; words/s 17490 ; accuracy train : 87.75985717773438
422336 ; loss 0.34 ; sentence/s 292 ; words/s 16785 ; accuracy train : 87.75521087646484
428736 ; loss 0.33 ; sentence/s 300 ; words/s 17161 ; accuracy train : 87.75326538085938
435136 ; loss 0.33 ; sentence/s 299 ; words/s 17130 ; accuracy train : 87.75252532958984
441536 ; loss 0.34 ; sentence/s 298 ; words/s 16788 ; accuracy train : 87.74932098388672
447936 ; loss 0.33 ; sentence/s 297 ; words/s 17488 ; accuracy train : 87.74508666992188
454336 ; loss 0.31 ; sentence/s 279 ; words/s 16366 ; accuracy train : 87.75814056396484
460736 ; loss 0.33 ; sentence/s 290 ; words/s 16824 ; accuracy train : 87.75824737548828
467136 ; loss 0.34 ; sentence/s 261 ; words/s 15168 ; accuracy train : 87.75149536132812
473536 ; loss 0.33 ; sen

441536 ; loss 0.33 ; sentence/s 274 ; words/s 16288 ; accuracy train : 87.74864196777344
447936 ; loss 0.34 ; sentence/s 275 ; words/s 16706 ; accuracy train : 87.75245666503906
454336 ; loss 0.34 ; sentence/s 296 ; words/s 17106 ; accuracy train : 87.7486801147461
460736 ; loss 0.35 ; sentence/s 294 ; words/s 17198 ; accuracy train : 87.74066925048828
467136 ; loss 0.32 ; sentence/s 292 ; words/s 16991 ; accuracy train : 87.75106811523438
473536 ; loss 0.34 ; sentence/s 302 ; words/s 17611 ; accuracy train : 87.74176788330078
479936 ; loss 0.32 ; sentence/s 293 ; words/s 17079 ; accuracy train : 87.74166870117188
486336 ; loss 0.32 ; sentence/s 286 ; words/s 16649 ; accuracy train : 87.74712371826172
492736 ; loss 0.35 ; sentence/s 289 ; words/s 16217 ; accuracy train : 87.73741912841797
499136 ; loss 0.33 ; sentence/s 306 ; words/s 17359 ; accuracy train : 87.73878479003906
505536 ; loss 0.34 ; sentence/s 307 ; words/s 17557 ; accuracy train : 87.73120880126953
511936 ; loss 0.32 ; s

473536 ; loss 0.35 ; sentence/s 307 ; words/s 17461 ; accuracy train : 87.73817443847656
479936 ; loss 0.32 ; sentence/s 300 ; words/s 17612 ; accuracy train : 87.74500274658203
486336 ; loss 0.32 ; sentence/s 306 ; words/s 17595 ; accuracy train : 87.75123596191406
492736 ; loss 0.34 ; sentence/s 308 ; words/s 17600 ; accuracy train : 87.75
499136 ; loss 0.34 ; sentence/s 300 ; words/s 17862 ; accuracy train : 87.74038696289062
505536 ; loss 0.32 ; sentence/s 304 ; words/s 17694 ; accuracy train : 87.75019836425781
511936 ; loss 0.34 ; sentence/s 308 ; words/s 17552 ; accuracy train : 87.74043273925781
518336 ; loss 0.33 ; sentence/s 301 ; words/s 17814 ; accuracy train : 87.7440185546875
524736 ; loss 0.33 ; sentence/s 306 ; words/s 17719 ; accuracy train : 87.74161529541016
531136 ; loss 0.32 ; sentence/s 296 ; words/s 17368 ; accuracy train : 87.7454833984375
537536 ; loss 0.33 ; sentence/s 308 ; words/s 17838 ; accuracy train : 87.73939514160156
543936 ; loss 0.32 ; sentence/s 307

511936 ; loss 0.34 ; sentence/s 292 ; words/s 17188 ; accuracy train : 87.74589538574219
518336 ; loss 0.33 ; sentence/s 299 ; words/s 17178 ; accuracy train : 87.74942016601562
524736 ; loss 0.34 ; sentence/s 292 ; words/s 17117 ; accuracy train : 87.74923706054688
531136 ; loss 0.32 ; sentence/s 300 ; words/s 17538 ; accuracy train : 87.7513198852539
537536 ; loss 0.33 ; sentence/s 292 ; words/s 17169 ; accuracy train : 87.74906921386719
543936 ; loss 0.34 ; sentence/s 292 ; words/s 17684 ; accuracy train : 87.74393463134766
results : epoch 14 ; mean accuracy train : 87.74552917480469

VALIDATION : Epoch 14
togrep : results : epoch 14 ; mean accuracy valid :              84.12924194335938
Shrinking lr by : 5. New lr = 5.616134547193394e-06


In [39]:
# Run best model on test set.
nli_net.load_state_dict(torch.load(os.path.join(params.outputdir, params.outputmodelname)))

print('\nTEST : Epoch {0}'.format(epoch))
evaluate(1e6, 'valid', True)
evaluate(0, 'test', True)

# Save encoder instead of full model
torch.save(nli_net.encoder.state_dict(), os.path.join(params.outputdir, params.outputmodelname + '.encoder.pkl'))


TEST : Epoch 15

VALIDATION : Epoch 1000000.0
finalgrep : accuracy valid : 84.24100494384766
finalgrep : accuracy test : 84.41571807861328
