<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
# Imports
import gzip
import os
import pandas as pd
import numpy as np
import torch
import basic_model
import random
from collections import Counter
from tqdm import tqdm
import ujson as json
from torch.autograd import Variable
import util
import args
import pprint

args = args.get_setup_args()
vars(args)

{'vocab_path': '/home/mouadh/Desktop/insuranceQA/V2/vocabulary',
 'label_path': '/home/mouadh/Desktop/insuranceQA/V2/InsuranceQA.label2answer.token.encoded',
 'train_path': '/home/mouadh/Desktop/insuranceQA/V2/InsuranceQA.question.anslabel.token.500.pool.solr.train.encoded',
 'test_path': '/home/mouadh/Desktop/insuranceQA/V2/InsuranceQA.question.anslabel.token.500.pool.solr.test.encoded',
 'glove_path': '/home/mouadh/Desktop/insuranceQA/glove.840B.300d/glove.840B.300d.txt',
 'glove_dim': 300,
 'glove_num_vecs': 2196017,
 'hidden_size': 100,
 'word_emb_file': '/home/mouadh/Desktop/insuranceQA/glove.840B.300d/word_embedding',
 'word2idx_file': '/home/mouadh/Desktop/insuranceQA/glove.840B.300d/word2idx',
 'seed': 0,
 'create_matrix_embedding': True,
 'lr': 0.5,
 'num_epochs': 30,
 'drop_prob': 0.2,
 'margin': 0.2,
 'batch_size': 128,
 'use_glove': False,
 'embd_size': 200}

In [2]:
# Fix the random seed to have reprodusable results
random.seed(args.seed)
np.random.seed(args.seed)
#torch.manual_seed(args.seed)
#torch.cuda.manual_seed_all(args.seed)

In [3]:
# Load answers labels and answer label with text
id2w, l2a, l2at = util.load_vocabulary(args.vocab_path, args.label_path)

# Create word to index vocabulary
w2i = {w: i for i, w in enumerate(id2w.values(), 1)}
# Add pad to the vocabulary
PAD = '<PAD>'
w2i[PAD] = 0

vocab_size = len(w2i)
args.vocab_size = vocab_size
print('vocab_size:', vocab_size)

vocab_size: 68581


In [4]:
train = util.load_data_train(args.train_path, id2w, l2at)

test = util.load_data_test(args.test_path, id2w, l2at)

In [5]:
print('n_train:', len(train))
print('n_test:', len(test))

n_train: 21325
n_test: 1988


In [6]:
if args.use_glove:
    if not args.create_matrix_embedding:
        # Create word embedding and word2idx for glove
        print("Creating word embedding")
        word_emb_mat, word2idx_dict = util.get_embedding('word', emb_file=args.glove_path,
                                        vec_size=args.glove_dim, num_vectors=args.glove_num_vecs)
        util.save(args.word_emb_file, word_emb_mat, message="word embedding")
        util.save(args.word2idx_file, word2idx_dict, message="word dictionary")
    else:
        # Get embeddings
        print('Loading word vectors embeddings...')
        word_vectors = util.torch_from_json(args.word_emb_file)
        print("Loading word to index dictionary")
        word2idx_dict = json.load(open(args.word2idx_file))
else:
    word2idx_dict = w2i

In [7]:
train[0][0]

['Is', 'Disability', 'Insurance', 'Required', 'By', 'Law', '?']

In [8]:
word2idx_dict['?']

50370

In [9]:
util.make_vector([train[0][0]], w2i, 20)

tensor([[66164, 54421, 29876, 56902, 59631, 57715, 50370,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
       device='cuda:0')

In [13]:
def train_model(model, data, optimizer, word2idx_dict, args, batch_size = 128):
    for epoch in range(5): #args.num_epoch:
        model.train()
        random.shuffle(data)
        losses = []
        for i, d in enumerate(tqdm(data)):
            q, pos, negs = d[0], d[1], d[2]
            #pos = process_words(pos, word2idx_dict)
            #q = process_words(q, word2idx_dict)
            
            vec_q = util.make_vector([q], word2idx_dict, len(q))
            vec_pos = util.make_vector([pos], word2idx_dict, len(pos))
            pos_sim = model(vec_q, vec_pos)

            for _ in range(100):
                neg = random.choice(negs)
                #neg = process_words(neg, word2idx_dict)
                vec_neg = util.make_vector([neg], word2idx_dict, len(neg))
                neg_sim = model(vec_q, vec_neg)
                #print('positive similarity {}, negative similarity{}'.format(pos_sim, neg_sim))
                loss = util.loss_fn(pos_sim, neg_sim, args.margin)
                if loss.data[0] != 0:
                    losses.append(loss)
                    break

            if len(losses) == batch_size or i == len(data) - 1:
                loss = torch.mean(torch.stack(losses, 0).squeeze(), 0)
                #print(loss)
                print("epoch: {}, iteration: {}, loss: {}".format(epoch, i, loss.item()))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses = []
    return model

In [14]:
def lowercase(tokens):
    return [token.lower() for token in tokens]

In [15]:
from nltk.tokenize import RegexpTokenizer
tokenizer_punct = RegexpTokenizer(r'\w+')
import itertools

def remove_special_chars(tokens):
    return list(itertools.chain(*[tokenizer_punct.tokenize(token) for token in tokens]))


In [16]:
def remove_words_not_glove(tokens, word2idx_dict):
    filteredlist = []
    for token in tokens:
        try: 
            word2idx_dict[token]
            filteredlist.append(token)
        except:
            pass
    return filteredlist

In [17]:
remove_words_not_glove(['websitelink', 'hello'], word2idx_dict)

[]

In [18]:
def process_words(tokens, word2idx_dict):
    tokens = lowercase(tokens)
    tokens = remove_special_chars(tokens)
    tokens = remove_words_not_glove(tokens, word2idx_dict)
    return tokens

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

if args.use_glove:
    model = basic_model.QA_LSTM(args , word_vectors)
else: 
    model = basic_model.QA_LSTM(args)

if torch.cuda.is_available():
    model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))

In [None]:
model = train_model(model, train, optimizer, word2idx_dict, args, args.batch_size)

  0%|          | 5/21325 [00:00<09:22, 37.90it/s]

epoch: 0


  1%|          | 123/21325 [00:02<08:39, 40.82it/s]

epoch: 0, iteration: 127, loss: 0.2012072503566742


  1%|          | 248/21325 [00:10<06:43, 52.18it/s]  

epoch: 0, iteration: 255, loss: 0.2000810205936432


  2%|▏         | 380/21325 [00:21<08:07, 42.97it/s]  

epoch: 0, iteration: 383, loss: 0.19875001907348633


  2%|▏         | 509/21325 [00:29<06:53, 50.37it/s]  

epoch: 0, iteration: 511, loss: 0.20075929164886475


  3%|▎         | 638/21325 [00:38<06:48, 50.61it/s]  

epoch: 0, iteration: 639, loss: 0.2011936902999878


  4%|▎         | 765/21325 [00:46<06:16, 54.60it/s]  

epoch: 0, iteration: 767, loss: 0.19968342781066895


  4%|▍         | 892/21325 [00:55<07:17, 46.75it/s]  

epoch: 0, iteration: 895, loss: 0.19879993796348572


  5%|▍         | 1021/21325 [01:04<11:24, 29.64it/s] 

epoch: 0, iteration: 1023, loss: 0.19846376776695251


  5%|▌         | 1150/21325 [01:13<07:30, 44.80it/s]  

epoch: 0, iteration: 1151, loss: 0.19810694456100464


  6%|▌         | 1277/21325 [01:22<06:57, 48.07it/s]  

epoch: 0, iteration: 1279, loss: 0.19784626364707947


  7%|▋         | 1403/21325 [01:31<07:58, 41.66it/s]  

epoch: 0, iteration: 1407, loss: 0.19759759306907654


  7%|▋         | 1533/21325 [01:40<06:11, 53.34it/s]  

epoch: 0, iteration: 1535, loss: 0.19568777084350586


  8%|▊         | 1661/21325 [01:49<08:14, 39.79it/s]  

epoch: 0, iteration: 1663, loss: 0.19495365023612976


  8%|▊         | 1790/21325 [01:59<09:08, 35.61it/s]  

epoch: 0, iteration: 1791, loss: 0.19645363092422485


  9%|▉         | 1917/21325 [02:13<11:51, 27.27it/s]  

epoch: 0, iteration: 1919, loss: 0.19647037982940674


 10%|▉         | 2046/21325 [02:23<20:42, 15.52it/s]  

epoch: 0, iteration: 2047, loss: 0.19322988390922546


 10%|█         | 2173/21325 [02:33<06:16, 50.88it/s]  

epoch: 0, iteration: 2175, loss: 0.19280970096588135


 11%|█         | 2301/21325 [02:41<04:10, 75.92it/s]  

epoch: 0, iteration: 2303, loss: 0.19256578385829926


 11%|█▏        | 2431/21325 [02:48<04:35, 68.52it/s]  

epoch: 0, iteration: 2431, loss: 0.1928739696741104


 12%|█▏        | 2551/21325 [02:55<04:55, 63.45it/s]  

epoch: 0, iteration: 2559, loss: 0.19376611709594727


 13%|█▎        | 2683/21325 [03:03<04:27, 69.59it/s]

epoch: 0, iteration: 2687, loss: 0.19745276868343353


 13%|█▎        | 2809/21325 [03:10<05:26, 56.65it/s]  

epoch: 0, iteration: 2815, loss: 0.19025999307632446


 14%|█▍        | 2940/21325 [03:17<04:35, 66.76it/s]  

epoch: 0, iteration: 2943, loss: 0.1908089965581894


 14%|█▍        | 3070/21325 [03:24<04:08, 73.42it/s]  

epoch: 0, iteration: 3071, loss: 0.1879393756389618


 15%|█▍        | 3196/21325 [03:32<04:30, 67.09it/s]

epoch: 0, iteration: 3199, loss: 0.18615120649337769


 16%|█▌        | 3320/21325 [03:39<04:31, 66.40it/s]  

epoch: 0, iteration: 3327, loss: 0.18874706327915192


 16%|█▌        | 3451/21325 [03:48<09:53, 30.14it/s]  

epoch: 0, iteration: 3455, loss: 0.1871625781059265


 17%|█▋        | 3578/21325 [04:01<07:18, 40.45it/s]  

epoch: 0, iteration: 3583, loss: 0.18684950470924377


 17%|█▋        | 3710/21325 [04:10<07:54, 37.11it/s]  

epoch: 0, iteration: 3711, loss: 0.18071693181991577


 18%|█▊        | 3838/21325 [04:21<07:08, 40.82it/s]  

epoch: 0, iteration: 3839, loss: 0.1929149180650711


 19%|█▊        | 3964/21325 [04:33<04:32, 63.82it/s]  

epoch: 0, iteration: 3967, loss: 0.18541359901428223


 19%|█▉        | 4095/21325 [04:41<05:16, 54.37it/s]  

epoch: 0, iteration: 4095, loss: 0.18711261451244354


 20%|█▉        | 4218/21325 [04:48<04:48, 59.25it/s]  

epoch: 0, iteration: 4223, loss: 0.1845526397228241


 20%|██        | 4345/21325 [04:56<03:58, 71.13it/s]  

epoch: 0, iteration: 4351, loss: 0.1872325837612152


 21%|██        | 4474/21325 [05:03<04:20, 64.77it/s]

epoch: 0, iteration: 4479, loss: 0.18194204568862915


 22%|██▏       | 4603/21325 [05:14<07:26, 37.44it/s]  

epoch: 0, iteration: 4607, loss: 0.1838524043560028


 22%|██▏       | 4731/21325 [05:27<07:57, 34.77it/s]  

epoch: 0, iteration: 4735, loss: 0.18083280324935913


 23%|██▎       | 4862/21325 [05:38<06:12, 44.22it/s]  

epoch: 0, iteration: 4863, loss: 0.18122732639312744


 23%|██▎       | 4988/21325 [05:47<06:05, 44.64it/s]  

epoch: 0, iteration: 4991, loss: 0.17276552319526672


 24%|██▍       | 5118/21325 [05:55<05:08, 52.56it/s]  

epoch: 0, iteration: 5119, loss: 0.17853882908821106


 25%|██▍       | 5245/21325 [06:04<06:49, 39.26it/s]  

epoch: 0, iteration: 5247, loss: 0.1708948016166687


 25%|██▌       | 5373/21325 [06:13<05:53, 45.09it/s]  

epoch: 0, iteration: 5375, loss: 0.17609049379825592


 26%|██▌       | 5499/21325 [06:22<05:21, 49.25it/s]  

epoch: 0, iteration: 5503, loss: 0.1810987889766693


 26%|██▋       | 5626/21325 [06:30<05:02, 51.81it/s]  

epoch: 0, iteration: 5631, loss: 0.17956969141960144


 27%|██▋       | 5758/21325 [06:39<04:38, 55.84it/s]  

epoch: 0, iteration: 5759, loss: 0.1694107949733734


 28%|██▊       | 5884/21325 [06:52<07:37, 33.73it/s]  

epoch: 0, iteration: 5887, loss: 0.17172688245773315


 28%|██▊       | 6013/21325 [07:04<08:49, 28.89it/s]  

epoch: 0, iteration: 6015, loss: 0.16810454428195953


 29%|██▉       | 6139/21325 [07:15<09:37, 26.29it/s]  

epoch: 0, iteration: 6143, loss: 0.1720658540725708


 29%|██▉       | 6268/21325 [07:24<04:35, 54.70it/s]  

epoch: 0, iteration: 6271, loss: 0.15972217917442322


 30%|███       | 6398/21325 [07:34<04:46, 52.09it/s]  

epoch: 0, iteration: 6399, loss: 0.1691184639930725


 31%|███       | 6525/21325 [07:43<03:42, 66.49it/s]  

epoch: 0, iteration: 6527, loss: 0.1571468561887741


 31%|███       | 6653/21325 [07:58<14:23, 16.99it/s]  

epoch: 0, iteration: 6655, loss: 0.16121557354927063


 32%|███▏      | 6780/21325 [08:08<09:05, 26.67it/s]  

epoch: 0, iteration: 6783, loss: 0.16781193017959595


 32%|███▏      | 6900/21325 [08:20<06:01, 39.95it/s]  

epoch: 0, iteration: 6911, loss: 0.16341255605220795


 33%|███▎      | 7039/21325 [08:30<04:13, 56.38it/s]

epoch: 0, iteration: 7039, loss: 0.16670063138008118


 34%|███▎      | 7165/21325 [08:39<04:05, 57.65it/s]  

epoch: 0, iteration: 7167, loss: 0.15693403780460358


 34%|███▍      | 7292/21325 [08:51<04:09, 56.18it/s]  

epoch: 0, iteration: 7295, loss: 0.1582498699426651


 35%|███▍      | 7416/21325 [09:01<03:46, 61.42it/s]  

epoch: 0, iteration: 7423, loss: 0.15809786319732666


 35%|███▌      | 7541/21325 [09:09<03:35, 63.96it/s]

epoch: 0, iteration: 7551, loss: 0.16954074800014496


 36%|███▌      | 7674/21325 [09:18<05:28, 41.58it/s]

epoch: 0, iteration: 7679, loss: 0.15418483316898346


 37%|███▋      | 7806/21325 [09:30<07:18, 30.84it/s]  

epoch: 0, iteration: 7807, loss: 0.1516759842634201


 37%|███▋      | 7929/21325 [09:40<04:06, 54.39it/s]  

epoch: 0, iteration: 7935, loss: 0.15966133773326874


 38%|███▊      | 8063/21325 [09:49<06:16, 35.18it/s]  

epoch: 0, iteration: 8063, loss: 0.16352763772010803


 38%|███▊      | 8190/21325 [10:00<06:07, 35.76it/s]  

epoch: 0, iteration: 8191, loss: 0.15030619502067566


 39%|███▉      | 8319/21325 [10:11<03:34, 60.74it/s]  

epoch: 0, iteration: 8319, loss: 0.1669164001941681


 40%|███▉      | 8446/21325 [10:24<04:11, 51.22it/s]  

epoch: 0, iteration: 8447, loss: 0.15549775958061218


 40%|████      | 8571/21325 [10:34<03:38, 58.39it/s]  

epoch: 0, iteration: 8575, loss: 0.15183496475219727


 41%|████      | 8697/21325 [10:42<03:58, 52.99it/s]

epoch: 0, iteration: 8703, loss: 0.14576922357082367


 41%|████▏     | 8826/21325 [10:52<03:16, 63.47it/s]  

epoch: 0, iteration: 8831, loss: 0.16736942529678345


 42%|████▏     | 8956/21325 [10:59<04:06, 50.12it/s]

epoch: 0, iteration: 8959, loss: 0.14317923784255981


 43%|████▎     | 9080/21325 [11:08<04:50, 42.21it/s]

epoch: 0, iteration: 9087, loss: 0.1492835134267807


 43%|████▎     | 9215/21325 [11:19<04:56, 40.81it/s]  

epoch: 0, iteration: 9215, loss: 0.1722930669784546


 44%|████▍     | 9337/21325 [11:28<04:13, 47.34it/s]  

epoch: 0, iteration: 9343, loss: 0.16198782622814178


 44%|████▍     | 9471/21325 [11:40<08:26, 23.40it/s]

epoch: 0, iteration: 9471, loss: 0.14880619943141937


 45%|████▌     | 9597/21325 [11:53<03:45, 52.00it/s]  

epoch: 0, iteration: 9599, loss: 0.1450852006673813


 46%|████▌     | 9725/21325 [12:06<03:54, 49.49it/s]  

epoch: 0, iteration: 9727, loss: 0.14793545007705688


 46%|████▌     | 9851/21325 [12:17<03:42, 51.67it/s]  

epoch: 0, iteration: 9855, loss: 0.17477047443389893


 47%|████▋     | 9981/21325 [12:28<04:32, 41.61it/s]  

epoch: 0, iteration: 9983, loss: 0.15620341897010803


 47%|████▋     | 10110/21325 [12:40<05:40, 32.92it/s] 

epoch: 0, iteration: 10111, loss: 0.1527608036994934


 48%|████▊     | 10233/21325 [12:48<03:18, 55.76it/s]  

epoch: 0, iteration: 10239, loss: 0.14311274886131287


 49%|████▊     | 10367/21325 [13:00<04:40, 39.04it/s]

epoch: 0, iteration: 10367, loss: 0.15745875239372253


 49%|████▉     | 10494/21325 [13:11<31:39,  5.70it/s]  

epoch: 0, iteration: 10495, loss: 0.14267197251319885


 50%|████▉     | 10621/21325 [13:23<05:22, 33.23it/s]  

epoch: 0, iteration: 10623, loss: 0.16037073731422424


 50%|█████     | 10750/21325 [13:34<06:29, 27.17it/s]  

epoch: 0, iteration: 10751, loss: 0.1759561449289322


 51%|█████     | 10879/21325 [13:43<04:29, 38.75it/s]  

epoch: 0, iteration: 10879, loss: 0.15533453226089478


 52%|█████▏    | 11006/21325 [13:56<1:03:04,  2.73it/s]

epoch: 0, iteration: 11007, loss: 0.1415022313594818


 52%|█████▏    | 11134/21325 [14:07<16:07, 10.54it/s]  

epoch: 0, iteration: 11136, loss: 0.17490680515766144


 53%|█████▎    | 11263/21325 [14:19<04:43, 35.53it/s]  

epoch: 0, iteration: 11264, loss: 0.1533457487821579


 53%|█████▎    | 11391/21325 [14:29<08:57, 18.48it/s]  

epoch: 0, iteration: 11392, loss: 0.16045711934566498


 54%|█████▍    | 11519/21325 [14:38<07:01, 23.24it/s]  

epoch: 0, iteration: 11520, loss: 0.14855051040649414


 55%|█████▍    | 11644/21325 [14:50<06:25, 25.11it/s]  

epoch: 0, iteration: 11649, loss: 0.15698647499084473


 55%|█████▌    | 11774/21325 [14:59<08:10, 19.48it/s]

epoch: 0, iteration: 11777, loss: 0.15164732933044434


 56%|█████▌    | 11902/21325 [15:09<05:11, 30.29it/s]  

epoch: 0, iteration: 11905, loss: 0.15948575735092163


 56%|█████▋    | 12032/21325 [15:19<05:21, 28.94it/s]  

epoch: 0, iteration: 12033, loss: 0.16830116510391235


 57%|█████▋    | 12158/21325 [15:28<04:50, 31.54it/s]  

epoch: 0, iteration: 12161, loss: 0.17708277702331543


 58%|█████▊    | 12290/21325 [15:39<04:46, 31.53it/s]  

epoch: 0, iteration: 12290, loss: 0.1661912351846695


 58%|█████▊    | 12418/21325 [15:51<04:51, 30.55it/s]  

epoch: 0, iteration: 12418, loss: 0.1471431404352188


 59%|█████▉    | 12544/21325 [16:02<08:22, 17.47it/s]  

epoch: 0, iteration: 12546, loss: 0.16435948014259338


 59%|█████▉    | 12672/21325 [16:13<07:13, 19.94it/s]  

epoch: 0, iteration: 12675, loss: 0.15276385843753815


 60%|██████    | 12803/21325 [16:23<03:45, 37.86it/s]  

epoch: 0, iteration: 12803, loss: 0.17893019318580627


 61%|██████    | 12931/21325 [16:35<05:23, 25.91it/s]

epoch: 0, iteration: 12932, loss: 0.18505579233169556


 61%|██████    | 13059/21325 [16:46<05:55, 23.26it/s]

epoch: 0, iteration: 13060, loss: 0.18059584498405457


 62%|██████▏   | 13188/21325 [16:58<11:08, 12.17it/s]  

epoch: 0, iteration: 13188, loss: 0.17397820949554443


 62%|██████▏   | 13316/21325 [17:11<07:33, 17.66it/s]  

epoch: 0, iteration: 13316, loss: 0.17092549800872803


 63%|██████▎   | 13442/21325 [17:22<06:43, 19.53it/s]  

epoch: 0, iteration: 13444, loss: 0.17828486859798431


 64%|██████▎   | 13572/21325 [17:34<05:44, 22.51it/s]  

epoch: 0, iteration: 13572, loss: 0.17231665551662445


 64%|██████▍   | 13698/21325 [17:46<06:53, 18.43it/s]  

epoch: 0, iteration: 13700, loss: 0.15944911539554596


 65%|██████▍   | 13827/21325 [17:57<05:49, 21.48it/s]  

epoch: 0, iteration: 13828, loss: 0.15955571830272675


 65%|██████▌   | 13954/21325 [18:08<06:14, 19.66it/s]  

epoch: 0, iteration: 13956, loss: 0.16683098673820496


 66%|██████▌   | 14084/21325 [18:20<06:46, 17.83it/s]  

epoch: 0, iteration: 14084, loss: 0.16919487714767456


 67%|██████▋   | 14211/21325 [18:32<03:53, 30.43it/s]  

epoch: 0, iteration: 14213, loss: 0.1631467342376709


 67%|██████▋   | 14341/21325 [18:44<12:37,  9.22it/s]

epoch: 0, iteration: 14342, loss: 0.183099627494812


 68%|██████▊   | 14467/21325 [18:54<06:27, 17.68it/s]  

epoch: 0, iteration: 14470, loss: 0.17627869546413422


 68%|██████▊   | 14597/21325 [19:05<06:00, 18.65it/s]

epoch: 0, iteration: 14598, loss: 0.16173626482486725


 69%|██████▉   | 14725/21325 [19:17<03:55, 28.03it/s]  

epoch: 0, iteration: 14726, loss: 0.15956875681877136


 70%|██████▉   | 14851/21325 [19:29<04:29, 24.02it/s]

epoch: 0, iteration: 14854, loss: 0.17082250118255615


 70%|███████   | 14980/21325 [19:42<07:49, 13.53it/s]

epoch: 0, iteration: 14983, loss: 0.15628525614738464


 71%|███████   | 15112/21325 [19:56<05:08, 20.12it/s]

epoch: 0, iteration: 15112, loss: 0.1615576595067978


 71%|███████▏  | 15241/21325 [20:09<06:29, 15.63it/s]

epoch: 0, iteration: 15241, loss: 0.15705269575119019


 72%|███████▏  | 15368/21325 [20:20<04:49, 20.60it/s]

epoch: 0, iteration: 15369, loss: 0.156161367893219


 73%|███████▎  | 15497/21325 [20:35<05:32, 17.51it/s]

epoch: 0, iteration: 15499, loss: 0.1612587869167328


 73%|███████▎  | 15627/21325 [20:48<07:58, 11.91it/s]

epoch: 0, iteration: 15628, loss: 0.15297529101371765


 74%|███████▍  | 15756/21325 [21:02<05:25, 17.13it/s]  

epoch: 0, iteration: 15756, loss: 0.16282644867897034


 74%|███████▍  | 15884/21325 [21:15<05:42, 15.87it/s]

epoch: 0, iteration: 15885, loss: 0.18797796964645386


 75%|███████▌  | 16014/21325 [21:33<03:24, 25.91it/s]  

epoch: 0, iteration: 16015, loss: 0.17709052562713623


 76%|███████▌  | 16141/21325 [21:45<04:59, 17.28it/s]

epoch: 0, iteration: 16143, loss: 0.16371211409568787


 76%|███████▋  | 16267/21325 [21:56<03:09, 26.73it/s]

epoch: 0, iteration: 16271, loss: 0.16182726621627808


 77%|███████▋  | 16398/21325 [22:12<03:21, 24.49it/s]

epoch: 0, iteration: 16399, loss: 0.166375070810318


 77%|███████▋  | 16525/21325 [22:26<06:44, 11.88it/s]

epoch: 0, iteration: 16527, loss: 0.1633816510438919


 78%|███████▊  | 16656/21325 [22:44<08:40,  8.97it/s]

epoch: 0, iteration: 16657, loss: 0.15950796008110046


 79%|███████▊  | 16785/21325 [23:01<04:05, 18.46it/s]  

epoch: 0, iteration: 16785, loss: 0.1519383192062378


 79%|███████▉  | 16914/21325 [23:14<03:22, 21.77it/s]

epoch: 0, iteration: 16915, loss: 0.16352573037147522


 80%|███████▉  | 17041/21325 [23:30<04:29, 15.90it/s]

epoch: 0, iteration: 17043, loss: 0.18080267310142517


 81%|████████  | 17168/21325 [23:47<05:40, 12.21it/s]  

epoch: 0, iteration: 17171, loss: 0.15805253386497498


 81%|████████  | 17300/21325 [24:05<03:22, 19.92it/s]

epoch: 0, iteration: 17301, loss: 0.1760888397693634


 82%|████████▏ | 17428/21325 [24:23<07:27,  8.72it/s]

epoch: 0, iteration: 17429, loss: 0.15611067414283752


 82%|████████▏ | 17557/21325 [24:44<03:50, 16.35it/s]  

epoch: 0, iteration: 17557, loss: 0.14810851216316223


 83%|████████▎ | 17682/21325 [24:57<03:20, 18.21it/s]

epoch: 0, iteration: 17685, loss: 0.1764041930437088


 84%|████████▎ | 17812/21325 [25:13<03:56, 14.84it/s]

epoch: 0, iteration: 17814, loss: 0.1510748267173767


 84%|████████▍ | 17941/21325 [25:29<05:52,  9.61it/s]

epoch: 0, iteration: 17942, loss: 0.16151931881904602


 85%|████████▍ | 18067/21325 [25:45<05:10, 10.48it/s]

epoch: 0, iteration: 18070, loss: 0.16148698329925537


 85%|████████▌ | 18200/21325 [26:06<03:14, 16.10it/s]

epoch: 0, iteration: 18202, loss: 0.17120622098445892


 86%|████████▌ | 18329/21325 [26:31<05:10,  9.65it/s]

epoch: 0, iteration: 18330, loss: 0.153428316116333


 87%|████████▋ | 18458/21325 [26:57<03:20, 14.33it/s]  

epoch: 0, iteration: 18459, loss: 0.16315650939941406


 87%|████████▋ | 18585/21325 [27:10<03:04, 14.89it/s]

epoch: 0, iteration: 18588, loss: 0.1717318892478943


 88%|████████▊ | 18716/21325 [27:23<02:14, 19.43it/s]

epoch: 0, iteration: 18716, loss: 0.15580883622169495


 88%|████████▊ | 18842/21325 [27:33<02:01, 20.44it/s]

epoch: 0, iteration: 18844, loss: 0.16512353718280792


 89%|████████▉ | 18974/21325 [27:48<03:39, 10.72it/s]

epoch: 0, iteration: 18974, loss: 0.15951673686504364


 90%|████████▉ | 19102/21325 [28:02<03:50,  9.66it/s]

epoch: 0, iteration: 19102, loss: 0.15701395273208618


 90%|█████████ | 19227/21325 [28:26<01:45, 19.96it/s]  

epoch: 0, iteration: 19230, loss: 0.1663718968629837


 91%|█████████ | 19357/21325 [28:43<01:41, 19.37it/s]

epoch: 0, iteration: 19358, loss: 0.1604972779750824


 91%|█████████▏| 19487/21325 [29:01<03:00, 10.20it/s]

epoch: 0, iteration: 19488, loss: 0.15558096766471863


 92%|█████████▏| 19618/21325 [29:18<05:59,  4.74it/s]

epoch: 0, iteration: 19619, loss: 0.17657819390296936


 93%|█████████▎| 19748/21325 [29:35<02:23, 10.99it/s]

epoch: 0, iteration: 19748, loss: 0.15774089097976685


 93%|█████████▎| 19877/21325 [29:51<01:47, 13.49it/s]

epoch: 0, iteration: 19878, loss: 0.16291838884353638


 94%|█████████▍| 20005/21325 [30:06<01:21, 16.25it/s]

epoch: 0, iteration: 20006, loss: 0.14283840358257294


 94%|█████████▍| 20135/21325 [30:25<02:04,  9.52it/s]

epoch: 0, iteration: 20136, loss: 0.15081439912319183


 95%|█████████▌| 20264/21325 [30:40<00:50, 21.02it/s]

epoch: 0, iteration: 20265, loss: 0.17677050828933716


 96%|█████████▌| 20393/21325 [30:57<01:24, 11.08it/s]

epoch: 0, iteration: 20394, loss: 0.14365293085575104


 96%|█████████▌| 20520/21325 [31:11<01:17, 10.41it/s]

epoch: 0, iteration: 20522, loss: 0.15422503650188446


 97%|█████████▋| 20648/21325 [31:28<01:42,  6.58it/s]

epoch: 0, iteration: 20652, loss: 0.1633191853761673


 97%|█████████▋| 20780/21325 [31:48<01:40,  5.45it/s]

epoch: 0, iteration: 20781, loss: 0.1527581512928009


 98%|█████████▊| 20911/21325 [32:09<01:12,  5.68it/s]

epoch: 0, iteration: 20912, loss: 0.14265766739845276


 99%|█████████▊| 21037/21325 [32:24<00:17, 16.27it/s]

epoch: 0, iteration: 21040, loss: 0.15206706523895264


 99%|█████████▉| 21167/21325 [32:38<00:09, 15.86it/s]

epoch: 0, iteration: 21168, loss: 0.16853636503219604


100%|█████████▉| 21295/21325 [32:54<00:01, 18.50it/s]

epoch: 0, iteration: 21296, loss: 0.13979105651378632


100%|█████████▉| 21323/21325 [33:01<00:00, 11.50it/s]

epoch: 0, iteration: 21324, loss: 0.1524733006954193


100%|██████████| 21325/21325 [33:02<00:00, 10.76it/s]
  0%|          | 0/21325 [00:00<?, ?it/s]

epoch: 1


  1%|          | 128/21325 [00:15<2:03:04,  2.87it/s]

epoch: 1, iteration: 128, loss: 0.14892496168613434


  1%|          | 255/21325 [00:29<22:23, 15.68it/s]   

epoch: 1, iteration: 257, loss: 0.14850279688835144


  2%|▏         | 386/21325 [00:47<42:41,  8.18it/s]  

epoch: 1, iteration: 386, loss: 0.15067481994628906


  2%|▏         | 513/21325 [01:06<39:08,  8.86it/s]  

epoch: 1, iteration: 514, loss: 0.14312881231307983


  3%|▎         | 643/21325 [01:24<22:45, 15.14it/s]  

epoch: 1, iteration: 643, loss: 0.18210509419441223


  4%|▎         | 769/21325 [01:45<57:14,  5.99it/s]  

epoch: 1, iteration: 772, loss: 0.1561812162399292


  4%|▍         | 900/21325 [02:03<28:27, 11.96it/s]  

epoch: 1, iteration: 901, loss: 0.14478996396064758


  5%|▍         | 1029/21325 [02:18<19:08, 17.67it/s] 

epoch: 1, iteration: 1030, loss: 0.14097750186920166


  5%|▌         | 1158/21325 [02:31<33:39,  9.99it/s]  

epoch: 1, iteration: 1158, loss: 0.14455746114253998


  6%|▌         | 1289/21325 [02:51<43:41,  7.64it/s]  

epoch: 1, iteration: 1289, loss: 0.1603405624628067


  7%|▋         | 1418/21325 [03:09<1:27:18,  3.80it/s]

epoch: 1, iteration: 1418, loss: 0.1349734663963318


  7%|▋         | 1545/21325 [03:28<40:23,  8.16it/s]  

epoch: 1, iteration: 1547, loss: 0.16064220666885376


  8%|▊         | 1676/21325 [03:50<57:54,  5.66it/s]  

epoch: 1, iteration: 1679, loss: 0.16041840612888336


  8%|▊         | 1802/21325 [04:06<22:05, 14.73it/s]  

epoch: 1, iteration: 1809, loss: 0.150725319981575


  9%|▉         | 1938/21325 [04:28<27:49, 11.61it/s]  

epoch: 1, iteration: 1939, loss: 0.15507309138774872


 10%|▉         | 2065/21325 [04:39<20:04, 15.99it/s]  

epoch: 1, iteration: 2067, loss: 0.14425405859947205


 10%|█         | 2195/21325 [05:02<31:47, 10.03it/s]  

epoch: 1, iteration: 2198, loss: 0.15086215734481812


 11%|█         | 2326/21325 [05:21<23:53, 13.26it/s]  

epoch: 1, iteration: 2329, loss: 0.14407707750797272


 12%|█▏        | 2453/21325 [05:33<18:42, 16.81it/s]  

epoch: 1, iteration: 2457, loss: 0.17597004771232605


 12%|█▏        | 2586/21325 [05:51<27:17, 11.44it/s]  

epoch: 1, iteration: 2588, loss: 0.1391095668077469


 13%|█▎        | 2713/21325 [06:03<16:03, 19.32it/s]  

epoch: 1, iteration: 2716, loss: 0.14863985776901245


 13%|█▎        | 2843/21325 [06:16<26:17, 11.72it/s]  

epoch: 1, iteration: 2844, loss: 0.161837637424469


 14%|█▍        | 2976/21325 [06:33<31:10,  9.81it/s]  

epoch: 1, iteration: 2976, loss: 0.16759231686592102


 15%|█▍        | 3107/21325 [06:50<18:33, 16.36it/s]  

epoch: 1, iteration: 3108, loss: 0.1716737002134323


 15%|█▌        | 3234/21325 [07:02<25:57, 11.62it/s]  

epoch: 1, iteration: 3236, loss: 0.16556954383850098


 16%|█▌        | 3365/21325 [07:22<36:38,  8.17it/s]  

epoch: 1, iteration: 3368, loss: 0.15175750851631165


 16%|█▋        | 3500/21325 [07:39<26:15, 11.32it/s]  

epoch: 1, iteration: 3500, loss: 0.14766517281532288


 17%|█▋        | 3631/21325 [08:02<39:19,  7.50it/s]  

epoch: 1, iteration: 3631, loss: 0.15741513669490814


 18%|█▊        | 3760/21325 [08:22<14:30, 20.18it/s]  

epoch: 1, iteration: 3761, loss: 0.1529121696949005


 18%|█▊        | 3889/21325 [08:42<37:55,  7.66it/s]  

epoch: 1, iteration: 3889, loss: 0.14761625230312347


 19%|█▉        | 4016/21325 [09:00<18:53, 15.27it/s]  

epoch: 1, iteration: 4019, loss: 0.14757126569747925


 19%|█▉        | 4148/21325 [09:19<31:09,  9.19it/s]  

epoch: 1, iteration: 4150, loss: 0.14904049038887024


 20%|██        | 4279/21325 [09:43<22:08, 12.83it/s]  

epoch: 1, iteration: 4282, loss: 0.15396693348884583


 21%|██        | 4410/21325 [10:04<1:06:08,  4.26it/s]

epoch: 1, iteration: 4411, loss: 0.16659536957740784


 21%|██▏       | 4540/21325 [10:26<21:16, 13.15it/s]  

epoch: 1, iteration: 4540, loss: 0.16541311144828796


 22%|██▏       | 4669/21325 [10:48<1:12:49,  3.81it/s]

epoch: 1, iteration: 4670, loss: 0.15546965599060059


 23%|██▎       | 4800/21325 [11:07<15:14, 18.08it/s]  

epoch: 1, iteration: 4800, loss: 0.13951167464256287


 23%|██▎       | 4927/21325 [11:31<17:45, 15.39it/s]  

epoch: 1, iteration: 4929, loss: 0.14341624081134796


 24%|██▎       | 5056/21325 [11:47<30:22,  8.93it/s]  

epoch: 1, iteration: 5058, loss: 0.14799165725708008


 24%|██▍       | 5186/21325 [12:06<22:43, 11.84it/s]  

epoch: 1, iteration: 5188, loss: 0.13856029510498047


 25%|██▍       | 5318/21325 [12:28<19:38, 13.58it/s]  

epoch: 1, iteration: 5319, loss: 0.16410689055919647


 26%|██▌       | 5451/21325 [12:49<54:49,  4.83it/s]  

epoch: 1, iteration: 5451, loss: 0.1597670167684555


 26%|██▌       | 5579/21325 [13:05<27:45,  9.45it/s]  

epoch: 1, iteration: 5581, loss: 0.15444941818714142


 27%|██▋       | 5712/21325 [13:23<16:22, 15.88it/s]  

epoch: 1, iteration: 5713, loss: 0.1456325650215149


 27%|██▋       | 5843/21325 [13:38<29:37,  8.71it/s]  

epoch: 1, iteration: 5843, loss: 0.14980271458625793


 28%|██▊       | 5972/21325 [13:55<40:53,  6.26it/s]  

epoch: 1, iteration: 5972, loss: 0.13279923796653748


 29%|██▊       | 6100/21325 [14:09<13:47, 18.40it/s]  

epoch: 1, iteration: 6101, loss: 0.15228867530822754


 29%|██▉       | 6227/21325 [14:26<16:33, 15.20it/s]  

epoch: 1, iteration: 6230, loss: 0.1326444447040558


 30%|██▉       | 6360/21325 [14:42<30:58,  8.05it/s]  

epoch: 1, iteration: 6360, loss: 0.1584196835756302


 30%|███       | 6488/21325 [14:58<23:03, 10.72it/s]  

epoch: 1, iteration: 6489, loss: 0.13654419779777527


 31%|███       | 6620/21325 [15:16<22:15, 11.01it/s]  

epoch: 1, iteration: 6620, loss: 0.13710851967334747


 32%|███▏      | 6749/21325 [15:37<1:41:59,  2.38it/s]

epoch: 1, iteration: 6752, loss: 0.14350013434886932


 32%|███▏      | 6879/21325 [15:53<25:07,  9.58it/s]  

epoch: 1, iteration: 6881, loss: 0.13921959698200226


 33%|███▎      | 7009/21325 [16:09<11:59, 19.91it/s]  

epoch: 1, iteration: 7011, loss: 0.15456822514533997


 33%|███▎      | 7143/21325 [16:33<1:34:27,  2.50it/s]

epoch: 1, iteration: 7143, loss: 0.16610662639141083


 34%|███▍      | 7273/21325 [16:52<19:48, 11.82it/s]  

epoch: 1, iteration: 7275, loss: 0.15431801974773407


 35%|███▍      | 7404/21325 [17:08<12:35, 18.42it/s]  

epoch: 1, iteration: 7404, loss: 0.13863588869571686


 35%|███▌      | 7533/21325 [17:24<1:09:00,  3.33it/s]

epoch: 1, iteration: 7533, loss: 0.137155681848526


 36%|███▌      | 7662/21325 [17:39<10:47, 21.09it/s]  

epoch: 1, iteration: 7664, loss: 0.13590237498283386


 37%|███▋      | 7795/21325 [18:06<50:17,  4.48it/s]  

epoch: 1, iteration: 7796, loss: 0.16517531871795654


 37%|███▋      | 7922/21325 [18:23<14:12, 15.72it/s]  

epoch: 1, iteration: 7926, loss: 0.16815996170043945


 38%|███▊      | 8054/21325 [18:41<1:00:13,  3.67it/s]

epoch: 1, iteration: 8057, loss: 0.16912594437599182


 38%|███▊      | 8186/21325 [19:03<34:32,  6.34it/s]  

epoch: 1, iteration: 8188, loss: 0.15731504559516907


 39%|███▉      | 8316/21325 [19:22<31:12,  6.95it/s]  

epoch: 1, iteration: 8319, loss: 0.16139262914657593


 40%|███▉      | 8444/21325 [19:43<18:47, 11.42it/s]  

epoch: 1, iteration: 8449, loss: 0.1420072466135025


 40%|████      | 8579/21325 [19:59<10:47, 19.69it/s]  

epoch: 1, iteration: 8579, loss: 0.15700849890708923


 41%|████      | 8710/21325 [20:19<17:19, 12.13it/s]  

epoch: 1, iteration: 8711, loss: 0.17139658331871033


 41%|████▏     | 8840/21325 [20:44<14:44, 14.12it/s]  

epoch: 1, iteration: 8844, loss: 0.14337782561779022


 42%|████▏     | 8975/21325 [21:01<14:02, 14.65it/s]  

epoch: 1, iteration: 8975, loss: 0.14888879656791687


 43%|████▎     | 9102/21325 [21:14<10:27, 19.47it/s]  

epoch: 1, iteration: 9103, loss: 0.17192226648330688


 43%|████▎     | 9233/21325 [21:32<19:49, 10.16it/s]  

epoch: 1, iteration: 9234, loss: 0.15533240139484406


 44%|████▍     | 9365/21325 [21:50<37:33,  5.31it/s]  

epoch: 1, iteration: 9366, loss: 0.15709900856018066


 45%|████▍     | 9494/21325 [22:05<15:02, 13.11it/s]  

epoch: 1, iteration: 9495, loss: 0.1374700963497162


 45%|████▌     | 9622/21325 [22:17<13:31, 14.41it/s]  

epoch: 1, iteration: 9623, loss: 0.15558013319969177


 46%|████▌     | 9751/21325 [22:33<35:17,  5.46it/s]  

epoch: 1, iteration: 9752, loss: 0.13785643875598907


 46%|████▋     | 9884/21325 [22:53<16:36, 11.48it/s]  

epoch: 1, iteration: 9884, loss: 0.14448252320289612


 47%|████▋     | 10015/21325 [23:07<09:36, 19.63it/s] 

epoch: 1, iteration: 10015, loss: 0.1333197057247162


 48%|████▊     | 10143/21325 [23:23<06:34, 28.34it/s]  

epoch: 1, iteration: 10146, loss: 0.12823563814163208


 48%|████▊     | 10277/21325 [23:43<08:55, 20.62it/s]  

epoch: 1, iteration: 10278, loss: 0.12440050393342972


 49%|████▉     | 10408/21325 [23:59<36:47,  4.95it/s]  

epoch: 1, iteration: 10409, loss: 0.1187022253870964


 49%|████▉     | 10540/21325 [24:20<18:26,  9.75it/s]  

epoch: 1, iteration: 10540, loss: 0.16067053377628326


 50%|█████     | 10667/21325 [24:34<10:37, 16.72it/s]  

epoch: 1, iteration: 10668, loss: 0.14971858263015747


 51%|█████     | 10797/21325 [24:49<11:53, 14.75it/s]  

epoch: 1, iteration: 10797, loss: 0.15605637431144714


 51%|█████     | 10927/21325 [25:09<30:24,  5.70it/s]  

epoch: 1, iteration: 10927, loss: 0.13163945078849792


 52%|█████▏    | 11057/21325 [25:31<07:36, 22.51it/s]  

epoch: 1, iteration: 11059, loss: 0.1502145528793335


 52%|█████▏    | 11187/21325 [25:46<07:21, 22.97it/s]  

epoch: 1, iteration: 11189, loss: 0.14823076128959656


 53%|█████▎    | 11319/21325 [26:02<08:22, 19.92it/s]  

epoch: 1, iteration: 11319, loss: 0.14693312346935272


 54%|█████▎    | 11449/21325 [26:18<08:14, 19.96it/s]  

epoch: 1, iteration: 11449, loss: 0.14143261313438416


 54%|█████▍    | 11576/21325 [26:32<10:18, 15.76it/s]  

epoch: 1, iteration: 11579, loss: 0.13935858011245728


 55%|█████▍    | 11709/21325 [26:48<10:12, 15.71it/s]  

epoch: 1, iteration: 11709, loss: 0.15887433290481567


 56%|█████▌    | 11837/21325 [27:04<11:40, 13.54it/s]  

epoch: 1, iteration: 11839, loss: 0.13933628797531128


 56%|█████▌    | 11968/21325 [27:20<16:13,  9.61it/s]  

epoch: 1, iteration: 11969, loss: 0.14811941981315613


 57%|█████▋    | 12097/21325 [27:37<29:56,  5.14it/s]  

epoch: 1, iteration: 12099, loss: 0.13878311216831207


 57%|█████▋    | 12227/21325 [27:54<13:19, 11.38it/s]  

epoch: 1, iteration: 12228, loss: 0.13719409704208374


 58%|█████▊    | 12357/21325 [28:13<13:47, 10.83it/s]  

epoch: 1, iteration: 12359, loss: 0.14235267043113708


 59%|█████▊    | 12484/21325 [28:33<11:37, 12.68it/s]  

epoch: 1, iteration: 12487, loss: 0.14939090609550476


 59%|█████▉    | 12616/21325 [28:50<10:57, 13.24it/s]  

epoch: 1, iteration: 12617, loss: 0.13774631917476654


 60%|█████▉    | 12745/21325 [29:08<05:20, 26.79it/s]  

epoch: 1, iteration: 12747, loss: 0.15139488875865936


 60%|██████    | 12876/21325 [29:29<12:33, 11.21it/s]  

epoch: 1, iteration: 12876, loss: 0.13676396012306213


 61%|██████    | 13006/21325 [29:50<17:18,  8.01it/s]  

epoch: 1, iteration: 13008, loss: 0.14587754011154175


 62%|██████▏   | 13140/21325 [30:16<32:51,  4.15it/s]  

epoch: 1, iteration: 13140, loss: 0.15461227297782898


 62%|██████▏   | 13269/21325 [30:31<06:26, 20.87it/s]  

epoch: 1, iteration: 13269, loss: 0.15770521759986877


 63%|██████▎   | 13396/21325 [30:49<07:25, 17.80it/s]  

epoch: 1, iteration: 13398, loss: 0.13774487376213074


 63%|██████▎   | 13530/21325 [31:13<21:51,  5.95it/s]  

epoch: 1, iteration: 13530, loss: 0.14387398958206177


 64%|██████▍   | 13661/21325 [31:33<24:52,  5.14it/s]  

epoch: 1, iteration: 13661, loss: 0.14618845283985138


 65%|██████▍   | 13797/21325 [31:55<18:04,  6.94it/s]  

epoch: 1, iteration: 13797, loss: 0.1742575466632843


 65%|██████▌   | 13926/21325 [32:09<20:47,  5.93it/s]  

epoch: 1, iteration: 13926, loss: 0.15630637109279633


 66%|██████▌   | 14056/21325 [32:26<29:26,  4.11it/s]  

epoch: 1, iteration: 14056, loss: 0.1237468272447586


 67%|██████▋   | 14188/21325 [32:42<22:35,  5.26it/s]  

epoch: 1, iteration: 14188, loss: 0.15079376101493835


 67%|██████▋   | 14315/21325 [32:58<11:16, 10.36it/s]  

epoch: 1, iteration: 14316, loss: 0.1407206952571869


 68%|██████▊   | 14449/21325 [33:16<10:32, 10.87it/s]  

epoch: 1, iteration: 14449, loss: 0.16697761416435242


 68%|██████▊   | 14578/21325 [33:34<16:57,  6.63it/s]  

epoch: 1, iteration: 14578, loss: 0.15090864896774292


 69%|██████▉   | 14705/21325 [33:54<17:29,  6.31it/s]  

epoch: 1, iteration: 14706, loss: 0.13112156093120575


 70%|██████▉   | 14836/21325 [34:13<10:04, 10.73it/s]  

epoch: 1, iteration: 14837, loss: 0.15136492252349854


 70%|███████   | 14967/21325 [34:34<21:15,  4.99it/s]  

epoch: 1, iteration: 14968, loss: 0.14503245055675507


 71%|███████   | 15101/21325 [34:57<10:44,  9.65it/s]  

epoch: 1, iteration: 15101, loss: 0.13893695175647736


 71%|███████▏  | 15228/21325 [35:20<15:20,  6.62it/s]  

epoch: 1, iteration: 15230, loss: 0.14192676544189453


 72%|███████▏  | 15356/21325 [35:40<11:14,  8.85it/s]  

epoch: 1, iteration: 15360, loss: 0.13398148119449615


 73%|███████▎  | 15490/21325 [36:00<14:26,  6.74it/s]

epoch: 1, iteration: 15490, loss: 0.13897767663002014


 73%|███████▎  | 15616/21325 [36:16<12:07,  7.84it/s]  

epoch: 1, iteration: 15619, loss: 0.15106350183486938


 74%|███████▍  | 15748/21325 [36:38<08:58, 10.36it/s]

epoch: 1, iteration: 15748, loss: 0.1496482789516449


 74%|███████▍  | 15880/21325 [37:00<32:36,  2.78it/s]  

epoch: 1, iteration: 15880, loss: 0.15801408886909485


 75%|███████▌  | 16008/21325 [37:16<05:31, 16.06it/s]  

epoch: 1, iteration: 16008, loss: 0.13151875138282776


 76%|███████▌  | 16135/21325 [37:35<08:06, 10.67it/s]  

epoch: 1, iteration: 16136, loss: 0.1527833640575409


 76%|███████▋  | 16264/21325 [37:58<10:44,  7.85it/s]  

epoch: 1, iteration: 16265, loss: 0.14194414019584656


 77%|███████▋  | 16396/21325 [38:17<06:24, 12.81it/s]  

epoch: 1, iteration: 16396, loss: 0.16579967737197876


 77%|███████▋  | 16525/21325 [38:41<09:31,  8.39it/s]  

epoch: 1, iteration: 16528, loss: 0.15555596351623535


 78%|███████▊  | 16657/21325 [38:59<08:05,  9.61it/s]

epoch: 1, iteration: 16658, loss: 0.14369496703147888


 79%|███████▊  | 16788/21325 [39:18<08:48,  8.58it/s]  

epoch: 1, iteration: 16788, loss: 0.12398037314414978


 79%|███████▉  | 16915/21325 [39:39<05:07, 14.34it/s]  

epoch: 1, iteration: 16918, loss: 0.13738347589969635


 80%|███████▉  | 17048/21325 [39:59<03:22, 21.08it/s]

epoch: 1, iteration: 17048, loss: 0.1303389072418213


 81%|████████  | 17177/21325 [40:20<13:34,  5.09it/s]

epoch: 1, iteration: 17177, loss: 0.14969882369041443


 81%|████████  | 17306/21325 [40:42<04:15, 15.73it/s]  

epoch: 1, iteration: 17307, loss: 0.14084017276763916


 82%|████████▏ | 17434/21325 [41:05<05:37, 11.54it/s]  

epoch: 1, iteration: 17436, loss: 0.1404498815536499


 82%|████████▏ | 17566/21325 [41:30<04:13, 14.83it/s]

epoch: 1, iteration: 17566, loss: 0.1374669373035431


 83%|████████▎ | 17696/21325 [41:57<08:52,  6.81it/s]  

epoch: 1, iteration: 17696, loss: 0.15835195779800415


 84%|████████▎ | 17824/21325 [42:13<04:33, 12.79it/s]

epoch: 1, iteration: 17825, loss: 0.12470747530460358


 84%|████████▍ | 17952/21325 [42:34<09:03,  6.21it/s]

epoch: 1, iteration: 17954, loss: 0.13445571064949036


 85%|████████▍ | 18087/21325 [43:00<11:20,  4.76it/s]

epoch: 1, iteration: 18087, loss: 0.15335485339164734


 85%|████████▌ | 18215/21325 [43:19<09:02,  5.73it/s]

epoch: 1, iteration: 18216, loss: 0.1307578980922699


 86%|████████▌ | 18346/21325 [43:35<02:21, 21.07it/s]

epoch: 1, iteration: 18347, loss: 0.13718664646148682


 87%|████████▋ | 18478/21325 [43:54<05:00,  9.49it/s]

epoch: 1, iteration: 18478, loss: 0.14317595958709717


 87%|████████▋ | 18608/21325 [44:14<04:18, 10.50it/s]

epoch: 1, iteration: 18608, loss: 0.16610966622829437


 88%|████████▊ | 18740/21325 [44:36<08:26,  5.10it/s]

epoch: 1, iteration: 18741, loss: 0.1562584936618805


 88%|████████▊ | 18868/21325 [44:55<02:43, 15.03it/s]

epoch: 1, iteration: 18870, loss: 0.14433132112026215


 89%|████████▉ | 19002/21325 [45:16<02:43, 14.17it/s]

epoch: 1, iteration: 19003, loss: 0.13959193229675293


 90%|████████▉ | 19134/21325 [45:34<02:36, 13.99it/s]

epoch: 1, iteration: 19135, loss: 0.15177848935127258


 90%|█████████ | 19264/21325 [45:51<04:53,  7.03it/s]

epoch: 1, iteration: 19265, loss: 0.14660251140594482


 91%|█████████ | 19392/21325 [46:07<02:21, 13.67it/s]

epoch: 1, iteration: 19394, loss: 0.12834104895591736


 92%|█████████▏| 19521/21325 [46:23<02:18, 13.00it/s]

epoch: 1, iteration: 19522, loss: 0.1185358464717865


 92%|█████████▏| 19651/21325 [46:39<01:46, 15.78it/s]

epoch: 1, iteration: 19652, loss: 0.15085375308990479


 93%|█████████▎| 19782/21325 [46:58<01:31, 16.78it/s]

epoch: 1, iteration: 19783, loss: 0.14154590666294098


 93%|█████████▎| 19911/21325 [47:10<02:03, 11.46it/s]

epoch: 1, iteration: 19911, loss: 0.1225077360868454


 94%|█████████▍| 20042/21325 [47:30<01:57, 10.96it/s]

epoch: 1, iteration: 20045, loss: 0.13923843204975128


 95%|█████████▍| 20172/21325 [47:49<01:00, 19.18it/s]

epoch: 1, iteration: 20174, loss: 0.1600920557975769


 95%|█████████▌| 20304/21325 [48:16<00:43, 23.55it/s]

epoch: 1, iteration: 20305, loss: 0.1554068922996521


 96%|█████████▌| 20433/21325 [48:32<00:43, 20.40it/s]

epoch: 1, iteration: 20433, loss: 0.15284378826618195


 96%|█████████▋| 20561/21325 [48:50<01:08, 11.08it/s]

epoch: 1, iteration: 20562, loss: 0.16230779886245728


 97%|█████████▋| 20690/21325 [49:05<00:28, 22.58it/s]

epoch: 1, iteration: 20691, loss: 0.16713589429855347


 98%|█████████▊| 20819/21325 [49:22<01:02,  8.14it/s]

epoch: 1, iteration: 20821, loss: 0.1577712893486023


 98%|█████████▊| 20949/21325 [49:39<00:40,  9.28it/s]

epoch: 1, iteration: 20950, loss: 0.14374767243862152


 99%|█████████▉| 21081/21325 [49:59<00:31,  7.69it/s]

epoch: 1, iteration: 21082, loss: 0.1550312340259552


 99%|█████████▉| 21213/21325 [50:18<00:11,  9.58it/s]

epoch: 1, iteration: 21213, loss: 0.15202586352825165


100%|█████████▉| 21324/21325 [50:33<00:00, 10.76it/s]

epoch: 1, iteration: 21324, loss: 0.13401077687740326


100%|██████████| 21325/21325 [50:37<00:00,  7.02it/s]
  0%|          | 0/21325 [00:00<?, ?it/s]

epoch: 2


  0%|          | 30/21325 [00:04<52:38,  6.74it/s]  

In [None]:
def test(model, data):
    acc, total = 0, 0
    for d in data:
        q = d[0]
        print('q', ' '.join(q))
        labels = d[1]
        cands = d[2]

        # preprare answer labels
        label_indices = [cands.index(l) for l in labels if l in cands]

        # build data
        q = make_vector([q], w2i, len(q))
        cands = [label_to_ans_text[c] for c in cands] # id to text
        max_cand_len = min(args.max_sent_len, max([len(c) for c in cands]))
        cands = make_vector(cands, w2i, max_cand_len)

        # predict
        scores = [model(q, c.unsqueeze(0)).data[0] for c in cands]
        pred_idx = np.argmax(scores)
        if pred_idx in label_indices:
            print('correct')
            acc += 1
        else:
            print('wrong')
        total += 1
    print('Test Acc:', 100*acc/total, '%')

In [None]:
test_model(model, test)