In [1]:
# import stuff
%load_ext autoreload
%autoreload 2
%matplotlib inline

import inspect
import os
from random import randint
import time

import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
import re

## Load model

In [2]:
# Load model
from models import InferSent
model_version = 1
MODEL_PATH = "../encoder/infersent%s.pkl" % model_version
params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
                'pool_type': 'max', 'dpout_model': 0.0, 'version': model_version}
model = InferSent(params_model)
model.load_state_dict(torch.load(MODEL_PATH))

<All keys matched successfully>

In [3]:
# Keep it on CPU or put it on GPU
use_cuda = torch.cuda.is_available()
#or force not to use cuda
#use_cuda = False
model = model.cuda() if use_cuda else model

In [4]:
# If infersent1 -> use GloVe embeddings. If infersent2 -> use InferSent embeddings.
W2V_PATH = '../GloVe/glove.840B.300d.txt' if model_version == 1 else '../fastText/crawl-300d-2M.vec'
model.set_w2v_path(W2V_PATH)

In [5]:
# Load embeddings of K most frequent words
#model.build_vocab_k_words(K=100000)

## Load sentences

In [6]:
# Load some sentences
#sentences = []
#with open('samples.txt') as f:
#    for line in f:
#        sentences.append(line.strip())
#print(len(sentences))

In [7]:
#sentences[:5]

## Encode sentences

In [8]:
# gpu mode : >> 1000 sentences/s
# cpu mode : ~100 sentences/s

In [9]:
#embeddings = model.encode(sentences, bsize=128, tokenize=False, verbose=True)
#print('nb sentences encoded : {0}'.format(len(embeddings)))

## Visualization

In [10]:
#np.linalg.norm(model.encode(['the cat eats.']))

In [11]:
#def cosine(u, v):
#    return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))

In [12]:
#cosine(model.encode(['the cat eats.'])[0], model.encode(['the cat drinks.'])[0])

In [13]:
#idx = randint(0, len(sentences))
#_, _ = model.visualize(sentences[idx])

In [14]:
#my_sent = 'The cat is drinking milk.'
#_, _ = model.visualize(my_sent)

In [15]:
#model.build_vocab_k_words(500000) # getting 500K words vocab
#my_sent = 'barack-obama is the former president of the United-States.'
#_, _ = model.visualize(my_sent)

**InferSent inference**

In [16]:
%ls

 Volume in drive C has no label.
 Volume Serial Number is F0F5-7230

 Directory of C:\Users\ktjam\YKT\MComp AI Classes\CS4248 Natural Language Processing\Github_project\4248-project\src

18/03/2023  11:33 pm    <DIR>          .
18/03/2023  04:59 pm    <DIR>          ..
18/03/2023  06:32 pm    <DIR>          .ipynb_checkpoints
18/03/2023  04:50 pm    <DIR>          __pycache__
18/03/2023  07:26 pm            50,356 demo_training.ipynb
18/03/2023  11:33 pm            10,486 eval_preds.py
01/03/2023  03:27 am            10,140 models.py
01/03/2023  03:27 am           590,791 samples.txt
18/03/2023  04:18 pm             4,395 test.ipynb
18/03/2023  04:19 pm           449,448 visualize.ipynb
               6 File(s)      1,115,616 bytes
               4 Dir(s)  46,681,780,224 bytes free


In [17]:
tmp1 = pd.read_csv('../dataset/esnli_train_1.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
tmp2 = pd.read_csv('../dataset/esnli_train_2.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
train = pd.concat([tmp1, tmp2], ignore_index=True)
train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 549367 entries, 0 to 549366
Data columns (total 3 columns):
 #   Column      Non-Null Count   Dtype 
---  ------      --------------   ----- 
 0   gold_label  549367 non-null  object
 1   Sentence1   549367 non-null  object
 2   Sentence2   549361 non-null  object
dtypes: object(3)
memory usage: 12.6+ MB


In [18]:
valid = pd.read_csv('../dataset/esnli_dev.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
valid.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9842 entries, 0 to 9841
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   gold_label  9842 non-null   object
 1   Sentence1   9842 non-null   object
 2   Sentence2   9842 non-null   object
dtypes: object(3)
memory usage: 230.8+ KB


In [19]:
test = pd.read_csv('../dataset/esnli_test.csv', usecols=['gold_label', 'Sentence1', 'Sentence2'])
test.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 9824 entries, 0 to 9823
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   gold_label  9824 non-null   object
 1   Sentence1   9824 non-null   object
 2   Sentence2   9824 non-null   object
dtypes: object(3)
memory usage: 230.4+ KB


In [20]:
#map label to int
label_to_int = {'entailment': 0, 'neutral': 1, 'contradiction': 2}

In [21]:
#add label int
train['label'] = train['gold_label'].apply(lambda x: label_to_int[x])
valid['label'] = valid['gold_label'].apply(lambda x: label_to_int[x])
test['label'] = test['gold_label'].apply(lambda x: label_to_int[x])

In [22]:
def build_vocab(sentences, glove_path):
    word_dict = get_word_dict(sentences)
    word_vec = get_glove(word_dict, glove_path)
    print('Vocab size : {0}'.format(len(word_vec)))
    return word_vec

def get_glove(word_dict, glove_path):
    # create word_vec with glove vectors
    word_vec = {}
    with open(glove_path, encoding='utf8') as f:
        for line in f:
            #print(line)
            #break
            word, vec = line.split(' ', 1)
            if word in word_dict:
                word_vec[word] = np.array(list(map(float, vec.split())))
    print('Found {0}(/{1}) words with glove vectors'.format(
                len(word_vec), len(word_dict)))
    return word_vec

def get_word_dict(sentences):
    # create vocab of words
    word_dict = {}
    for sent in sentences:
        for word in str(sent).split():
            if word not in word_dict:
                word_dict[word] = ''
    word_dict['<s>'] = ''
    word_dict['</s>'] = ''
    word_dict['<p>'] = ''
    return word_dict

In [23]:
glove_path = '../GloVe/glove.840B.300d.txt'

In [24]:
#converts DataFrames to dict
train = train.to_dict(orient='list')
valid = valid.to_dict(orient='list')
test = test.to_dict(orient='list')

In [25]:
for i in range(1):
    print(train['Sentence2'][i])

A person is training his horse for a competition.


In [26]:
word_vec = build_vocab(train['Sentence1'] + train['Sentence2'] +
                       valid['Sentence1'] + valid['Sentence2'] +
                       test['Sentence1'] + test['Sentence2'], glove_path)

Found 37925(/64300) words with glove vectors
Vocab size : 37925


In [27]:
class NLINet(nn.Module):
    def __init__(self, config):
        super(NLINet, self).__init__()

        # classifier
        self.nonlinear_fc = config['nonlinear_fc']
        self.fc_dim = config['fc_dim']
        self.n_classes = config['n_classes']
        self.enc_lstm_dim = config['enc_lstm_dim']
        self.encoder_type = config['encoder_type']
        self.dpout_fc = config['dpout_fc']

        self.encoder = model  #eval(self.encoder_type)(config)
        self.inputdim = 4*2*self.enc_lstm_dim
        self.inputdim = 4*self.inputdim if self.encoder_type in \
                        ["ConvNetEncoder", "InnerAttentionMILAEncoder"] else self.inputdim
        self.inputdim = self.inputdim/2 if self.encoder_type == "LSTMEncoder" \
                                        else self.inputdim
        if self.nonlinear_fc:
            self.classifier = nn.Sequential(
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.inputdim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Tanh(),
                nn.Dropout(p=self.dpout_fc),
                nn.Linear(self.fc_dim, self.n_classes),
                )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(self.inputdim, self.fc_dim),
                nn.Linear(self.fc_dim, self.fc_dim),
                nn.Linear(self.fc_dim, self.n_classes)
                )

    def forward(self, s1, s2):
        # s1 : (s1, s1_len)
        u = self.encoder(s1)
        v = self.encoder(s2)

        features = torch.cat((u, v, torch.abs(u-v), u*v), 1)
        output = self.classifier(features)
        return output

    def encode(self, s1):
        emb = self.encoder(s1)
        return emb


In [28]:
for split in ['Sentence1', 'Sentence2']:
    for data_type in ['train', 'valid', 'test']:
        eval(data_type)[split] = np.array([['<s>'] + \
            [word for word in str(sent).split() if word in word_vec] + \
            ['</s>'] for sent in eval(data_type)[split]])

  eval(data_type)[split] = np.array([['<s>'] + \


In [29]:
train['label'] = np.array(train['label'])

In [30]:
parser = argparse.ArgumentParser(description='NLI training')
# paths
parser.add_argument("--nlipath", type=str, default='dataset/SNLI/', help="NLI data path (SNLI or MultiNLI)")
parser.add_argument("--outputdir", type=str, default='../savedir/', help="Output directory")
parser.add_argument("--outputmodelname", type=str, default='model.pickle')
parser.add_argument("--word_emb_path", type=str, default="../dataset/GloVe/glove.840B.300d.txt", help="word embedding file path")

# training
parser.add_argument("--n_epochs", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=64)  #64)
parser.add_argument("--dpout_model", type=float, default=0., help="encoder dropout")
parser.add_argument("--dpout_fc", type=float, default=0., help="classifier dropout")
parser.add_argument("--nonlinear_fc", type=float, default=1, help="use nonlinearity in fc")
parser.add_argument("--optimizer", type=str, default="sgd,lr=0.1", help="adam or sgd,lr=0.1")
parser.add_argument("--lrshrink", type=float, default=5, help="shrink factor for sgd")
parser.add_argument("--decay", type=float, default=0.99, help="lr decay")
parser.add_argument("--minlr", type=float, default=1e-5, help="minimum lr")
parser.add_argument("--max_norm", type=float, default=5., help="max norm (grad clipping)")

# model
parser.add_argument("--encoder_type", type=str, default='InferSentV1', help="see list of encoders")
parser.add_argument("--enc_lstm_dim", type=int, default=2048, help="encoder nhid dimension")
parser.add_argument("--n_enc_layers", type=int, default=1, help="encoder num layers")
parser.add_argument("--fc_dim", type=int, default=512, help="nhid of fc layers")
parser.add_argument("--n_classes", type=int, default=3, help="entailment/neutral/contradiction")
parser.add_argument("--pool_type", type=str, default='max', help="max or mean")

# gpu
parser.add_argument("--gpu_id", type=int, default=3, help="GPU ID")
parser.add_argument("--seed", type=int, default=1234, help="seed")

# data
parser.add_argument("--word_emb_dim", type=int, default=300, help="word embedding dimension")

params, _ = parser.parse_known_args()
config_nli_model = {
    'n_words'        :  len(word_vec)          ,
    'word_emb_dim'   :  params.word_emb_dim   ,
    'enc_lstm_dim'   :  params.enc_lstm_dim   ,
    'n_enc_layers'   :  params.n_enc_layers   ,
    'dpout_model'    :  params.dpout_model    ,
    'dpout_fc'       :  params.dpout_fc       ,
    'fc_dim'         :  params.fc_dim         ,
    'bsize'          :  params.batch_size     ,
    'n_classes'      :  params.n_classes      ,
    'pool_type'      :  params.pool_type      ,
    'nonlinear_fc'   :  params.nonlinear_fc   ,
    'encoder_type'   :  params.encoder_type   ,
    'use_cuda'       :  True                  ,

}
nli_net = NLINet(config_nli_model)
print(nli_net)

NLINet(
  (encoder): InferSent(
    (enc_lstm): LSTM(300, 2048, bidirectional=True)
  )
  (classifier): Sequential(
    (0): Dropout(p=0.0, inplace=False)
    (1): Linear(in_features=16384, out_features=512, bias=True)
    (2): Tanh()
    (3): Dropout(p=0.0, inplace=False)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): Tanh()
    (6): Dropout(p=0.0, inplace=False)
    (7): Linear(in_features=512, out_features=3, bias=True)
  )
)


In [31]:
"""
SEED
"""
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)

In [32]:
def get_optimizer(s):
    """
    Parse optimizer parameters.
    Input should be of the form:
        - "sgd,lr=0.01"
        - "adagrad,lr=0.1,lr_decay=0.05"
    """
    if "," in s:
        method = s[:s.find(',')]
        optim_params = {}
        for x in s[s.find(',') + 1:].split(','):
            split = x.split('=')
            assert len(split) == 2
            assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
            optim_params[split[0]] = float(split[1])
    else:
        method = s
        optim_params = {}

    if method == 'adadelta':
        optim_fn = optim.Adadelta
    elif method == 'adagrad':
        optim_fn = optim.Adagrad
    elif method == 'adam':
        optim_fn = optim.Adam
    elif method == 'adamax':
        optim_fn = optim.Adamax
    elif method == 'asgd':
        optim_fn = optim.ASGD
    elif method == 'rmsprop':
        optim_fn = optim.RMSprop
    elif method == 'rprop':
        optim_fn = optim.Rprop
    elif method == 'sgd':
        optim_fn = optim.SGD
        assert 'lr' in optim_params
    else:
        raise Exception('Unknown optimization method: "%s"' % method)

    # check that we give good parameters to the optimizer
    #expected_args = inspect.getargspec(optim_fn.__init__)[0]
    #assert expected_args[:2] == ['self', 'params']
    #if not all(k in expected_args[2:] for k in optim_params.keys()):
    #    raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
    #        str(expected_args[2:]), str(optim_params.keys())))

    return optim_fn, optim_params

In [33]:
# loss
weight = torch.FloatTensor(params.n_classes).fill_(1)
loss_fn = nn.CrossEntropyLoss(weight=weight)
loss_fn.size_average = False

# optimizer
optim_fn, optim_params = get_optimizer(params.optimizer)
optimizer = optim_fn(nli_net.parameters(), **optim_params)

# cuda by default
nli_net.cuda()
loss_fn.cuda()

CrossEntropyLoss()

In [34]:
"""
TRAIN
"""
val_acc_best = -1e10
adam_stop = False
stop_training = False
lr = optim_params['lr'] if 'sgd' in params.optimizer else None

In [35]:
def trainepoch(epoch):
    print('\nTRAINING : Epoch ' + str(epoch))
    nli_net.train()
    all_costs = []
    logs = []
    words_count = 0

    last_time = time.time()
    correct = 0.
    # shuffle the data
    permutation = np.random.permutation(len(train['Sentence1']))

    s1 = train['Sentence1'][permutation]
    s2 = train['Sentence2'][permutation]
    target = train['label'][permutation]


    optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * params.decay if epoch>1\
        and 'sgd' in params.optimizer else optimizer.param_groups[0]['lr']
    print('Learning rate : {0}'.format(optimizer.param_groups[0]['lr']))

    for stidx in range(0, len(s1), params.batch_size):
        # prepare batch
        s1_batch, s1_len = get_batch(s1[stidx:stidx + params.batch_size],
                                     word_vec, params.word_emb_dim)
        s2_batch, s2_len = get_batch(s2[stidx:stidx + params.batch_size],
                                     word_vec, params.word_emb_dim)
        s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable(s2_batch.cuda())
        tgt_batch = Variable(torch.LongTensor(target[stidx:stidx + params.batch_size])).cuda()
        k = s1_batch.size(1)  # actual batch size
        
        # model forward
        output = nli_net((s1_batch, s1_len), (s2_batch, s2_len))

        pred = output.data.max(1)[1]
        correct += pred.long().eq(tgt_batch.data.long()).cpu().sum()
        assert len(pred) == len(s1[stidx:stidx + params.batch_size])

        # loss
        loss = loss_fn(output, tgt_batch)
        #print(type(loss))
        all_costs.append(loss.item())  #.data[0])
        words_count += (s1_batch.nelement() + s2_batch.nelement()) / params.word_emb_dim

        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient clipping (off by default)
        shrink_factor = 1
        total_norm = 0

        for p in nli_net.parameters():
            if p.requires_grad:
                p.grad.data.div_(k)  # divide by the actual batch size
                total_norm += p.grad.data.norm() ** 2
        total_norm = np.sqrt(total_norm.cpu())

        if total_norm > params.max_norm:
            shrink_factor = params.max_norm / total_norm
        current_lr = optimizer.param_groups[0]['lr'] # current lr (no external "lr", for adam)
        optimizer.param_groups[0]['lr'] = current_lr * shrink_factor # just for update

        # optimizer step
        optimizer.step()
        optimizer.param_groups[0]['lr'] = current_lr

        if len(all_costs) == 100:
            logs.append('{0} ; loss {1} ; sentence/s {2} ; words/s {3} ; accuracy train : {4}'.format(
                            stidx, round(np.mean(all_costs), 2),
                            int(len(all_costs) * params.batch_size / (time.time() - last_time)),
                            int(words_count * 1.0 / (time.time() - last_time)),
                            100.*correct/(stidx+k)))
            print(logs[-1])
            last_time = time.time()
            words_count = 0
            all_costs = []
    train_acc = 100 * correct/len(s1)  #round(100 * correct/len(s1), 2)
    print('results : epoch {0} ; mean accuracy train : {1}'
          .format(epoch, train_acc))
    return train_acc

In [36]:
def get_batch(batch, word_vec, emb_dim=300):
    # sent in batch in decreasing order of lengths (bsize, max_len, word_dim)
    lengths = np.array([len(x) for x in batch])
    max_len = np.max(lengths)
    embed = np.zeros((max_len, len(batch), emb_dim))

    for i in range(len(batch)):
        for j in range(len(batch[i])):
            embed[j, i, :] = word_vec[batch[i][j]]

    return torch.from_numpy(embed).float(), lengths

In [37]:
def evaluate(epoch, eval_type='valid', final_eval=False):
    nli_net.eval()
    correct = 0.
    global val_acc_best, lr, stop_training, adam_stop

    if eval_type == 'valid':
        print('\nVALIDATION : Epoch {0}'.format(epoch))

    s1 = valid['Sentence1'] if eval_type == 'valid' else test['Sentence1']
    s2 = valid['Sentence2'] if eval_type == 'valid' else test['Sentence2']
    target = valid['label'] if eval_type == 'valid' else test['label']

    for i in range(0, len(s1), params.batch_size):
        # prepare batch
        s1_batch, s1_len = get_batch(s1[i:i + params.batch_size], word_vec, params.word_emb_dim)
        s2_batch, s2_len = get_batch(s2[i:i + params.batch_size], word_vec, params.word_emb_dim)
        s1_batch, s2_batch = Variable(s1_batch.cuda()), Variable(s2_batch.cuda())
        tgt_batch = Variable(torch.LongTensor(target[i:i + params.batch_size])).cuda()

        # model forward
        output = nli_net((s1_batch, s1_len), (s2_batch, s2_len))

        pred = output.data.max(1)[1]
        correct += pred.long().eq(tgt_batch.data.long()).cpu().sum()

        
    # save model
    eval_acc = 100 * correct/len(s1)  #round(100 * correct / len(s1), 2)
    if final_eval:
        print('finalgrep : accuracy {0} : {1}'.format(eval_type, eval_acc))
    else:
        print('togrep : results : epoch {0} ; mean accuracy {1} :\
              {2}'.format(epoch, eval_type, eval_acc))

    if eval_type == 'valid' and epoch <= params.n_epochs:
        if eval_acc > val_acc_best:
            print('saving model at epoch {0}'.format(epoch))
            if not os.path.exists(params.outputdir):
                os.makedirs(params.outputdir)
            torch.save(nli_net.state_dict(), os.path.join(params.outputdir,
                       params.outputmodelname))
            val_acc_best = eval_acc
        else:
            if 'sgd' in params.optimizer:
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] / params.lrshrink
                print('Shrinking lr by : {0}. New lr = {1}'
                      .format(params.lrshrink,
                              optimizer.param_groups[0]['lr']))
                if optimizer.param_groups[0]['lr'] < params.minlr:
                    stop_training = True
            if 'adam' in params.optimizer:
                # early stopping (at 2nd decrease in accuracy)
                stop_training = adam_stop
                adam_stop = True
    return eval_acc


In [38]:
"""
Train model on Natural Language Inference task
"""
epoch = 1

while not stop_training and epoch <= params.n_epochs:
    train_acc = trainepoch(epoch)
    eval_acc = evaluate(epoch, 'valid')
    epoch += 1


TRAINING : Epoch 1
Learning rate : 0.1
6336 ; loss 1.1 ; sentence/s 263 ; words/s 13801 ; accuracy train : 31.953125
12736 ; loss 1.1 ; sentence/s 272 ; words/s 14194 ; accuracy train : 32.4140625
19136 ; loss 1.1 ; sentence/s 271 ; words/s 14348 ; accuracy train : 33.04166793823242
25536 ; loss 1.1 ; sentence/s 273 ; words/s 14259 ; accuracy train : 33.890625
31936 ; loss 1.1 ; sentence/s 272 ; words/s 14260 ; accuracy train : 34.243751525878906
38336 ; loss 1.1 ; sentence/s 273 ; words/s 14325 ; accuracy train : 34.15104293823242
44736 ; loss 1.1 ; sentence/s 270 ; words/s 14316 ; accuracy train : 34.32143020629883
51136 ; loss 1.1 ; sentence/s 272 ; words/s 14416 ; accuracy train : 34.345703125
57536 ; loss 1.09 ; sentence/s 273 ; words/s 14045 ; accuracy train : 34.31770706176758
63936 ; loss 1.09 ; sentence/s 271 ; words/s 14136 ; accuracy train : 34.34062576293945
70336 ; loss 1.09 ; sentence/s 271 ; words/s 13987 ; accuracy train : 34.421875
76736 ; loss 1.09 ; sentence/s 272 ;

44736 ; loss 0.98 ; sentence/s 331 ; words/s 17330 ; accuracy train : 60.90625
51136 ; loss 0.98 ; sentence/s 329 ; words/s 17732 ; accuracy train : 60.90625
57536 ; loss 0.98 ; sentence/s 334 ; words/s 17698 ; accuracy train : 60.89756774902344
63936 ; loss 0.97 ; sentence/s 337 ; words/s 17515 ; accuracy train : 60.985939025878906
70336 ; loss 0.97 ; sentence/s 333 ; words/s 17527 ; accuracy train : 61.01420593261719
76736 ; loss 0.96 ; sentence/s 335 ; words/s 17565 ; accuracy train : 61.01823043823242
83136 ; loss 0.96 ; sentence/s 336 ; words/s 17394 ; accuracy train : 61.03725814819336
89536 ; loss 0.96 ; sentence/s 333 ; words/s 17340 ; accuracy train : 61.04352569580078
95936 ; loss 0.95 ; sentence/s 332 ; words/s 17541 ; accuracy train : 61.11145782470703
102336 ; loss 0.95 ; sentence/s 332 ; words/s 17503 ; accuracy train : 61.1435546875
108736 ; loss 0.95 ; sentence/s 331 ; words/s 17614 ; accuracy train : 61.15900802612305
115136 ; loss 0.95 ; sentence/s 333 ; words/s 17434

83136 ; loss 0.75 ; sentence/s 339 ; words/s 17557 ; accuracy train : 67.60816955566406
89536 ; loss 0.75 ; sentence/s 334 ; words/s 17670 ; accuracy train : 67.6953125
95936 ; loss 0.75 ; sentence/s 336 ; words/s 17538 ; accuracy train : 67.73229217529297
102336 ; loss 0.75 ; sentence/s 334 ; words/s 17612 ; accuracy train : 67.759765625
108736 ; loss 0.75 ; sentence/s 333 ; words/s 17679 ; accuracy train : 67.81433868408203
115136 ; loss 0.75 ; sentence/s 332 ; words/s 17553 ; accuracy train : 67.85243225097656
121536 ; loss 0.73 ; sentence/s 332 ; words/s 17621 ; accuracy train : 67.9555892944336
127936 ; loss 0.73 ; sentence/s 336 ; words/s 17322 ; accuracy train : 67.98359680175781
134336 ; loss 0.74 ; sentence/s 338 ; words/s 17532 ; accuracy train : 67.98809814453125
140736 ; loss 0.74 ; sentence/s 334 ; words/s 17457 ; accuracy train : 68.0234375
147136 ; loss 0.74 ; sentence/s 337 ; words/s 17400 ; accuracy train : 68.06521606445312
153536 ; loss 0.74 ; sentence/s 334 ; words/

121536 ; loss 0.67 ; sentence/s 332 ; words/s 17724 ; accuracy train : 72.66447448730469
127936 ; loss 0.66 ; sentence/s 336 ; words/s 17456 ; accuracy train : 72.6539077758789
134336 ; loss 0.66 ; sentence/s 334 ; words/s 17922 ; accuracy train : 72.64657592773438
140736 ; loss 0.66 ; sentence/s 335 ; words/s 17442 ; accuracy train : 72.6555404663086
147136 ; loss 0.66 ; sentence/s 340 ; words/s 17459 ; accuracy train : 72.6596450805664
153536 ; loss 0.65 ; sentence/s 336 ; words/s 17512 ; accuracy train : 72.68880462646484
159936 ; loss 0.66 ; sentence/s 335 ; words/s 17596 ; accuracy train : 72.70437622070312
166336 ; loss 0.65 ; sentence/s 336 ; words/s 17479 ; accuracy train : 72.72535705566406
172736 ; loss 0.64 ; sentence/s 341 ; words/s 17262 ; accuracy train : 72.76388549804688
179136 ; loss 0.63 ; sentence/s 333 ; words/s 17800 ; accuracy train : 72.81082916259766
185536 ; loss 0.65 ; sentence/s 336 ; words/s 17608 ; accuracy train : 72.8356704711914
191936 ; loss 0.65 ; sent

166336 ; loss 0.6 ; sentence/s 338 ; words/s 17669 ; accuracy train : 75.46935272216797
172736 ; loss 0.6 ; sentence/s 336 ; words/s 17534 ; accuracy train : 75.47164154052734
179136 ; loss 0.6 ; sentence/s 335 ; words/s 17793 ; accuracy train : 75.48213958740234
185536 ; loss 0.58 ; sentence/s 333 ; words/s 17733 ; accuracy train : 75.4989242553711
191936 ; loss 0.58 ; sentence/s 334 ; words/s 17696 ; accuracy train : 75.5218734741211
198336 ; loss 0.6 ; sentence/s 332 ; words/s 17967 ; accuracy train : 75.5241928100586
204736 ; loss 0.59 ; sentence/s 332 ; words/s 17775 ; accuracy train : 75.54248046875
211136 ; loss 0.59 ; sentence/s 334 ; words/s 17624 ; accuracy train : 75.5516128540039
217536 ; loss 0.59 ; sentence/s 334 ; words/s 17611 ; accuracy train : 75.56204223632812
223936 ; loss 0.58 ; sentence/s 336 ; words/s 17623 ; accuracy train : 75.5999984741211
230336 ; loss 0.59 ; sentence/s 335 ; words/s 17527 ; accuracy train : 75.58853912353516
236736 ; loss 0.57 ; sentence/s 3

211136 ; loss 0.55 ; sentence/s 333 ; words/s 17591 ; accuracy train : 77.05255889892578
217536 ; loss 0.56 ; sentence/s 339 ; words/s 17471 ; accuracy train : 77.05928039550781
223936 ; loss 0.56 ; sentence/s 338 ; words/s 17437 ; accuracy train : 77.06160736083984
230336 ; loss 0.56 ; sentence/s 340 ; words/s 17497 ; accuracy train : 77.06163024902344
236736 ; loss 0.55 ; sentence/s 331 ; words/s 17807 ; accuracy train : 77.07474517822266
243136 ; loss 0.56 ; sentence/s 337 ; words/s 17624 ; accuracy train : 77.08264923095703
249536 ; loss 0.55 ; sentence/s 335 ; words/s 17464 ; accuracy train : 77.09214782714844
255936 ; loss 0.56 ; sentence/s 338 ; words/s 17534 ; accuracy train : 77.08203125
262336 ; loss 0.55 ; sentence/s 341 ; words/s 17328 ; accuracy train : 77.09565734863281
268736 ; loss 0.56 ; sentence/s 334 ; words/s 17649 ; accuracy train : 77.09970092773438
275136 ; loss 0.55 ; sentence/s 334 ; words/s 17941 ; accuracy train : 77.109375
281536 ; loss 0.56 ; sentence/s 335

249536 ; loss 0.54 ; sentence/s 338 ; words/s 17616 ; accuracy train : 78.171875
255936 ; loss 0.53 ; sentence/s 338 ; words/s 17533 ; accuracy train : 78.1878890991211
262336 ; loss 0.54 ; sentence/s 337 ; words/s 17591 ; accuracy train : 78.18978881835938
268736 ; loss 0.55 ; sentence/s 335 ; words/s 17777 ; accuracy train : 78.17298889160156
275136 ; loss 0.54 ; sentence/s 336 ; words/s 17631 ; accuracy train : 78.17805480957031
281536 ; loss 0.53 ; sentence/s 336 ; words/s 17542 ; accuracy train : 78.20134735107422
287936 ; loss 0.54 ; sentence/s 334 ; words/s 17728 ; accuracy train : 78.1982650756836
294336 ; loss 0.53 ; sentence/s 336 ; words/s 17629 ; accuracy train : 78.21399688720703
300736 ; loss 0.54 ; sentence/s 334 ; words/s 17669 ; accuracy train : 78.21210479736328
307136 ; loss 0.54 ; sentence/s 337 ; words/s 17420 ; accuracy train : 78.21614837646484
313536 ; loss 0.55 ; sentence/s 333 ; words/s 17711 ; accuracy train : 78.20695495605469
319936 ; loss 0.54 ; sentence/s

287936 ; loss 0.52 ; sentence/s 338 ; words/s 17695 ; accuracy train : 79.20972442626953
294336 ; loss 0.51 ; sentence/s 333 ; words/s 17932 ; accuracy train : 79.20686340332031
300736 ; loss 0.53 ; sentence/s 340 ; words/s 17618 ; accuracy train : 79.19248962402344
307136 ; loss 0.52 ; sentence/s 336 ; words/s 17792 ; accuracy train : 79.18326568603516
313536 ; loss 0.53 ; sentence/s 339 ; words/s 17559 ; accuracy train : 79.1769790649414
319936 ; loss 0.53 ; sentence/s 332 ; words/s 17883 ; accuracy train : 79.15937805175781
326336 ; loss 0.53 ; sentence/s 332 ; words/s 17712 ; accuracy train : 79.14920043945312
332736 ; loss 0.52 ; sentence/s 331 ; words/s 18383 ; accuracy train : 79.15084075927734
339136 ; loss 0.53 ; sentence/s 335 ; words/s 17689 ; accuracy train : 79.14533996582031
345536 ; loss 0.53 ; sentence/s 335 ; words/s 17549 ; accuracy train : 79.13194274902344
351936 ; loss 0.53 ; sentence/s 337 ; words/s 17604 ; accuracy train : 79.12216186523438
358336 ; loss 0.51 ; s

326336 ; loss 0.51 ; sentence/s 337 ; words/s 17378 ; accuracy train : 79.79901885986328
332736 ; loss 0.5 ; sentence/s 335 ; words/s 17839 ; accuracy train : 79.80198669433594
339136 ; loss 0.51 ; sentence/s 336 ; words/s 17795 ; accuracy train : 79.796875
345536 ; loss 0.5 ; sentence/s 335 ; words/s 17846 ; accuracy train : 79.80034637451172
351936 ; loss 0.5 ; sentence/s 337 ; words/s 17656 ; accuracy train : 79.81108093261719
358336 ; loss 0.52 ; sentence/s 335 ; words/s 17847 ; accuracy train : 79.80078125
364736 ; loss 0.51 ; sentence/s 341 ; words/s 17568 ; accuracy train : 79.79988861083984
371136 ; loss 0.5 ; sentence/s 340 ; words/s 17557 ; accuracy train : 79.80899810791016
377536 ; loss 0.51 ; sentence/s 335 ; words/s 17846 ; accuracy train : 79.7971420288086
383936 ; loss 0.49 ; sentence/s 335 ; words/s 17936 ; accuracy train : 79.8031234741211
390336 ; loss 0.51 ; sentence/s 337 ; words/s 17379 ; accuracy train : 79.79585266113281
396736 ; loss 0.51 ; sentence/s 337 ; wor

371136 ; loss 0.51 ; sentence/s 334 ; words/s 17785 ; accuracy train : 79.94773864746094
377536 ; loss 0.5 ; sentence/s 335 ; words/s 17696 ; accuracy train : 79.94808959960938
383936 ; loss 0.49 ; sentence/s 337 ; words/s 17565 ; accuracy train : 79.96562194824219
390336 ; loss 0.49 ; sentence/s 335 ; words/s 17759 ; accuracy train : 79.97976684570312
396736 ; loss 0.49 ; sentence/s 341 ; words/s 17332 ; accuracy train : 79.98815155029297
403136 ; loss 0.5 ; sentence/s 340 ; words/s 17524 ; accuracy train : 79.97866821289062
409536 ; loss 0.51 ; sentence/s 344 ; words/s 17355 ; accuracy train : 79.977294921875
415936 ; loss 0.51 ; sentence/s 335 ; words/s 17910 ; accuracy train : 79.97595977783203
422336 ; loss 0.49 ; sentence/s 337 ; words/s 17695 ; accuracy train : 79.98413848876953
428736 ; loss 0.5 ; sentence/s 335 ; words/s 17957 ; accuracy train : 79.97457885742188
435136 ; loss 0.5 ; sentence/s 334 ; words/s 17730 ; accuracy train : 79.97702026367188
441536 ; loss 0.52 ; senten

409536 ; loss 0.5 ; sentence/s 341 ; words/s 17493 ; accuracy train : 80.1064453125
415936 ; loss 0.49 ; sentence/s 335 ; words/s 17669 ; accuracy train : 80.11129760742188
422336 ; loss 0.5 ; sentence/s 338 ; words/s 17547 ; accuracy train : 80.109375
428736 ; loss 0.5 ; sentence/s 339 ; words/s 17570 ; accuracy train : 80.11566925048828
435136 ; loss 0.5 ; sentence/s 338 ; words/s 17511 ; accuracy train : 80.11764526367188
441536 ; loss 0.5 ; sentence/s 339 ; words/s 17295 ; accuracy train : 80.11164093017578
447936 ; loss 0.5 ; sentence/s 336 ; words/s 17761 ; accuracy train : 80.11473083496094
454336 ; loss 0.49 ; sentence/s 336 ; words/s 17764 ; accuracy train : 80.12698364257812
460736 ; loss 0.49 ; sentence/s 339 ; words/s 17652 ; accuracy train : 80.12890625
467136 ; loss 0.51 ; sentence/s 336 ; words/s 17554 ; accuracy train : 80.13291931152344
473536 ; loss 0.51 ; sentence/s 341 ; words/s 17388 ; accuracy train : 80.1233139038086
479936 ; loss 0.49 ; sentence/s 332 ; words/s 

447936 ; loss 0.51 ; sentence/s 332 ; words/s 18134 ; accuracy train : 80.17388153076172
454336 ; loss 0.51 ; sentence/s 338 ; words/s 17648 ; accuracy train : 80.16527557373047
460736 ; loss 0.5 ; sentence/s 336 ; words/s 17859 ; accuracy train : 80.15950775146484
467136 ; loss 0.49 ; sentence/s 336 ; words/s 17737 ; accuracy train : 80.16288757324219
473536 ; loss 0.5 ; sentence/s 337 ; words/s 17755 ; accuracy train : 80.16216278076172
479936 ; loss 0.5 ; sentence/s 338 ; words/s 17863 ; accuracy train : 80.16666412353516
486336 ; loss 0.48 ; sentence/s 337 ; words/s 17521 ; accuracy train : 80.1772232055664
492736 ; loss 0.5 ; sentence/s 340 ; words/s 17427 ; accuracy train : 80.17005157470703
499136 ; loss 0.49 ; sentence/s 339 ; words/s 17411 ; accuracy train : 80.1728744506836
505536 ; loss 0.49 ; sentence/s 340 ; words/s 17507 ; accuracy train : 80.17701721191406
511936 ; loss 0.48 ; sentence/s 336 ; words/s 17860 ; accuracy train : 80.1861343383789
518336 ; loss 0.51 ; sentenc

486336 ; loss 0.49 ; sentence/s 337 ; words/s 17642 ; accuracy train : 80.32154846191406
492736 ; loss 0.5 ; sentence/s 341 ; words/s 17487 ; accuracy train : 80.31919860839844
499136 ; loss 0.5 ; sentence/s 336 ; words/s 17910 ; accuracy train : 80.30689239501953
505536 ; loss 0.49 ; sentence/s 334 ; words/s 17714 ; accuracy train : 80.31092071533203
511936 ; loss 0.49 ; sentence/s 340 ; words/s 17485 ; accuracy train : 80.3115234375
518336 ; loss 0.49 ; sentence/s 333 ; words/s 17740 ; accuracy train : 80.31192016601562
524736 ; loss 0.49 ; sentence/s 339 ; words/s 17663 ; accuracy train : 80.31288146972656
531136 ; loss 0.48 ; sentence/s 334 ; words/s 17716 ; accuracy train : 80.32003021240234
537536 ; loss 0.5 ; sentence/s 338 ; words/s 17493 ; accuracy train : 80.32328796386719
543936 ; loss 0.49 ; sentence/s 338 ; words/s 17701 ; accuracy train : 80.3218765258789
results : epoch 13 ; mean accuracy train : 80.3182601928711

VALIDATION : Epoch 13
togrep : results : epoch 13 ; mean 

524736 ; loss 0.49 ; sentence/s 334 ; words/s 17754 ; accuracy train : 80.4209213256836
531136 ; loss 0.48 ; sentence/s 335 ; words/s 17623 ; accuracy train : 80.41679382324219
537536 ; loss 0.48 ; sentence/s 335 ; words/s 17700 ; accuracy train : 80.42764282226562
543936 ; loss 0.5 ; sentence/s 331 ; words/s 18132 ; accuracy train : 80.42518615722656
results : epoch 14 ; mean accuracy train : 80.42947387695312

VALIDATION : Epoch 14
togrep : results : epoch 14 ; mean accuracy valid :              77.92115783691406
saving model at epoch 14

TRAINING : Epoch 15
Learning rate : 0.0034749832510759127
6336 ; loss 0.49 ; sentence/s 334 ; words/s 17568 ; accuracy train : 80.234375
12736 ; loss 0.51 ; sentence/s 334 ; words/s 17635 ; accuracy train : 79.703125
19136 ; loss 0.49 ; sentence/s 336 ; words/s 17667 ; accuracy train : 80.08333587646484
25536 ; loss 0.5 ; sentence/s 338 ; words/s 17562 ; accuracy train : 80.08203125
31936 ; loss 0.5 ; sentence/s 332 ; words/s 17815 ; accuracy train 


TRAINING : Epoch 16
Learning rate : 0.0034402334185651535
6336 ; loss 0.49 ; sentence/s 334 ; words/s 17517 ; accuracy train : 80.390625
12736 ; loss 0.49 ; sentence/s 335 ; words/s 17485 ; accuracy train : 80.609375
19136 ; loss 0.49 ; sentence/s 337 ; words/s 17784 ; accuracy train : 80.515625
25536 ; loss 0.5 ; sentence/s 338 ; words/s 17611 ; accuracy train : 80.46484375
31936 ; loss 0.48 ; sentence/s 342 ; words/s 17423 ; accuracy train : 80.56562805175781
38336 ; loss 0.49 ; sentence/s 337 ; words/s 17440 ; accuracy train : 80.57291412353516
44736 ; loss 0.49 ; sentence/s 338 ; words/s 17672 ; accuracy train : 80.46205139160156
51136 ; loss 0.49 ; sentence/s 336 ; words/s 17469 ; accuracy train : 80.470703125
57536 ; loss 0.49 ; sentence/s 339 ; words/s 17521 ; accuracy train : 80.42361450195312
63936 ; loss 0.49 ; sentence/s 338 ; words/s 17518 ; accuracy train : 80.47969055175781
70336 ; loss 0.48 ; sentence/s 335 ; words/s 17755 ; accuracy train : 80.515625
76736 ; loss 0.49 

44736 ; loss 0.47 ; sentence/s 338 ; words/s 17612 ; accuracy train : 80.31696319580078
51136 ; loss 0.49 ; sentence/s 338 ; words/s 17386 ; accuracy train : 80.333984375
57536 ; loss 0.5 ; sentence/s 332 ; words/s 17710 ; accuracy train : 80.22048950195312
63936 ; loss 0.48 ; sentence/s 329 ; words/s 17756 ; accuracy train : 80.265625
70336 ; loss 0.5 ; sentence/s 334 ; words/s 17671 ; accuracy train : 80.20454406738281
76736 ; loss 0.48 ; sentence/s 331 ; words/s 17658 ; accuracy train : 80.27083587646484
83136 ; loss 0.49 ; sentence/s 332 ; words/s 17660 ; accuracy train : 80.28966522216797
89536 ; loss 0.48 ; sentence/s 336 ; words/s 17780 ; accuracy train : 80.34598541259766
95936 ; loss 0.47 ; sentence/s 335 ; words/s 17717 ; accuracy train : 80.38541412353516
102336 ; loss 0.48 ; sentence/s 330 ; words/s 17807 ; accuracy train : 80.404296875
108736 ; loss 0.5 ; sentence/s 332 ; words/s 17389 ; accuracy train : 80.37224578857422
115136 ; loss 0.49 ; sentence/s 330 ; words/s 16894

83136 ; loss 0.49 ; sentence/s 337 ; words/s 17591 ; accuracy train : 80.34615325927734
89536 ; loss 0.5 ; sentence/s 332 ; words/s 17795 ; accuracy train : 80.33258819580078
95936 ; loss 0.49 ; sentence/s 332 ; words/s 17601 ; accuracy train : 80.35520935058594
102336 ; loss 0.5 ; sentence/s 338 ; words/s 17566 ; accuracy train : 80.33203125
108736 ; loss 0.49 ; sentence/s 336 ; words/s 17676 ; accuracy train : 80.33547973632812
115136 ; loss 0.48 ; sentence/s 333 ; words/s 17434 ; accuracy train : 80.3671875
121536 ; loss 0.49 ; sentence/s 334 ; words/s 17673 ; accuracy train : 80.37006378173828
127936 ; loss 0.5 ; sentence/s 335 ; words/s 17704 ; accuracy train : 80.35469055175781
134336 ; loss 0.48 ; sentence/s 339 ; words/s 17242 ; accuracy train : 80.41294860839844
140736 ; loss 0.51 ; sentence/s 335 ; words/s 17453 ; accuracy train : 80.3671875
147136 ; loss 0.49 ; sentence/s 339 ; words/s 17513 ; accuracy train : 80.35869598388672
153536 ; loss 0.49 ; sentence/s 333 ; words/s 1

121536 ; loss 0.48 ; sentence/s 324 ; words/s 17301 ; accuracy train : 80.76973724365234
127936 ; loss 0.49 ; sentence/s 324 ; words/s 17050 ; accuracy train : 80.75077819824219
134336 ; loss 0.49 ; sentence/s 334 ; words/s 17390 ; accuracy train : 80.73958587646484
140736 ; loss 0.49 ; sentence/s 330 ; words/s 17221 ; accuracy train : 80.72372436523438
147136 ; loss 0.5 ; sentence/s 336 ; words/s 17367 ; accuracy train : 80.69837188720703
153536 ; loss 0.49 ; sentence/s 334 ; words/s 17599 ; accuracy train : 80.69075775146484
159936 ; loss 0.49 ; sentence/s 332 ; words/s 17410 ; accuracy train : 80.69437408447266
166336 ; loss 0.49 ; sentence/s 335 ; words/s 17722 ; accuracy train : 80.70673370361328
172736 ; loss 0.49 ; sentence/s 329 ; words/s 17066 ; accuracy train : 80.69791412353516
179136 ; loss 0.48 ; sentence/s 321 ; words/s 16763 ; accuracy train : 80.72098541259766
185536 ; loss 0.49 ; sentence/s 335 ; words/s 17724 ; accuracy train : 80.70097351074219
191936 ; loss 0.47 ; s

159936 ; loss 0.49 ; sentence/s 341 ; words/s 17452 ; accuracy train : 80.53312683105469
166336 ; loss 0.49 ; sentence/s 340 ; words/s 17671 ; accuracy train : 80.52884674072266
172736 ; loss 0.48 ; sentence/s 338 ; words/s 17620 ; accuracy train : 80.54513549804688
179136 ; loss 0.49 ; sentence/s 339 ; words/s 17538 ; accuracy train : 80.55245208740234
185536 ; loss 0.48 ; sentence/s 338 ; words/s 17672 ; accuracy train : 80.5792007446289
191936 ; loss 0.49 ; sentence/s 337 ; words/s 17654 ; accuracy train : 80.55677032470703
198336 ; loss 0.48 ; sentence/s 339 ; words/s 17540 ; accuracy train : 80.5826644897461
204736 ; loss 0.48 ; sentence/s 333 ; words/s 18064 ; accuracy train : 80.5927734375
211136 ; loss 0.49 ; sentence/s 337 ; words/s 17673 ; accuracy train : 80.59091186523438
217536 ; loss 0.49 ; sentence/s 335 ; words/s 17924 ; accuracy train : 80.59788513183594
223936 ; loss 0.49 ; sentence/s 336 ; words/s 17740 ; accuracy train : 80.60044860839844
230336 ; loss 0.49 ; senten

198336 ; loss 0.48 ; sentence/s 328 ; words/s 16850 ; accuracy train : 80.52368927001953
204736 ; loss 0.49 ; sentence/s 324 ; words/s 17188 ; accuracy train : 80.53369140625
211136 ; loss 0.51 ; sentence/s 318 ; words/s 16597 ; accuracy train : 80.50757598876953
217536 ; loss 0.49 ; sentence/s 324 ; words/s 17184 ; accuracy train : 80.50827026367188
223936 ; loss 0.49 ; sentence/s 327 ; words/s 17256 ; accuracy train : 80.51473236083984
230336 ; loss 0.49 ; sentence/s 331 ; words/s 17190 ; accuracy train : 80.50347137451172
236736 ; loss 0.5 ; sentence/s 324 ; words/s 16483 ; accuracy train : 80.48690795898438
243136 ; loss 0.5 ; sentence/s 326 ; words/s 17009 ; accuracy train : 80.48026275634766
249536 ; loss 0.49 ; sentence/s 325 ; words/s 17076 ; accuracy train : 80.47996520996094
255936 ; loss 0.49 ; sentence/s 319 ; words/s 16535 ; accuracy train : 80.486328125
262336 ; loss 0.48 ; sentence/s 315 ; words/s 16348 ; accuracy train : 80.48704528808594
268736 ; loss 0.48 ; sentence/s

In [39]:
# Run best model on test set.
nli_net.load_state_dict(torch.load(os.path.join(params.outputdir, params.outputmodelname)))

print('\nTEST : Epoch {0}'.format(epoch))
evaluate(1e6, 'valid', True)
evaluate(0, 'test', True)

# Save encoder instead of full model
torch.save(nli_net.encoder.state_dict(), os.path.join(params.outputdir, params.outputmodelname + '.encoder.pkl'))


TEST : Epoch 22

VALIDATION : Epoch 1000000.0
finalgrep : accuracy valid : 78.03292083740234
finalgrep : accuracy test : 78.7866439819336
