In [1]:
from parameters_lstm import *

from preprocess_datapoints_lstm import *
from preprocess_text_to_tensors_lstm import *
from meter import *

import torch
from torch.autograd import Variable
import time

saved_model_name = 'lstm_dir_trans'

# Initialize the data sets
processed_corpus = process_whole_corpuses()
word_to_id_vocab = processed_corpus['word_to_id']
word2vec = load_glove_embeddings(glove_path, word_to_id_vocab)
ubuntu_id_to_data_title = processed_corpus['ubuntu_id_to_data_title']
android_id_to_data_title = processed_corpus['android_id_to_data_title']
ubuntu_id_to_data_body = processed_corpus['ubuntu_id_to_data_body']
android_id_to_data_body = processed_corpus['android_id_to_data_body']


''' Data Sets '''
training_data_ubuntu = ubuntu_id_to_similar_different()
training_question_ids_ubuntu = list(training_data_ubuntu.keys())
dev_data_android = android_id_to_similar_different(dev=True)
dev_question_ids_android = list(dev_data_android.keys())
test_data_android = android_id_to_similar_different(dev=False)
test_question_ids_android = list(test_data_android.keys())
# Note: Remember to edit batch_size accordingly if testing on smaller size data sets

In [2]:
def eval_model(lstm, ids, data, word2vec, id2Data_title, id2Data_body, word_to_id_vocab, truncation_val_title, truncation_val_body):
    lstm.eval()
    auc_scorer.reset()

    candidate_ids, q_main_ids, labels = organize_test_ids(ids, data)
    num_q_main = len(q_main_ids)
    len_pieces = round(num_q_main/50)
    print(num_q_main)

    for i in range(0, num_q_main, len_pieces):
        print(i, end = ' ')
        q_main_id_num_repl_tuple = q_main_ids[i:i+len_pieces]
        candidates = candidate_ids[i:i+len_pieces]
        current_labels = torch.from_numpy(np.array(labels[i:i+len_pieces])).long()

        candidates_qs_matrix = construct_qs_matrix_testing(candidates, lstm, h0, c0, word2vec, id2Data_title, id2Data_body,
        word_to_id_vocab, truncation_val_title, truncation_val_body, main=False)
        main_qs_matrix = construct_qs_matrix_testing(q_main_id_num_repl_tuple, lstm, h0, c0, word2vec, id2Data_title, id2Data_body,
        word_to_id_vocab, truncation_val_title, truncation_val_body, main=True)

        similarity_matrix_this_batch = torch.nn.functional.cosine_similarity(candidates_qs_matrix, main_qs_matrix, eps=1e-08).data
        auc_scorer.add(similarity_matrix_this_batch, current_labels)

    auc_score = auc_scorer.value()

    return auc_score

In [3]:
''' Params Dashboard '''

''' Procedural parameters '''
batch_size = 40
num_differing_questions = 20
num_epochs = 2


''' Model specs LSTM '''
dropout = 0.2
margin = 0.15
lr_lstm = 10**-3

input_size = 300
hidden_size = 240
num_layers = 1
bias = True
batch_first = True
bidirectional = True
first_dim = num_layers * 2 if bidirectional else num_layers


''' Model specs NN '''
lr_nn = -10**-4
lamb = 10**-3

input_size_nn = 2*hidden_size if bidirectional else hidden_size
first_hidden_size_nn = 300
second_hidden_size_nn = 150


''' Data processing specs '''
truncation_val_title = 15
truncation_val_body = 85
padding_idx = 0

glove_path = '../glove.840B.300d.txt'
android_corpus_path = '../android_dataset/corpus.tsv'
ubuntu_corpus_path = '../ubuntu_dataset/text_tokenized.txt'

In [6]:
''' Encoder (LSTM) '''
lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional)
loss_function_lstm = torch.nn.MultiMarginLoss(margin=margin)
optimizer_lstm = torch.optim.Adam(lstm.parameters(), lr=lr_lstm, weight_decay = 0.00001)

h0 = Variable(torch.zeros(first_dim, 1, hidden_size), requires_grad=False)
c0 = Variable(torch.zeros(first_dim, 1, hidden_size), requires_grad=False)


''' Procedural parameters '''
num_batches = round(len(training_question_ids_ubuntu) / batch_size)
auc_scorer = AUCMeter()


def train_lstm_question_similarity(lstm, batch_ids, batch_data, word2vec, id2Data_title, id2Data_body, word_to_id_vocab, truncation_val_title, truncation_val_body):
    lstm.train()
    sequence_ids, dict_sequence_lengths = organize_ids_training(batch_ids, batch_data, num_differing_questions)

    candidates_qs_tuples_matrix = construct_qs_matrix_training(sequence_ids, lstm, h0, c0, word2vec, id2Data_title, id2Data_body,
        dict_sequence_lengths, num_differing_questions, word_to_id_vocab, truncation_val_title, truncation_val_body, candidates=True)
    main_qs_tuples_matrix = construct_qs_matrix_training(batch_ids, lstm, h0, c0, word2vec, id2Data_title, id2Data_body,
        dict_sequence_lengths, num_differing_questions, word_to_id_vocab, truncation_val_title, truncation_val_body, candidates=False)

    similarity_matrix = torch.nn.functional.cosine_similarity(candidates_qs_tuples_matrix, main_qs_tuples_matrix, dim=2, eps=1e-6)
    target = Variable(torch.LongTensor([0] * int(len(sequence_ids) / (1 + num_differing_questions))))
    loss_batch = loss_function_lstm(similarity_matrix, target)

    print("lstm multi-margin loss on batch:", loss_batch.data[0])
    return loss_batch


'''Begin training'''
for epoch in range(num_epochs):

    # Train on whole training data set
    for batch in range(1, num_batches + 1):
        if batch == 93 or batch == 301:
            continue
        start = time.time()
        optimizer_lstm.zero_grad()
        print("Working on batch #: ", batch)

        # Train on ubuntu similar question retrieval
        ids_this_batch_for_lstm = training_question_ids_ubuntu[batch_size * (batch - 1):batch_size * batch]
        loss_batch_similarity = train_lstm_question_similarity(lstm, ids_this_batch_for_lstm,
        training_data_ubuntu, word2vec, ubuntu_id_to_data_title, ubuntu_id_to_data_body, word_to_id_vocab, truncation_val_title, truncation_val_body)

        overall_loss = loss_batch_similarity
        overall_loss.backward()
        optimizer_lstm.step()

        print("Time on batch:", time.time() - start)
        
        if batch == 1 or batch % 100 == 0:
            # Save model for this epoch
            torch.save(lstm, '../Pickle/' + saved_model_name + '_e' + str(epoch) + '_b' + str(batch) + '.pth')

            # Save optimizer for this epoch
            torch.save(optimizer_lstm, '../Pickle/' + 'optim_lstm_direct_transfer' + '_e' + str(epoch) + '_b' + str(batch) + '.pth')
            
        if batch % 100 == 0:
            # Evaluate on dev set for AUC score
            dev_AUC_score = eval_model(lstm, dev_question_ids_android, dev_data_android, word2vec, android_id_to_data_title, android_id_to_data_body,
                            word_to_id_vocab, truncation_val_title, truncation_val_body)

            print("Dev AUC score:", dev_AUC_score)

Working on batch #:  1
lstm multi-margin loss on batch: 0.10016369074583054
Time on batch: 38.725939989089966
Working on batch #:  2
lstm multi-margin loss on batch: 0.07415125519037247
Time on batch: 52.00062870979309
Working on batch #:  3
lstm multi-margin loss on batch: 0.049514882266521454
Time on batch: 42.72429919242859
Working on batch #:  4
lstm multi-margin loss on batch: 0.03558170795440674
Time on batch: 33.90351915359497
Working on batch #:  5
lstm multi-margin loss on batch: 0.027691852301359177
Time on batch: 164.3411569595337
Working on batch #:  6
lstm multi-margin loss on batch: 0.016386093571782112
Time on batch: 42.048569440841675
Working on batch #:  7
lstm multi-margin loss on batch: 0.013205495662987232
Time on batch: 68.38670110702515
Working on batch #:  8
lstm multi-margin loss on batch: 0.012787408195436
Time on batch: 35.018958568573
Working on batch #:  9
lstm multi-margin loss on batch: 0.018541747704148293
Time on batch: 28.809173107147217
Working on batc

lstm multi-margin loss on batch: 0.01374753937125206
Time on batch: 32.91535711288452
Working on batch #:  76
lstm multi-margin loss on batch: 0.008533590473234653
Time on batch: 27.039607524871826
Working on batch #:  77
lstm multi-margin loss on batch: 0.008358956314623356
Time on batch: 24.920485734939575
Working on batch #:  78
lstm multi-margin loss on batch: 0.005405467934906483
Time on batch: 25.544949769973755
Working on batch #:  79
lstm multi-margin loss on batch: 0.009258832782506943
Time on batch: 38.40528082847595
Working on batch #:  80
lstm multi-margin loss on batch: 0.011072542518377304
Time on batch: 31.269765615463257
Working on batch #:  81
lstm multi-margin loss on batch: 0.006201568059623241
Time on batch: 37.517468214035034
Working on batch #:  82
lstm multi-margin loss on batch: 0.007232066243886948
Time on batch: 27.48449444770813
Working on batch #:  83
lstm multi-margin loss on batch: 0.008635452948510647
Time on batch: 37.25682091712952
Working on batch #:  

Time on batch: 61.181384801864624
Working on batch #:  147
lstm multi-margin loss on batch: 0.005511469207704067
Time on batch: 30.263870239257812
Working on batch #:  148
lstm multi-margin loss on batch: 0.005371310748159885
Time on batch: 31.637228965759277
Working on batch #:  149
lstm multi-margin loss on batch: 0.011021154932677746
Time on batch: 28.040550470352173
Working on batch #:  150
lstm multi-margin loss on batch: 0.006077388767153025
Time on batch: 26.61675524711609
Working on batch #:  151
lstm multi-margin loss on batch: 0.007018269971013069
Time on batch: 41.56206679344177
Working on batch #:  152
lstm multi-margin loss on batch: 0.014008154161274433
Time on batch: 91.24815225601196
Working on batch #:  153
lstm multi-margin loss on batch: 0.013529484160244465
Time on batch: 37.70327925682068
Working on batch #:  154
lstm multi-margin loss on batch: 0.010403303429484367
Time on batch: 50.49574542045593
Working on batch #:  155
lstm multi-margin loss on batch: 0.0087623

Time on batch: 35.289726972579956
Working on batch #:  217
lstm multi-margin loss on batch: 0.010398208163678646
Time on batch: 26.398828983306885
Working on batch #:  218
lstm multi-margin loss on batch: 0.003970896825194359
Time on batch: 29.9364914894104
Working on batch #:  219
lstm multi-margin loss on batch: 0.007909082807600498
Time on batch: 31.946791172027588
Working on batch #:  220
lstm multi-margin loss on batch: 0.01221136562526226
Time on batch: 197.52108597755432
Working on batch #:  221
lstm multi-margin loss on batch: 0.010511831380426884
Time on batch: 56.519999980926514
Working on batch #:  222
lstm multi-margin loss on batch: 0.004416931886225939
Time on batch: 32.648961544036865
Working on batch #:  223
lstm multi-margin loss on batch: 0.002608521841466427
Time on batch: 27.58937406539917
Working on batch #:  224
lstm multi-margin loss on batch: 0.0053591299802064896
Time on batch: 31.74249291419983
Working on batch #:  225
lstm multi-margin loss on batch: 0.020004

RuntimeError: $ Torch: not enough memory: you tried to reallocate 0GB. Buy new RAM! at D:\Projects\pytorch\torch\lib\TH\THGeneral.c:298

In [28]:
''' Encoder (LSTM) '''
#lstm_x = torch.load('../Pickle/lstm_dir_trans_title_body_p1_epoch1_batch200.pt')

''' Procedural parameters '''
auc_scorer = AUCMeter()

# Evaluate on dev set for AUC score
test_AUC_score = eval_model(lstm, test_question_ids_android, test_data_android, word2vec, android_id_to_data_title, android_id_to_data_body,
                            word_to_id_vocab, truncation_val_title, truncation_val_body)

print("Test AUC score:", test_AUC_score)

119685
0 4787 9574 14361 19148 23935 28722 33509 38296 43083 47870 52657 57444 62231 67018 71805 76592 81379 86166 90953 95740 100527 105314 110101 114888 119675 Test AUC score: 0.554701771596
