In [1]:
import sys
sys.path.append("../src/")
from loss_function import loss_function
from loss_function import cs
from torch.autograd import Variable
from evaluate import Evaluation
from lstm import LSTM
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sklearn
import random
import time

In [2]:
import data_reader as dr

In [3]:
corpus_path = "../data/text_tokenized.txt.gz"
corpus = dr.read_corpus(corpus_path)

In [4]:
embedding_path = "../data/vectors_pruned.200.txt.gz"
embedding_tensor, word_to_indx = dr.getEmbeddingTensor(embedding_path)

In [5]:
ids_corpus = dr.map_corpus(corpus, word_to_indx, kernel_width = 1)

In [6]:
train_path = "../data/train_random.txt"
train = dr.read_annotations(train_path)

In [7]:
train_ex = dr.create_train_set(ids_corpus, train)

In [37]:
import torch.utils.data

def train_model(train_data, dev_data, model):
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay = 1e-5)
    model.train()

    lasttime = time.time()
    for epoch in range(1, 31):
        print("-------------\nEpoch {}:\n".format(epoch))

        loss = run_epoch(train_data, True, model, optimizer)
        #return loss
        print('Train loss: {:.6f}'.format(loss))
        torch.save(model, "model{}".format(30 + epoch))
        
        (MAP, MRR, P1, P5) = run_epoch(dev_data, False, model, optimizer)
        print('Val MAP: {:.6f}, MRR: {:.6f}, P1: {:.6f}, P5: {:.6f}'.format(MAP, MRR, P1, P5))
        
        print('This epoch took: {:.6f}'.format(time.time() - lasttime))
        lasttime = time.time()

        
def run_epoch(data, is_training, model, optimizer):
    '''
    Train model for one pass of train data, and return loss, acccuracy
    '''
    data_loader = torch.utils.data.DataLoader(
        data,
        batch_size=40,
        shuffle=True,
        num_workers=4,
        drop_last=False)

    losses = []

    if is_training:
        model.train()
    else:
        model.eval()

    for batch in data_loader:
        pid_title = torch.unsqueeze(Variable(batch['pid_title']), 1)
        pid_body = torch.unsqueeze(Variable(batch['pid_body']), 1)
        rest_title = Variable(batch['rest_title'])
        rest_body = Variable(batch['rest_body'])
        
        pid_title_pad = torch.unsqueeze(Variable(batch['pid_title_pad']), 1)
        pid_body_pad = torch.unsqueeze(Variable(batch['pid_body_pad']), 1)
        rest_title_pad = Variable(batch['rest_title_pad'])
        rest_body_pad = Variable(batch['rest_body_pad'])
        
        pid_title, pid_body = pid_title.cuda(), pid_body.cuda()
        rest_title, rest_body = rest_title.cuda(), rest_body.cuda()
        pid_title_pad, pid_body_pad = pid_title_pad.cuda(), pid_body_pad.cuda()
        rest_title_pad, rest_body_pad = rest_title_pad.cuda(), rest_body_pad.cuda()
        
        if is_training:
            optimizer.zero_grad()
        
        pt = model(pid_title)
        pb = model(pid_body)
        rt = model(rest_title)
        rb = model(rest_body)
        
        pt = normalize(pt, 2)
        pb = normalize(pb, 2)
        rt = normalize(rt, 2)
        rb = normalize(rb, 2)
        
        # we need to take the mean pooling taking into account the padding
        # tensors are of dim batch_size x samples x output_size x (len - kernel + 1)
        # pad tensors are of dim batch_size x samples x (len - kernel + 1)
        
        pid_title_pad_ex = torch.unsqueeze(pid_title_pad, 2).expand_as(pt)
        pid_body_pad_ex = torch.unsqueeze(pid_body_pad, 2).expand_as(pb)
        rest_title_pad_ex = torch.unsqueeze(rest_title_pad, 2).expand_as(rt)
        rest_body_pad_ex = torch.unsqueeze(rest_body_pad, 2).expand_as(rb)
        
        pt = torch.squeeze(torch.sum(pt * pid_title_pad_ex, dim = 3), dim = 3)
        pb = torch.squeeze(torch.sum(pb * pid_body_pad_ex, dim = 3), dim = 3)
        rt = torch.squeeze(torch.sum(rt * rest_title_pad_ex, dim = 3), dim = 3)
        rb = torch.squeeze(torch.sum(rb * rest_body_pad_ex, dim = 3), dim = 3)

        # tensors are not of dim batch_size x samples x output_size
        # need to scale down because not all uniformly padded
        
        ptp_norm = torch.sum(pid_title_pad, dim = 2).clamp(min = 1).expand_as(pt)
        pbp_norm = torch.sum(pid_body_pad, dim = 2).clamp(min = 1).expand_as(pb)
        rtp_norm = torch.sum(rest_title_pad, dim = 2).clamp(min = 1).expand_as(rt)
        rbp_norm = torch.sum(rest_body_pad, dim = 2).clamp(min = 1).expand_as(rb)
        
        pt = pt / ptp_norm
        pb = pb / pbp_norm
        rt = rt / rtp_norm
        rb = rb / rbp_norm
        
        pid_tensor = (pt + pb)/2
        rest_tensor = (rt + rb)/2
        
        if is_training:
            dropout = nn.Dropout(p = 0.2)
            # we don't need to re-scale these on eval because its just cos sim
            pid_tensor = dropout(pid_tensor)
            rest_tensor = dropout(rest_tensor)
        
        if is_training:
            loss = loss_function(pid_tensor, rest_tensor, margin = 1.0)
            loss.backward()
            losses.append(loss.cpu().data[0])
            optimizer.step()
        else:
            expanded = pid_tensor.expand_as(rest_tensor)
            similarity = cs(expanded, rest_tensor, dim=2).squeeze(2)
            similarity = similarity.data.cpu().numpy()
            #return similarity
            labels = batch['labels'].numpy()
            l = dr.convert(similarity, labels)
            losses.extend(l)

    # Calculate epoch level scores
    if is_training:
        avg_loss = np.mean(losses)
        return avg_loss
    else:
        e = Evaluation(losses)
        MAP = e.MAP()*100
        MRR = e.MRR()*100
        P1 = e.Precision(1)*100
        P5 = e.Precision(5)*100
        return (MAP, MRR, P1, P5)

In [38]:
model = LSTM(310, embedding_tensor, 0.2)

In [39]:
val_path = "../data/test.txt"
val = dr.read_annotations(val_path, K_neg = -1, prune_pos_cnt = -1)
val_ex = dr.create_dev_set(ids_corpus, val)

In [None]:
z = train_model(train_ex, val_ex, model)

In [27]:
model = torch.load("model40")
real_val_path = "../data/dev.txt"
real_val = dr.read_annotations(real_val_path, K_neg = -1, prune_pos_cnt = -1)
real_val_ex = dr.create_dev_set(ids_corpus, real_val)
model = model.cuda()
z = run_epoch(real_val_ex, False, model, None)

In [28]:
print(z)

(51.1562037379168, 62.27780000889243, 44.97354497354497, 40.95238095238096)


In [14]:
model = model.cuda()
z = run_epoch(val_ex, False, model, None)

In [15]:
print(z)

(52.91403092863726, 67.84170771599614, 53.76344086021505, 41.075268817204325)


In [33]:
t = 0
for i in model.parameters():
    print(i.size())
    if len(i.size()) == 1:
        t += i.size()[0]
    else:
        if i.size()[0] != 100408:
            t += i.size()[0] * i.size()[1]
print(t)

torch.Size([100408, 201])
torch.Size([620, 201])
torch.Size([620, 155])
torch.Size([620])
torch.Size([620])
torch.Size([620, 201])
torch.Size([620, 155])
torch.Size([620])
torch.Size([620])
443920


In [9]:
def normalize(x, dim):
    l2 = torch.norm(x, 2, dim).expand_as(x)
    return x / l2.clamp(min = 1e-8)

def mean_pool_pad(x, mask):
    mask 

In [30]:
for i in range(20):
    if i != 8:
        model = torch.load("model" + str(i + 31))
        real_val_path = "../data/dev.txt"
        real_val = dr.read_annotations(real_val_path, K_neg = -1, prune_pos_cnt = -1)
        real_val_ex = dr.create_dev_set(ids_corpus, real_val)
        model = model.cuda()
        z = run_epoch(real_val_ex, False, model, None)
        print((i, z))

(0, (42.37619878987305, 53.283766443504035, 35.978835978835974, 34.70899470899471))
(1, (45.22990586203595, 55.329089061889235, 38.095238095238095, 35.238095238095255))
(2, (48.85533094194322, 59.45116287903402, 42.32804232804233, 38.20105820105824))
(3, (49.706145216776115, 60.57115963339855, 43.91534391534391, 38.73015873015875))
(4, (48.99226053091319, 57.970810054143385, 38.095238095238095, 38.41269841269844))
(5, (50.263996342773574, 58.759211834608664, 40.74074074074074, 39.78835978835982))
(6, (50.99273366028096, 60.007872731102374, 42.857142857142854, 40.31746031746036))
(7, (49.590223722830004, 58.149638156948114, 38.095238095238095, 39.894179894179906))
(9, (51.15620373791676, 62.277800008892456, 44.97354497354497, 40.952380952381006))
(10, (51.40831691010621, 63.14382456772746, 46.03174603174603, 40.84656084656088))
(11, (50.81305455611835, 60.95742587806079, 42.857142857142854, 40.95238095238097))
(12, (49.615610380874244, 58.85543368417221, 40.21164021164021, 40.9523809523