In [29]:
from preprocess import *

import numpy as np
import torch
from torch.autograd import Variable
from torch.nn.modules.distance import CosineSimilarity

import time

# Produces tensor [1 x num_words x input_size] for one particular question
def get_question_matrix(questionID, word2vec, id2Data, input_size, truncation_val):
    # Get the vector representation for each word in this question as list [v1,v2,v3,...]
    q_word_vecs = []
    for word in id2Data[questionID]:
        try:
            word_vec = np.array(word2vec[word]).astype(np.float32).reshape(len(word2vec[word]), -1)
            q_word_vecs.append(word_vec)
        except KeyError:
            pass

    # num_words x dim_words
    q_matrix = torch.Tensor(np.concatenate(q_word_vecs, axis=1).T)
    num_words_found = q_matrix.size()[0]

    if num_words_found < truncation_val:
        padding_rows = torch.zeros(truncation_val-num_words_found, input_size)
        q_matrix = torch.cat((q_matrix, padding_rows), 0)

    return [q_matrix.unsqueeze(0), num_words_found]


# Take list of question tensors and makes a batch of all of them [batch_size x num_questions_in_batch x input_size]
def padded_q_matrix(q_matrix_list):
    return Variable(torch.cat(q_matrix_list, 0))


''' Data Prep '''
training_data = training_id_to_similar_different()
trainingQuestionIds = list(training_data.keys())
word2vec = get_words_and_embeddings()
id2Data = questionID_to_questionData_truncate(50)

In [30]:
''' Model Specs '''
# CNN parameters
input_size = len(word2vec[list(word2vec.keys())[0]])
hidden_size = 200
kernel_size = 3
stride = 1
padding = 0
dilation = 1
groups = 1
bias = True

# CNN model
cnn = torch.nn.Sequential()
cnn.add_module('conv', torch.nn.Conv1d(in_channels = 200, out_channels = hidden_size, kernel_size = kernel_size, padding = padding, dilation = dilation, groups = groups, bias = bias))
cnn.add_module('tanh', torch.nn.Tanh())

# Loss function
loss_function = torch.nn.MultiMarginLoss(margin=0.2)

# Optimizer
optimizer = torch.optim.Adam(cnn.parameters(), lr=10**-2, weight_decay=0.001)


''' Procedural parameters '''
batch_size = 100
num_differing_questions = 20

num_epochs = 10
num_batches = round(len(trainingQuestionIds)/batch_size)


# Given ids of main qs in this batch
#
# Returns:
# 1. ids in ordered list as: 
# [
# q_1+, q_1-, q_1--,..., q_1++, q_1-, q_1--,...,
# q_2+, q_2-, q_2--,..., q_2++, q_2-, q_2--,...,
# ...
# ]
# All n main questions have their pos,neg,neg,neg,... interleaved
#
# 2. A dict mapping main question id --> its interleaved sequence length
def order_ids(q_ids):
    global training_data
    global num_differing_questions

    sequence_ids = []
    dict_sequence_lengths = {}

    for q_main in q_ids:
        p_pluses = training_data[q_main][0]
        p_minuses = list(np.random.choice(training_data[q_main][1], num_differing_questions, replace = False))
        sequence_length = len(p_pluses) * num_differing_questions + len(p_pluses)
        dict_sequence_lengths[q_main] = sequence_length
        for p_plus in p_pluses:
            sequence_ids += [p_plus] + p_minuses

    return sequence_ids, dict_sequence_lengths


'''Matrix constructors (use global vars, leave in order)'''
# A tuple is (q+, q-, q--, q--- ...)
# Let all main questions be set Q
# Each q in Q has a number of tuples equal to number of positives |q+, q++, ...|
# Each q in Q will have a 2D matrix of: num_tuples x num_candidates_in_tuple
# Concatenate this matrix for all q in Q and you get a matrix of: |Q| x num_tuples x num_candidates_in_tuple

# The above is for candidates
# To do cosine_similarity, need same structure with q's
# Basically each q will be a matrix of repeated q's: num_tuples x num_candidates_in_tuple, all elts are q (repeated)

# This method constructs those matrices, use candidates=True for candidates matrix

def construct_qs_matrix(q_ids_sequential, dict_sequence_lengths, candidates=False):
    global cnn, h0, c0, word2vec, id2data, input_size, num_differing_questions

    if not candidates:
        q_ids_complete = []
        for q in q_ids_sequential:
            q_ids_complete += [q] * dict_sequence_lengths[q]

    else: q_ids_complete = q_ids_sequential

    qs_matrix_list = []
    qs_seq_length = []

    for q in q_ids_complete:
        q_matrix_3d, q_num_words = get_question_matrix(q, word2vec, id2Data, input_size)
        qs_matrix_list.append(q_matrix_3d)
        qs_seq_length.append(q_num_words)

    qs_padded = padded_q_matrix(qs_matrix_list)
    qs_hidden = cnn(torch.transpose(qs_padded, 1, 2))
    sum_h_qs = torch.sum(qs_hidden, dim=2)
    mean_pooled_h_qs = torch.div(sum_h_qs, torch.autograd.Variable(torch.FloatTensor(qs_seq_length)[:, np.newaxis]))
    qs_tuples = mean_pooled_h_qs.split(1+num_differing_questions)
    final_matrix_tuples_by_constituent_qs_by_hidden_size = torch.stack(qs_tuples, dim=0, out=None)
    return final_matrix_tuples_by_constituent_qs_by_hidden_size


'''Begin training'''

for epoch in range(num_epochs):

    for batch in range(1, num_batches+1):
        start = time.time()

        print("Working on batch #: ", batch)
        optimizer.zero_grad()
        questions_this_batch = trainingQuestionIds[batch_size * (batch - 1):batch_size * batch]
        sequence_ids, dict_sequence_lengths = order_ids(questions_this_batch)

        candidates_qs_tuples_matrix = construct_qs_matrix(sequence_ids, dict_sequence_lengths, candidates=True)
        main_qs_tuples_matrix = construct_qs_matrix(questions_this_batch, dict_sequence_lengths, candidates=False)

        similarity_matrix = torch.nn.functional.cosine_similarity(candidates_qs_tuples_matrix, main_qs_tuples_matrix, dim=2, eps=1e-08)

        target = Variable(torch.LongTensor([0] * int(len(sequence_ids)/(1+num_differing_questions))))
        loss_batch = loss_function(similarity_matrix, target)

        loss_batch.backward()

        optimizer.step()

        print("loss_on_batch:", loss_batch.data[0], " time_on_batch:", time.time() - start)


Working on batch #:  1
loss_on_batch: 0.1852787733078003  time_on_batch: 72.3687093257904
Working on batch #:  2
loss_on_batch: 0.11037001758813858  time_on_batch: 142.6866376399994
Working on batch #:  3
loss_on_batch: 0.17380912601947784  time_on_batch: 100.35934805870056
Working on batch #:  4
loss_on_batch: 0.17231303453445435  time_on_batch: 77.33512353897095
Working on batch #:  5
loss_on_batch: 0.17525404691696167  time_on_batch: 117.04152965545654
Working on batch #:  6
loss_on_batch: 0.17258015275001526  time_on_batch: 65.9485433101654
Working on batch #:  7
loss_on_batch: 0.1699548065662384  time_on_batch: 64.75343918800354
Working on batch #:  8
loss_on_batch: 0.16704756021499634  time_on_batch: 56.31921410560608
Working on batch #:  9
loss_on_batch: 0.1718629151582718  time_on_batch: 46.24176621437073
Working on batch #:  10
loss_on_batch: 0.17305846512317657  time_on_batch: 68.52842307090759
Working on batch #:  11
loss_on_batch: 0.17407797276973724  time_on_batch: 51.1799

KeyboardInterrupt: 