In [1]:
from preprocess import *
from scoring_metrics import *
from cnn_utils import *

import torch
from torch.autograd import Variable

import time

saved_model_name = "best_cnn_title_body"

'''Hyperparams dashboard'''
margin = 0.3
lr = 10**-3
truncation_val_title = 40
truncation_val_body = 60


''' Data Prep '''
word2vec = get_words_and_embeddings()
id2Data = questionID_to_questionData_truncate(truncation_val_title, truncation_val_body)

training_data = training_id_to_similar_different()
trainingQuestionIds = list(training_data.keys())

dev_data = devTest_id_to_similar_different(dev=True)
dev_question_ids = list(dev_data.keys())

test_data = devTest_id_to_similar_different(dev=False)
test_question_ids = list(test_data.keys())

In [7]:
''' Model Specs '''
# CNN parameters
input_size = len(word2vec[list(word2vec.keys())[0]])
hidden_size = 667
kernel_size = 3
stride = 1
padding = 0
dilation = 1
groups = 1
bias = True

# CNN model
#cnn = torch.nn.Sequential()
#cnn.add_module('conv', torch.nn.Conv1d(in_channels = input_size, out_channels = hidden_size, kernel_size = kernel_size, padding = padding, dilation = dilation, groups = groups, bias = bias))
#cnn.add_module('tanh', torch.nn.Tanh())
#cnn.add_module('norm', torch.nn.BatchNorm1d(num_features = hidden_size))

# Loss function
#loss_function = torch.nn.MultiMarginLoss(margin=margin)

# Optimizer
#optimizer = torch.optim.Adam(cnn.parameters(), lr=lr)


''' Procedural parameters '''
batch_size = 20
num_differing_questions = 20
num_epochs = 10
num_batches = round(len(trainingQuestionIds)/batch_size)


def train_model(cnn, optimizer, batch_ids, batch_data, word2vec, id2Data, truncation_val_title, truncation_val_body):
    cnn.train()
    optimizer.zero_grad()

    sequence_ids, dict_sequence_lengths = organize_ids_training(batch_ids, batch_data, num_differing_questions)

    candidates_qs_tuples_matrix = construct_qs_matrix_training(sequence_ids, cnn, word2vec, id2Data, dict_sequence_lengths, input_size, num_differing_questions, truncation_val_title,\
                                                               truncation_val_body, candidates=True)
    main_qs_tuples_matrix = construct_qs_matrix_training(batch_ids, cnn, word2vec, id2Data, dict_sequence_lengths, input_size, num_differing_questions, truncation_val_title,\
                                                         truncation_val_body, candidates=False)
    similarity_matrix = torch.nn.functional.cosine_similarity(candidates_qs_tuples_matrix, main_qs_tuples_matrix, dim=2, eps=1e-08)

    target = Variable(torch.LongTensor([0] * int(len(sequence_ids)/(1+num_differing_questions))))
    loss_batch = loss_function(similarity_matrix, target)

    loss_batch.backward()
    optimizer.step()

    print("loss_on_batch:", loss_batch.data[0], " time_on_batch:", time.time() - start)
    return


def eval_model(cnn, ids, data, word2vec, id2Data, truncation_val_title, truncation_val_body):
    cnn.eval()
    sequence_ids, p_pluses_indices_dict = organize_test_ids(ids, data)

    candidates_qs_tuples_matrix = construct_qs_matrix_testing(sequence_ids, cnn, word2vec, id2Data, input_size, num_differing_questions, truncation_val_title, truncation_val_body,\
                                                              candidates=True)
    main_qs_tuples_matrix = construct_qs_matrix_testing(ids, cnn, word2vec, id2Data, input_size, num_differing_questions, truncation_val_title, truncation_val_body, candidates=False)

    similarity_matrix = torch.nn.functional.cosine_similarity(candidates_qs_tuples_matrix, main_qs_tuples_matrix, dim=2, eps=1e-08)
    MRR_score = get_MRR_score(similarity_matrix, p_pluses_indices_dict)
    MAP_score = get_MAP_score(similarity_matrix, p_pluses_indices_dict)
    avg_prec_at_1 = avg_precision_at_k(similarity_matrix, p_pluses_indices_dict, 1)
    avg_prec_at_5 = avg_precision_at_k(similarity_matrix, p_pluses_indices_dict, 5) 
    return MRR_score, MAP_score, avg_prec_at_1, avg_prec_at_5


'''Begin training'''
for epoch in range(num_epochs):

    # Train on whole training data set
    for batch in range(1, num_batches+1):
        #if batch == 121 or batch == 98: # memory error with this batch
        #    continue
        start = time.time()
        questions_this_training_batch = trainingQuestionIds[batch_size * (batch - 1):batch_size * batch]
        print("Working on batch #: ", batch)
        train_model(cnn, optimizer, questions_this_training_batch, training_data, word2vec, id2Data, truncation_val_title, truncation_val_body)
        
        if batch == 1 or batch % 50 == 0:
            
            # Evaluate on dev and test sets for MRR score

            dev_scores = eval_model(cnn, dev_question_ids, dev_data, word2vec, id2Data, truncation_val_title, truncation_val_body)
            test_scores = eval_model(cnn, test_question_ids, test_data, word2vec, id2Data, truncation_val_title, truncation_val_body)
            print("MRR score on dev set:", dev_scores[0])
            print("MRR score on test set:", test_scores[0])
            print("MAP score on dev set:", dev_scores[1])
            print("MAP score on test set:", test_scores[1])
            print("Precision at 1 score on dev set:", dev_scores[2])
            print("Precision at 1 score on test set:", test_scores[2])
            print("Precision at 5 score on dev set:", dev_scores[3])
            print("Precision at 5 score on test set:", test_scores[3])

            # Log results to local logs.txt file
            with open('logs_cnn2.txt', 'a') as log_file:
                log_file.write('epoch: ' + str(epoch) + '\n')
                log_file.write('batch: ' + str(batch) + '\n')
                log_file.write('lr: ' + str(lr) +  ' marg: ' + str(margin) + '\n' )        
                log_file.write('dev_MRR: ' +  str(dev_scores[0]) + '\n')
                log_file.write('test_MRR: ' +  str(test_scores[0]) + '\n')
                log_file.write('dev_MAP: ' +  str(dev_scores[1]) + '\n')
                log_file.write('test_MAP: ' +  str(test_scores[1]) + '\n')
                log_file.write('dev_p_at_1: ' +  str(dev_scores[2]) + '\n')
                log_file.write('test_p_at_1: ' +  str(test_scores[2]) + '\n')
                log_file.write('dev_p_at_5: ' +  str(dev_scores[3]) + '\n')
                log_file.write('test_p_at_5: ' +  str(test_scores[3]) + '\n')

            # Save model for this epoch
            torch.save(cnn, '../Pickle/' + saved_model_name + '_epoch' + str(epoch) + '_batch' + str(batch)+ '.pt')
            # Save optimizer for this epoch
            torch.save(optimizer, '../Pickle/' + 'optim_cnn2' + '_epoch' + str(epoch) + '_batch' + str(batch) +'.pth')

Working on batch #:  1
loss_on_batch: 0.012950340285897255  time_on_batch: 21.740819215774536
MRR score on dev set: 0.6679373215818497
MRR score on test set: 0.6643294471155751
MAP score on dev set: 0.536302686749
MAP score on test set: 0.543962369513
Precision at 1 score on dev set: 0.515
Precision at 1 score on test set: 0.475
Precision at 5 score on dev set: 0.3880000000000002
Precision at 5 score on test set: 0.37700000000000017
Working on batch #:  2
loss_on_batch: 0.028630465269088745  time_on_batch: 13.067741870880127
Working on batch #:  3
loss_on_batch: 0.03022531047463417  time_on_batch: 32.06428074836731
Working on batch #:  4
loss_on_batch: 0.01395374070852995  time_on_batch: 12.967473030090332
Working on batch #:  5
loss_on_batch: 0.021925244480371475  time_on_batch: 16.046658515930176
Working on batch #:  6
loss_on_batch: 0.02627919428050518  time_on_batch: 22.281232595443726
Working on batch #:  7
loss_on_batch: 0.030992334708571434  time_on_batch: 12.543344259262085
Wor

loss_on_batch: 0.010863441973924637  time_on_batch: 13.13291335105896
Working on batch #:  82
loss_on_batch: 0.029183262959122658  time_on_batch: 12.078110218048096
Working on batch #:  83
loss_on_batch: 0.0250753965228796  time_on_batch: 45.45584154129028
Working on batch #:  84
loss_on_batch: 0.02293611504137516  time_on_batch: 13.958107233047485
Working on batch #:  85
loss_on_batch: 0.02669551968574524  time_on_batch: 22.289252758026123
Working on batch #:  86
loss_on_batch: 0.021008215844631195  time_on_batch: 16.70139980316162
Working on batch #:  87
loss_on_batch: 0.058786336332559586  time_on_batch: 156.31770634651184
Working on batch #:  88
loss_on_batch: 0.014715636149048805  time_on_batch: 15.336303234100342
Working on batch #:  89
loss_on_batch: 0.009951694868505001  time_on_batch: 14.569249391555786
Working on batch #:  90
loss_on_batch: 0.02232053503394127  time_on_batch: 12.438554286956787
Working on batch #:  91
loss_on_batch: 0.014329861849546432  time_on_batch: 13.887

loss_on_batch: 0.00857936404645443  time_on_batch: 13.25423550605774
Working on batch #:  162
loss_on_batch: 0.01584712229669094  time_on_batch: 25.093711376190186
Working on batch #:  163
loss_on_batch: 0.019255228340625763  time_on_batch: 12.552371740341187
Working on batch #:  164
loss_on_batch: 0.030755311250686646  time_on_batch: 12.565404653549194
Working on batch #:  165
loss_on_batch: 0.022939734160900116  time_on_batch: 17.13655686378479
Working on batch #:  166
loss_on_batch: 0.01450288761407137  time_on_batch: 20.216744422912598
Working on batch #:  167
loss_on_batch: 0.016805989667773247  time_on_batch: 16.795657873153687
Working on batch #:  168
loss_on_batch: 0.014187564142048359  time_on_batch: 13.350488662719727
Working on batch #:  169
loss_on_batch: 0.019240472465753555  time_on_batch: 15.900268077850342
Working on batch #:  170
loss_on_batch: 0.020914098247885704  time_on_batch: 11.499572277069092
Working on batch #:  171
loss_on_batch: 0.03653458133339882  time_on_b

KeyboardInterrupt: 