[View in Colaboratory](https://colab.research.google.com/github/hyerim1048/100_pytorch/blob/master/07-sentrep/text_retrieval_pytorch.ipynb)

In [0]:
!pip install torch torchtext
!git clone https://github.com/neubig/nn4nlp-code.git

fatal: destination path 'nn4nlp-code' already exists and is not an empty directory.


In [0]:
from __future__ import print_function
import time

from collections import defaultdict
import random
import math
import sys
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [0]:
# format of files: each line is "word1 word2 ..." aligned line-by-line
train_src_file = "nn4nlp-code/data/parallel/train.ja"
train_trg_file = "nn4nlp-code/data/parallel/train.en"
dev_src_file = "nn4nlp-code/data/parallel/dev.ja"
dev_trg_file = "nn4nlp-code/data/parallel/dev.en"

w2i_src = defaultdict(lambda: len(w2i_src))
w2i_trg = defaultdict(lambda: len(w2i_trg))

def read(fname_src, fname_trg):
    """
    Read parallel files where each line lines up
    """
    with open(fname_src, "r") as f_src, open(fname_trg, "r") as f_trg:
        for line_src, line_trg in zip(f_src, f_trg):
            sent_src = [w2i_src[x] for x in line_src.strip().split()]
            sent_trg = [w2i_trg[x] for x in line_trg.strip().split()]
            yield (sent_src, sent_trg)

# Read the data
train = list(read(train_src_file, train_trg_file))
unk_src = w2i_src["<unk>"]
w2i_src = defaultdict(lambda: unk_src, w2i_src)
unk_trg = w2i_trg["<unk>"]
w2i_trg = defaultdict(lambda: unk_trg, w2i_trg)
nwords_src = len(w2i_src)
nwords_trg = len(w2i_trg)
dev = list(read(dev_src_file, dev_trg_file))

In [0]:
train_src, train_trg = zip(*train)
dev_src, dev_trg = zip(*dev)
train_src_len, train_trg_len = map(len, train_src), map(len, train_trg)
dev_src_len, dev_trg_len = map(len, dev_src), map(len, dev_trg)

In [0]:
train_src_len, train_trg_len = map(len, train_src), map(len, train_trg)

In [0]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

In [0]:
class ParallelCorpus(Dataset):
  def __init__(self, data):
    self.data = data
    
  def __len__(self):
    return len(self.data)
    
  def __getitem__(self, ix):
    return torch.LongTensor(self.data[ix][0]), torch.LongTensor(self.data[ix][1])

In [0]:
def my_collate_fn(batch):
  src, trg = zip(*batch)
  src_len, trg_len = list(map(len, src)), list(map(len, trg))
  src_maxlen, trg_maxlen = max(src_len), max(trg_len)
  
  src = torch.stack([F.pad(e, (0, src_maxlen-len(e))) for e in src])
  trg = torch.stack([F.pad(e, (0, trg_maxlen-len(e))) for e in trg])
  
  return src, trg, torch.LongTensor(src_len), torch.LongTensor(trg_len)

# my_collate_fn([train_corpus[i] for i in range(4)])

In [0]:
# Model parameters
EMBED_SIZE = 64
HIDDEN_SIZE = 128
BATCH_SIZE = 16

In [0]:
train_corpus = ParallelCorpus(train)
train_loader = DataLoader(train_corpus, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, collate_fn=my_collate_fn)

dev_corpus = ParallelCorpus(dev)
dev_loader = DataLoader(dev_corpus, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, collate_fn=my_collate_fn)

In [0]:
for e in dev_loader:
  print(e[1].shape)
  break

torch.Size([16, 22])


In [0]:
class SentRep(nn.Module):
  def __init__(self, vocab_size, emb_dim, hid_dim, batch_size):
    super(SentRep, self).__init__()
    
    self.hid_dim = hid_dim
    self.batch_size = batch_size
    
    self.embeddings = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
    self.lstm = nn.LSTM(emb_dim, hid_dim, bidirectional=True, batch_first=True)

  def init_hidden(self, bs):
    # Before we've done anything, we dont have any hidden state.
    # Refer to the Pytorch documentation to see exactly
    # why they have this dimensionality.
    # The axes semantics are (num_layers, minibatch_size, hidden_dim)
    return (torch.zeros(2, bs, self.hid_dim).cuda(),
            torch.zeros(2, bs, self.hid_dim).cuda())
  def forward(self, x, x_len):
    h0 = self.init_hidden(x_len.shape[0])
    encoded = self.embeddings(x)
    output, _ = self.lstm(encoded, h0)
    return torch.stack([output[i, x_len[i]-1] for i in range(len(x_len))])
#     return output

In [0]:
def calc_score(x, y):
  m = x @ y.transpose(0, 1)
  score = torch.max(1 + m - m.diag(), torch.cuda.FloatTensor([0]))
  for i in range(m.shape[0]):
    score[i,i] = 0.
  
  return score.sum() / m.shape[0]

In [0]:
def index_corpus(loader):
  souts, touts = [], []
  for batch_i, (s, t, sl, tl) in enumerate(loader):
    sout = src_reps(s.cuda(), sl.cuda())
    tout = trg_reps(t.cuda(), tl.cuda())
    souts.append(sout)
    touts.append(tout)
  
  return torch.cat(souts, 0).cpu().detach().numpy(), torch.cat(touts, 0).cpu().detach().numpy()

def retrieve(src, db_mtx):
    scores = np.dot(db_mtx, src)
    ranks = np.argsort(-scores)
    return ranks, scores

In [0]:
src_reps = SentRep(nwords_src, EMBED_SIZE, HIDDEN_SIZE, BATCH_SIZE).cuda()
trg_reps = SentRep(nwords_trg, EMBED_SIZE, HIDDEN_SIZE, BATCH_SIZE).cuda()

trainer = torch.optim.Adam(list(src_reps.parameters()) + list(trg_reps.parameters()), lr=1e-3)

In [0]:
for epoch_i in range(100):
  src_reps.train()
  trg_reps.train()
  total_loss = 0.
  for batch_i, (s, t, sl, tl) in enumerate(train_loader):
    sout = src_reps(s.cuda(), sl.cuda())
    tout = trg_reps(t.cuda(), tl.cuda())

    score = calc_score(sout, tout)
    total_loss += score.item()

    trainer.zero_grad()
    score.backward()
    trainer.step()
  
  print("epoch %r | loss: %f" % (epoch_i, total_loss / batch_i))
  
  # Perform evaluation 
  src_reps.eval()
  trg_reps.eval()
  dev_start = time.time()
  rec_at_1, rec_at_5, rec_at_10 = 0, 0, 0
  src_mtx, trg_mtx = index_corpus(dev_loader)
  for i in range(trg_mtx.shape[0]):
    ranks, scores = retrieve(src_mtx[i], trg_mtx)
    if ranks[0] == i: rec_at_1 += 1
    if i in ranks[:5]: rec_at_5 += 1
    if i in ranks[:10]: rec_at_10 += 1
  print("epoch %r: dev recall@1=%.2f%% recall@5=%.2f%% recall@10=%.2f%%" % (epoch_i, rec_at_1/len(dev)*100, rec_at_5/len(dev)*100, rec_at_10/len(dev)*100))

tensor([ 10,  17,  22,  ...,  12,  13,  11])