In [1]:
from preprocess import *
from scoring_metrics import *

import numpy as np
import torch
from torch.autograd import Variable
from torch.nn.modules.distance import CosineSimilarity

import time


# Given a matrix of tuples [ [p+,p-,p--,...],[p++,p-,p--,...] ]
# Returns the MRR score for this set
def get_MRR_score(similarity_matrix):
    rows = similarity_matrix.split(1)
    reciprocal_ranks = []
    for r in rows:
        lst_scores = list(r[0].data)
        score_pos = lst_scores[0]
        lst_sorted_scores = sorted(lst_scores, reverse=True)
        rank = lst_sorted_scores.index(score_pos) + 1
        reciprocal_ranks.append(1.0 / rank)
    return sum(reciprocal_ranks)/len(reciprocal_ranks)


# Produces tensor [1 x num_words x input_size] for one particular question
def get_question_matrix(questionID, word2vec, id2Data, input_size):
    # Get the vector representation for each word in this question as list [v1,v2,v3,...]
    q_word_vecs = []
    for word in id2Data[questionID]:
        try:
            word_vec = np.array(word2vec[word]).astype(np.float32).reshape(len(word2vec[word]), -1)
            q_word_vecs.append(word_vec)
        except KeyError:
            pass

    # num_words x dim_words
    q_matrix = torch.Tensor(np.concatenate(q_word_vecs, axis=1).T)
    num_words_found = q_matrix.size()[0]
    
    if num_words_found < 100:
        padding_rows = torch.zeros(100-num_words_found, input_size)
        q_matrix = torch.cat((q_matrix, padding_rows), 0)
    
    return [q_matrix.unsqueeze(0), num_words_found]


''' Data Prep '''
training_data = training_id_to_similar_different()
trainingQuestionIds = list(training_data.keys())[:10]
word2vec = get_words_and_embeddings()
id2Data = questionID_to_questionData_truncate(100)

dev_data = devTest_id_to_similar_different(dev=True)
devQuestionIds = list(dev_data.keys())[:2]

''' Model Specs '''
input_size = len(word2vec[list(word2vec.keys())[0]])
hidden_size = 100
num_layers = 1
bias = True
batch_first = True
bidirectional = False

'''Hyperparams'''
dropout = 0.1
margin = 0.2
lr = 10**-3

lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional)
loss_function = torch.nn.MultiMarginLoss(margin=margin)
optimizer = torch.optim.Adam(lstm.parameters(), lr=lr)

h0 = Variable(torch.zeros(1, 1, hidden_size), requires_grad=True)
c0 = Variable(torch.zeros(1, 1, hidden_size), requires_grad=True)


''' Procedural parameters '''
batch_size = 5
num_differing_questions = 20

num_epochs = 10
num_batches = round(len(trainingQuestionIds)/batch_size)


# Given ids of main qs in this batch
#
# Returns:
# 1. ids in ordered list as: 
# [
# q_1+, q_1-, q_1--,..., q_1++, q_1-, q_1--,...,
# q_2+, q_2-, q_2--,..., q_2++, q_2-, q_2--,...,
# ...
# ]
# All n main questions have their pos,neg,neg,neg,... interleaved
#
# 2. A dict mapping main question id --> its interleaved sequence length
def order_ids(q_ids, dev=False):
    global training_data
    global dev_data
    global num_differing_questions
    
    if dev: data = dev_data
    else: data = training_data
        
    sequence_ids = []
    dict_sequence_lengths = {}
    
    for q_main in q_ids:
        p_pluses = data[q_main][0]
        p_minuses = list(np.random.choice(data[q_main][1], num_differing_questions, replace = False))
        sequence_length = len(p_pluses) * num_differing_questions + len(p_pluses)
        dict_sequence_lengths[q_main] = sequence_length
        for p_plus in p_pluses:
            sequence_ids += [p_plus] + p_minuses

    return sequence_ids, dict_sequence_lengths


'''Matrix constructors (use global vars, leave in order)'''
# A tuple is (q+, q-, q--, q--- ...)
# Let all main questions be set Q
# Each q in Q has a number of tuples equal to number of positives |q+, q++, ...|
# Each q in Q will have a 2D matrix of: num_tuples x num_candidates_in_tuple
# Concatenate this matrix for all q in Q and you get a matrix of: |Q| x num_tuples x num_candidates_in_tuple

# The above is for candidates
# To do cosine_similarity, need same structure with q's
# Basically each q will be a matrix of repeated q's: num_tuples x num_candidates_in_tuple, all elts are q (repeated)

# This method constructs those matrices, use candidates=True for candidates matrix

def construct_qs_matrix(q_ids_sequential, dict_sequence_lengths, candidates=False):
    global lstm, h0, c0, word2vec, id2data, input_size, num_differing_questions
    
    if not candidates:
        q_ids_complete = []
        for q in q_ids_sequential:
            q_ids_complete += [q] * dict_sequence_lengths[q]
    
    else: q_ids_complete = q_ids_sequential

    qs_matrix_list = []
    qs_seq_length = []
    
    for q in q_ids_complete:
        q_matrix_3d, q_num_words = get_question_matrix(q, word2vec, id2Data, input_size)
        qs_matrix_list.append(q_matrix_3d)
        qs_seq_length.append(q_num_words)

    qs_padded = Variable(torch.cat(qs_matrix_list, 0))
    qs_hidden = lstm(qs_padded, (h0, c0)) # [ [num_q, num_word_per_q, hidden_size] i.e. all hidden, [1, num_q, hidden_size]  i.e. final hidden]
    sum_h_qs = torch.sum(qs_hidden[0], dim=1)
    mean_pooled_h_qs = torch.div(sum_h_qs, torch.autograd.Variable(torch.FloatTensor(qs_seq_length)[:, np.newaxis]))
    qs_tuples = mean_pooled_h_qs.split(1+num_differing_questions)
    final_matrix_tuples_by_constituent_qs_by_hidden_size = torch.stack(qs_tuples, dim=0, out=None)
    return final_matrix_tuples_by_constituent_qs_by_hidden_size


'''Begin training'''

for epoch in range(num_epochs):
    
    for batch in range(1, num_batches+1):
        start = time.time()
        
        print("Working on batch #: ", batch)
        
        optimizer.zero_grad()
        questions_this_batch = trainingQuestionIds[batch_size * (batch - 1):batch_size * batch]
        sequence_ids, dict_sequence_lengths = order_ids(questions_this_batch)

        candidates_qs_tuples_matrix = construct_qs_matrix(sequence_ids, dict_sequence_lengths, candidates=True)
        main_qs_tuples_matrix = construct_qs_matrix(questions_this_batch, dict_sequence_lengths, candidates=False)
        
        similarity_matrix = torch.nn.functional.cosine_similarity(candidates_qs_tuples_matrix, main_qs_tuples_matrix, dim=2, eps=1e-08)

        target = Variable(torch.LongTensor([0] * int(len(sequence_ids)/(1+num_differing_questions))))
        loss_batch = loss_function(similarity_matrix, target)

        loss_batch.backward()

        optimizer.step()
        
        print("loss_on_batch:", loss_batch.data[0], " time_on_batch:", time.time() - start)
        

    '''Dev eval after each epoch'''
    
    optimizer.zero_grad()
    sequence_ids, dict_sequence_lengths = order_ids(devQuestionIds, dev=True)

    candidates_qs_tuples_matrix = construct_qs_matrix(sequence_ids, dict_sequence_lengths, candidates=True)
    main_qs_tuples_matrix = construct_qs_matrix(devQuestionIds, dict_sequence_lengths, candidates=False)

    similarity_matrix = torch.nn.functional.cosine_similarity(candidates_qs_tuples_matrix, main_qs_tuples_matrix, dim=2, eps=1e-08)

    MRR_score = get_MRR_score(similarity_matrix)
    
    with open('logs.txt', 'a') as log_file:
        log_file.write('epoch: ' + str(epoch) + '\n')
        log_file.write('lr: ' + str(lr) +  ' marg: ' + str(margin) + ' drop: ' + str(dropout) + '\n' )        
        log_file.write('MRR: ' +  str(MRR_score) + '\n')

    print("MRR score on evaluation set:", MRR_score)

    target = Variable(torch.LongTensor([0] * int(len(sequence_ids)/(1+num_differing_questions))))
    loss_batch = loss_function(similarity_matrix, target)

    loss_batch.backward()

    optimizer.step()
    
    
#     '''Save model for this epoch'''
    
#     torch.save(lstm, '../Pickle/LSTM_m2d2l3epoch' + str(epoch) + '.pt')


Working on batch #:  1
loss_on_batch: 0.18772007524967194  time_on_batch: 2.4718480110168457
Working on batch #:  2
loss_on_batch: 0.18679234385490417  time_on_batch: 1.761749029159546
scores [0.9962385892868042, 0.9947255849838257, 0.9980756044387817, 0.9975488185882568, 0.9964129328727722, 0.9953456521034241, 0.9955769181251526, 0.9986089468002319, 0.9974638223648071, 0.9988320469856262, 0.9977778196334839, 0.9969286918640137, 0.9951834082603455, 0.9962385892868042, 0.9968279004096985, 0.9983557462692261, 0.9953375458717346, 0.9974240660667419, 0.9924705028533936, 0.9952450394630432, 0.9970040917396545]
my score 0.9962385892868042
descending [0.9988320469856262, 0.9986089468002319, 0.9983557462692261, 0.9980756044387817, 0.9977778196334839, 0.9975488185882568, 0.9974638223648071, 0.9974240660667419, 0.9970040917396545, 0.9969286918640137, 0.9968279004096985, 0.9964129328727722, 0.9962385892868042, 0.9962385892868042, 0.9955769181251526, 0.9953456521034241, 0.9953375458717346, 0.99524

Working on batch #:  1
loss_on_batch: 0.18417543172836304  time_on_batch: 1.86604905128479
Working on batch #:  2
loss_on_batch: 0.17990969121456146  time_on_batch: 1.3840439319610596
scores [0.9922423362731934, 0.993614673614502, 0.9921610951423645, 0.9922423362731934, 0.9908297061920166, 0.9909071326255798, 0.9910522103309631, 0.9912462830543518, 0.9979712963104248, 0.9933276772499084, 0.9972124695777893, 0.9946883320808411, 0.9955887198448181, 0.9972532987594604, 0.9845411777496338, 0.9959691166877747, 0.9961565732955933, 0.9895541667938232, 0.9981297254562378, 0.9947322010993958, 0.9964109063148499]
my score 0.9922423362731934
descending [0.9981297254562378, 0.9979712963104248, 0.9972532987594604, 0.9972124695777893, 0.9964109063148499, 0.9961565732955933, 0.9959691166877747, 0.9955887198448181, 0.9947322010993958, 0.9946883320808411, 0.993614673614502, 0.9933276772499084, 0.9922423362731934, 0.9922423362731934, 0.9921610951423645, 0.9912462830543518, 0.9910522103309631, 0.99090713

Working on batch #:  1
loss_on_batch: 0.17353467643260956  time_on_batch: 1.895920991897583
Working on batch #:  2
loss_on_batch: 0.17179527878761292  time_on_batch: 1.4222626686096191
scores [0.9834895133972168, 0.9869424700737, 0.9823215007781982, 0.9863227605819702, 0.9814980030059814, 0.9783787131309509, 0.9671558737754822, 0.9960179924964905, 0.9968607425689697, 0.9908470511436462, 0.9951043128967285, 0.99701988697052, 0.9866730570793152, 0.9943199753761292, 0.9819519519805908, 0.9947921633720398, 0.990044355392456, 0.9932354688644409, 0.9908567667007446, 0.9834895133972168, 0.9807136654853821]
my score 0.9834895133972168
descending [0.99701988697052, 0.9968607425689697, 0.9960179924964905, 0.9951043128967285, 0.9947921633720398, 0.9943199753761292, 0.9932354688644409, 0.9908567667007446, 0.9908470511436462, 0.990044355392456, 0.9869424700737, 0.9866730570793152, 0.9863227605819702, 0.9834895133972168, 0.9834895133972168, 0.9823215007781982, 0.9819519519805908, 0.9814980030059814,

Working on batch #:  1
loss_on_batch: 0.15767152607440948  time_on_batch: 1.918045997619629
Working on batch #:  2
loss_on_batch: 0.1571490466594696  time_on_batch: 1.7790768146514893
scores [0.9667055010795593, 0.9913989305496216, 0.995366632938385, 0.9912290573120117, 0.9606286287307739, 0.9808458089828491, 0.9883272051811218, 0.9837499856948853, 0.9941472411155701, 0.9815576076507568, 0.9646101593971252, 0.9578217267990112, 0.9311615824699402, 0.9945917129516602, 0.9772298336029053, 0.9747296571731567, 0.9922195076942444, 0.9642335772514343, 0.9631184339523315, 0.9714096784591675, 0.9667055010795593]
my score 0.9667055010795593
descending [0.995366632938385, 0.9945917129516602, 0.9941472411155701, 0.9922195076942444, 0.9913989305496216, 0.9912290573120117, 0.9883272051811218, 0.9837499856948853, 0.9815576076507568, 0.9808458089828491, 0.9772298336029053, 0.9747296571731567, 0.9714096784591675, 0.9667055010795593, 0.9667055010795593, 0.9646101593971252, 0.9642335772514343, 0.96311843

Working on batch #:  1
loss_on_batch: 0.14384378492832184  time_on_batch: 1.9275918006896973
Working on batch #:  2
loss_on_batch: 0.11622164398431778  time_on_batch: 1.4006409645080566
scores [0.9583896994590759, 0.9825911521911621, 0.9676429033279419, 0.9548956751823425, 0.9495834708213806, 0.9567411541938782, 0.9887893795967102, 0.9721230864524841, 0.9775277972221375, 0.9770222306251526, 0.9921197891235352, 0.9872093200683594, 0.9608380198478699, 0.8873852491378784, 0.9446130990982056, 0.967731237411499, 0.9916563034057617, 0.9930668473243713, 0.9583896994590759, 0.9888454079627991, 0.9488916397094727]
my score 0.9583896994590759
descending [0.9930668473243713, 0.9921197891235352, 0.9916563034057617, 0.9888454079627991, 0.9887893795967102, 0.9872093200683594, 0.9825911521911621, 0.9775277972221375, 0.9770222306251526, 0.9721230864524841, 0.967731237411499, 0.9676429033279419, 0.9608380198478699, 0.9583896994590759, 0.9583896994590759, 0.9567411541938782, 0.9548956751823425, 0.949583

Working on batch #:  1
loss_on_batch: 0.12803234159946442  time_on_batch: 1.9010870456695557
Working on batch #:  2
loss_on_batch: 0.10104205459356308  time_on_batch: 1.3933238983154297
scores [0.970180869102478, 0.9626343250274658, 0.9696707725524902, 0.9800425171852112, 0.9918833374977112, 0.9794012308120728, 0.9754745960235596, 0.9756754636764526, 0.9862624406814575, 0.9843453168869019, 0.9924389123916626, 0.9635952711105347, 0.9675367474555969, 0.9676913619041443, 0.8871744871139526, 0.970180869102478, 0.9714919924736023, 0.9906483292579651, 0.9905983209609985, 0.9836387634277344, 0.9560202956199646]
my score 0.970180869102478
descending [0.9924389123916626, 0.9918833374977112, 0.9906483292579651, 0.9905983209609985, 0.9862624406814575, 0.9843453168869019, 0.9836387634277344, 0.9800425171852112, 0.9794012308120728, 0.9756754636764526, 0.9754745960235596, 0.9714919924736023, 0.970180869102478, 0.970180869102478, 0.9696707725524902, 0.9676913619041443, 0.9675367474555969, 0.963595271

Working on batch #:  1
loss_on_batch: 0.1085008755326271  time_on_batch: 1.9827311038970947
Working on batch #:  2
loss_on_batch: 0.06663309037685394  time_on_batch: 1.4795432090759277
scores [0.977342963218689, 0.9795952439308167, 0.9536849856376648, 0.980187714099884, 0.976687490940094, 0.9801796078681946, 0.9925503730773926, 0.9905223250389099, 0.9813316464424133, 0.972704291343689, 0.9808725118637085, 0.9716188311576843, 0.9784212708473206, 0.977342963218689, 0.9881708025932312, 0.9814333319664001, 0.9766296148300171, 0.8838210105895996, 0.9902105331420898, 0.9753222465515137, 0.9909160733222961]
my score 0.977342963218689
descending [0.9925503730773926, 0.9909160733222961, 0.9905223250389099, 0.9902105331420898, 0.9881708025932312, 0.9814333319664001, 0.9813316464424133, 0.9808725118637085, 0.980187714099884, 0.9801796078681946, 0.9795952439308167, 0.9784212708473206, 0.977342963218689, 0.977342963218689, 0.976687490940094, 0.9766296148300171, 0.9753222465515137, 0.972704291343689

Working on batch #:  1
loss_on_batch: 0.08181502670049667  time_on_batch: 2.1473090648651123
Working on batch #:  2
loss_on_batch: 0.05690991133451462  time_on_batch: 1.9216089248657227
scores [0.9746415615081787, 0.9652218222618103, 0.9907729625701904, 0.9752257466316223, 0.9875660538673401, 0.9827098250389099, 0.973492443561554, 0.9769141674041748, 0.7924537062644958, 0.9734825491905212, 0.9685850143432617, 0.9874421954154968, 0.9706308841705322, 0.9746415615081787, 0.9814789891242981, 0.9797103404998779, 0.9727612733840942, 0.9591606259346008, 0.9206885099411011, 0.9874862432479858, 0.9739153385162354]
my score 0.9746415615081787
descending [0.9907729625701904, 0.9875660538673401, 0.9874862432479858, 0.9874421954154968, 0.9827098250389099, 0.9814789891242981, 0.9797103404998779, 0.9769141674041748, 0.9752257466316223, 0.9746415615081787, 0.9746415615081787, 0.9739153385162354, 0.973492443561554, 0.9734825491905212, 0.9727612733840942, 0.9706308841705322, 0.9685850143432617, 0.965221

Working on batch #:  1
loss_on_batch: 0.042748793959617615  time_on_batch: 1.9721949100494385
Working on batch #:  2
loss_on_batch: 0.047602854669094086  time_on_batch: 1.4090549945831299
scores [0.9575063586235046, 0.9523175954818726, 0.5928608179092407, 0.8326156139373779, 0.964106559753418, 0.9738541841506958, 0.957752525806427, 0.9534027576446533, 0.9575063586235046, 0.9696760773658752, 0.9640253782272339, 0.9261384010314941, 0.9760856628417969, 0.9394931197166443, 0.916240394115448, 0.9819503426551819, 0.9827983975410461, 0.9870944619178772, 0.9567950963973999, 0.9725953340530396, 0.9614583253860474]
my score 0.9575063586235046
descending [0.9870944619178772, 0.9827983975410461, 0.9819503426551819, 0.9760856628417969, 0.9738541841506958, 0.9725953340530396, 0.9696760773658752, 0.964106559753418, 0.9640253782272339, 0.9614583253860474, 0.957752525806427, 0.9575063586235046, 0.9575063586235046, 0.9567950963973999, 0.9534027576446533, 0.9523175954818726, 0.9394931197166443, 0.9261384

Working on batch #:  1


KeyboardInterrupt: 