## Load

In [1]:
import os
from collections import defaultdict
import math
import numpy as np 
import random
import re
import torch
import torch.nn as nn
from itertools import cycle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from tqdm.auto import tqdm

# Used to get the data
from sklearn.metrics import ndcg_score

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import nltk
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
nltk.download('stopwords')

import matplotlib.pyplot as plt 
import matplotlib
matplotlib.use('Agg')

seed = 33
import pandas as pd

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/chrisliu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Preprocess config

In [2]:
config = {}

config["dataset"] = "CNN" # "IMDB" "CNN", "PubMed"
config["n_document"] = 10000000
config["normalize_word_embedding"] = True
config["min_word_freq_threshold"] = 20
config["topk_word_freq_threshold"] = 100
config["document_vector_agg_weight"] = 'IDF' # ['mean', 'IDF', 'uniform', 'gaussian', 'exponential', 'pmi']
config["document_vector_weight_normalize"] = False # weighted sum or mean, True for sum, False for mean 
config["select_topk_TFIDF"] = None # ignore
config["embedding_file"] = "../data/glove.6B.100d.txt"
config["topk"] = [10, 30, 50]


In [3]:
def in_notebook():
    try:
        from IPython import get_ipython
        if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
            return False
    except ImportError:
        return False
    return True

In [4]:
def load_word2emb(embedding_file):
    word2embedding = dict()
    word_dim = int(re.findall(r".(\d+)d", embedding_file)[0])

    with open(embedding_file, "r") as f:
        for line in tqdm(f):
            line = line.strip().split()
            word = line[0]
            embedding = list(map(float, line[1:]))
            word2embedding[word] = np.array(embedding)

    print("Number of words:%d" % len(word2embedding))

    return word2embedding

word2embedding = load_word2emb(config["embedding_file"])

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Number of words:400000


In [5]:
def normalize_wordemb(word2embedding):
    # Every word emb should have norm 1
    
    word_emb = []
    word_list = []
    for word, emb in word2embedding.items():
        word_list.append(word)
        word_emb.append(emb)

    word_emb = np.array(word_emb)

    for i in range(len(word_emb)):
        norm = np.linalg.norm(word_emb[i])
        word_emb[i] = word_emb[i] / norm

    for word, emb in tqdm(zip(word_list, word_emb)):
        word2embedding[word] = emb
    return word2embedding

if config["normalize_word_embedding"]:
    normalize_wordemb(word2embedding)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [6]:
class Vocabulary:
    def __init__(self, word2embedding, config):
        # The low frequency words will be assigned as <UNK> token
        self.itos = {0: "<UNK>"}
        self.stoi = {"<UNK>": 0}
        
        self.word2embedding = word2embedding
        self.config = config

        self.word_freq_in_corpus = defaultdict(int)
        self.IDF = {}
        self.ps = PorterStemmer()
        self.stop_words = set(stopwords.words('english'))
        
        self.word_dim = len(word2embedding['the'])
    def __len__(self):
        return len(self.itos)

    def tokenizer_eng(self, text):
        text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
        text = text.strip().split()
        
        return [self.ps.stem(w) for w in text if w.lower() not in self.stop_words]
    
    def read_raw(self):        
        if self.config["dataset"] == 'IMDB':
            data_file_path = '../data/IMDB.txt'
        elif self.config["dataset"] == 'CNN':
            data_file_path = '../data/CNN.txt'
        elif self.config["dataset"] == 'PubMed':
            data_file_path = '../data/PubMed.txt'
        
        # raw documents
        self.raw_documents = []
        with open(data_file_path,'r',encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading documents"):
                self.raw_documents.append(line.strip("\n"))
                
        return self.raw_documents
    
    def build_vocabulary(self):
        sentence_list = self.raw_documents
        
        self.doc_freq = defaultdict(int) # # of document a word appear
        self.document_num = len(sentence_list)
        self.word_vectors = [[0]*self.word_dim] # unknown word emb
        
        for sentence in tqdm(sentence_list, desc="Preprocessing documents"):
            # for doc_freq
            document_words = set()
            
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.word2embedding:
                    continue
                    
                # calculate word freq
                self.word_freq_in_corpus[word] += 1
                document_words.add(word)
                
            for word in document_words:
                self.doc_freq[word] += 1
        
        # calculate IDF
        print('doc num', self.document_num)
        for word, freq in self.doc_freq.items():
            self.IDF[word] = math.log(self.document_num / (freq+1))
        
        # delete less freq words:
        delete_words = []
        for word, v in self.word_freq_in_corpus.items():
            if v < self.config["min_word_freq_threshold"]:
                delete_words.append(word)     
        for word in delete_words:
            del self.IDF[word]    
            del self.word_freq_in_corpus[word]    
        
        # delete too freq words
        print('eliminate freq words')
        IDF = [(word, freq) for word, freq in self.IDF.items()]
        IDF.sort(key=lambda x: x[1])

        for i in range(self.config["topk_word_freq_threshold"]):
            print(word)
            word = IDF[i][0]
            del self.IDF[word]
            del self.word_freq_in_corpus[word]
        
        # construct word_vectors
        idx = 1
        for word in self.word_freq_in_corpus:
            self.word_vectors.append(self.word2embedding[word])
            self.stoi[word] = idx
            self.itos[idx] = word
            idx += 1
            
    def init_word_weight(self,sentence_list, agg):
        if agg == 'mean':
            self.word_weight = {word: 1 for word in self.IDF.keys()}
        elif agg == 'IDF':
            self.word_weight = self.IDF
        elif agg == 'uniform':
            self.word_weight = {word: np.random.uniform(low=0.0, high=1.0) for word in self.IDF.keys()}
        elif agg == 'gaussian':
            mu, sigma = 10, 1 # mean and standard deviation
            self.word_weight = {word: np.random.normal(mu, sigma) for word in self.IDF.keys()}
        elif agg == 'exponential':
            self.word_weight = {word: np.random.exponential(scale=1.0) for word in self.IDF.keys()}
        elif agg == 'pmi':
            trigram_measures = BigramAssocMeasures()
            self.word_weight = defaultdict(int)
            corpus = []

            for text in tqdm(sentence_list):
                corpus.extend(text.split())

            finder = BigramCollocationFinder.from_words(corpus)
            for pmi_score in finder.score_ngrams(trigram_measures.pmi):
                pair, score = pmi_score
                self.word_weight[pair[0]] += score
                self.word_weight[pair[1]] += score
                
    def calculate_document_vector(self):
        # Return
        # document_vectors: weighted sum of word emb
        # document_answers_idx: doc to word index list
        # document_answers_wsum: word weight summation, e.g. total TFIDF score of a doc
        
        document_vectors = [] 
        document_answers = []
        document_answers_wsum = []
        
        sentence_list = self.raw_documents
        agg = self.config["document_vector_agg_weight"]
        n_document = self.config["n_document"]
        select_topk_TFIDF = self.config["select_topk_TFIDF"]
        
        self.init_word_weight(sentence_list, agg)
        for sentence in tqdm(sentence_list[:min(n_document, len(sentence_list))], desc="calculate document vectors"):
            document_vector = np.zeros(len(self.word_vectors[0]))
            select_words = []
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.stoi:
                    continue
                else:
                    select_words.append(word)

            # select topk TDIDF
            if select_topk_TFIDF is not None:
                doc_TFIDF = defaultdict(float)
                for word in select_words:    
                    doc_TFIDF[word] += self.IDF[word]

                doc_TFIDF_l = [(word, TFIDF) for word, TFIDF in doc_TFIDF.items()]
                doc_TFIDF_l.sort(key=lambda x:x[1], reverse=True)
                
                select_topk_words = set(list(map(lambda x:x[0], doc_TFIDF_l[:select_topk_TFIDF])))
                select_words = [word for word in select_words if word in select_topk_words]
            else:
                pass
            
            total_weight = 0
            # aggregate to doc vectors
            for word in select_words:
                document_vector += np.array(self.word2embedding[word]) * self.word_weight[word]
                total_weight += self.word_weight[word]
                
            if len(select_words) == 0:
                print('error', sentence)
                continue
            else:
                if self.config["document_vector_weight_normalize"]:
                    document_vector /= total_weight
            
            document_vectors.append(document_vector)
            document_answers.append(select_words)
            document_answers_wsum.append(total_weight)
        
        # get answers
        document_answers_idx = []    
        for ans in document_answers:
            ans_idx = []
            for token in ans:
                if token in self.stoi:
                    ans_idx.append(self.stoi[token])                    
            document_answers_idx.append(ans_idx)
        
        self.document_vectors = document_vectors
        self.document_answers_idx = document_answers_idx
        self.document_answers_wsum = document_answers_wsum
        
        return document_vectors, document_answers_idx, document_answers_wsum
        
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
    
    def check_docemb(self):
        word_vectors = np.array(self.word_vectors)
        pred = np.zeros(word_vectors.shape[1])
        cnt = 0

        for word_idx in self.document_answers_idx[0]:
            pred += word_vectors[word_idx] * self.word_weight[self.itos[word_idx]]
            cnt += self.word_weight[self.itos[word_idx]]
        
        if self.config["document_vector_weight_normalize"]:
            pred /= cnt
        assert np.sum(self.document_vectors[0]) - np.sum(pred) == 0

In [7]:
def build_vocab(config, word2embedding):
    # build vocabulary
    vocab = Vocabulary(word2embedding, config)
    vocab.read_raw()
    vocab.build_vocabulary()
    vocab_size = len(vocab)
    # get doc emb
    vocab.calculate_document_vector()
    vocab.check_docemb()
    
    return vocab

vocab = build_vocab(config, word2embedding)

HBox(children=(IntProgress(value=1, bar_style='info', description='Loading documents', max=1, style=ProgressSt…




HBox(children=(IntProgress(value=0, description='Preprocessing documents', max=19026, style=ProgressStyle(desc…


doc num 19026
eliminate freq words
paroxysm
subject
line
organ
write
univers
one
would
use
like
get
know
dont
think
time
make
also
say
go
im
could
want
new
work
good
well
way
need
look
even
anyon
thing
see
tri
thank
much
year
world
system
right
problem
may
take
mani
two
first
seem
question
pleas
1
state
us
come
2
post
help
call
usa
point
sinc
find
read
still
back
mean
ive
give
email
sure
differ
might
run
cant
reason
last
day
interest
case
let
person
said
never
start
doesnt
tell
better
ask
got
without
follow
part
lot
3
number
put
fact
gener
inform
actual
that


HBox(children=(IntProgress(value=0, description='calculate document vectors', max=19026, style=ProgressStyle(d…

error |> 
error 
error 
error Mikael Fredriksson
error 
error -------------------------------------------------
error email: mikael_fredriksson@macexchange.se
error 
error FIDO 2:203/211
error  
error  
error   
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error    
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error    
error  
error  
error  
error  
error  
error  
error  
error  



In [8]:
print("Finish building dataset!")
print(f"Number of documents:{len(vocab.raw_documents)}")
print(f"Number of words:{len(vocab)}")

l = list(map(len, vocab.document_answers_idx))
print("Average length of document:", np.mean(l))

Finish building dataset!
Number of documents:19026
Number of words:7602
Average length of document: 84.86410175216382


In [9]:
word_vectors = np.array(vocab.word_vectors)
print(word_vectors.shape)

document_vectors = np.array(vocab.document_vectors)
print(document_vectors.shape)

document_answers_wsum = np.array(vocab.document_answers_wsum).reshape(-1, 1)
print(document_answers_wsum.shape)

# create weight_ans
document_answers_idx = vocab.document_answers_idx

# random shuffle
shuffle_idx = list(range(len(document_vectors)))
random.Random(seed).shuffle(shuffle_idx)

document_vectors = document_vectors[shuffle_idx]
document_answers_wsum = document_answers_wsum[shuffle_idx]
document_answers_idx = [document_answers_idx[idx] for idx in shuffle_idx]

(7602, 100)
(18948, 100)
(18948, 1)


In [10]:
# onthot_ans: word freq matrix
# weight_ans: TFIDF matrix

onehot_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
weight_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
print(weight_ans.shape)

for i in tqdm(range(len(document_answers_idx))):
    for word_idx in document_answers_idx[i]:
        weight_ans[i, word_idx] += vocab.word_weight[vocab.itos[word_idx]]
        onehot_ans[i, word_idx] += 1

(18948, 7602)


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))




## Results

In [11]:
final_results = []
select_columns = ['model']
for topk in config["topk"]:
    select_columns.append('F1@{}'.format(topk))
for topk in config["topk"]:
    select_columns.append('ndcg@{}'.format(topk))
select_columns.append('ndcg@all')
select_columns

['model',
 'F1@10',
 'F1@30',
 'F1@50',
 'ndcg@10',
 'ndcg@30',
 'ndcg@50',
 'ndcg@all']

## setting training size

In [12]:
train_size_ratio = 1
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

18948

## Top K freq word

In [13]:
topk_results = {}

In [14]:
test_ans = document_answers_idx[:train_size]

In [15]:
word_freq = [(word, freq) for word, freq in vocab.word_freq_in_corpus.items()]
word_freq.sort(key=lambda x:x[1], reverse=True)
word_freq[:10]

[('x', 6539),
 ('god', 5208),
 ('file', 4918),
 ('0', 4520),
 ('window', 4444),
 ('program', 4201),
 ('drive', 3633),
 ('4', 3528),
 ('game', 3474),
 ('govern', 3268)]

In [16]:
def topk_word_evaluation(k=50):
    topk_word = [word for (word, freq) in word_freq[:k]]

    pr, re = [], []
    for ans in tqdm(test_ans):
        ans = set(ans)
        ans = [vocab.itos[a] for a in ans]

        hit = []
        for word in ans:
            if word in topk_word:
                hit.append(word)

        precision = len(hit) / k
        recall = len(hit) / len(ans)
        pr.append(precision)
        re.append(recall)

    pr = np.mean(pr)
    re = np.mean(re)
    f1 = 2 * pr * re / (pr + re) if (pr + re) != 0 else 0
    print('top {} word'.format(k))
    print('percision', np.mean(pr))
    print('recall', np.mean(re))
    print('F1', f1)
    return f1


for topk in config['topk']:
    topk_results["F1@{}".format(topk)] = topk_word_evaluation(k=topk)


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 10 word
percision 0.06950073886426009
recall 0.016353539716943234
F1 0.026477028568787433


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 30 word
percision 0.07639680529167546
recall 0.04850870113830837
F1 0.059339414277828975


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 50 word
percision 0.07682921680388433
recall 0.07978013464427035
F1 0.07827687433156728


In [17]:
def topk_word_evaluation_NDCG(k=50):
    freq_word =[word for (word, freq) in word_freq]
    freq_word_idx = [vocab.stoi[word] for word in freq_word if word in vocab.stoi]
    
    scores = np.zeros(len(vocab.word_vectors))
    for rank, idx in enumerate(freq_word_idx):
        scores[idx] = len(vocab.word_vectors) - rank
    
    NDCGs = []
    
    for ans in tqdm(test_ans):
        weight_ans = np.zeros(len(vocab.word_vectors))
        
        for word_idx in ans:
            if word_idx == 0:
                continue
            word = vocab.itos[word_idx]
            weight_ans[word_idx] += vocab.IDF[word]

        NDCG_score = ndcg_score(weight_ans.reshape(1,-1), scores.reshape(1,-1), k=k)
        NDCGs.append(NDCG_score)

    print('top {} NDCG:{}'.format(k, np.mean(NDCGs)))
    
    return np.mean(NDCGs)


for topk in config['topk']:
    topk_results["ndcg@{}".format(topk)] = topk_word_evaluation_NDCG(k=topk)
    
topk_results["ndcg@all"] = topk_word_evaluation_NDCG(k=None)


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 10 NDCG:0.028211750096691315


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 30 NDCG:0.03761860810723896


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 50 NDCG:0.04587430885683149


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top None NDCG:0.31437606429732357


In [18]:
topk_results["model"] = "topk"
final_results.append(pd.Series(topk_results))

## MLP Decoder

In [19]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [20]:
class MLPDecoderDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans,
                 topk=50):
        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.weight_ans_s[:, topk:] = -1
        
        assert len(doc_vectors) == len(weight_ans)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [21]:
class MLPDecoderDataset2(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans,
                 topk=50):
        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        
        self.weight_ans_s_pos = self.weight_ans_s[:, :topk]
        self.weight_ans_s_neg = self.weight_ans_s[:, topk:]
        
        assert len(doc_vectors) == len(weight_ans)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s_pos[idx], self.weight_ans_s_neg[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [22]:
batch_size = 100

train_size_ratio = 0.9
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

print('train size', train_size)
print('valid size', len(document_vectors) - train_size)

train_dataset = MLPDecoderDataset(document_vectors[:train_size], weight_ans[:train_size], topk=50)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

valid_dataset = MLPDecoderDataset(document_vectors[train_size:], weight_ans[train_size:], topk=50)
valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

train_dataset2 = MLPDecoderDataset2(document_vectors[:train_size], weight_ans[:train_size], topk=50)
train_loader2  = torch.utils.data.DataLoader(train_dataset2, batch_size=batch_size, shuffle=True, pin_memory=True)

valid_dataset2 = MLPDecoderDataset2(document_vectors[train_size:], weight_ans[train_size:], topk=50)
valid_loader2  = torch.utils.data.DataLoader(valid_dataset2, batch_size=batch_size, shuffle=True, pin_memory=True)

train size 17053
valid size 1895


In [23]:
class MLPDecoder(nn.Module):
    def __init__(self, doc_emb_dim, num_words, h_dim=300):
        super().__init__()
        
        self.fc1 = nn.Linear(doc_emb_dim, h_dim) 
#         self.fc2 = nn.Linear(h_dim, h_dim)
#         self.fc3 = nn.Linear(h_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, num_words)
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = F.tanh(self.fc1(x))
#         x = F.tanh(self.fc2(x))
#         x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.fc4(x)
        
        return x

In [24]:
def _dcg(target):
    """Computes Discounted Cumulative Gain for input tensor."""
    denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)
    return (target / denom).sum(dim=-1)

def retrieval_normalized_dcg_all(preds, target, k=None):
    """Computes `Normalized Discounted Cumulative Gain`_ (for information retrieval).
    ``preds`` and ``target`` should be of the same shape and live on the same device.
    ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
    otherwise an error is raised.
    Args:
        preds: estimated probabilities of each document to be relevant.
        target: ground truth about each document relevance.
        k: consider only the top k elements (default: None, which considers them all)
    Return:
        a single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``.
    Raises:
        ValueError:
            If ``k`` parameter is not `None` or an integer larger than 0
    Example:
        >>> from torchmetrics.functional import retrieval_normalized_dcg
        >>> preds = torch.tensor([.1, .2, .3, 4, 70])
        >>> target = torch.tensor([10, 0, 0, 1, 5])
        >>> retrieval_normalized_dcg(preds, target)
        tensor(0.6957)
    """
    k = [preds.shape[-1]] if k is None else k + [preds.shape[-1]]
    
    assert preds.shape == target.shape and max(k) <= preds.shape[-1]
    
    if not isinstance(k, list):
        raise ValueError("`k` has to be a list of positive integer or None")

    sorted_target = target.gather(1, torch.argsort(preds, dim=-1, descending=True))
    ideal_target = torch.sort(target, descending=True)[0]
    
    ndcg_scores = {}
    for topk in k:
        sorted_target_k = sorted_target[:,:topk]
        ideal_target_k = ideal_target[:,:topk]
        
        ideal_dcg_k = _dcg(ideal_target_k)
        target_dcg_k = _dcg(sorted_target_k)
        
        # filter undefined scores
        target_dcg_k /= ideal_dcg_k
        
        if topk == preds.shape[-1]:
            topk = 'all'
        ndcg_scores[topk] = target_dcg_k.mean().item()
        
    return ndcg_scores

def retrieval_precision_all(preds, target, k = [10]):
        
    assert preds.shape == target.shape and max(k) <= preds.shape[-1]
    
    if not isinstance(k, list):
        raise ValueError("`k` has to be a list of positive integer")
        
    precision_scores = {}
    target_onehot = target > 0
    
    for topk in k:
        relevant = target_onehot.gather(1, preds.topk(topk, dim=-1)[1])
        relevant = relevant.sum(axis=1).float()
        relevant /= topk    
        precision_scores[topk] = relevant.mean().item()
    
    return precision_scores

In [25]:
import timeit
from torchmetrics.functional import retrieval_normalized_dcg, retrieval_precision
from sklearn.metrics import ndcg_score

def evaluate_MLPDecoder(model, data_loader):
    results = {}
    model.eval()
    
    pred_all = []
    target_all = []
    
    # predict all data
    start = timeit.default_timer()
    for data in data_loader:
        doc_embs, target, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
        pred_all.append(pred)
        target_all.append(target)
        
    pred_all = torch.cat(pred_all, dim=0)
    target_all = torch.cat(target_all, dim=0)
    stop1 = timeit.default_timer()
    print('Time1: ', stop1 - start)  
    
    # Precision
    precision_scores = retrieval_precision_all(pred_all, target_all, k=config["topk"])
    for k, v in precision_scores.items():
        results['precision@{}'.format(k)] = v
    
    stop2 = timeit.default_timer()
    print('Time2: ', stop2 - stop1)  
    
    # NDCG
    ndcg_scores = retrieval_normalized_dcg_all(pred_all, target_all, k=config["topk"])
    for k, v in ndcg_scores.items():
        results['ndcg@{}'.format(k)] = v
        
    stop3 = timeit.default_timer()
    print('Time3: ', stop3 - stop2)      
    
    return results

In [26]:
# setting
lr = 0.05
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 1000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        # Model backwarding
        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)
            


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))



Epoch 0 13.774092027318408 9.819838373284592
Time1:  0.5011053886264563
Time2:  0.050783200189471245
Time3:  0.164084542542696
Time1:  0.054288867861032486
Time2:  0.00835254043340683
Time3:  0.018570462241768837

train {'precision@10': 0.13616959750652313, 'precision@30': 0.11514885723590851, 'precision@50': 0.11314137279987335, 'ndcg@10': 0.09797479957342148, 'ndcg@30': 0.10815028101205826, 'ndcg@50': 0.12431544065475464, 'ndcg@all': 0.39644137024879456}
valid {'precision@10': 0.13999998569488525, 'precision@30': 0.11502198874950409, 'precision@50': 0.11233773082494736, 'ndcg@10': 0.09675917029380798, 'ndcg@30': 0.10591454803943634, 'ndcg@50': 0.12054401636123657, 'ndcg@all': 0.3911643922328949}
Epoch 1 7.626935518275925 8.113733843753213
Epoch 2 5.74174726218508 7.4057998406259635
Epoch 3 4.703449546245107 7.012203216552734
Epoch 4 4.043521631530851 6.727579141917982
Epoch 5 3.56131168555098 6.581755638122559
Epoch 6 3.207375446955363 6.433835581729286
Epoch 7 2.9368464361157334 6.3

Epoch 136 0.4296006637009961 5.686600760409706
Epoch 137 0.42914326445401063 5.6860671043396
Epoch 138 0.42668687373574016 5.6868034162019425
Epoch 139 0.423295512708307 5.699390888214111
Epoch 140 0.42316667779147277 5.686873285393966
Epoch 141 0.42209145897313166 5.69462776184082
Epoch 142 0.41998631354661015 5.694002427552876
Epoch 143 0.41675617297490436 5.691985782824065
Epoch 144 0.41477007370943214 5.7012862406278915
Epoch 145 0.41468492376874067 5.687095140155993
Epoch 146 0.413425761530971 5.6937637078134635
Epoch 147 0.4106232780113555 5.694085497605173
Epoch 148 0.4089116299012948 5.69727877566689
Epoch 149 0.4074685177956408 5.6989678834614
Epoch 150 0.40622720348904706 5.694697455355995
Time1:  0.4880787320435047
Time2:  0.05154542252421379
Time3:  0.1613706648349762
Time1:  0.05425946041941643
Time2:  0.008229320868849754
Time3:  0.01865067146718502

train {'precision@10': 0.5076349973678589, 'precision@30': 0.5306046009063721, 'precision@50': 0.5456787347793579, 'ndcg@10

Epoch 268 0.29240303569369847 5.76671048214561
Epoch 269 0.2918248390942289 5.773429644735236
Epoch 270 0.291341586593996 5.771934459083958
Epoch 271 0.29034136802132365 5.776569015101383
Epoch 272 0.28946574468013136 5.773978961141486
Epoch 273 0.2892913014916649 5.779578610470421
Epoch 274 0.2883787752940641 5.777883454373009
Epoch 275 0.28801096862519693 5.779555672093442
Epoch 276 0.28739846571844224 5.7749840083875155
Epoch 277 0.28614871508893913 5.779549372823615
Epoch 278 0.2865511853443949 5.772086369363885
Epoch 279 0.2853540267512115 5.785025069588109
Epoch 280 0.2851044415381917 5.768930159117046
Epoch 281 0.28391007216353165 5.780322928177683
Epoch 282 0.2829780716296525 5.780143461729351
Epoch 283 0.2843003862085398 5.787902355194092
Epoch 284 0.28286239269532654 5.779654000934801
Epoch 285 0.2827959224494577 5.784282910196405
Epoch 286 0.28213270409413943 5.78583920629401
Epoch 287 0.2814030338797653 5.774811594109786
Epoch 288 0.28093179833819293 5.781280542674818
Epoch

Epoch 401 0.23650873657207042 5.8612195065147
Epoch 402 0.23558108107736933 5.854523708945827
Epoch 403 0.23551087090146472 5.85194206237793
Epoch 404 0.23515070913827907 5.859023872174714
Epoch 405 0.235338092808835 5.851057855706466
Epoch 406 0.23472198870098382 5.854438430384586
Epoch 407 0.23451573257906394 5.852364715776946
Epoch 408 0.23436497587558122 5.857109747434917
Epoch 409 0.233199316071488 5.859779056749846
Epoch 410 0.233017276079334 5.861140828383596
Epoch 411 0.23344259635049697 5.858320813429983
Epoch 412 0.2329014647773832 5.863837543286775
Epoch 413 0.2319366047431154 5.858856100785105
Epoch 414 0.2324632542175159 5.864731989408794
Epoch 415 0.23271469458153374 5.857732898310611
Epoch 416 0.23207369865032665 5.865589317522551
Epoch 417 0.23126493324661812 5.863570564671567
Epoch 418 0.23058289533470108 5.862864870774119
Epoch 419 0.23095301787058511 5.860668358049895
Epoch 420 0.23096478830652628 5.867321516338148
Epoch 421 0.22940205124735136 5.8689264498258895
Epo

Epoch 546 0.20158138679482088 5.9298320820457056
Epoch 547 0.20222887963230846 5.938636152367843
Epoch 548 0.20237419183491268 5.935240845931204
Epoch 549 0.20129749327026614 5.9341633947272046
Epoch 550 0.20149074930545183 5.933874406312642
Time1:  0.4887952357530594
Time2:  0.05116501823067665
Time3:  0.16155647858977318
Time1:  0.054428666830062866
Time2:  0.008151691406965256
Time3:  0.0185895636677742

train {'precision@10': 0.6139798760414124, 'precision@30': 0.6506538987159729, 'precision@50': 0.658177375793457, 'ndcg@10': 0.5152397155761719, 'ndcg@30': 0.6097865700721741, 'ndcg@50': 0.6986270546913147, 'ndcg@all': 0.7390797734260559}
valid {'precision@10': 0.42358842492103577, 'precision@30': 0.31927886605262756, 'precision@50': 0.2725910246372223, 'ndcg@10': 0.39375123381614685, 'ndcg@30': 0.384351521730423, 'ndcg@50': 0.3979935348033905, 'ndcg@all': 0.5920253396034241}
Epoch 551 0.20141379918619903 5.936331347415321
Epoch 552 0.20109015790342588 5.938878912674753
Epoch 553 0.

Epoch 677 0.1816972687579038 6.000831026779978
Epoch 678 0.1809112675357283 5.994735441709819
Epoch 679 0.18152080206146018 5.999486496574001
Epoch 680 0.1820363834587454 6.004712104797363
Epoch 681 0.1813156994288428 6.0050359525178605
Epoch 682 0.1804262924264049 6.00637129733437
Epoch 683 0.18113773817207382 6.006562483938117
Epoch 684 0.18103955006390288 6.003423364538896
Epoch 685 0.18085941898892496 6.004913104207892
Epoch 686 0.17965436583025413 6.0120485205399365
Epoch 687 0.18038689210052378 6.006954946016011
Epoch 688 0.17964857807982038 6.008853184549432
Epoch 689 0.1794189551071814 6.008137979005513
Epoch 690 0.1798882405137458 6.0135932972556665
Epoch 691 0.17975520913363896 6.014241544823897
Epoch 692 0.1801463587765108 6.007287351708663
Epoch 693 0.1791451359875718 6.010923385620117
Epoch 694 0.1794222675220311 6.011786109522769
Epoch 695 0.18041109498481303 6.008814962286698
Epoch 696 0.17908960641824712 6.008260626541941
Epoch 697 0.17930596378463054 6.014237177999396


Epoch 809 0.16613457600275675 6.067732685490658
Epoch 810 0.1660820555965803 6.068340251320286
Epoch 811 0.1664863320296271 6.062950284857499
Epoch 812 0.16616390959212654 6.065597207922685
Epoch 813 0.16615875178610373 6.06797813114367
Epoch 814 0.16603390166634008 6.07265241522538
Epoch 815 0.1667611138862476 6.074714083420603
Epoch 816 0.1651394968492943 6.070406236146626
Epoch 817 0.16582262794873867 6.072983992727179
Epoch 818 0.16593037746105974 6.074741212945235
Epoch 819 0.16580406389041255 6.0745191072162825
Epoch 820 0.1656268187607938 6.071238768728156
Epoch 821 0.16536873735879598 6.075608680122777
Epoch 822 0.16485071077681424 6.075381630345395
Epoch 823 0.16534774685115144 6.076525060754073
Epoch 824 0.16478516684289565 6.071406339344225
Epoch 825 0.16555589123776085 6.0754860325863485
Epoch 826 0.1647962096490358 6.065119216316624
Epoch 827 0.1651611557306602 6.079014025236431
Epoch 828 0.1647767320013883 6.079128917894866
Epoch 829 0.1651252807058089 6.0782338945489185


Epoch 951 0.15377008470526912 6.133888646175987
Epoch 952 0.15435433117618338 6.130281749524568
Epoch 953 0.15435087218967794 6.128678597901997
Epoch 954 0.15416659092345433 6.126897385245876
Epoch 955 0.1536862363069378 6.134445918233771
Epoch 956 0.1538882070814657 6.134461628763299
Epoch 957 0.15378774388840324 6.13007957056949
Epoch 958 0.153656993257372 6.132847861239784
Epoch 959 0.15336524202809695 6.138913531052439
Epoch 960 0.15428219317344197 6.133810796235737
Epoch 961 0.15429008102904984 6.131294150101511
Epoch 962 0.15358493429178383 6.133237261521189
Epoch 963 0.15381190041352433 6.133098200747841
Epoch 964 0.1537640755288085 6.134746350740132
Epoch 965 0.15370922374446488 6.133723936582866
Epoch 966 0.15372614745508162 6.138847300880833
Epoch 967 0.1532796348570383 6.136527161849172
Epoch 968 0.15297188561910774 6.137047115125154
Epoch 969 0.15314763732123793 6.133457459901509
Epoch 970 0.15403434730064103 6.1372356163827995
Epoch 971 0.15181501662870597 6.13675975799560

### train {'precision@10': 0.6233331561088562, 'precision@30': 0.6673136949539185, 'precision@50': 0.6772298216819763, 'ndcg@10': 0.5198198556900024, 'ndcg@30': 0.6220351457595825, 'ndcg@50': 0.7136395573616028, 'ndcg@all': 0.7433396577835083}
### valid {'precision@10': 0.42121371626853943, 'precision@30': 0.319947212934494, 'precision@50': 0.27297094464302063, 'ndcg@10': 0.39169588685035706, 'ndcg@30': 0.38360509276390076, 'ndcg@50': 0.397529661655426, 'ndcg@all': 0.5908011198043823}

In [29]:
def listNet(y_pred, y_true, eps=1e-10, padded_value_indicator=-1):
    """
    ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: loss value, a torch.Tensor
    """
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    mask = y_true == padded_value_indicator
    y_pred[mask] = float('-inf')
    y_true[mask] = float('-inf')

    preds_smax = F.softmax(y_pred, dim=1)
    true_smax = F.softmax(y_true, dim=1)

    preds_smax = preds_smax + eps
    preds_log = torch.log(preds_smax)

    return torch.mean(-torch.sum(true_smax * preds_log, dim=1))

In [30]:
# setting
lr = 0.1
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 1000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
# criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        pred = model(doc_embs)     
        loss = listNet(pred, target)
    
        # Model backwarding
        model.zero_grad()
        loss.backward()
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        loss = listNet(pred, target)
    
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

Epoch 0 7.513300865017182 6.595380983854595
Time1:  0.48470241390168667
Time2:  0.049540938809514046
Time3:  0.1513405479490757
Time1:  0.05470498837530613
Time2:  0.008028106763958931
Time3:  0.017527904361486435

train {'precision@10': 0.20546531677246094, 'precision@30': 0.11203112453222275, 'precision@50': 0.08418343216180801, 'ndcg@10': 0.2586550712585449, 'ndcg@30': 0.22550149261951447, 'ndcg@50': 0.22232019901275635, 'ndcg@all': 0.436257928609848}
valid {'precision@10': 0.19841688871383667, 'precision@30': 0.1105189174413681, 'precision@50': 0.08373614400625229, 'ndcg@10': 0.22550415992736816, 'ndcg@30': 0.19801877439022064, 'ndcg@50': 0.19583581387996674, 'ndcg@all': 0.41462066769599915}
Epoch 1 5.358762289348402 5.63200549075478
Epoch 2 4.1026947107928535 5.08010987231606
Epoch 3 3.277609840471145 4.737294824499833
Epoch 4 2.723377107876783 4.504819230029457
Epoch 5 2.33578712410397 4.368196023137946
Epoch 6 2.054239586779946 4.242001257444683
Epoch 7 1.87393291582141 4.144072

Epoch 135 1.2575212577630204 3.39735494161907
Epoch 136 1.2587859874580338 3.394058679279528
Epoch 137 1.2588528802520351 3.390517749284443
Epoch 138 1.2577525628240485 3.386471434643394
Epoch 139 1.257191134126563 3.3902717640525415
Epoch 140 1.2598868803671228 3.392557482970388
Epoch 141 1.2574038411441602 3.3858758273877596
Epoch 142 1.2567851951247768 3.387982443759316
Epoch 143 1.2556423123119866 3.387971928245143
Epoch 144 1.2555978876805445 3.3840127995139673
Epoch 145 1.2560276720258925 3.3841958045959473
Epoch 146 1.2556004360405326 3.381834180731522
Epoch 147 1.2567581659869145 3.3834549753289473
Epoch 148 1.256479385303475 3.3832312508633264
Epoch 149 1.256467498533907 3.385368322071276
Epoch 150 1.254506065831547 3.3790754017076994
Time1:  0.4843525681644678
Time2:  0.050489213317632675
Time3:  0.1516621895134449
Time1:  0.05436438135802746
Time2:  0.008164893835783005
Time3:  0.017581962049007416

train {'precision@10': 0.5095291137695312, 'precision@30': 0.259674727916717

Epoch 267 1.2398215828583254 3.340048488817717
Epoch 268 1.239216064849095 3.336559295654297
Epoch 269 1.2403309892492684 3.3403472147489848
Epoch 270 1.2397420873418885 3.339729848660921
Epoch 271 1.238817498697872 3.3366229283182243
Epoch 272 1.2401445023497644 3.3305798078838147
Epoch 273 1.2404422512528492 3.337523999967073
Epoch 274 1.2394634575871697 3.3446221602590462
Epoch 275 1.2390470034197758 3.3426906435113204
Epoch 276 1.238864891710337 3.3461053496912907
Epoch 277 1.2399338863746465 3.333189939197741
Epoch 278 1.2396065725220575 3.3405116482784876
Epoch 279 1.2384741853552255 3.3425873706215308
Epoch 280 1.2401665128462496 3.3421010343652022
Epoch 281 1.239274613689958 3.335429304524472
Epoch 282 1.2377914908336618 3.3359299207988538
Epoch 283 1.2402692858935798 3.329199740761205
Epoch 284 1.239063795198474 3.334661496312995
Epoch 285 1.2417156351240057 3.3340547461258736
Epoch 286 1.2385170776941623 3.341941896237825
Epoch 287 1.240433817369896 3.3330900543614437
Epoch 2

Epoch 401 1.2309200547591985 3.3270285882447896
Epoch 402 1.2316580815621985 3.327397007691233
Epoch 403 1.230797084799984 3.3296976842378316
Epoch 404 1.233496264058944 3.3252651691436768
Epoch 405 1.231392666958926 3.3232551499416956
Epoch 406 1.2317682671965213 3.323264435717934
Epoch 407 1.230333629755946 3.325551760824103
Epoch 408 1.2308810377678676 3.3296040233812834
Epoch 409 1.2307747912685774 3.322053206594367
Epoch 410 1.2318375821699177 3.3204206290997957
Epoch 411 1.2313317966740034 3.32558358343024
Epoch 412 1.2331651693199113 3.325477248743961
Epoch 413 1.2313869421245063 3.3170148573423686
Epoch 414 1.231625259271142 3.330118078934519
Epoch 415 1.231686674711997 3.321652788864939
Epoch 416 1.2313124237701907 3.325735192549856
Epoch 417 1.2298018559377792 3.3304693071465743
Epoch 418 1.2309424988707605 3.325980449977674
Epoch 419 1.2306291516761334 3.328049107601768
Epoch 420 1.2311610729373685 3.3276096394187524
Epoch 421 1.2298539740997447 3.31656981769361
Epoch 422 1.

Epoch 548 1.226323379759203 3.3246301851774516
Epoch 549 1.2266624179499888 3.3389627933502197
Epoch 550 1.2262386561137193 3.324533224105835
Time1:  0.48502988182008266
Time2:  0.05011959932744503
Time3:  0.15183335542678833
Time1:  0.054034704342484474
Time2:  0.008094891905784607
Time3:  0.017591675743460655

train {'precision@10': 0.5650560259819031, 'precision@30': 0.2926054298877716, 'precision@50': 0.2092464715242386, 'ndcg@10': 0.7497748136520386, 'ndcg@30': 0.624396800994873, 'ndcg@50': 0.6043345332145691, 'ndcg@all': 0.7579907178878784}
valid {'precision@10': 0.4669129550457001, 'precision@30': 0.2575901746749878, 'precision@50': 0.18890763819217682, 'ndcg@10': 0.5836119055747986, 'ndcg@30': 0.5022545456886292, 'ndcg@50': 0.48964783549308777, 'ndcg@all': 0.6650370359420776}
Epoch 551 1.2247742968693114 3.328445735730623
Epoch 552 1.2260914427495142 3.3324062824249268
Epoch 553 1.2245571205490513 3.320934069784064
Epoch 554 1.225280921710165 3.3311219466359994
Epoch 555 1.2265

Epoch 680 1.2224910071021633 3.329479556334646
Epoch 681 1.2203986292694047 3.3327376466048393
Epoch 682 1.2203714882421215 3.340544349268863
Epoch 683 1.223043962180266 3.3359009968607047
Epoch 684 1.221206231423986 3.332494760814466
Epoch 685 1.2211447887950473 3.3244613848234477
Epoch 686 1.2197212792976557 3.3350739604548405
Epoch 687 1.220973082801752 3.3295965947602926
Epoch 688 1.2208931965437548 3.331462157400031
Epoch 689 1.2214211490419176 3.3332244471499792
Epoch 690 1.220191514631461 3.3273226838362846
Epoch 691 1.2200945695938423 3.3341881601434005
Epoch 692 1.2203943596945868 3.336782969926533
Epoch 693 1.2210693045666343 3.341597004940635
Epoch 694 1.2219985859435902 3.3382940166874935
Epoch 695 1.2207756770981684 3.339051886608726
Epoch 696 1.2197506727530942 3.336682081222534
Epoch 697 1.2194838004502637 3.3295265498914217
Epoch 698 1.2194753202081423 3.337496017154894
Epoch 699 1.220673490337461 3.3323603554775842
Epoch 700 1.2206607060125696 3.3232415224376477
Time1:

Epoch 812 1.2201024687778184 3.3422907026190507
Epoch 813 1.2185490835479826 3.3447787259754382
Epoch 814 1.2185511449624222 3.34516672084206
Epoch 815 1.2172074049536945 3.3427210857993677
Epoch 816 1.220147162501575 3.3423282347227397
Epoch 817 1.21673537973772 3.3406547998127185
Epoch 818 1.2163659125043635 3.341764286944741
Epoch 819 1.2183237975103813 3.3400083717546964
Epoch 820 1.2175627044069837 3.347992269616378
Epoch 821 1.2167053010031494 3.345720805619892
Epoch 822 1.2177938676019857 3.347747940766184
Epoch 823 1.217124862977636 3.343925852524607
Epoch 824 1.218441060760565 3.34426454493874
Epoch 825 1.2179170205579166 3.3486279688383402
Epoch 826 1.2168966977916964 3.3376299707513106
Epoch 827 1.2160156536520572 3.3502594169817472
Epoch 828 1.2174491934608995 3.3431459351589807
Epoch 829 1.2167638874890512 3.3455091526633813
Epoch 830 1.2164120740360684 3.3476367749665914
Epoch 831 1.2178276738925287 3.3412371434663473
Epoch 832 1.2174692471124973 3.350340667523836
Epoch 8

Epoch 951 1.2145649253973487 3.3610455487903796
Epoch 952 1.2152688133786296 3.3641931132266394
Epoch 953 1.2146847798810367 3.362121443999441
Epoch 954 1.2140061294126232 3.3580689430236816
Epoch 955 1.214541891164947 3.362520933151245
Epoch 956 1.216187037919697 3.3647820573104057
Epoch 957 1.2150691001735934 3.369672097657856
Epoch 958 1.215353424437562 3.3552990963584497
Epoch 959 1.2128165765115393 3.359632680290624
Epoch 960 1.2145664301532053 3.3548501918190405
Epoch 961 1.2138162767677976 3.3554646717874625
Epoch 962 1.214707664927544 3.3602983198667826
Epoch 963 1.2155962387720745 3.3571763038635254
Epoch 964 1.2144718131823846 3.3570073152843274
Epoch 965 1.215750839975145 3.3623277262637488
Epoch 966 1.2160770858240406 3.3554169880716422
Epoch 967 1.2163181444357711 3.35803232694927
Epoch 968 1.216225813006797 3.363519543095639
Epoch 969 1.2144569029584962 3.365216430864836
Epoch 970 1.2140905658404033 3.3676430802596244
Epoch 971 1.2147586181846977 3.3624141592728463
Epoch 

### train {'precision@10': 0.5873629450798035, 'precision@30': 0.3073164224624634, 'precision@50': 0.2198205590248108, 'ndcg@10': 0.7637870907783508, 'ndcg@30': 0.638739287853241, 'ndcg@50': 0.6186253428459167, 'ndcg@all': 0.7680783867835999}
### valid {'precision@10': 0.4790501594543457, 'precision@30': 0.2653825879096985, 'precision@50': 0.19505013525485992, 'ndcg@10': 0.5892416834831238, 'ndcg@30': 0.508570671081543, 'ndcg@50': 0.4960639774799347, 'ndcg@all': 0.6699864268302917}

In [31]:
# def MultiLabelMarginLoss_pos(y_pred, y_pos_id, y_neg_id):
#     """
#     ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
#     :param y_pred: predictions from the model, shape [batch_size, slate_length]
#     :param y_true: ground truth labels, shape [batch_size, slate_length]
#     :param eps: epsilon value, used for numerical stability
#     :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
#     :return: loss value, a torch.Tensor
#     """
#     loss = 0
    
#     for i in (range(y_pred.shape[0])):
#         pred, pos_id, neg_id = y_pred[i], y_pos_id[i], y_neg_id[i]
        
# #         idx = (gt == -1).nonzero(as_tuple=True)[0][0]     
# #         pos_id = gt[:idx]
# #         neg_id = set([j for j in range(len(gt))]) - set(pos_id.tolist())
# #         pos_id = torch.tensor(list(pos_id))
# #         neg_id = torch.tensor(list(neg_id))

#         m = pred[pos_id].view(-1, 1) - pred[neg_id].view(1, -1)
#         m = 1 - m
#         l = torch.max(m, torch.zeros(m.shape).to(device))
#         l = torch.sum(l)
        
#         mp = pred[pos_id].view(-1, 1) - pred[pos_id].view(1, -1)
#         mp = 1 - mp - torch.eye(len(pos_id)).to(device)
#         lp = torch.max(mp, torch.zeros(mp.shape).to(device))
#         lp = torch.sum(lp)
#         loss += (l+lp)/y_pred.shape[1]
    
#     loss /= y_pred.shape[0]
    
#     return loss

# x = torch.FloatTensor([[-0.5, 0.2, 0.4, 0.8, 0.9]]).to(device)
# x.requires_grad=True
# y_pos = torch.LongTensor([[3, 1]]).to(device)
# y_neg = torch.LongTensor([[0, 2, 4]]).to(device)

# MultiLabelMarginLoss_pos(x, y_pos, y_neg)

In [32]:
# def MultiLabelMarginLoss_pos(y_pred, y_pos_id, y_neg_id):
#     """
#     ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
#     :param y_pred: predictions from the model, shape [batch_size, slate_length]
#     :param y_true: ground truth labels, shape [batch_size, slate_length]
#     :param eps: epsilon value, used for numerical stability
#     :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
#     :return: loss value, a torch.Tensor
#     """
#     loss = 0
    
#     pn_zero = torch.zeros((y_pos_id.shape[1], y_neg_id.shape[1])).to(device)
#     pp_zero = torch.zeros((y_pos_id.shape[1], y_pos_id.shape[1])).to(device)
#     eye = torch.eye(y_pos_id.shape[1]).to(device)
    
#     for i in (range(y_pred.shape[0])):
#         pred, pos_id, neg_id = y_pred[i], y_pos_id[i], y_neg_id[i]
        
# #         idx = (gt == -1).nonzero(as_tuple=True)[0][0]     
# #         pos_id = gt[:idx]
# #         neg_id = set([j for j in range(len(gt))]) - set(pos_id.tolist())
# #         pos_id = torch.tensor(list(pos_id))
# #         neg_id = torch.tensor(list(neg_id))

#         m = pred[pos_id].view(-1, 1) - pred[neg_id].view(1, -1)
#         m = 1 - m
#         l = torch.max(m, pn_zero)
#         l = torch.sum(l)
        
#         mp = pred[pos_id].view(-1, 1) - pred[pos_id].view(1, -1)
#         mp = 1 - mp - eye
#         lp = torch.max(mp, pp_zero)
#         lp = torch.sum(lp)
#         loss += (l+lp)/y_pred.shape[1]
    
#     loss /= y_pred.shape[0]
    
#     return loss

# x = torch.FloatTensor([[-0.5, 0.2, 0.4, 0.8, 0.9]]).to(device)
# x.requires_grad=True
# y_pos = torch.LongTensor([[3, 1]]).to(device)
# y_neg = torch.LongTensor([[0, 2, 4]]).to(device)

# MultiLabelMarginLoss_pos(x, y_pos, y_neg)

In [38]:
def MultiLabelMarginLossPos(y_pred, y_pos_id, y_neg_id, alpha=1):
    """
    MultiLabelMarginLoss add positive pairs
    y_pos_id -> index of positive target, the same as MultiLabelMarginLoss before -1
    y_neg_id -> index of negative target
    alpha -> magnitude of positive pairs compared to negative pairs
    """

    y_pos = y_pred.gather(1, y_pos_id)
    y_neg = y_pred.gather(1, y_neg_id)
    
    m = y_pos.view(y_pos.shape[0], y_pos.shape[1], 1) - y_neg.view(y_neg.shape[0], 1, y_neg.shape[1])
    m = 1 - m
    l = torch.max(m, torch.zeros(m.shape).to(device))
    l = torch.sum(l)
    
    mp = y_pos.view(y_pos.shape[0], y_pos.shape[1], 1) - y_pos.view(y_pos.shape[0], 1, y_pos.shape[1])

    mp = 1 - mp - torch.eye(mp.shape[-1]).to(device)
    lp = torch.max(mp, torch.zeros(mp.shape).to(device))
    lp = torch.sum(lp)
    
    loss = (l + alpha * lp) / y_pred.shape[1] / y_pred.shape[0]
    
    return loss

x = torch.FloatTensor([[-0.5, 0.2, 0.4, 0.8, 0.9], [-0.5, 0.2, 0.4, 0.8, 0.9]]).to(device)
x.requires_grad=True
y_pos = torch.LongTensor([[3, 1], [3, 1]]).to(device)
y_neg = torch.LongTensor([[0, 2, 4], [0, 2, 4]]).to(device)

MultiLabelMarginLoss_pos(x, y_pos, y_neg)

tensor(1.3800, device='cuda:0', grad_fn=<DivBackward0>)

In [35]:
# setting
lr = 0.1
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 1000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
# criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader2:
        doc_embs, target, target_pos_rank, target_neg_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_pos_rank = target_pos_rank.to(device)
        target_neg_rank = target_neg_rank.to(device)
        
        pred = model(doc_embs)
        loss = MultiLabelMarginLossPos(pred, target_pos_rank, target_neg_rank)
    
        # Model backwarding
        model.zero_grad()
        loss.backward()
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader2:
        doc_embs, target, target_pos_rank, target_neg_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_pos_rank = target_pos_rank.to(device)
        target_neg_rank = target_neg_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        loss = MultiLabelMarginLossPos(pred, target_pos_rank, target_neg_rank)
    
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

Epoch 0 14.563985752083404 10.376786181801243
Time1:  0.48530817590653896
Time2:  0.05050884373486042
Time3:  0.15140090323984623
Time1:  0.053982071578502655
Time2:  0.008150896057486534
Time3:  0.017509274184703827

train {'precision@10': 0.19220079481601715, 'precision@30': 0.14422878623008728, 'precision@50': 0.127905935049057, 'ndcg@10': 0.13978461921215057, 'ndcg@30': 0.14641310274600983, 'ndcg@50': 0.16017544269561768, 'ndcg@all': 0.4227375090122223}
valid {'precision@10': 0.1878627985715866, 'precision@30': 0.14059807360172272, 'precision@50': 0.12414774298667908, 'ndcg@10': 0.1334739476442337, 'ndcg@30': 0.13958598673343658, 'ndcg@50': 0.15216395258903503, 'ndcg@all': 0.4133604168891907}
Epoch 1 7.810711007369192 8.699836505086799
Epoch 2 5.927192278075636 7.943102284481651
Epoch 3 4.8852589590507645 7.908320502230995
Epoch 4 4.240541916841652 7.141800428691663
Epoch 5 3.731580762138144 7.152816948137786
Epoch 6 3.400951780073824 6.930462134511847
Epoch 7 3.0964213867633665 7.

Epoch 136 0.6698262088480051 6.203723531020315
Epoch 137 0.6667650558098018 6.193453613080476
Epoch 138 0.6658122905513697 6.207250419415925
Epoch 139 0.6657338449132373 6.204200945402446
Epoch 140 0.6644852335690058 6.197366187446995
Epoch 141 0.6632090256228085 6.219700236069529
Epoch 142 0.6622286362257618 6.209369057103207
Epoch 143 0.6597332086479455 6.209658848611932
Epoch 144 0.6586667362012362 6.2036137329904655
Epoch 145 0.6576474455364963 6.2206880669844775
Epoch 146 0.6576867507912262 6.220217027162251
Epoch 147 0.6575241548973217 6.2210039841501334
Epoch 148 0.6548841532907987 6.207953578547428
Epoch 149 0.6527762353768822 6.209484602275648
Epoch 150 0.6531577131204438 6.216138438174599
Time1:  0.48524813912808895
Time2:  0.05209263786673546
Time3:  0.15165714174509048
Time1:  0.05401102639734745
Time2:  0.008294003084301949
Time3:  0.017528505995869637

train {'precision@10': 0.6717938184738159, 'precision@30': 0.6415860652923584, 'precision@50': 0.620295524597168, 'ndcg@1

Epoch 269 0.5748184703246891 6.3223183531510205
Epoch 270 0.5732010913174055 6.346611851140072
Epoch 271 0.5727275069694073 6.329144678617778
Epoch 272 0.5720968389371682 6.343059012764378
Epoch 273 0.5716253048495242 6.332570326955695
Epoch 274 0.571628867534169 6.33752253181056
Epoch 275 0.5711224622893751 6.338383900491815
Epoch 276 0.5716086615595901 6.329749433617843
Epoch 277 0.5713306787418343 6.336130368082147
Epoch 278 0.5701666599825809 6.355755078165155
Epoch 279 0.5698894422653823 6.331891059875488
Epoch 280 0.5691483449517635 6.3355174566570085
Epoch 281 0.5697231902713664 6.3384644357781665
Epoch 282 0.5690673722161187 6.3344847277591105
Epoch 283 0.5689052716333266 6.336974420045552
Epoch 284 0.5679056919806185 6.3430041262978
Epoch 285 0.5676511435480843 6.340524798945377
Epoch 286 0.5673437097616363 6.3394979677702255
Epoch 287 0.5668299804654038 6.347302361538536
Epoch 288 0.5661685337797243 6.344789354424727
Epoch 289 0.5661221317380493 6.354895516445763
Epoch 290 0.

Epoch 402 0.5360968207058153 6.426361611014919
Epoch 403 0.535529551450272 6.437639487417121
Epoch 404 0.5345350737460175 6.426363493266859
Epoch 405 0.5345262191448993 6.432425348382247
Epoch 406 0.5340991347853901 6.437035535511217
Epoch 407 0.5341279684451589 6.440703818672581
Epoch 408 0.534466811090882 6.442030304356625
Epoch 409 0.533837738441445 6.434668013924046
Epoch 410 0.5331015579881724 6.435083439475612
Epoch 411 0.533100886651647 6.443674890618575
Epoch 412 0.5326475288435729 6.441915938728734
Epoch 413 0.5326759578191747 6.436111726258931
Epoch 414 0.5331718416241874 6.43978718707436
Epoch 415 0.5326414948318436 6.452534023084138
Epoch 416 0.5312365746637534 6.4446800382513745
Epoch 417 0.5318717287297834 6.439502716064453
Epoch 418 0.5306436206862243 6.431755517658434
Epoch 419 0.5328308894620304 6.430754260012978
Epoch 420 0.5313421609806038 6.450653879266036
Epoch 421 0.5307786726115042 6.436202149642141
Epoch 422 0.5308955886907745 6.440694030962493
Epoch 423 0.53032

Epoch 549 0.5111620441863411 6.513813043895521
Epoch 550 0.511207163682458 6.5191529675533895
Time1:  0.4846885930746794
Time2:  0.051844263449311256
Time3:  0.15175271220505238
Time1:  0.054115207865834236
Time2:  0.008264893665909767
Time3:  0.01750105433166027

train {'precision@10': 0.7130651473999023, 'precision@30': 0.7024140954017639, 'precision@50': 0.6832709908485413, 'ndcg@10': 0.5418848991394043, 'ndcg@30': 0.6450462937355042, 'ndcg@50': 0.7281717658042908, 'ndcg@all': 0.7548179030418396}
valid {'precision@10': 0.433878630399704, 'precision@30': 0.30916449427604675, 'precision@50': 0.2619524896144867, 'ndcg@10': 0.39896899461746216, 'ndcg@30': 0.3806408941745758, 'ndcg@50': 0.39512142539024353, 'ndcg@all': 0.595841646194458}
Epoch 551 0.5112997659465723 6.516752795169228
Epoch 552 0.5109029860175841 6.515046069496556
Epoch 553 0.5106567167050657 6.510393845407586
Epoch 554 0.5108750713498968 6.517717235966733
Epoch 555 0.5098942197554293 6.521425523256001
Epoch 556 0.5102649

Epoch 682 0.4968184088057245 6.5664161380968595
Epoch 683 0.4966203321141806 6.559604669872083
Epoch 684 0.4957735998588696 6.5670317348681
Epoch 685 0.4969809689716986 6.585301349037572
Epoch 686 0.49651797647364654 6.573781565616005
Epoch 687 0.49605552756298354 6.57199407878675
Epoch 688 0.4968748143193317 6.5703005288776595
Epoch 689 0.4964337578991003 6.588527353186357
Epoch 690 0.49631593485324704 6.574982517643979
Epoch 691 0.49542177699462714 6.588081384959974
Epoch 692 0.49568830996926067 6.5684763255872225
Epoch 693 0.4952526265069058 6.564931969893606
Epoch 694 0.4955050770999395 6.572181751853542
Epoch 695 0.4953367879167635 6.584192953611675
Epoch 696 0.4953707862667173 6.575858015763132
Epoch 697 0.49526453837316636 6.584512058057283
Epoch 698 0.49582721615395353 6.573223114013672
Epoch 699 0.49518057006841515 6.570089440596731
Epoch 700 0.49488093473060785 6.575203895568848
Time1:  0.48433514684438705
Time2:  0.051492564380168915
Time3:  0.15164403058588505
Time1:  0.054

Epoch 814 0.48587085733636776 6.62394410685489
Epoch 815 0.4861821987127003 6.613210050683272
Epoch 816 0.48552731794920584 6.628408607683684
Epoch 817 0.48597763714037445 6.626312757793226
Epoch 818 0.48645889201359443 6.6299744154277604
Epoch 819 0.4859402482969719 6.645490219718532
Epoch 820 0.4850939224686539 6.623635869277151
Epoch 821 0.4856348408941637 6.61234019931994
Epoch 822 0.48620637327606914 6.623444030159398
Epoch 823 0.4859933096762986 6.622115110096178
Epoch 824 0.4855875357201225 6.612985284704911
Epoch 825 0.48432909466369806 6.611190218674509
Epoch 826 0.48585493226497495 6.62050209547344
Epoch 827 0.4853754216118863 6.620643465142501
Epoch 828 0.4853215841521994 6.6399109238072445
Epoch 829 0.4854461742423431 6.610394076297157
Epoch 830 0.4852507992794639 6.62677714699193
Epoch 831 0.4849030670018224 6.624011140120657
Epoch 832 0.485117492968576 6.620534771367123
Epoch 833 0.4845675637847499 6.62547977347123
Epoch 834 0.484184787113067 6.623802486218904
Epoch 835 0

Epoch 951 0.4768064133605065 6.657757031290155
Epoch 952 0.47730435526858994 6.653310901240299
Epoch 953 0.4782184268298902 6.65786095669395
Epoch 954 0.47731283917064554 6.6505044886940405
Epoch 955 0.4776344809964386 6.652169277793483
Epoch 956 0.47756186580797383 6.659525394439697
Epoch 957 0.47697924870496605 6.654048919677734
Epoch 958 0.47809444812306184 6.654305734132466
Epoch 959 0.47751660165730975 6.658536735333894
Epoch 960 0.4778016666222734 6.669397128255744
Epoch 961 0.47748891564837675 6.675879127100894
Epoch 962 0.4769464391365386 6.661111455214651
Epoch 963 0.47677228412432976 6.655846269507157
Epoch 964 0.4770177658538372 6.660222179011295
Epoch 965 0.47776188052188584 6.659768731970536
Epoch 966 0.4765037615396823 6.661595018286454
Epoch 967 0.47692857360282137 6.661310221019544
Epoch 968 0.47755426621576497 6.655500186117072
Epoch 969 0.4764488918042322 6.662264648236726
Epoch 970 0.4767209005634687 6.653168126156456
Epoch 971 0.4763489755970693 6.653682231903076
Ep

### train {'precision@10': 0.754506528377533, 'precision@30': 0.7356379628181458, 'precision@50': 0.7033014893531799, 'ndcg@10': 0.5550932884216309, 'ndcg@30': 0.6669020056724548, 'ndcg@50': 0.7454498410224915, 'ndcg@all': 0.7618288993835449}
### valid {'precision@10': 0.4426385462284088, 'precision@30': 0.3160949945449829, 'precision@50': 0.26678627729415894, 'ndcg@10': 0.40172940492630005, 'ndcg@30': 0.3851601779460907, 'ndcg@50': 0.39924025535583496, 'ndcg@all': 0.5973230600357056}