<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
# Imports
import gzip
import os
import pandas as pd
import numpy as np
import torch
import basic_model
import random
from collections import Counter
from tqdm import tqdm
import ujson as json
from torch.autograd import Variable
import util
import args
import pprint

args = args.get_setup_args()
vars(args)

{'vocab_path': '/home/mouadh/Desktop/insuranceQA/V2/vocabulary',
 'label_path': '/home/mouadh/Desktop/insuranceQA/V2/InsuranceQA.label2answer.token.encoded',
 'train_path': '/home/mouadh/Desktop/insuranceQA/V2/InsuranceQA.question.anslabel.token.500.pool.solr.train.encoded',
 'test_path': '/home/mouadh/Desktop/insuranceQA/V2/InsuranceQA.question.anslabel.token.500.pool.solr.test.encoded',
 'glove_path': '/home/mouadh/Desktop/insuranceQA/glove.840B.300d/glove.840B.300d.txt',
 'glove_dim': 300,
 'glove_num_vecs': 2196017,
 'hidden_size': 100,
 'word_emb_file': '/home/mouadh/Desktop/insuranceQA/glove.840B.300d/word_embedding',
 'word2idx_file': '/home/mouadh/Desktop/insuranceQA/glove.840B.300d/word2idx',
 'seed': 0,
 'create_matrix_embedding': True,
 'lr': 0.5,
 'num_epochs': 30,
 'drop_prob': 0.2,
 'margin': 0.2,
 'batch_size': 128,
 'use_glove': False,
 'embd_size': 200}

In [2]:
# Fix the random seed to have reprodusable results
random.seed(args.seed)
np.random.seed(args.seed)
#torch.manual_seed(args.seed)
#torch.cuda.manual_seed_all(args.seed)

In [3]:
# Load answers labels and answer label with text
id2w, l2a, l2at = util.load_vocabulary(args.vocab_path, args.label_path)

# Create word to index vocabulary
w2i = {w: i for i, w in enumerate(id2w.values(), 1)}
# Add pad to the vocabulary
PAD = '<PAD>'
w2i[PAD] = 0

vocab_size = len(w2i)
args.vocab_size = vocab_size
print('vocab_size:', vocab_size)

vocab_size: 68581


In [4]:
train = util.load_data_train(args.train_path, id2w, l2at)

In [25]:
test = util.load_data_test(args.test_path, id2w, l2at)

In [5]:
print('n_train:', len(train))
print('n_test:', len(test))

n_train: 21325
n_test: 1988


In [6]:
if args.use_glove:
    if not args.create_matrix_embedding:
        # Create word embedding and word2idx for glove
        print("Creating word embedding")
        word_emb_mat, word2idx_dict = util.get_embedding('word', emb_file=args.glove_path,
                                        vec_size=args.glove_dim, num_vectors=args.glove_num_vecs)
        util.save(args.word_emb_file, word_emb_mat, message="word embedding")
        util.save(args.word2idx_file, word2idx_dict, message="word dictionary")
    else:
        # Get embeddings
        print('Loading word vectors embeddings...')
        word_vectors = util.torch_from_json(args.word_emb_file)
        print("Loading word to index dictionary")
        word2idx_dict = json.load(open(args.word2idx_file))
else:
    word2idx_dict = w2i

In [7]:
train[0][0]

['Is', 'Disability', 'Insurance', 'Required', 'By', 'Law', '?']

In [8]:
word2idx_dict['?']

50370

In [9]:
util.make_vector([train[0][0]], w2i, 20)

tensor([[66164, 54421, 29876, 56902, 59631, 57715, 50370,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]])

In [11]:
def train_model(model, data, optimizer, word2idx_dict, args, batch_size = 128):
    for epoch in range(1): #args.num_epoch:
        model.train()
        random.shuffle(data)
        losses = []
        for i, d in enumerate(tqdm(data)):
            q, pos, negs = d[0], d[1], d[2]
            #pos = process_words(pos, word2idx_dict)
            #q = process_words(q, word2idx_dict)
            
            vec_q = util.make_vector([q], word2idx_dict, len(q))
            vec_pos = util.make_vector([pos], word2idx_dict, len(pos))
            pos_sim = model(vec_q, vec_pos)

            for _ in range(100):
                neg = random.choice(negs)
                #neg = process_words(neg, word2idx_dict)
                vec_neg = util.make_vector([neg], word2idx_dict, len(neg))
                neg_sim = model(vec_q, vec_neg)
                #print('positive similarity {}, negative similarity{}'.format(pos_sim, neg_sim))
                loss = util.loss_fn(pos_sim, neg_sim, args.margin)
                if loss.data[0] != 0:
                    losses.append(loss)
                    break

            if len(losses) == batch_size or i == len(data) - 1:
                loss = torch.mean(torch.stack(losses, 0).squeeze(), 0)
                #print(loss)
                print("epoch: {}, iteration: {}, loss: {}".format(epoch, i, loss.item()))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses = []
    return model

In [12]:
def lowercase(tokens):
    return [token.lower() for token in tokens]

In [13]:
from nltk.tokenize import RegexpTokenizer
tokenizer_punct = RegexpTokenizer(r'\w+')
import itertools

def remove_special_chars(tokens):
    return list(itertools.chain(*[tokenizer_punct.tokenize(token) for token in tokens]))


In [14]:
def remove_words_not_glove(tokens, word2idx_dict):
    filteredlist = []
    for token in tokens:
        try: 
            word2idx_dict[token]
            filteredlist.append(token)
        except:
            pass
    return filteredlist

In [15]:
remove_words_not_glove(['websitelink', 'hello'], word2idx_dict)

[]

In [16]:
def process_words(tokens, word2idx_dict):
    tokens = lowercase(tokens)
    tokens = remove_special_chars(tokens)
    tokens = remove_words_not_glove(tokens, word2idx_dict)
    return tokens

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

if args.use_glove:
    model = basic_model.QA_LSTM(args , word_vectors)
else: 
    model = basic_model.QA_LSTM(args)

if torch.cuda.is_available():
    model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))

In [18]:
model = train_model(model, train, optimizer, word2idx_dict, args, args.batch_size)

  1%|          | 125/21325 [00:06<19:39, 17.98it/s]

epoch: 0, iteration: 127, loss: 0.20360736548900604


  1%|          | 255/21325 [00:45<19:30, 18.00it/s]   

epoch: 0, iteration: 255, loss: 0.19979801774024963


  2%|▏         | 383/21325 [01:24<18:14, 19.14it/s]   

epoch: 0, iteration: 383, loss: 0.2003074735403061


  2%|▏         | 511/21325 [02:04<17:36, 19.70it/s]   

epoch: 0, iteration: 511, loss: 0.19824783504009247


  3%|▎         | 637/21325 [02:43<15:30, 22.24it/s]   

epoch: 0, iteration: 639, loss: 0.19930337369441986


  4%|▎         | 766/21325 [03:21<20:16, 16.90it/s]   

epoch: 0, iteration: 767, loss: 0.20024891197681427


  4%|▍         | 893/21325 [04:01<15:47, 21.56it/s]   

epoch: 0, iteration: 895, loss: 0.19931268692016602


  5%|▍         | 1022/21325 [04:40<16:02, 21.09it/s]  

epoch: 0, iteration: 1023, loss: 0.20061331987380981


  5%|▌         | 1150/21325 [05:19<22:34, 14.90it/s]   

epoch: 0, iteration: 1151, loss: 0.19722385704517365


  6%|▌         | 1278/21325 [05:57<14:17, 23.39it/s]   

epoch: 0, iteration: 1279, loss: 0.198110431432724


  7%|▋         | 1405/21325 [06:41<16:32, 20.07it/s]   

epoch: 0, iteration: 1407, loss: 0.19955451786518097


  7%|▋         | 1535/21325 [07:21<16:39, 19.81it/s]   

epoch: 0, iteration: 1535, loss: 0.1980333924293518


  8%|▊         | 1662/21325 [07:58<14:06, 23.22it/s]   

epoch: 0, iteration: 1663, loss: 0.1967649757862091


  8%|▊         | 1789/21325 [08:37<15:03, 21.62it/s]   

epoch: 0, iteration: 1791, loss: 0.19923222064971924


  9%|▉         | 1918/21325 [09:16<23:20, 13.86it/s]   

epoch: 0, iteration: 1919, loss: 0.19830664992332458


 10%|▉         | 2046/21325 [09:55<13:59, 22.97it/s]   

epoch: 0, iteration: 2047, loss: 0.1958111673593521


 10%|█         | 2174/21325 [10:32<14:29, 22.02it/s]   

epoch: 0, iteration: 2175, loss: 0.19635848701000214


 11%|█         | 2302/21325 [11:13<12:50, 24.69it/s]   

epoch: 0, iteration: 2303, loss: 0.1947476714849472


 11%|█▏        | 2431/21325 [11:50<14:50, 21.21it/s]   

epoch: 0, iteration: 2431, loss: 0.19431360065937042


 12%|█▏        | 2557/21325 [12:30<17:25, 17.95it/s]   

epoch: 0, iteration: 2559, loss: 0.19615250825881958


 13%|█▎        | 2685/21325 [13:14<14:30, 21.41it/s]   

epoch: 0, iteration: 2687, loss: 0.19718849658966064


 13%|█▎        | 2814/21325 [13:54<16:17, 18.93it/s]   

epoch: 0, iteration: 2815, loss: 0.1929452121257782


 14%|█▍        | 2943/21325 [14:37<15:02, 20.37it/s]   

epoch: 0, iteration: 2943, loss: 0.1927233189344406


 14%|█▍        | 3071/21325 [15:15<12:10, 24.99it/s]   

epoch: 0, iteration: 3071, loss: 0.1906544417142868


 15%|█▍        | 3198/21325 [15:53<15:41, 19.26it/s]   

epoch: 0, iteration: 3199, loss: 0.18828672170639038


 16%|█▌        | 3325/21325 [16:32<16:19, 18.38it/s]   

epoch: 0, iteration: 3327, loss: 0.1919081062078476


 16%|█▌        | 3454/21325 [17:13<17:48, 16.73it/s]   

epoch: 0, iteration: 3455, loss: 0.19091440737247467


 17%|█▋        | 3583/21325 [17:57<19:54, 14.86it/s]   

epoch: 0, iteration: 3583, loss: 0.1905534565448761


 17%|█▋        | 3710/21325 [18:39<15:26, 19.02it/s]   

epoch: 0, iteration: 3711, loss: 0.18700608611106873


 18%|█▊        | 3838/21325 [19:20<15:59, 18.23it/s]   

epoch: 0, iteration: 3839, loss: 0.19253912568092346


 19%|█▊        | 3966/21325 [20:01<13:26, 21.53it/s]   

epoch: 0, iteration: 3967, loss: 0.18781445920467377


 19%|█▉        | 4095/21325 [20:43<19:31, 14.71it/s]   

epoch: 0, iteration: 4095, loss: 0.18860551714897156


 20%|█▉        | 4222/21325 [21:24<14:08, 20.15it/s]   

epoch: 0, iteration: 4223, loss: 0.18997427821159363


 20%|██        | 4349/21325 [22:06<11:36, 24.38it/s]   

epoch: 0, iteration: 4351, loss: 0.19253185391426086


 21%|██        | 4477/21325 [22:46<14:52, 18.88it/s]   

epoch: 0, iteration: 4479, loss: 0.18485566973686218


 22%|██▏       | 4607/21325 [23:27<15:49, 17.61it/s]   

epoch: 0, iteration: 4607, loss: 0.18848192691802979


 22%|██▏       | 4734/21325 [24:09<16:39, 16.60it/s]   

epoch: 0, iteration: 4735, loss: 0.18578074872493744


 23%|██▎       | 4863/21325 [24:50<23:50, 11.51it/s]   

epoch: 0, iteration: 4863, loss: 0.18965578079223633


 23%|██▎       | 4990/21325 [25:32<13:11, 20.65it/s]   

epoch: 0, iteration: 4991, loss: 0.18074525892734528


 24%|██▍       | 5119/21325 [26:14<12:32, 21.54it/s]   

epoch: 0, iteration: 5119, loss: 0.1851525455713272


 25%|██▍       | 5247/21325 [27:01<14:56, 17.93it/s]   

epoch: 0, iteration: 5247, loss: 0.180517315864563


 25%|██▌       | 5374/21325 [27:48<15:39, 16.98it/s]   

epoch: 0, iteration: 5375, loss: 0.18123219907283783


 26%|██▌       | 5502/21325 [28:32<17:24, 15.15it/s]   

epoch: 0, iteration: 5503, loss: 0.1836794912815094


 26%|██▋       | 5631/21325 [29:16<18:20, 14.26it/s]   

epoch: 0, iteration: 5631, loss: 0.18394427001476288


 27%|██▋       | 5757/21325 [29:59<11:24, 22.74it/s]   

epoch: 0, iteration: 5759, loss: 0.17834383249282837


 28%|██▊       | 5887/21325 [30:41<11:32, 22.30it/s]   

epoch: 0, iteration: 5887, loss: 0.17646852135658264


 28%|██▊       | 6014/21325 [31:26<14:38, 17.43it/s]   

epoch: 0, iteration: 6015, loss: 0.17885719239711761


 29%|██▉       | 6143/21325 [32:10<13:25, 18.85it/s]   

epoch: 0, iteration: 6143, loss: 0.1798500269651413


 29%|██▉       | 6270/21325 [32:53<12:00, 20.91it/s]   

epoch: 0, iteration: 6271, loss: 0.17297278344631195


 30%|███       | 6399/21325 [33:36<17:50, 13.95it/s]   

epoch: 0, iteration: 6399, loss: 0.17367716133594513


 31%|███       | 6526/21325 [34:25<12:58, 19.02it/s]   

epoch: 0, iteration: 6527, loss: 0.17436984181404114


 31%|███       | 6653/21325 [35:03<14:25, 16.95it/s]   

epoch: 0, iteration: 6655, loss: 0.1711960881948471


 32%|███▏      | 6783/21325 [35:45<12:26, 19.48it/s]   

epoch: 0, iteration: 6783, loss: 0.1633915901184082


 32%|███▏      | 6910/21325 [36:31<10:54, 22.03it/s]   

epoch: 0, iteration: 6911, loss: 0.16570928692817688


 33%|███▎      | 7038/21325 [37:15<10:00, 23.78it/s]   

epoch: 0, iteration: 7039, loss: 0.1717926412820816


 34%|███▎      | 7165/21325 [37:55<12:32, 18.83it/s]   

epoch: 0, iteration: 7167, loss: 0.17230001091957092


 34%|███▍      | 7293/21325 [38:35<12:23, 18.86it/s]   

epoch: 0, iteration: 7295, loss: 0.17247124016284943


 35%|███▍      | 7421/21325 [39:18<10:31, 22.03it/s]   

epoch: 0, iteration: 7423, loss: 0.1629311442375183


 35%|███▌      | 7549/21325 [40:01<10:04, 22.79it/s]   

epoch: 0, iteration: 7551, loss: 0.1710156500339508


 36%|███▌      | 7678/21325 [40:40<10:34, 21.50it/s]   

epoch: 0, iteration: 7679, loss: 0.166587695479393


 37%|███▋      | 7807/21325 [41:24<16:04, 14.01it/s]   

epoch: 0, iteration: 7807, loss: 0.1667833924293518


 37%|███▋      | 7934/21325 [42:04<10:16, 21.72it/s]   

epoch: 0, iteration: 7935, loss: 0.1631874293088913


 38%|███▊      | 8062/21325 [42:50<12:13, 18.09it/s]   

epoch: 0, iteration: 8063, loss: 0.1549822986125946


 38%|███▊      | 8189/21325 [43:31<11:50, 18.50it/s]   

epoch: 0, iteration: 8191, loss: 0.15388214588165283


 39%|███▉      | 8317/21325 [44:14<11:51, 18.27it/s]   

epoch: 0, iteration: 8319, loss: 0.16623805463314056


 40%|███▉      | 8445/21325 [44:57<13:11, 16.28it/s]   

epoch: 0, iteration: 8447, loss: 0.1633126139640808


 40%|████      | 8574/21325 [45:40<17:16, 12.30it/s]   

epoch: 0, iteration: 8575, loss: 0.15570461750030518


 41%|████      | 8703/21325 [46:24<14:08, 14.87it/s]   

epoch: 0, iteration: 8703, loss: 0.1563260555267334


 41%|████▏     | 8831/21325 [47:09<12:50, 16.22it/s]   

epoch: 0, iteration: 8831, loss: 0.1592862606048584


 42%|████▏     | 8959/21325 [47:54<12:26, 16.56it/s]   

epoch: 0, iteration: 8959, loss: 0.15731391310691833


 43%|████▎     | 9087/21325 [48:36<12:15, 16.63it/s]   

epoch: 0, iteration: 9087, loss: 0.1466289907693863


 43%|████▎     | 9215/21325 [49:20<10:36, 19.02it/s]   

epoch: 0, iteration: 9215, loss: 0.16238033771514893


 44%|████▍     | 9341/21325 [50:05<09:04, 22.03it/s]   

epoch: 0, iteration: 9343, loss: 0.14819103479385376


 44%|████▍     | 9470/21325 [50:48<10:32, 18.73it/s]   

epoch: 0, iteration: 9471, loss: 0.15639635920524597


 45%|████▌     | 9598/21325 [51:29<11:31, 16.96it/s]   

epoch: 0, iteration: 9599, loss: 0.1524704396724701


 46%|████▌     | 9727/21325 [52:11<14:37, 13.22it/s]   

epoch: 0, iteration: 9727, loss: 0.16402742266654968


 46%|████▌     | 9855/21325 [52:51<13:02, 14.66it/s]   

epoch: 0, iteration: 9855, loss: 0.15128029882907867


 47%|████▋     | 9982/21325 [53:35<10:35, 17.84it/s]   

epoch: 0, iteration: 9983, loss: 0.1463891714811325


 47%|████▋     | 10110/21325 [54:15<09:47, 19.10it/s]  

epoch: 0, iteration: 10111, loss: 0.1490676999092102


 48%|████▊     | 10239/21325 [54:57<10:31, 17.55it/s]   

epoch: 0, iteration: 10239, loss: 0.14497633278369904


 49%|████▊     | 10367/21325 [55:38<12:31, 14.59it/s]   

epoch: 0, iteration: 10367, loss: 0.14588819444179535


 49%|████▉     | 10493/21325 [56:24<19:17,  9.36it/s]   

epoch: 0, iteration: 10495, loss: 0.14753104746341705


 50%|████▉     | 10622/21325 [57:13<11:36, 15.36it/s]   

epoch: 0, iteration: 10623, loss: 0.1586778610944748


 50%|█████     | 10751/21325 [57:56<19:52,  8.86it/s]   

epoch: 0, iteration: 10751, loss: 0.15759894251823425


 51%|█████     | 10878/21325 [58:40<10:56, 15.91it/s]   

epoch: 0, iteration: 10879, loss: 0.16875752806663513


 52%|█████▏    | 11007/21325 [59:24<11:26, 15.02it/s]   

epoch: 0, iteration: 11007, loss: 0.14767487347126007


 52%|█████▏    | 11135/21325 [1:00:08<30:25,  5.58it/s]  

epoch: 0, iteration: 11135, loss: 0.16744668781757355


 53%|█████▎    | 11263/21325 [1:00:51<11:28, 14.60it/s]   

epoch: 0, iteration: 11263, loss: 0.152044415473938


 53%|█████▎    | 11390/21325 [1:01:53<20:25,  8.10it/s]   

epoch: 0, iteration: 11391, loss: 0.1577659398317337


 54%|█████▍    | 11519/21325 [1:02:52<12:55, 12.64it/s]   

epoch: 0, iteration: 11519, loss: 0.152431920170784


 55%|█████▍    | 11648/21325 [1:03:45<13:32, 11.91it/s]   

epoch: 0, iteration: 11648, loss: 0.13059690594673157


 55%|█████▌    | 11776/21325 [1:04:33<24:30,  6.49it/s]   

epoch: 0, iteration: 11776, loss: 0.15821915864944458


 56%|█████▌    | 11902/21325 [1:05:15<08:49, 17.80it/s]   

epoch: 0, iteration: 11904, loss: 0.16416528820991516


 56%|█████▋    | 12032/21325 [1:05:56<12:04, 12.83it/s]  

epoch: 0, iteration: 12032, loss: 0.170810729265213


 57%|█████▋    | 12160/21325 [1:06:39<19:58,  7.65it/s]   

epoch: 0, iteration: 12160, loss: 0.15404602885246277


 58%|█████▊    | 12288/21325 [1:07:20<27:19,  5.51it/s]   

epoch: 0, iteration: 12288, loss: 0.1509464830160141


 58%|█████▊    | 12416/21325 [1:08:11<16:04,  9.24it/s]   

epoch: 0, iteration: 12416, loss: 0.13267649710178375


 59%|█████▉    | 12545/21325 [1:09:02<12:28, 11.73it/s]   

epoch: 0, iteration: 12545, loss: 0.150248721241951


 59%|█████▉    | 12672/21325 [1:09:53<17:11,  8.39it/s]   

epoch: 0, iteration: 12673, loss: 0.16147702932357788


 60%|██████    | 12800/21325 [1:10:36<10:18, 13.78it/s]   

epoch: 0, iteration: 12801, loss: 0.18090090155601501


 61%|██████    | 12929/21325 [1:11:26<16:54,  8.28it/s]   

epoch: 0, iteration: 12930, loss: 0.18212802708148956


 61%|██████    | 13059/21325 [1:12:10<08:05, 17.04it/s]   

epoch: 0, iteration: 13059, loss: 0.17926529049873352


 62%|██████▏   | 13187/21325 [1:12:52<08:31, 15.90it/s]   

epoch: 0, iteration: 13187, loss: 0.17151734232902527


 62%|██████▏   | 13315/21325 [1:13:42<28:08,  4.74it/s]   

epoch: 0, iteration: 13315, loss: 0.13338936865329742


 63%|██████▎   | 13443/21325 [1:14:30<24:53,  5.28it/s]   

epoch: 0, iteration: 13443, loss: 0.13983607292175293


 64%|██████▎   | 13571/21325 [1:15:21<08:27, 15.29it/s]   

epoch: 0, iteration: 13571, loss: 0.17743510007858276


 64%|██████▍   | 13699/21325 [1:16:04<08:49, 14.41it/s]   

epoch: 0, iteration: 13699, loss: 0.15927624702453613


 65%|██████▍   | 13828/21325 [1:16:54<42:01,  2.97it/s]   

epoch: 0, iteration: 13828, loss: 0.17584995925426483


 65%|██████▌   | 13955/21325 [1:17:42<21:15,  5.78it/s]   

epoch: 0, iteration: 13956, loss: 0.14615361392498016


 66%|██████▌   | 14084/21325 [1:18:29<10:08, 11.90it/s]   

epoch: 0, iteration: 14084, loss: 0.15419641137123108


 67%|██████▋   | 14212/21325 [1:19:15<10:31, 11.26it/s]   

epoch: 0, iteration: 14212, loss: 0.15496496856212616


 67%|██████▋   | 14339/21325 [1:20:00<12:11,  9.56it/s]   

epoch: 0, iteration: 14340, loss: 0.16007256507873535


 68%|██████▊   | 14467/21325 [1:20:51<25:23,  4.50it/s]   

epoch: 0, iteration: 14468, loss: 0.1698366403579712


 68%|██████▊   | 14595/21325 [1:21:44<16:09,  6.94it/s]   

epoch: 0, iteration: 14596, loss: 0.15651848912239075


 69%|██████▉   | 14724/21325 [1:22:34<41:05,  2.68it/s]  

epoch: 0, iteration: 14724, loss: 0.161428302526474


 70%|██████▉   | 14853/21325 [1:23:33<11:11,  9.64it/s]   

epoch: 0, iteration: 14853, loss: 0.14017006754875183


 70%|███████   | 14980/21325 [1:24:31<17:31,  6.04it/s]   

epoch: 0, iteration: 14982, loss: 0.151548370718956


 71%|███████   | 15109/21325 [1:25:31<21:27,  4.83it/s]  

epoch: 0, iteration: 15110, loss: 0.1385304182767868


 71%|███████▏  | 15237/21325 [1:26:27<33:29,  3.03it/s]  

epoch: 0, iteration: 15238, loss: 0.18277566134929657


 72%|███████▏  | 15366/21325 [1:27:16<19:41,  5.04it/s]  

epoch: 0, iteration: 15366, loss: 0.16079887747764587


 73%|███████▎  | 15494/21325 [1:28:16<15:00,  6.47it/s]   

epoch: 0, iteration: 15494, loss: 0.15411067008972168


 73%|███████▎  | 15622/21325 [1:29:10<09:45,  9.73it/s]   

epoch: 0, iteration: 15622, loss: 0.16550424695014954


 74%|███████▍  | 15750/21325 [1:30:01<09:50,  9.44it/s]   

epoch: 0, iteration: 15750, loss: 0.14227651059627533


 74%|███████▍  | 15879/21325 [1:30:54<22:31,  4.03it/s]   

epoch: 0, iteration: 15879, loss: 0.15108729898929596


 75%|███████▌  | 16008/21325 [1:31:48<09:54,  8.95it/s]   

epoch: 0, iteration: 16008, loss: 0.15206299722194672


 76%|███████▌  | 16136/21325 [1:32:42<10:34,  8.17it/s]   

epoch: 0, iteration: 16136, loss: 0.17218786478042603


 76%|███████▋  | 16262/21325 [1:33:35<08:58,  9.40it/s]   

epoch: 0, iteration: 16264, loss: 0.15451231598854065


 77%|███████▋  | 16391/21325 [1:34:27<12:53,  6.38it/s]  

epoch: 0, iteration: 16392, loss: 0.1773928552865982


 77%|███████▋  | 16519/21325 [1:35:22<16:28,  4.86it/s]  

epoch: 0, iteration: 16520, loss: 0.1835508942604065


 78%|███████▊  | 16648/21325 [1:36:14<10:16,  7.59it/s]  

epoch: 0, iteration: 16648, loss: 0.14505481719970703


 79%|███████▊  | 16776/21325 [1:37:22<10:42,  7.09it/s]   

epoch: 0, iteration: 16777, loss: 0.14750032126903534


 79%|███████▉  | 16905/21325 [1:38:16<04:52, 15.10it/s]  

epoch: 0, iteration: 16906, loss: 0.1655094474554062


 80%|███████▉  | 17034/21325 [1:39:11<11:35,  6.17it/s]  

epoch: 0, iteration: 17034, loss: 0.16292642056941986


 80%|████████  | 17162/21325 [1:40:01<10:04,  6.89it/s]   

epoch: 0, iteration: 17162, loss: 0.16776548326015472


 81%|████████  | 17289/21325 [1:40:59<09:08,  7.36it/s]   

epoch: 0, iteration: 17290, loss: 0.16091641783714294


 82%|████████▏ | 17416/21325 [1:41:47<09:31,  6.84it/s]  

epoch: 0, iteration: 17418, loss: 0.1621837168931961


 82%|████████▏ | 17547/21325 [1:42:45<09:59,  6.30it/s]  

epoch: 0, iteration: 17547, loss: 0.1573891043663025


 83%|████████▎ | 17674/21325 [1:43:32<11:59,  5.07it/s]  

epoch: 0, iteration: 17675, loss: 0.18327830731868744


 83%|████████▎ | 17803/21325 [1:44:22<05:04, 11.58it/s]  

epoch: 0, iteration: 17804, loss: 0.1728220134973526


 84%|████████▍ | 17933/21325 [1:45:15<06:08,  9.20it/s]  

epoch: 0, iteration: 17933, loss: 0.15015053749084473


 85%|████████▍ | 18060/21325 [1:46:01<06:29,  8.37it/s]  

epoch: 0, iteration: 18061, loss: 0.1495552957057953


 85%|████████▌ | 18189/21325 [1:46:55<16:07,  3.24it/s]  

epoch: 0, iteration: 18189, loss: 0.14745450019836426


 86%|████████▌ | 18316/21325 [1:47:46<10:19,  4.86it/s]  

epoch: 0, iteration: 18317, loss: 0.14980991184711456


 86%|████████▋ | 18445/21325 [1:48:35<06:25,  7.48it/s]  

epoch: 0, iteration: 18446, loss: 0.16629332304000854


 87%|████████▋ | 18574/21325 [1:49:26<14:13,  3.22it/s]  

epoch: 0, iteration: 18574, loss: 0.16860835254192352


 88%|████████▊ | 18703/21325 [1:50:25<16:39,  2.62it/s]  

epoch: 0, iteration: 18703, loss: 0.1482483595609665


 88%|████████▊ | 18830/21325 [1:51:11<04:10,  9.95it/s]  

epoch: 0, iteration: 18831, loss: 0.1714610606431961


 89%|████████▉ | 18961/21325 [1:52:04<03:08, 12.55it/s]  

epoch: 0, iteration: 18961, loss: 0.16433162987232208


 90%|████████▉ | 19089/21325 [1:52:57<08:19,  4.48it/s]  

epoch: 0, iteration: 19089, loss: 0.15079568326473236


 90%|█████████ | 19217/21325 [1:53:48<24:07,  1.46it/s]  

epoch: 0, iteration: 19217, loss: 0.14205022156238556


 91%|█████████ | 19347/21325 [1:54:48<04:01,  8.20it/s]  

epoch: 0, iteration: 19347, loss: 0.1506427526473999


 91%|█████████▏| 19476/21325 [1:55:45<03:06,  9.90it/s]  

epoch: 0, iteration: 19477, loss: 0.15843670070171356


 92%|█████████▏| 19604/21325 [1:56:31<02:33, 11.23it/s]  

epoch: 0, iteration: 19605, loss: 0.1570667326450348


 93%|█████████▎| 19735/21325 [1:57:33<13:46,  1.92it/s]  

epoch: 0, iteration: 19735, loss: 0.1679469347000122


 93%|█████████▎| 19866/21325 [1:58:33<04:00,  6.08it/s]  

epoch: 0, iteration: 19866, loss: 0.1733582615852356


 94%|█████████▍| 19995/21325 [1:59:29<03:40,  6.03it/s]  

epoch: 0, iteration: 19995, loss: 0.1467379480600357


 94%|█████████▍| 20121/21325 [2:00:17<03:26,  5.83it/s]  

epoch: 0, iteration: 20123, loss: 0.1670229136943817


 95%|█████████▍| 20252/21325 [2:01:12<05:43,  3.12it/s]  

epoch: 0, iteration: 20252, loss: 0.14732804894447327


 96%|█████████▌| 20379/21325 [2:02:01<03:16,  4.82it/s]  

epoch: 0, iteration: 20380, loss: 0.15044447779655457


 96%|█████████▌| 20507/21325 [2:02:56<01:12, 11.25it/s]  

epoch: 0, iteration: 20508, loss: 0.1405458152294159


 97%|█████████▋| 20635/21325 [2:03:48<01:39,  6.95it/s]

epoch: 0, iteration: 20636, loss: 0.17533892393112183


 97%|█████████▋| 20764/21325 [2:04:40<02:05,  4.47it/s]

epoch: 0, iteration: 20764, loss: 0.14934562146663666


 98%|█████████▊| 20893/21325 [2:05:44<02:15,  3.18it/s]  

epoch: 0, iteration: 20893, loss: 0.16148704290390015


 99%|█████████▊| 21021/21325 [2:06:40<01:02,  4.87it/s]  

epoch: 0, iteration: 21022, loss: 0.19505535066127777


 99%|█████████▉| 21151/21325 [2:07:41<00:59,  2.95it/s]

epoch: 0, iteration: 21151, loss: 0.17896968126296997


100%|█████████▉| 21280/21325 [2:08:44<00:07,  6.36it/s]

epoch: 0, iteration: 21280, loss: 0.1609634906053543


100%|█████████▉| 21323/21325 [2:09:25<00:00, 10.08it/s]

epoch: 0, iteration: 21324, loss: 0.15466032922267914


100%|██████████| 21325/21325 [2:09:35<00:00,  1.47s/it]


In [32]:
def test_model(model, data):
    acc, total = 0, 0
    for d in data:
        q = d[0]
        print('Questions:', ' '.join(q))
        labels = d[1]
        cands = d[2]

        # preprare answer labels
        label_indices = [cands.index(l) for l in labels if l in cands]

        # build data
        q = util.make_vector([q], w2i, len(q))
        cands = [l2at[c] for c in cands] # id to text
        max_cand_len = min(200, max([len(c) for c in cands]))
        cands = util.make_vector(cands, w2i, max_cand_len)

        # predict
        scores = [model(q, c.unsqueeze(0)).data[0] for c in cands]
        pred_idx = np.argmax(scores)
        if pred_idx in label_indices:
            print('Correct Prediciton')
            acc += 1
        else:
            print('Wrong Prediction')
        total += 1
    print('Test Acc:', 100*acc/total, '%')

In [33]:
test_model(model, test)

q How To Get Health Insurance When Pregnant ?
wrong
q Can A Corporation Pay For Disability Insurance ?
wrong
q Will Homeowners Insurance Cover A Tree Falling On My Car ?
wrong
q Where To Start Looking For Health Insurance ?
wrong
q What Is Medigap Blue ?
wrong
q Is Workers Comp The Same As Disability Insurance ?
wrong
q Is Disability Insurance Necessary ?
wrong
q How Much Optional Life Insurance Do I Need ?
wrong
q What Is A Health Insurance Claim Form 1500 ?
wrong
q Can You Borrow Money From Renters Insurance ?
wrong
q Where To Get Home Insurance Quote ?
wrong
q Why Buy Variable Life Insurance ?
wrong
q How Long To Plan For Retirement ?
wrong
q Who Is The Least Expensive Auto Insurance ?
wrong
q What Is Centers For Medicare And Medicaid ?
wrong
q How To File A Renters Insurance Claim ?
wrong
q What Do I Need To Know About Buying Life Insurance ?
wrong
q What Is The Biggest Car Insurance Company In Us ?
wrong
q Does Home Owners Insurance Cover Pool Damage ?
wrong
q Why Do Smokers Pay M

wrong
q Can I Get Life Insurance If I Have Epilepsy ?
wrong
q What Does Health Insurance Protect You From ?
wrong
q Can Felons Buy Life Insurance ?
wrong
q How Many People Go Without Health Insurance ?
wrong
q What 's The Difference Between Whole Life Insurance And Endowment Insurance ?
wrong
q Does Renters Insurance Cover Hurricanes
wrong
q Can Health Insurance Charge More For Smokers ?
wrong
q Why Is Dental Care Not Covered By Medicare ?
wrong
q Where Can I Get Health Insurance In Georgia ?
wrong
q What Does Medicare Voucher System Mean ?
wrong
q Does Life Insurance Expire At A Certain Age ?
wrong
q Does Life Insurance Cover AD&D ?
wrong
q What Glucose Meter Is Covered By Medicare ?
wrong
q Who Is Western National Life Insurance Company ?
wrong
q How Much Is Medicare Per Year ?
wrong
q How Much Is A Term Life Insurance Policy ?
wrong
q Why Was Universal Life Insurance Created ?
wrong
q How Is Medicare Advantage Affected By Obamacare ?
wrong
q What The Best Life Insurance Policy To Ge

wrong
q Is Homeowners Insurance A Waste Of Money ?
wrong
q What Is The Best Medigap Insurance Company ?
wrong
q How To File A Claim With Home Insurance ?
wrong
q How To Write An Appeal Letter To Health Insurance ?
wrong
q Can You Negotiate Life Insurance Rates ?
wrong
q How Soon Should I Apply For Medicare ?
wrong
q Does My Renters Insurance Cover My Roommate ?
wrong
q Does Renters Insurance Cover A Broken Refrigerator ?
wrong
q Who Should Enroll In Medicare Part B ?
wrong
q How Long Can Children stay On Parents Auto Insurance ?
wrong
q Does AAA Offer Homeowners Insurance ?
wrong
q Who Has The Best Whole Life Insurance Rates ?
wrong
q What Does Dave Ramsey Say About Cancer Insurance ?
wrong
q Will Annuity Rates Improve In The Future ?
wrong
q When Do Companies Have To Offer Health Insurance ?
wrong
q Does Homeowners Insurance Cover Raccoon Damage ?
wrong
q What Does Annuity Due Mean ?
wrong
q What Does The Bible Say About Life Insurance ?
wrong
q Do You Need Home Insurance Before Closi

KeyboardInterrupt: 

In [40]:
# Save model
torch.save(model.state_dict(), 'first_model')


In [41]:
# Test load
m = basic_model.QA_LSTM(args)
m.load_state_dict(torch.load('first_model'))


IncompatibleKeys(missing_keys=[], unexpected_keys=[])