In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm_notebook
import torch.nn as nn
device = 0
torch.cuda.set_device(device)

In [None]:
#Data preprocessing
import csv
import pickle
import os
import json

data_path = 'data/WikiPassageQA'

def get_data(data_type):
    dic = {'q': [], 'pid': [], 'label': []}
    with open(os.path.join(data_path, data_type+'.tsv'), 'r', encoding='utf-8') as fp:
        reader = csv.DictReader(fp, dialect='excel-tab')
        for row in tqdm_notebook(reader):
            dic['q'].append(row['Question'])
            dic['pid'].append(row['DocumentID'])
            dic['label'].append(row['RelevantPassages'].split(','))
    return dic

def get_document():
        with open(os.path.join(data_path, 'document_passages.json'), 'r', encoding='utf-8') as fp:
            doc = json.load(fp)
        for key in doc.keys():
            tmp_p = []
            for k in sorted(doc[key], key= lambda x: int(x)):
                tmp_p.append(doc[key][k])
            doc[key] = tmp_p
        return doc
    
def doc_to_list(doc):
    doc_list = []
    for key in doc.keys():
        doc_list.extend(doc[key])            
    return doc_list

class WikiPassageQADataset(Dataset):    
    # The most compact version
    def __init__(self, data_type, passage):
        dic = get_data(data_type)
        self.q = dic['q']
        self.p = passage
        self.pid = dic['pid']
        self.label = dic['label']
        self.len = len(self.q)
        self.data_type = data_type

    def __getitem__(self, index):
        return self.q[index], self.p[self.pid[index]], self.label[index]

    def encoding(self, encoder):
        if os.path.exists(os.path.join(data_path, self.data_type+'_token.pkl')):
            with open(os.path.join(data_path, self.data_type+'_token.pkl'), 'rb') as fp:
                self.q = pickle.load(fp)
        else:
            for i, qi in enumerate(self.q):
                self.q[i] = sentence_encoder.encode(qi)
            with open(os.path.join(data_path, self.data_type+'_token.pkl'), 'wb') as fp:
                pickle.dump(self.q, fp, pickle.HIGHEST_PROTOCOL)

        if os.path.exists(os.path.join(data_path, 'passage_token.pkl')):
            with open(os.path.join(data_path, 'passage_token.pkl'), 'rb') as fp:
                self.p = pickle.load(fp)
        else:
            for i, k in tqdm_notebook(enumerate(self.p.keys())):
                tmp_p = []
                for j, pij in enumerate(self.p[k]):
                    tmp_p.append(sentence_encoder.encode(pij))
                self.p[k] = tmp_p

            with open(os.path.join(data_path, 'passage_token.pkl'), 'wb') as fp:
                pickle.dump(self.p, fp, pickle.HIGHEST_PROTOCOL)

    
    def __len__(self):
        return self.len


In [None]:
p = get_document()
train_data = WikiPassageQADataset(data_type='train', passage=p)
dev_data = WikiPassageQADataset(data_type='dev', passage=p)
test_data = WikiPassageQADataset(data_type='test', passage=p)

if os.path.exists(os.path.join(data_path, 'vocabulary.pkl')):
    with open(os.path.join(data_path, 'vocabulary.pkl'), 'rb') as fp:
        sentence_encoder = pickle.load(fp)
else:
    sentence_corpus = []
    sentence_corpus.extend(train_data.q)
    sentence_corpus.extend(dev_data.q)
    sentence_corpus.extend(test_data.q)
    sentence_corpus.extend(doc_to_list(p))
    from torchnlp.text_encoders import SpacyEncoder
    sentence_encoder = SpacyEncoder(sentence_corpus)
    with open(os.path.join(data_path, 'vocabulary.pkl'), 'wb') as fp:
        pickle.dump(sentence_encoder, fp, pickle.HIGHEST_PROTOCOL)
        
embed = nn.Embedding(sentence_encoder.vocab_size, 300)
from torchnlp.word_to_vector import GloVe
vectors = GloVe('840B')
for i, token in tqdm_notebook(enumerate(sentence_encoder.vocab)):
    embed.weight.data[i] = vectors[token]

train_data.encoding(sentence_encoder)
dev_data.encoding(sentence_encoder)
test_data.encoding(sentence_encoder)

In [None]:
import math

def idf(docFreq, docCount):
    score = math.log(1 + (docCount - docFreq + 0.5) / (docFreq + 0.5))
    return score


def tfNorm(termFreq, avgFieldLength, fieldLength, k1=1.2, b=0.75):
    return (termFreq * (k1 + 1)) / (termFreq + k1 * (1 - b + b * fieldLength / avgFieldLength))


def _get_bm25(docFreq, docCount, termFreq, avgFieldLength, fieldLength, k1=1.2, b=0.75):
    return idf(docFreq, docCount) * tfNorm(termFreq, avgFieldLength, fieldLength, k1=k1, b=b)


def doc_freq(term, docs):
    tf = 0
    for doc in docs:
        tf += int(bool(term_freq(term, doc)))
    return tf

def doc_count(docs):
    return len(docs)


def avg_field_length(docs):
    total_num = doc_count(docs)
    doc_len = 0
    for doc in docs:
        doc_len += field_length(doc)
    return doc_len / total_num


def field_length(doc):
    return len(doc)


def term_freq(term, doc):
    return list(doc).count(term)

In [None]:
max_sent_length = 400
# 99.9% of all passages having less than 400 words
special_characters = [' ', '!', '"', '#', '$', '%', '&', "'", "(", ")", "*", "+", ',', "-", ".", "/", ":", ";", "<", "=", ">", "?", "@", "[", "\\", ']', '^', '_', '`', '{', '|', '}', '~']
# https://www.owasp.org/index.php/Password_special_characters
special_characters.extend(['what', 'how', 'why'])
# Breaking down the queries by the first word of the question, “what”, “how”and “why” make up 43.8%, 36.6%, and 14.0% of the collection. 

def collate_fn(batch):
    q, p, _label = zip(*batch)
    
    batch_size = len(q)
    max_q = 0
    max_d = 0
    max_p = 0
    
    for i, qi in enumerate(q):
        if max_q < len(qi):
            max_q = len(qi)
    for i, pi in enumerate(p):
        if max_d < len(pi):
            max_d = len(pi)
        for j, pij in enumerate(pi):
            if max_p < len(pij):
                max_p = len(pij)
    
    max_p = min(max_p, max_sent_length)
    
    q_input = np.zeros((batch_size, max_d, max_q), dtype=np.int64)
    p_input = np.zeros((batch_size*max_d, max_p), dtype=np.int64)
    label = np.zeros((batch_size*max_d, 1), dtype=np.int64)
    q_idf = np.zeros((batch_size, 1, max_q), dtype=np.float32)
    
    for i in range(len(q)):
        q_input[i, :, :q[i].size(0)] = np.repeat(np.expand_dims(q[i], 0), max_d, axis=0)
    q_input = q_input.reshape(-1, max_q)

    for i in range(len(p)):
        for j in range(len(p[i])):
            tmp_p = p[i][j][:max_p]
            p_input[i*max_d+j, :tmp_p.size(0)] = tmp_p
    
    for i, (qi, pi) in enumerate(zip(q, p)):
        doc_cnt = doc_count(pi)
        for k, w in enumerate(q[i]):
            t = sentence_encoder.decode(w.view(1, 1)).lower()
            if t in special_characters:
                continue
            q_doc_frequency = doc_freq(w, pi)
            q_idf[i, 0, k] = idf(q_doc_frequency, doc_cnt)
        q_idf = np.repeat(q_idf, max_d, axis=1)
    q_idf = q_idf.reshape(-1, max_q)
    
    for i, li in enumerate(_label):
        li = [int(l) for l in li]
        for lij in li:
            label[i*max_d+lij, 0] = 1.
            
    batch = {}    
    batch['q'] = q_input
    batch['p'] = p_input
    batch['label'] = label
    batch['q_idf'] = q_idf
    for key in batch.keys():
        batch[key] = torch.tensor(np.asarray(batch[key])).cuda()
    batch['label'] = batch['label'].cuda(device)
    return batch

In [None]:
batch_size = 1
eval_batch_size = 1
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
dev_loader = DataLoader(dataset=dev_data, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(dataset=test_data, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
def criterion(output, label):
    sigmoid = nn.Sigmoid()
    output = output.double()
    label = label.double()
    
    num_of_r = torch.sum(label, dim=1)
    relevant = output * label
    non_relevant = (1. - label) * output
    num_of_p = label.size(1)
    distance = torch.sum(relevant, dim=1) / num_of_r - torch.max(non_relevant, dim=1)[0]

    output = sigmoid(output)
    bce_loss = nn.BCELoss(reduction='none')(output, label)
    loss = (max(0., 1. - distance.view(-1, 1))) * bce_loss
    
    return loss.sum()

In [None]:
import numpy as np
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import make_scorer


def dcg_score(y_true, y_score, k=5):
    """Discounted cumulative gain (DCG) at rank K.

    Parameters
    ----------
    y_true : array, shape = [n_samples]
        Ground truth (true relevance labels).
    y_score : array, shape = [n_samples, n_classes]
        Predicted scores.
    k : int
        Rank.

    Returns
    -------
    score : float
    """
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])

    gain = 2 ** y_true - 1

    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gain / discounts)


def ndcg_score(ground_truth, predictions, k=5):
    actual = dcg_score(ground_truth, predictions, k)
    best = dcg_score(ground_truth, ground_truth, k)
    score = float(actual) / float(best)
    return score

In [None]:
from sklearn.metrics import average_precision_score
import copy

def get_metric(output, label):
    metric = {'mAP': 0., 'MRR': 0., 'precision_at_5': 0., 'precision_at_10': 0., 
              'nDCG': 0., 'recall_at_5': 0., 'recall_at_10': 0., 'recall_at_20': 0.}
    rank = 9999
    merged = [(o, l) for o, l in zip(output, label)]
    sorted_rank = sorted(merged, key=lambda m: m[0], reverse=True)    
    
    rank = []
    for i, m in enumerate(sorted_rank):
        if m[1] == 1.:
            rank.append(i+1)
    if len(rank) > 0:
        metric['mAP'] = average_precision_score(np.array(label), np.array(output))
        metric['MRR'] = 1 / (rank[0])
        metric['recall_at_5'] = sum([m[1] for m in sorted_rank[:5]]) / sum(label)
        metric['recall_at_10'] = sum([m[1] for m in sorted_rank[:10]]) / sum(label)
        metric['recall_at_20'] = sum([m[1] for m in sorted_rank[:20]]) / sum(label)
    metric['precision_at_5'] = sum([m[1] for m in sorted_rank[:5]]) / 5
    metric['precision_at_10'] = sum([m[1] for m in sorted_rank[:10]]) / 10
    metric['nDCG'] = ndcg_score(np.array(label), np.array(output), len(label))
    
    return metric, rank
    
def test(model, loader):
    model.eval()
    metric = {'mAP': 0., 'MRR': 0., 'precision_at_5': 0., 'precision_at_10': 0., 'nDCG': 0., 'recall_at_5': 0., 'recall_at_10': 0., 'recall_at_20': 0.}
    rank = []
    data_size = len(loader.dataset)
    for i, batch in enumerate(tqdm_notebook(loader)):
        output = model(batch['q'], batch['p'], batch['q_idf'])
        output = output.view(eval_batch_size, -1)
        label = batch['label'].view(eval_batch_size, -1)
        for o, l in zip(output, label):
            _metric, _rank = get_metric(o.data.tolist(), l.data.tolist())
            rank.append(_rank)
            for k in metric.keys():
                metric[k] += _metric[k]            
        torch.cuda.empty_cache()
    for k in metric.keys():
        metric[k] /= data_size
    for k in metric:
        print(str(k)+":\t"+str(metric[k]))
    print()
    model.train()
    return metric['MRR'], metric['mAP'], metric

In [None]:
import time
import math

def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def save_checkpoint(state, mrr):
    torch.save(state, 'checkpoint/'+str(mrr)+'.ckpt')

def load_checkpoint(mrr):
    return torch.load('checkpoint/'+str(mrr)+'.ckpt')
        
def train(model, epoch, batch_size, best_score, parallel, save_duration = 500):
    model.train()
    
    start = time.time()
    total_loss = 0
    num_of_instance = 0
    max_mrr = best_score['MRR']
    max_map = best_score['mAP']
    low_cnt = 0
    
    for i, batch in enumerate(tqdm_notebook(train_loader)):
        output = model(batch['q'], batch['p'], batch['q_idf'])
        output = output.view(eval_batch_size, -1)
        label = batch['label'].view(batch_size, -1)
        loss = 0.
        loss += criterion(output, label)
        total_loss += loss.item()
        num_of_instance += batch['q'].size(0)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 10 == 0:
            print('[{}] Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.2f}'.format(
                time_since(start), epoch,  i *
                batch_size, len(train_loader.dataset),
                100. * i * batch_size / len(train_loader.dataset),
                total_loss / num_of_instance))
        if i != 0 and (i+len(train_loader)*epoch) % save_duration == 0:
            mrr, map, metric = test(model, dev_loader)
            if map >= max_map or mrr >= max_mrr:
                if parallel:
                    param = model.module.state_dict()
                else:
                    param = model.state_dict()
                save_checkpoint({
                    'best_MRR': mrr,
                    'best_mAP': map,
                    'state_dict': param,
                }, str(mrr))
                if mrr >= max_mrr:
                    max_mrr = mrr
                    max_map = map
        torch.cuda.empty_cache()
    return max_mrr, max_map

In [None]:
from torch.autograd import Variable
import copy

class RelevanceMatching(nn.Module):
    def __init__(self, hidden_size, dropout, embed):
        super(RelevanceMatching, self).__init__()
        self.embeddings = embed
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.dropout0 = nn.Dropout(self.dropout)
        self.dropout1 = nn.Dropout(self.dropout)
        self.dropout2 = nn.Dropout(self.dropout)
        self.dropout3 = nn.Dropout(self.dropout)
        self.tanh = nn.Tanh()
        self.cnn1 = nn.Conv1d(hidden_size, hidden_size, 1)
        self.cnn2 = nn.Conv1d(hidden_size, hidden_size, 2, padding=1)
        self.cnn3 = nn.Conv1d(hidden_size, hidden_size, 3, padding=2)
        self.cnn5 = nn.Conv1d(hidden_size, hidden_size, 5, padding=4)
        self.linear = nn.Linear(16, 16) # Unused
                                  
    def get_embeddings(self, input):
        embedding = self.embeddings(input)
        cnn1 = self.dropout0(self.tanh(self.cnn1(embedding.permute(0, 2, 1)).permute(0, 2, 1)))
        cnn2 = self.cnn2(embedding.permute(0, 2, 1)).permute(0, 2, 1)
        cnn2 = self.dropout1(self.tanh((cnn2[:, 1:, :] + cnn2[:, :cnn2.size(1)-1, :]) / 2))
        cnn3 = self.cnn3(embedding.permute(0, 2, 1)).permute(0, 2, 1)
        cnn3 = self.dropout2(self.tanh((cnn3[:, 2:, :] + cnn3[:, 1:cnn3.size(1)-1, :] + cnn3[:, :cnn3.size(1)-2, :]) / 3))
        cnn5 = self.cnn5(embedding.permute(0, 2, 1)).permute(0, 2, 1)
        cnn5 = self.dropout3(self.tanh((cnn5[:, 4:, :] + cnn5[:, 3:cnn5.size(1)-1, :] + cnn5[:, 2:cnn5.size(1)-2, :] + cnn5[:, 1:cnn5.size(1)-3, :] + cnn5[:, :cnn5.size(1)-4, :]) / 5))
                
        return (cnn1, cnn2, cnn3, cnn5)
    
    def get_matrix(self, q, p):
        matrix = torch.bmm(q, p.permute(0, 2, 1))
        return matrix.unsqueeze(1)
    
    def forward(self, q, p, q_idf):
        q_embed = self.get_embeddings(q)
        p_embed = self.get_embeddings(p)
        
        m11 = self.get_matrix(q_embed[0], p_embed[0])
        m12 = self.get_matrix(q_embed[0], p_embed[1])
        m13 = self.get_matrix(q_embed[0], p_embed[2])
        m15 = self.get_matrix(q_embed[0], p_embed[3])
        m1 = torch.cat((m11, m12, m13, m15), 1)
        
        m21 = self.get_matrix(q_embed[1], p_embed[0])
        m22 = self.get_matrix(q_embed[1], p_embed[1])
        m23 = self.get_matrix(q_embed[1], p_embed[2])
        m25 = self.get_matrix(q_embed[1], p_embed[3])    
        m2 = torch.cat((m21, m22, m23, m25), 1)

        m31 = self.get_matrix(q_embed[2], p_embed[0])
        m32 = self.get_matrix(q_embed[2], p_embed[1])
        m33 = self.get_matrix(q_embed[2], p_embed[2])
        m35 = self.get_matrix(q_embed[2], p_embed[3])
        m3 = torch.cat((m31, m32, m33, m35), 1)

        m51 = self.get_matrix(q_embed[3], p_embed[0])
        m52 = self.get_matrix(q_embed[3], p_embed[1])
        m53 = self.get_matrix(q_embed[3], p_embed[2])
        m55 = self.get_matrix(q_embed[3], p_embed[3])
        m5 = torch.cat((m51, m52, m53, m55), 1)
        
        m = torch.cat((m1, m2, m3, m5), 1)
        matrix = m.clone()
        m = m.max(3)[0].sum(1)
        
        m = m * q_idf.float()
        
        return m.sum(1).unsqueeze(1)

In [None]:
model = RelevanceMatching(hidden_size=300, dropout=0.3, embed=embed)

load = False
if load:
    ckpt = load_checkpoint(load)
    model.load_state_dict(ckpt['state_dict'])    
    best_MRR = ckpt['best_MRR']
    best_mAP = ckpt['best_mAP']
    print("checkpoint loaded...")
    print("best_MRR:", best_MRR)
    print("best_mAP:", best_mAP)
else:
    best_MRR = 0.
    best_mAP = 0.
    
parallel = False
if parallel and torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [33, xxx] -> [11, ...], [11, ...|], [11, ...] on 3 GPUs
    model = nn.DataParallel(model,output_device=device).cuda()
else:
    model.cuda()

save_duration = 500
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6, weight_decay=1e-7)
epoch = 0
while True:
    max_mrr, max_map = train(model, epoch, batch_size, {'MRR': best_MRR, 'mAP': best_mAP}, parallel, save_duration)
    epoch += 1
    if max_mrr > best_MRR and max_map > best_mAP:
        best_MRR = max_mrr
        best_mAP = max_map