### raw data
* word embedding: glove
* doc text: ./data/IMDB.txt

### dataset
1. IMDB
2. CNNNews
3. [PubMed](https://github.com/LIAAD/KeywordExtractor-Datasets/blob/master/datasets/PubMed.zip)

### preprocess
1. filter too frequent and less frequent words
2. stemming
3. document vector aggregation

### model
1. TopK
2. Sklearn
3. Our model

### evaluation
1. F1
2. NDCG

In [1]:
import os
from collections import defaultdict
import math
import numpy as np 
import random
import re
import torch
import torch.nn as nn
from itertools import cycle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from tqdm.auto import tqdm

# Used to get the data
from sklearn.metrics import ndcg_score

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import nltk
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
nltk.download('stopwords')

import matplotlib.pyplot as plt 
import matplotlib
matplotlib.use('Agg')

seed = 33
import pandas as pd

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/chrisliu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Preprocess config

In [2]:
config = {}

config["dataset"] = "CNN" # "IMDB" "CNN", "PubMed"
config["n_document"] = 10000
config["normalize_word_embedding"] = True
config["min_word_freq_threshold"] = 20
config["topk_word_freq_threshold"] = 100
config["document_vector_agg_weight"] = 'IDF' # ['mean', 'IDF', 'uniform', 'gaussian', 'exponential', 'pmi']
config["document_vector_weight_normalize"] = False # weighted sum or mean, True for sum, False for mean 
config["select_topk_TFIDF"] = None # ignore
config["embedding_file"] = "../data/glove.6B.100d.txt"
config["topk"] = [10, 30, 50]


In [3]:
def in_notebook():
    try:
        from IPython import get_ipython
        if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
            return False
    except ImportError:
        return False
    return True

In [4]:
def load_word2emb(embedding_file):
    word2embedding = dict()
    word_dim = int(re.findall(r".(\d+)d", embedding_file)[0])

    with open(embedding_file, "r") as f:
        for line in tqdm(f):
            line = line.strip().split()
            word = line[0]
            embedding = list(map(float, line[1:]))
            word2embedding[word] = np.array(embedding)

    print("Number of words:%d" % len(word2embedding))

    return word2embedding

word2embedding = load_word2emb(config["embedding_file"])

0it [00:00, ?it/s]

Number of words:400000


In [5]:
def normalize_wordemb(word2embedding):
    # Every word emb should have norm 1
    
    word_emb = []
    word_list = []
    for word, emb in word2embedding.items():
        word_list.append(word)
        word_emb.append(emb)

    word_emb = np.array(word_emb)

    for i in range(len(word_emb)):
        norm = np.linalg.norm(word_emb[i])
        word_emb[i] = word_emb[i] / norm

    for word, emb in tqdm(zip(word_list, word_emb)):
        word2embedding[word] = emb
    return word2embedding

if config["normalize_word_embedding"]:
    normalize_wordemb(word2embedding)

0it [00:00, ?it/s]

In [6]:
class Vocabulary:
    def __init__(self, word2embedding, config):
        # The low frequency words will be assigned as <UNK> token
        self.itos = {0: "<UNK>"}
        self.stoi = {"<UNK>": 0}
        
        self.word2embedding = word2embedding
        self.config = config

        self.word_freq_in_corpus = defaultdict(int)
        self.IDF = {}
        self.ps = PorterStemmer()
        self.stop_words = set(stopwords.words('english'))
        
        self.word_dim = len(word2embedding['the'])
    def __len__(self):
        return len(self.itos)

    def tokenizer_eng(self, text):
        text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
        text = text.strip().split()
        
        return [self.ps.stem(w) for w in text if w.lower() not in self.stop_words]
    
    def read_raw(self):        
        if self.config["dataset"] == 'IMDB':
            data_file_path = '../data/IMDB.txt'
        elif self.config["dataset"] == 'CNN':
            data_file_path = '../data/CNN.txt'
        elif self.config["dataset"] == 'PubMed':
            data_file_path = '../data/PubMed.txt'
        
        # raw documents
        self.raw_documents = []
        with open(data_file_path,'r',encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading documents"):
                self.raw_documents.append(line.strip("\n"))
                
        return self.raw_documents
    
    def build_vocabulary(self):
        sentence_list = self.raw_documents
        
        self.doc_freq = defaultdict(int) # # of document a word appear
        self.document_num = len(sentence_list)
        self.word_vectors = [[0]*self.word_dim] # unknown word emb
        
        for sentence in tqdm(sentence_list, desc="Preprocessing documents"):
            # for doc_freq
            document_words = set()
            
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.word2embedding:
                    continue
                    
                # calculate word freq
                self.word_freq_in_corpus[word] += 1
                document_words.add(word)
                
            for word in document_words:
                self.doc_freq[word] += 1
        
        # calculate IDF
        print('doc num', self.document_num)
        for word, freq in self.doc_freq.items():
            self.IDF[word] = math.log(self.document_num / (freq+1))
        
        # delete less freq words:
        delete_words = []
        for word, v in self.word_freq_in_corpus.items():
            if v < self.config["min_word_freq_threshold"]:
                delete_words.append(word)     
        for word in delete_words:
            del self.IDF[word]    
            del self.word_freq_in_corpus[word]    
        
        # delete too freq words
        print('eliminate freq words')
        IDF = [(word, freq) for word, freq in self.IDF.items()]
        IDF.sort(key=lambda x: x[1])

        for i in range(self.config["topk_word_freq_threshold"]):
            print(word)
            word = IDF[i][0]
            del self.IDF[word]
            del self.word_freq_in_corpus[word]
        
        # construct word_vectors
        idx = 1
        for word in self.word_freq_in_corpus:
            self.word_vectors.append(self.word2embedding[word])
            self.stoi[word] = idx
            self.itos[idx] = word
            idx += 1
            
    def init_word_weight(self,sentence_list, agg):
        if agg == 'mean':
            self.word_weight = {word: 1 for word in self.IDF.keys()}
        elif agg == 'IDF':
            self.word_weight = self.IDF
        elif agg == 'uniform':
            self.word_weight = {word: np.random.uniform(low=0.0, high=1.0) for word in self.IDF.keys()}
        elif agg == 'gaussian':
            mu, sigma = 10, 1 # mean and standard deviation
            self.word_weight = {word: np.random.normal(mu, sigma) for word in self.IDF.keys()}
        elif agg == 'exponential':
            self.word_weight = {word: np.random.exponential(scale=1.0) for word in self.IDF.keys()}
        elif agg == 'pmi':
            trigram_measures = BigramAssocMeasures()
            self.word_weight = defaultdict(int)
            corpus = []

            for text in tqdm(sentence_list):
                corpus.extend(text.split())

            finder = BigramCollocationFinder.from_words(corpus)
            for pmi_score in finder.score_ngrams(trigram_measures.pmi):
                pair, score = pmi_score
                self.word_weight[pair[0]] += score
                self.word_weight[pair[1]] += score
                
    def calculate_document_vector(self):
        # Return
        # document_vectors: weighted sum of word emb
        # document_answers_idx: doc to word index list
        # document_answers_wsum: word weight summation, e.g. total TFIDF score of a doc
        
        document_vectors = [] 
        document_answers = []
        document_answers_wsum = []
        
        sentence_list = self.raw_documents
        agg = self.config["document_vector_agg_weight"]
        n_document = self.config["n_document"]
        select_topk_TFIDF = self.config["select_topk_TFIDF"]
        
        self.init_word_weight(sentence_list, agg)
        for sentence in tqdm(sentence_list[:min(n_document, len(sentence_list))], desc="calculate document vectors"):
            document_vector = np.zeros(len(self.word_vectors[0]))
            select_words = []
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.stoi:
                    continue
                else:
                    select_words.append(word)

            # select topk TDIDF
            if select_topk_TFIDF is not None:
                doc_TFIDF = defaultdict(float)
                for word in select_words:    
                    doc_TFIDF[word] += self.IDF[word]

                doc_TFIDF_l = [(word, TFIDF) for word, TFIDF in doc_TFIDF.items()]
                doc_TFIDF_l.sort(key=lambda x:x[1], reverse=True)
                
                select_topk_words = set(list(map(lambda x:x[0], doc_TFIDF_l[:select_topk_TFIDF])))
                select_words = [word for word in select_words if word in select_topk_words]
            else:
                pass
            
            total_weight = 0
            # aggregate to doc vectors
            for word in select_words:
                document_vector += np.array(self.word2embedding[word]) * self.word_weight[word]
                total_weight += self.word_weight[word]
                
            if len(select_words) == 0:
                print('error', sentence)
                continue
            else:
                if self.config["document_vector_weight_normalize"]:
                    document_vector /= total_weight
            
            document_vectors.append(document_vector)
            document_answers.append(select_words)
            document_answers_wsum.append(total_weight)
        
        # get answers
        document_answers_idx = []    
        for ans in document_answers:
            ans_idx = []
            for token in ans:
                if token in self.stoi:
                    ans_idx.append(self.stoi[token])                    
            document_answers_idx.append(ans_idx)
        
        self.document_vectors = document_vectors
        self.document_answers_idx = document_answers_idx
        self.document_answers_wsum = document_answers_wsum
        
        return document_vectors, document_answers_idx, document_answers_wsum
        
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
    
    def check_docemb(self):
        word_vectors = np.array(self.word_vectors)
        pred = np.zeros(word_vectors.shape[1])
        cnt = 0

        for word_idx in self.document_answers_idx[0]:
            pred += word_vectors[word_idx] * self.word_weight[self.itos[word_idx]]
            cnt += self.word_weight[self.itos[word_idx]]
        
        if self.config["document_vector_weight_normalize"]:
            pred /= cnt
        assert np.sum(self.document_vectors[0]) - np.sum(pred) == 0

In [7]:
def build_vocab(config, word2embedding):
    # build vocabulary
    vocab = Vocabulary(word2embedding, config)
    vocab.read_raw()
    vocab.build_vocabulary()
    vocab_size = len(vocab)
    # get doc emb
    vocab.calculate_document_vector()
    vocab.check_docemb()
    
    return vocab

vocab = build_vocab(config, word2embedding)

Loading documents: 0it [00:00, ?it/s]

Preprocessing documents:   0%|          | 0/19026 [00:00<?, ?it/s]

doc num 19026
eliminate freq words
paroxysm
subject
line
organ
write
univers
one
would
use
like
get
know
dont
think
time
make
also
say
go
im
could
want
new
work
good
well
way
need
look
even
anyon
thing
see
tri
thank
much
year
world
system
right
problem
may
take
mani
two
first
seem
question
pleas
1
state
us
come
2
post
help
call
usa
point
sinc
find
read
still
back
mean
ive
give
email
sure
differ
might
run
cant
reason
last
day
interest
case
let
person
said
never
start
doesnt
tell
better
ask
got
without
follow
part
lot
3
number
put
fact
gener
inform
actual
that


calculate document vectors:   0%|          | 0/10000 [00:00<?, ?it/s]

error |> 
error 
error 
error Mikael Fredriksson
error 
error -------------------------------------------------
error email: mikael_fredriksson@macexchange.se
error 
error FIDO 2:203/211
error  
error  


In [8]:
print("Finish building dataset!")
print(f"Number of documents:{len(vocab.raw_documents)}")
print(f"Number of words:{len(vocab)}")

l = list(map(len, vocab.document_answers_idx))
print("Average length of document:", np.mean(l))

Finish building dataset!
Number of documents:19026
Number of words:7602
Average length of document: 86.49083992391631


In [9]:
word_vectors = np.array(vocab.word_vectors)
print(word_vectors.shape)

document_vectors = np.array(vocab.document_vectors)
print(document_vectors.shape)

document_answers_wsum = np.array(vocab.document_answers_wsum).reshape(-1, 1)
print(document_answers_wsum.shape)

# create weight_ans
document_answers_idx = vocab.document_answers_idx

# random shuffle
shuffle_idx = list(range(len(document_vectors)))
random.Random(seed).shuffle(shuffle_idx)

document_vectors = document_vectors[shuffle_idx]
document_answers_wsum = document_answers_wsum[shuffle_idx]
document_answers_idx = [document_answers_idx[idx] for idx in shuffle_idx]

(7602, 100)
(9989, 100)
(9989, 1)


In [10]:
# onthot_ans: word freq matrix
# weight_ans: TFIDF matrix

onehot_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
weight_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
print(weight_ans.shape)

for i in tqdm(range(len(document_answers_idx))):
    for word_idx in document_answers_idx[i]:
        weight_ans[i, word_idx] += vocab.word_weight[vocab.itos[word_idx]]
        onehot_ans[i, word_idx] += 1

(9989, 7602)


  0%|          | 0/9989 [00:00<?, ?it/s]

## Results

In [11]:
final_results = []
select_columns = ['model']
for topk in config["topk"]:
    select_columns.append('F1@{}'.format(topk))
for topk in config["topk"]:
    select_columns.append('ndcg@{}'.format(topk))
select_columns.append('ndcg@all')
select_columns

['model',
 'F1@10',
 'F1@30',
 'F1@50',
 'ndcg@10',
 'ndcg@30',
 'ndcg@50',
 'ndcg@all']

## setting training size

In [12]:
train_size_ratio = 1
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

9989

## Top K freq word

In [13]:
topk_results = {}

In [14]:
test_ans = document_answers_idx[:train_size]

In [15]:
word_freq = [(word, freq) for word, freq in vocab.word_freq_in_corpus.items()]
word_freq.sort(key=lambda x:x[1], reverse=True)
word_freq[:10]

[('x', 6539),
 ('god', 5208),
 ('file', 4918),
 ('0', 4520),
 ('window', 4444),
 ('program', 4201),
 ('drive', 3633),
 ('4', 3528),
 ('game', 3474),
 ('govern', 3268)]

In [16]:
def topk_word_evaluation(k=50):
    topk_word = [word for (word, freq) in word_freq[:k]]

    pr, re = [], []
    for ans in tqdm(test_ans):
        ans = set(ans)
        ans = [vocab.itos[a] for a in ans]

        hit = []
        for word in ans:
            if word in topk_word:
                hit.append(word)

        precision = len(hit) / k
        recall = len(hit) / len(ans)
        pr.append(precision)
        re.append(recall)

    pr = np.mean(pr)
    re = np.mean(re)
    f1 = 2 * pr * re / (pr + re) if (pr + re) != 0 else 0
    print('top {} word'.format(k))
    print('percision', np.mean(pr))
    print('recall', np.mean(re))
    print('F1', f1)
    return f1


for topk in config['topk']:
    topk_results["F1@{}".format(topk)] = topk_word_evaluation(k=topk)


  0%|          | 0/9989 [00:00<?, ?it/s]

top 10 word
percision 0.07090799879867855
recall 0.016414023763480924
F1 0.02665732064265653


  0%|          | 0/9989 [00:00<?, ?it/s]

top 30 word
percision 0.07792238128608137
recall 0.04873623777882011
F1 0.05996660520519939


  0%|          | 0/9989 [00:00<?, ?it/s]

top 50 word
percision 0.07793572930223247
recall 0.07962461018581644
F1 0.07877111823192375


In [17]:
def topk_word_evaluation_NDCG(k=50):
    freq_word =[word for (word, freq) in word_freq]
    freq_word_idx = [vocab.stoi[word] for word in freq_word if word in vocab.stoi]
    
    scores = np.zeros(len(vocab.word_vectors))
    for rank, idx in enumerate(freq_word_idx):
        scores[idx] = len(vocab.word_vectors) - rank
    
    NDCGs = []
    
    for ans in tqdm(test_ans):
        weight_ans = np.zeros(len(vocab.word_vectors))
        
        for word_idx in ans:
            if word_idx == 0:
                continue
            word = vocab.itos[word_idx]
            weight_ans[word_idx] += vocab.IDF[word]

        NDCG_score = ndcg_score(weight_ans.reshape(1,-1), scores.reshape(1,-1), k=k)
        NDCGs.append(NDCG_score)

    print('top {} NDCG:{}'.format(k, np.mean(NDCGs)))
    
    return np.mean(NDCGs)


for topk in config['topk']:
    topk_results["ndcg@{}".format(topk)] = topk_word_evaluation_NDCG(k=topk)
    
topk_results["ndcg@all"] = topk_word_evaluation_NDCG(k=None)


  0%|          | 0/9989 [00:00<?, ?it/s]

top 10 NDCG:0.02843140685579532


  0%|          | 0/9989 [00:00<?, ?it/s]

top 30 NDCG:0.03790739702719704


  0%|          | 0/9989 [00:00<?, ?it/s]

top 50 NDCG:0.046074423314069815


  0%|          | 0/9989 [00:00<?, ?it/s]

top None NDCG:0.31543624801354186


In [18]:
topk_results["model"] = "topk"
final_results.append(pd.Series(topk_results))

## Sklearn

In [19]:
from sklearn.linear_model import LinearRegression, Lasso

In [20]:
print(document_vectors.shape)
print(weight_ans.shape)
print(word_vectors.shape)

(9989, 100)
(9989, 7602)
(7602, 100)


In [21]:
def evaluate_sklearn(pred, ans):
    results = {}
        
    one_hot_ans = np.arange(ans.shape[0])[ans > 0]
    
    for topk in config["topk"]:
        one_hot_pred = np.argsort(pred)[-topk:]
        hit = np.intersect1d(one_hot_pred, one_hot_ans)
        percision = len(hit) / topk
        recall = len(hit) / len(one_hot_ans)
        f1 = 2 * percision * recall / (percision + recall) if (percision + recall) > 0 else 0
        
        results['F1@{}'.format(topk)] = f1
        
    ans = ans.reshape(1, -1)
    pred = pred.reshape(1, -1)
    for topk in config["topk"]:
        results['ndcg@{}'.format(topk)] = ndcg_score(ans, pred, k=topk)

    results['ndcg@all'] = (ndcg_score(ans, pred, k=None))
    
    return results

### linear regression

In [22]:
# results = []

# for doc_id, doc_emb in enumerate(tqdm(document_vectors[:train_size])):
#     x = word_vectors.T
#     y = doc_emb
    
#     ans = weight_ans[doc_id]
#     model = LinearRegression(fit_intercept=False).fit(x, y)
#     r2 = model.score(x, y)

#     res = evaluate_sklearn(model.coef_, ans)
#     results.append(res)

In [23]:
# results = pd.DataFrame(results).mean()
# results['model'] = 'sk-linear-regression'
# final_results.append(results)
# results

### lasso

In [24]:
# results = []
# sk_lasso_epoch = 10000

# for doc_id, doc_emb in enumerate(tqdm(document_vectors[:train_size])):
#     x = word_vectors.T
#     y = doc_emb
    
#     ans = weight_ans[doc_id]
#     model = Lasso(positive=True, fit_intercept=False, alpha=0.0001, max_iter=sk_lasso_epoch, tol=0).fit(x, y)
#     r2 = model.score(x, y)

#     res = evaluate_sklearn(model.coef_, ans)
#     results.append(res)

In [25]:
# results = pd.DataFrame(results).mean()
# results['model'] = 'sk-lasso'
# final_results.append(results)
# results

## Our Model

In [26]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [27]:
class Custom_Lasso_Dataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 doc_w_sum,
                 weight_ans
                 ):
        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.doc_w_sum = torch.FloatTensor(doc_w_sum)
        self.weight_ans = weight_ans
        assert len(doc_vectors) == len(doc_w_sum)
        
    def __getitem__(self, idx):
                
        return self.doc_vectors[idx], self.doc_w_sum[idx], idx

    def __len__(self):
        return len(self.doc_vectors)


In [28]:
class LR(nn.Module):
    """
    Input shape: (N, 3, 64, 64)
    Output shape: (N, )
    """
    def __init__(self, num_doc, num_words):
        super(LR, self).__init__()
        weight = torch.zeros(num_doc, num_words).to(device)
        self.emb = torch.nn.Embedding.from_pretrained(weight, freeze=False)
        
    def forward(self, doc_ids, word_vectors):
        return self.emb(doc_ids) @ word_vectors

In [29]:
def evaluate_NDCG(model, train_loader):
    results = {}
    model.eval()
    
    scores = np.array(model.emb.cpu().weight.data)
    model.emb.to(device)
    true_relevance = train_loader.dataset.weight_ans

    # F1
    F1s = []
    for i in range(true_relevance.shape[0]):
        one_hot_ans = np.arange(true_relevance.shape[1])[true_relevance[i] > 0]
        pred = scores[i]
        
        F1 = []
        for topk in config["topk"]:
            one_hot_pred = np.argsort(pred)[-topk:]
            
            hit = np.intersect1d(one_hot_pred, one_hot_ans)
            percision = len(hit) / topk
            recall = len(hit) / len(one_hot_ans)
            
            ans = 2 * percision * recall / (percision + recall) if (percision + recall) > 0 else 0
            F1.append(ans)
        F1s.append(F1)
        
    F1s = np.mean(F1s, axis=0)
    
    for i, topk in enumerate(config["topk"]):
        results['F1@{}'.format(topk)] = F1s[i]

    # NDCG
    for topk in config["topk"]:
        results['ndcg@{}'.format(topk)] = ndcg_score(true_relevance, scores, k=topk)
    results['ndcg@all'] = ndcg_score(true_relevance, scores, k=None)
    
    return results

In [30]:
batch_size = 100
print('document num', train_size)

train_dataset = Custom_Lasso_Dataset(document_vectors[:train_size], document_answers_wsum[:train_size], weight_ans[:train_size])
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

document num 9989


## start training

In [31]:
# # setting
# lr = 0.5
# momentum = 0.99
# weight_decay = 0
# nesterov = False # True

# n_epoch = 50000

# w_sum_reg = 1e-3
# w_sum_reg_mul = 0.9
# w_clip_value = 0

# L1 = 1e-6

# verbose = False
# valid_epoch = 100

# model = LR(num_doc=train_size, num_words=word_vectors.shape[0]).to(device)
# model.train()

# word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)
    
# opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# criterion = nn.MSELoss(reduction='mean')

# results = []
# step = 0
# for epoch in tqdm(range(n_epoch)):    
#     loss_mse_his = []
#     loss_w_reg_his = []
    
#     model.train()

#     for data in train_loader:
#         doc_embs, doc_w_sum, doc_ids = data
        
#         doc_embs = doc_embs.to(device)
#         doc_w_sum = doc_w_sum.to(device)
#         doc_ids = doc_ids.to(device)
        
#         w_reg = (torch.ones(doc_embs.size(0), 1) * w_sum_reg_mul).to(device)
        
#         # MSE loss
#         pred_doc_embs = model(doc_ids, word_vectors_tensor)     
#         loss_mse = criterion(pred_doc_embs, doc_embs)

#         pred_w_sum = torch.sum(model.emb(doc_ids), axis=1).view(-1, 1)
#         loss_w_reg = criterion(pred_w_sum, w_reg)
        
#         loss_l1 = torch.sum(torch.abs(model.emb(doc_ids)))
#         loss = loss_mse + loss_w_reg * w_sum_reg + loss_l1 * L1
        
#         # Model backwarding
#         model.zero_grad()
#         loss.backward()
#         opt.step()

#         loss_mse_his.append(loss_mse.item())
#         loss_w_reg_his.append(loss_w_reg.item())

#         for p in model.parameters():
#             p.data.clamp_(w_clip_value, float('inf'))

        
#     if epoch % valid_epoch == 0:
#         res = {}
#         res['epoch'] = epoch
#         res['loss_mse'] = np.mean(loss_mse_his)
#         res['loss_w_reg'] = np.mean(loss_w_reg_his)
        
#         res_ndcg = evaluate_NDCG(model, train_loader)
#         res.update(res_ndcg)
#         results.append(res)
        
#         if verbose:
#             print()
#             for k, v in res.items():
#                 print(k, v)

In [32]:
# pd.set_option('display.max_rows', 500)
# results_df = pd.DataFrame(results).set_index('epoch')
# results_df

NameError: name 'results' is not defined

In [33]:
# results_df['model'] = 'our-lasso'
# final_results.append(results_df[select_columns].iloc[-1])

## Quality Check

In [34]:
# # select doc_id and k
# doc_id = 40
# topk = 30

# model

In [35]:
# import colored
# from colored import stylize

# word_list = vocab.itos

# gt = [word_list[word_idx] for word_idx in np.argsort(weight_ans[doc_id])[::-1][:topk]]
# pred = [word_list[word_idx] for word_idx in np.argsort(model.emb.cpu().weight.data[doc_id].numpy())[::-1][:topk]]

# print('ground truth')
# for word in gt:
#     if word in pred:
#         print(stylize(word, colored.bg("yellow")), end=' ')
#     else:
#         print(word, end=' ')

# print()
# print('\nprediction')
# for word in pred:
#     if word in gt:
#         print(stylize(word, colored.bg("yellow")), end=' ')
#     else:
#         print(word, end=' ')


In [36]:
# # raw document
# print()
# ps = PorterStemmer()
    
# for word in vocab.raw_documents[doc_id].split():
#     word_stem = ps.stem(word).lower()

#     if word_stem in gt:
#         if word_stem in pred:
#             print(stylize(word, colored.bg("yellow")), end=' ')
#         else:
#             print(stylize(word, colored.bg("light_gray")), end=' ')
#     else:
#         print(word, end=' ')
# # print(dataset.documents[doc_id])

In [37]:
# results = {}
   
# scores = np.array(model.emb.weight.data)[doc_id].reshape(1, -1)
# true_relevance = train_loader.dataset.weight_ans[doc_id].reshape(1, -1)

# results['ndcg@50'] = (ndcg_score(true_relevance, scores, k=50))
# results['ndcg@100'] = (ndcg_score(true_relevance, scores, k=100))
# results['ndcg@200'] = (ndcg_score(true_relevance, scores, k=200))
# results['ndcg@all'] = (ndcg_score(true_relevance, scores, k=None))

# print('This document ndcg:')
# print('ground truth length:', np.sum(weight_ans[doc_id] > 0))
# print('NDCG top50', results['ndcg@50'])
# print('NDCG top100', results['ndcg@100'])
# print('NDCG top200', results['ndcg@200'])
# print('NDCG ALL', results['ndcg@all'])


## Final results

In [38]:
is_notebook = in_notebook()

In [39]:
# final_results_df = pd.DataFrame(final_results).reset_index(drop=True)

# experiment_dir = './records/dataset-{}-n_document-{}-wdist-{}-filtertopk-{}'.format(
#                                         config['dataset'],
#                                         config['n_document'],
#                                         config["document_vector_agg_weight"],
#                                         config["topk_word_freq_threshold"])

# print('Saving to directory', experiment_dir)
# os.makedirs(experiment_dir, exist_ok=True)

In [40]:
# final_results_df.to_csv(os.path.join(experiment_dir, 'result.csv'), index=False)

# import json
# with open(os.path.join(experiment_dir, 'config.json'), 'w') as json_file:
#     json.dump(config, json_file)

In [41]:
# for feat in final_results_df.set_index('model').columns:
#     plt.bar(final_results_df['model'],
#             final_results_df[feat], 
#             width=0.5, 
#             bottom=None, 
#             align='center', 
#             color=['lightsteelblue', 
#                    'cornflowerblue', 
#                    'royalblue', 
#                    'navy'])
#     plt.title(feat)
#     plt.savefig(os.path.join(experiment_dir, '{}.png'.format(feat)))
#     plt.clf()
#     if is_notebook:
#         plt.show()

In [42]:
# print(final_results_df)

## MLP Decoder

In [43]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [44]:
class MLPDecoderDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans,
                 topk=50):
        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.weight_ans_s[:, topk:] = -1
        
        assert len(doc_vectors) == len(weight_ans)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [45]:
batch_size = 100

train_size_ratio = 0.9
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

print('train size', train_size)
print('valid size', len(document_vectors) - train_size)

train_dataset = MLPDecoderDataset(document_vectors[:train_size], weight_ans[:train_size], topk=50)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

valid_dataset = MLPDecoderDataset(document_vectors[train_size:], weight_ans[train_size:], topk=50)
valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

train size 8990
valid size 999


In [46]:
class MLPDecoder(nn.Module):
    def __init__(self, doc_emb_dim, num_words, h_dim=300):
        super().__init__()
        
        self.fc1 = nn.Linear(doc_emb_dim, h_dim) 
#         self.fc2 = nn.Linear(h_dim, h_dim)
#         self.fc3 = nn.Linear(h_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, num_words)
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = F.tanh(self.fc1(x))
#         x = F.tanh(self.fc2(x))
#         x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.fc4(x)
        
        return x

In [47]:
from sklearn.metrics import f1_score
import timeit
from torchmetrics import F1
from torchmetrics.functional import retrieval_normalized_dcg

def evaluate_MLPDecoder(model, data_loader):
    results = {}
    model.eval()
    
    pred_all = []
    target_all = []
    
    # predict all data
    start = timeit.default_timer()
    for data in data_loader:
        doc_embs, target, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
        pred_all.append(pred)
        target_all.append(target)
        
    pred_all = torch.cat(pred_all, dim=0)
    target_all = torch.cat(target_all, dim=0)
    stop1 = timeit.default_timer()
    print('Time1: ', stop1 - start)  
    
    # F1
    for topk in config["topk"]:
        f1 = F1(top_k=topk)
        f1_score = f1(pred_all.cpu(), torch.sign(target_all).int().cpu())
        results['F1@{}'.format(topk)] = f1_score.item()
    stop2 = timeit.default_timer()
    print('Time2: ', stop2 - stop1)  
    
    # NDCG
    for topk in config["topk"]:
        ndcg_scores = []
        for i in range(pred_all.shape[0]):
            ndcg_scores.append(retrieval_normalized_dcg(pred_all[i], target_all[i], k=topk).item())     
        results['ndcg@{}'.format(topk)] = np.mean(ndcg_scores)
        
    ndcg_scores = []
    for i in range(pred_all.shape[0]):
        ndcg_scores.append(retrieval_normalized_dcg(pred_all[i], target_all[i]).item())     
    results['ndcg@all'] = np.mean(ndcg_scores)
    stop3 = timeit.default_timer()
    print('Time3: ', stop3 - stop2)      
    
    return results

In [48]:
DEFAULT_EPS = 1e-10
PADDED_Y_VALUE = -1

In [54]:
# setting
lr = 0.05
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 5000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        # Model backwarding
        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)
            


  0%|          | 0/5000 [00:00<?, ?it/s]

Epoch 0 16.27092376285129 11.797836303710938
Time1:  0.31581695564091206
Time2:  5.738117668777704
Time3:  20.17899980954826
Time1:  0.029079269617795944
Time2:  0.6093807443976402
Time3:  2.2846240177750587

train {'F1@10': 0.022304276004433632, 'F1@30': 0.054059337824583054, 'F1@50': 0.08223669230937958, 'ndcg@10': 0.046150971230938385, 'ndcg@30': 0.06001905465219636, 'ndcg@50': 0.07714005401871453, 'ndcg@all': 0.36135738962583736}
valid {'F1@10': 0.022343382239341736, 'F1@30': 0.05248989164829254, 'F1@50': 0.07892013341188431, 'ndcg@10': 0.046766206919046885, 'ndcg@30': 0.05951047831497438, 'ndcg@50': 0.07486342325190397, 'ndcg@all': 0.3565012764584672}
Epoch 1 8.92164199617174 9.856325054168702
Epoch 2 6.442120212978787 9.010136842727661
Epoch 3 5.077125724156698 8.550235843658447
Epoch 4 4.203001186582777 8.289485311508178
Epoch 5 3.593711259629991 8.111674070358276
Epoch 6 3.1613865507973564 7.971252822875977
Epoch 7 2.8388562917709352 7.8873515129089355
Epoch 8 2.588604919115702

Epoch 138 0.30710871981249915 7.359880208969116
Epoch 139 0.30371472239494324 7.3572102069854735
Epoch 140 0.3022480378548304 7.3659289360046385
Epoch 141 0.30153957075542875 7.3633434772491455
Epoch 142 0.30019175046020086 7.367502212524414
Epoch 143 0.29837778045071495 7.370369482040405
Epoch 144 0.2964234656757779 7.365205764770508
Epoch 145 0.29498112201690674 7.368148612976074
Epoch 146 0.2935927285088433 7.362018489837647
Epoch 147 0.29078970717059244 7.370318078994751
Epoch 148 0.2906249708599514 7.367989540100098
Epoch 149 0.2885123252868652 7.3690392017364506
Epoch 150 0.2881535937388738 7.361169052124024
Time1:  0.2586024980992079
Time2:  4.051573064178228
Time3:  15.940886987373233
Time1:  0.02937052771449089
Time2:  0.42283318378031254
Time3:  1.754525724798441

train {'F1@10': 0.16021370887756348, 'F1@30': 0.39880651235580444, 'F1@50': 0.568867027759552, 'ndcg@10': 0.43356649215427434, 'ndcg@30': 0.5214489496970867, 'ndcg@50': 0.6168439858943199, 'ndcg@all': 0.692684572865

Epoch 272 0.18382354312472873 7.4549102783203125
Epoch 273 0.18319454424911075 7.454755544662476
Epoch 274 0.1823247199257215 7.452724933624268
Epoch 275 0.18207369877232446 7.466057682037354
Epoch 276 0.18207059287362629 7.460947799682617
Epoch 277 0.18161441485087076 7.454405498504639
Epoch 278 0.1816793700059255 7.46027102470398
Epoch 279 0.1805549446079466 7.464733695983886
Epoch 280 0.1810360550880432 7.469304418563842
Epoch 281 0.1796016866962115 7.467050790786743
Epoch 282 0.1791052887837092 7.471155977249145
Epoch 283 0.17918475137816536 7.46147608757019
Epoch 284 0.17797645578781765 7.459161710739136
Epoch 285 0.17749439941512213 7.463137483596801
Epoch 286 0.17736505998505486 7.469886207580567
Epoch 287 0.17658910519546933 7.468975162506103
Epoch 288 0.17569056500991184 7.470464134216309
Epoch 289 0.1768105637696054 7.47409200668335
Epoch 290 0.17539538542429606 7.477338218688965
Epoch 291 0.17525885668065813 7.469485187530518
Epoch 292 0.17476791391770044 7.472585535049438
E

Epoch 406 0.13802005391981867 7.550264883041382
Epoch 407 0.13789961561560632 7.55513710975647
Epoch 408 0.13768347021606234 7.556053400039673
Epoch 409 0.1376279557744662 7.54733452796936
Epoch 410 0.13683834175268808 7.558027458190918
Epoch 411 0.13618515125579303 7.563369369506836
Epoch 412 0.1364024424718486 7.551956129074097
Epoch 413 0.13581805965966648 7.558000946044922
Epoch 414 0.13537215929892327 7.55357928276062
Epoch 415 0.13526911611358325 7.553326511383057
Epoch 416 0.13528130261434448 7.551773738861084
Epoch 417 0.13555233139130804 7.550339984893799
Epoch 418 0.1354444960753123 7.55699405670166
Epoch 419 0.13541649762127134 7.557705736160278
Epoch 420 0.13464082520869042 7.554332828521728
Epoch 421 0.1349620213939084 7.559082412719727
Epoch 422 0.13376895462473234 7.560801887512207
Epoch 423 0.1332315791812208 7.558249759674072
Epoch 424 0.1335543279018667 7.553527402877807
Epoch 425 0.1344457829164134 7.556572103500367
Epoch 426 0.1333683422870106 7.557716274261475
Epoc

Time2:  0.4315410442650318
Time3:  1.7479445151984692

train {'F1@10': 0.18996410071849823, 'F1@30': 0.47269144654273987, 'F1@50': 0.66602623462677, 'ndcg@10': 0.5046630417598666, 'ndcg@30': 0.6118511625657789, 'ndcg@50': 0.7157400509613971, 'ndcg@all': 0.7357560570549647}
valid {'F1@10': 0.11421271413564682, 'F1@30': 0.19712920486927032, 'F1@50': 0.22898586094379425, 'ndcg@10': 0.35749135861309383, 'ndcg@30': 0.3472371970320137, 'ndcg@50': 0.3615148259462925, 'ndcg@all': 0.5645432679562478}
Epoch 551 0.11170983620815807 7.642478179931641
Epoch 552 0.11111850539843242 7.633271551132202
Epoch 553 0.11161943036648962 7.6328778743743895
Epoch 554 0.11076355054974556 7.642559576034546
Epoch 555 0.110913445138269 7.630881309509277
Epoch 556 0.11119532502359813 7.637653589248657
Epoch 557 0.11045482009649277 7.639739847183227
Epoch 558 0.11042225534717241 7.6385030269622805
Epoch 559 0.11066984393530421 7.6407708644866945
Epoch 560 0.11038478041688601 7.640829992294312
Epoch 561 0.1100167188

Epoch 687 0.0961205504834652 7.703366899490357
Epoch 688 0.0951374380952782 7.709466505050659
Epoch 689 0.09639393935600916 7.700532913208008
Epoch 690 0.0959256416393651 7.712397480010987
Epoch 691 0.09530509089430174 7.717400312423706
Epoch 692 0.09531877024306191 7.704378509521485
Epoch 693 0.09576510844959153 7.713072919845581
Epoch 694 0.09569648694660929 7.716042804718017
Epoch 695 0.09535803356104428 7.713563060760498
Epoch 696 0.09591751620173454 7.707888841629028
Epoch 697 0.09552560630771849 7.7042285919189455
Epoch 698 0.0949524571498235 7.708773183822632
Epoch 699 0.09494426780276828 7.715491819381714
Epoch 700 0.09524691849946976 7.715918827056885
Time1:  0.25818501971662045
Time2:  3.9582136813551188
Time3:  15.788366299122572
Time1:  0.028829175978899002
Time2:  0.4367510788142681
Time3:  1.7522211894392967

train {'F1@10': 0.18982456624507904, 'F1@30': 0.47601404786109924, 'F1@50': 0.6731919646263123, 'ndcg@10': 0.5026670344228814, 'ndcg@30': 0.6137889785824698, 'ndcg@5

Epoch 820 0.08563120464483896 7.769314575195312
Epoch 821 0.08562528904941347 7.768396043777466
Epoch 822 0.08482945304777888 7.774940538406372
Epoch 823 0.08508239818943872 7.766163778305054
Epoch 824 0.0839565407898691 7.768567132949829
Epoch 825 0.0846867400738928 7.767776679992676
Epoch 826 0.08562084486087164 7.772817850112915
Epoch 827 0.08502113090621101 7.766281509399414
Epoch 828 0.08522918423016866 7.774250793457031
Epoch 829 0.08476383859912555 7.7613283634185795
Epoch 830 0.08449768755171035 7.76769609451294
Epoch 831 0.084861661410994 7.772727537155151
Epoch 832 0.08419377389881345 7.7790850639343265
Epoch 833 0.08495370166169272 7.763153791427612
Epoch 834 0.08432583750949965 7.770180892944336
Epoch 835 0.08477092799213197 7.773558473587036
Epoch 836 0.0840770460665226 7.770386505126953
Epoch 837 0.08436948267949952 7.765375375747681
Epoch 838 0.08436714990271463 7.7691370964050295
Epoch 839 0.08389450990491443 7.7782539367675785
Epoch 840 0.08397633739643627 7.7695452690

Epoch 954 0.07753872805171542 7.829533624649048
Epoch 955 0.07749150875541899 7.830902862548828
Epoch 956 0.07697399556636811 7.821924924850464
Epoch 957 0.07735319187243779 7.828331232070923
Epoch 958 0.07742391659153833 7.839411449432373
Epoch 959 0.07695769055022134 7.829232931137085
Epoch 960 0.07672654257880317 7.828181552886963
Epoch 961 0.07716319503055678 7.822313213348389
Epoch 962 0.07696787276201777 7.825705242156983
Epoch 963 0.07729084731804 7.8267217636108395
Epoch 964 0.07757146532336871 7.8247472763061525
Epoch 965 0.07697880905535485 7.830098724365234
Epoch 966 0.07659748006198142 7.8311878681182865
Epoch 967 0.07691320818331507 7.8331183910369875
Epoch 968 0.0770681384536955 7.824763584136963
Epoch 969 0.07595796386400859 7.827968597412109
Epoch 970 0.07714802614516682 7.829751062393188
Epoch 971 0.07619845122098923 7.837547683715821
Epoch 972 0.07654456082317564 7.830195951461792
Epoch 973 0.07653608032398754 7.827185726165771
Epoch 974 0.07662506815459993 7.83344454

Epoch 1098 0.07043555921150578 7.874982309341431
Epoch 1099 0.07009857735700077 7.879877281188965
Epoch 1100 0.07033725207050641 7.883915662765503
Time1:  0.2579369805753231
Time2:  3.952727947384119
Time3:  15.815279016271234
Time1:  0.02896430343389511
Time2:  0.4299383610486984
Time3:  1.7504175584763288

train {'F1@10': 0.19044393301010132, 'F1@30': 0.48312291502952576, 'F1@50': 0.6824499368667603, 'ndcg@10': 0.5000179078482772, 'ndcg@30': 0.6189961858472898, 'ndcg@50': 0.7260840111752936, 'ndcg@all': 0.7353004243237291}
valid {'F1@10': 0.11361847072839737, 'F1@30': 0.19733542203903198, 'F1@50': 0.230290949344635, 'ndcg@10': 0.35597640099572586, 'ndcg@30': 0.34667964228470854, 'ndcg@50': 0.3625339699799085, 'ndcg@all': 0.5633423658194127}
Epoch 1101 0.07033728878531191 7.881252717971802
Epoch 1102 0.07005002680752012 7.877880668640136
Epoch 1103 0.07104230779740546 7.883528518676758
Epoch 1104 0.0704152384152015 7.881418657302857
Epoch 1105 0.07061330137981309 7.879646110534668
Epo

Epoch 1229 0.06570829111668798 7.917858695983886
Epoch 1230 0.06548140086233616 7.919892454147339
Epoch 1231 0.0658771015289757 7.926320314407349
Epoch 1232 0.06583023162351714 7.915860509872436
Epoch 1233 0.0655169507695569 7.925847959518433
Epoch 1234 0.06509752708176772 7.919323110580445
Epoch 1235 0.06534259712530507 7.931047630310059
Epoch 1236 0.06544499529732598 7.921488809585571
Epoch 1237 0.06548708346154955 7.926074552536011
Epoch 1238 0.06530274165173372 7.9256895065307615
Epoch 1239 0.06531566662920846 7.92656683921814
Epoch 1240 0.06523455513848199 7.919151782989502
Epoch 1241 0.06509030800726678 7.933512020111084
Epoch 1242 0.06481446545157167 7.926430130004883
Epoch 1243 0.06500975067416827 7.925202512741089
Epoch 1244 0.06513366753028499 7.92923641204834
Epoch 1245 0.06533324159681797 7.926547479629517
Epoch 1246 0.0653295091042916 7.933751153945923
Epoch 1247 0.06498683666189511 7.9243511199951175
Epoch 1248 0.06513426825404167 7.9368894577026365
Epoch 1249 0.064633780

Epoch 1359 0.06152565090192689 7.964690446853638
Epoch 1360 0.06139019504189491 7.977547311782837
Epoch 1361 0.06138296516405212 7.98258228302002
Epoch 1362 0.06127396271460586 7.969602108001709
Epoch 1363 0.06108328054348628 7.971360683441162
Epoch 1364 0.060950702304641405 7.979902172088623
Epoch 1365 0.06121253110468387 7.976155424118042
Epoch 1366 0.0607494062019719 7.971653699874878
Epoch 1367 0.06134419014884366 7.978368043899536
Epoch 1368 0.06113534929851691 7.976303434371948
Epoch 1369 0.061108754533860415 7.978564691543579
Epoch 1370 0.06135344583955076 7.971100521087647
Epoch 1371 0.061686068028211596 7.970548343658447
Epoch 1372 0.06077036903136306 7.973257541656494
Epoch 1373 0.06061430701778995 7.967184352874756
Epoch 1374 0.06054247890909513 7.978722763061524
Epoch 1375 0.061443879620896445 7.979366779327393
Epoch 1376 0.06103732830120458 7.9812733173370365
Epoch 1377 0.06089255230294333 7.974730968475342
Epoch 1378 0.06116827059951094 7.98211407661438
Epoch 1379 0.06059

Time1:  0.25773273780941963
Time2:  3.9559650644659996
Time3:  15.805981867015362
Time1:  0.028937162831425667
Time2:  0.4338857810944319
Time3:  1.7509114127606153

train {'F1@10': 0.18404600024223328, 'F1@30': 0.47931045293807983, 'F1@50': 0.6860391497612, 'ndcg@10': 0.48177726618433836, 'ndcg@30': 0.6095537344567702, 'ndcg@50': 0.721441958429037, 'ndcg@all': 0.7278994474580741}
valid {'F1@10': 0.11040957272052765, 'F1@30': 0.1950213462114334, 'F1@50': 0.22837059199810028, 'ndcg@10': 0.345715535492324, 'ndcg@30': 0.3406699866520705, 'ndcg@50': 0.3574648274470482, 'ndcg@all': 0.5585940971567824}
Epoch 1501 0.05807926683790154 8.0274151802063
Epoch 1502 0.05805324133899477 8.007538223266602
Epoch 1503 0.05760491270985868 8.024711275100708
Epoch 1504 0.05777031232913335 8.017338275909424
Epoch 1505 0.05749732823007637 8.022352457046509
Epoch 1506 0.05799436217380895 8.017513656616211
Epoch 1507 0.057047044734160106 8.02062931060791
Epoch 1508 0.05757047488457627 8.021351289749145
Epoch 

Epoch 1632 0.05481370753712124 8.047397708892822
Epoch 1633 0.054594740519920984 8.054189491271973
Epoch 1634 0.05485740999380748 8.0552649974823
Epoch 1635 0.05447383692695035 8.065141630172729
Epoch 1636 0.0546534033285247 8.06380581855774
Epoch 1637 0.055078980326652524 8.062771987915038
Epoch 1638 0.05453979116347101 8.053283786773681
Epoch 1639 0.05420392925540606 8.05304217338562
Epoch 1640 0.05465914756059646 8.05260148048401
Epoch 1641 0.05458404007885191 8.060881757736206
Epoch 1642 0.05481058628194862 8.060192823410034
Epoch 1643 0.054619802824325034 8.054856729507446
Epoch 1644 0.05441303741600778 8.05541534423828
Epoch 1645 0.05407781112525198 8.059922122955323
Epoch 1646 0.05420096516609192 8.06683316230774
Epoch 1647 0.05423528667953279 8.058914136886596
Epoch 1648 0.054380225431587964 8.050005292892456
Epoch 1649 0.0541789165387551 8.053609943389892
Epoch 1650 0.05424917286468877 8.07479019165039
Time1:  0.2587531115859747
Time2:  3.970171980559826
Time3:  15.82747696712

Epoch 1763 0.051908071297738285 8.096971988677979
Epoch 1764 0.051861854394276936 8.096527910232544
Epoch 1765 0.05188369630939431 8.092075300216674
Epoch 1766 0.05218850320412053 8.08618655204773
Epoch 1767 0.05173533451226023 8.093121433258057
Epoch 1768 0.052349773877196845 8.100607585906982
Epoch 1769 0.05162763227191236 8.096926832199097
Epoch 1770 0.052007083470622696 8.098446702957153
Epoch 1771 0.05154583607282903 8.100353240966797
Epoch 1772 0.052206295190585984 8.103116750717163
Epoch 1773 0.05176212746236059 8.104344177246094
Epoch 1774 0.05146613675687048 8.098731088638306
Epoch 1775 0.051727672749095496 8.0969744682312
Epoch 1776 0.05154489771359497 8.093671703338623
Epoch 1777 0.051962752847207916 8.092840099334717
Epoch 1778 0.0513975165784359 8.095728445053101
Epoch 1779 0.05142304338514805 8.106065702438354
Epoch 1780 0.0521154144157966 8.091757011413574
Epoch 1781 0.0516846127808094 8.107407760620116
Epoch 1782 0.0511818875455194 8.108851528167724
Epoch 1783 0.0513478

Epoch 1901 0.04956140919691986 8.12650089263916
Epoch 1902 0.04982141852378845 8.13837308883667
Epoch 1903 0.04958527634541194 8.135844039916993
Epoch 1904 0.049464289678467645 8.134235906600953
Epoch 1905 0.049129393614000746 8.135094356536865
Epoch 1906 0.04951608131329219 8.129247331619263
Epoch 1907 0.04967779620654053 8.13455309867859
Epoch 1908 0.049540828127000065 8.135205030441284
Epoch 1909 0.04941839679247803 8.131310272216798
Epoch 1910 0.049953172852595644 8.138406372070312
Epoch 1911 0.04908854998648167 8.128668069839478
Epoch 1912 0.049260026257899076 8.137015056610107
Epoch 1913 0.049755842776762114 8.12500615119934
Epoch 1914 0.049173663432399435 8.137998723983765
Epoch 1915 0.049327744874689314 8.127329587936401
Epoch 1916 0.04939284457100762 8.127794361114502
Epoch 1917 0.04946137223806646 8.128180170059204
Epoch 1918 0.04913751607139905 8.138126993179322
Epoch 1919 0.04899077212644948 8.141171932220459
Epoch 1920 0.0492030221141047 8.138358402252198
Epoch 1921 0.0490

Epoch 2044 0.04719983786344528 8.158564615249634
Epoch 2045 0.04757089917030599 8.180843210220337
Epoch 2046 0.04723801894320382 8.174518966674805
Epoch 2047 0.04707379076215956 8.172006845474243
Epoch 2048 0.04711380253235499 8.167097568511963
Epoch 2049 0.04697791453864839 8.157781219482422
Epoch 2050 0.04723312295973301 8.16982855796814
Time1:  0.257731419056654
Time2:  3.987428905442357
Time3:  15.859078647568822
Time1:  0.028839215636253357
Time2:  0.4363979045301676
Time3:  1.7618069592863321

train {'F1@10': 0.1917983889579773, 'F1@30': 0.49090149998664856, 'F1@50': 0.6910829544067383, 'ndcg@10': 0.49862930421842816, 'ndcg@30': 0.6260889360882151, 'ndcg@50': 0.732974677595483, 'ndcg@all': 0.7360143055225639}
valid {'F1@10': 0.11290538311004639, 'F1@30': 0.19843515753746033, 'F1@50': 0.23051467537879944, 'ndcg@10': 0.3529596682451479, 'ndcg@30': 0.34684969726632964, 'ndcg@50': 0.362909586393499, 'ndcg@all': 0.5621704472316517}
Epoch 2051 0.047328190753857295 8.172993469238282
Epo

Epoch 2175 0.04529775782591767 8.205718898773194
Epoch 2176 0.045459727322061855 8.198348569869996
Epoch 2177 0.04551427144971159 8.200740623474122
Epoch 2178 0.04524165495402283 8.210577964782715
Epoch 2179 0.04590280449224843 8.19968376159668
Epoch 2180 0.045515247931083046 8.203929376602172
Epoch 2181 0.04524252890712685 8.204383611679077
Epoch 2182 0.04567943906618489 8.201511526107788
Epoch 2183 0.045246903391347994 8.21257905960083
Epoch 2184 0.045531070232391356 8.211217641830444
Epoch 2185 0.04569129397471746 8.203909969329834
Epoch 2186 0.044868792303734355 8.205661296844482
Epoch 2187 0.044903086374203365 8.22233510017395
Epoch 2188 0.04530925059484111 8.197622966766357
Epoch 2189 0.04443679629928536 8.20797781944275
Epoch 2190 0.044967762836151656 8.213773632049561
Epoch 2191 0.04538362796107928 8.213127279281617
Epoch 2192 0.04522518436941836 8.211595916748047
Epoch 2193 0.045071810401148264 8.212260341644287
Epoch 2194 0.04489904186791844 8.216304302215576
Epoch 2195 0.045

Epoch 2305 0.04418833119173845 8.239878845214843
Epoch 2306 0.04395017189284166 8.234713077545166
Epoch 2307 0.044324978275431526 8.221443557739258
Epoch 2308 0.043731244819031824 8.240064430236817
Epoch 2309 0.043963765435748633 8.229146432876586
Epoch 2310 0.043781837531261976 8.242859601974487
Epoch 2311 0.04386407786773311 8.242995643615723
Epoch 2312 0.04373742582069503 8.242847967147828
Epoch 2313 0.04367152079939842 8.235160303115844
Epoch 2314 0.0437319861104091 8.235620021820068
Epoch 2315 0.04369467964602841 8.237807607650756
Epoch 2316 0.043814382867680655 8.235635709762573
Epoch 2317 0.04365226324233744 8.230633306503297
Epoch 2318 0.04372824012405342 8.235096549987793
Epoch 2319 0.0435998825977246 8.240533304214477
Epoch 2320 0.04308101228541798 8.238311767578125
Epoch 2321 0.04349744444092115 8.236018514633178
Epoch 2322 0.0438560142285294 8.234566974639893
Epoch 2323 0.043305272112290065 8.243887615203857
Epoch 2324 0.04333281384574043 8.249486684799194
Epoch 2325 0.0432

Epoch 2448 0.04221036384503047 8.268073749542236
Epoch 2449 0.04232171575228373 8.269210052490234
Epoch 2450 0.04199278416732947 8.27009449005127
Time1:  0.25814448669552803
Time2:  3.9510425087064505
Time3:  15.78801561333239
Time1:  0.028768587857484818
Time2:  0.43122883327305317
Time3:  1.750363763421774

train {'F1@10': 0.19515050947666168, 'F1@30': 0.4967048764228821, 'F1@50': 0.6933062076568604, 'ndcg@10': 0.5047163648101956, 'ndcg@30': 0.6328309672210237, 'ndcg@50': 0.7378149421622477, 'ndcg@all': 0.7391839355843616}
valid {'F1@10': 0.11569830030202866, 'F1@30': 0.2001764178276062, 'F1@50': 0.23120447993278503, 'ndcg@10': 0.3577080105428581, 'ndcg@30': 0.34984712970365694, 'ndcg@50': 0.365086049048437, 'ndcg@all': 0.5637060655010594}
Epoch 2451 0.04217396639287472 8.275912857055664
Epoch 2452 0.042263862821790905 8.275386857986451
Epoch 2453 0.04196529876854685 8.273436498641967
Epoch 2454 0.0420437999897533 8.278204488754273
Epoch 2455 0.041922230314877296 8.261780881881714
Ep

Epoch 2578 0.04074928640491433 8.300139093399048
Epoch 2579 0.041135667181677285 8.297080707550048
Epoch 2580 0.04044537747071849 8.30473780632019
Epoch 2581 0.04078024009035693 8.303091764450073
Epoch 2582 0.0406185091783603 8.303248882293701
Epoch 2583 0.040359294538696605 8.307743406295776
Epoch 2584 0.04064930540819963 8.307206678390504
Epoch 2585 0.04062437601387501 8.303932857513427
Epoch 2586 0.04040813023845355 8.304090309143067
Epoch 2587 0.04053335069782204 8.30728759765625
Epoch 2588 0.04041852951049805 8.308993291854858
Epoch 2589 0.04062200213472048 8.298439073562623
Epoch 2590 0.040562180222736464 8.30800395011902
Epoch 2591 0.04070603305266963 8.30535192489624
Epoch 2592 0.04076857707566685 8.31609649658203
Epoch 2593 0.04092179102202256 8.2955472946167
Epoch 2594 0.040516122885876235 8.307344388961791
Epoch 2595 0.04047212509645356 8.30036883354187
Epoch 2596 0.0403059755348497 8.306413221359254
Epoch 2597 0.04076293698615498 8.315699911117553
Epoch 2598 0.0409014327244

Epoch 2709 0.04027576405141089 8.324152565002441
Epoch 2710 0.03960717341138257 8.323559951782226
Epoch 2711 0.03959749920500649 8.319403028488159
Epoch 2712 0.039517122672663794 8.325510454177856
Epoch 2713 0.03962590971754657 8.33189115524292
Epoch 2714 0.03903299358983835 8.337518882751464
Epoch 2715 0.039659280454119046 8.328873825073241
Epoch 2716 0.039195041524039374 8.332196903228759
Epoch 2717 0.03936241397427188 8.320109319686889
Epoch 2718 0.03984762599898709 8.3233473777771
Epoch 2719 0.039375444501638414 8.315608882904053
Epoch 2720 0.03923661009305053 8.334021520614623
Epoch 2721 0.039311027899384496 8.323015451431274
Epoch 2722 0.03933563571837213 8.3281888961792
Epoch 2723 0.03948150227467219 8.332848978042602
Epoch 2724 0.03942407936685615 8.326542949676513
Epoch 2725 0.039764591058095294 8.343820476531983
Epoch 2726 0.03914614818576309 8.341254091262817
Epoch 2727 0.039729216943184535 8.348093366622924
Epoch 2728 0.03923847795360618 8.338160610198974
Epoch 2729 0.03961

Time2:  3.9495657347142696
Time3:  15.843152113258839
Time1:  0.02894400805234909
Time2:  0.42987013049423695
Time3:  1.7538739405572414

train {'F1@10': 0.19297589361667633, 'F1@30': 0.4961628317832947, 'F1@50': 0.6945581436157227, 'ndcg@10': 0.4982812706194611, 'ndcg@30': 0.631062141145363, 'ndcg@50': 0.7366128322263316, 'ndcg@all': 0.7369158308136715}
valid {'F1@10': 0.11358875781297684, 'F1@30': 0.19948908686637878, 'F1@50': 0.23150280117988586, 'ndcg@10': 0.3523354086208272, 'ndcg@30': 0.3470310129624259, 'ndcg@50': 0.36302839442938417, 'ndcg@all': 0.5614131557213532}
Epoch 2851 0.03846247738434209 8.357382440567017
Epoch 2852 0.03839828388558494 8.362364292144775
Epoch 2853 0.038507869467139244 8.362634944915772
Epoch 2854 0.03844217786358462 8.3648024559021
Epoch 2855 0.03831714532441563 8.362031745910645
Epoch 2856 0.03868921659886837 8.35427646636963
Epoch 2857 0.0382121125029193 8.36125135421753
Epoch 2858 0.03849299405184057 8.364939308166504
Epoch 2859 0.03822083891265922 8

Epoch 2982 0.0373837838984198 8.383244514465332
Epoch 2983 0.0370822061267164 8.386211585998534
Epoch 2984 0.03696215817083915 8.401716804504394
Epoch 2985 0.03746652656959163 8.397342586517334
Epoch 2986 0.03728695712569687 8.398602151870728
Epoch 2987 0.0370847100391984 8.389498567581176
Epoch 2988 0.037441081491609414 8.385215520858765
Epoch 2989 0.03690579757094383 8.399203634262085
Epoch 2990 0.03707849697934257 8.381817245483399
Epoch 2991 0.037285484663314286 8.38159646987915
Epoch 2992 0.03731967769563198 8.38449935913086
Epoch 2993 0.03727425667974684 8.390125942230224
Epoch 2994 0.037141128298309116 8.386331129074097
Epoch 2995 0.03730278785030047 8.388624572753907
Epoch 2996 0.037145440735750726 8.38301224708557
Epoch 2997 0.036698510787553255 8.381248092651367
Epoch 2998 0.03707855112022824 8.397642612457275
Epoch 2999 0.03686467972066668 8.391482973098755
Epoch 3000 0.036932339767615 8.390386867523194
Time1:  0.258599903434515
Time2:  3.96446780487895
Time3:  15.8228374589

Epoch 3113 0.036263931356370446 8.41009979248047
Epoch 3114 0.03588178459968832 8.4243634223938
Epoch 3115 0.036324537897275556 8.412762308120728
Epoch 3116 0.03609855824874507 8.410626411437988
Epoch 3117 0.03662786678307586 8.40910291671753
Epoch 3118 0.036488118146856624 8.425509452819824
Epoch 3119 0.03663442002402412 8.417601490020752
Epoch 3120 0.03614295725193289 8.409615421295166
Epoch 3121 0.036061400009526144 8.420972776412963
Epoch 3122 0.0364888582792547 8.423680782318115
Epoch 3123 0.036240094817346996 8.422263717651367
Epoch 3124 0.03605532294346227 8.411875486373901
Epoch 3125 0.036243647171391384 8.404303884506225
Epoch 3126 0.036610224097967145 8.405243158340454
Epoch 3127 0.036008361768391395 8.418468046188355
Epoch 3128 0.036015441889564195 8.420290470123291
Epoch 3129 0.03624295954489046 8.421184873580932
Epoch 3130 0.036372987346516714 8.420170211791993
Epoch 3131 0.03582191111312972 8.42524333000183
Epoch 3132 0.036192641448643474 8.420343780517578
Epoch 3133 0.03

Epoch 3251 0.03499844279140234 8.442130899429321
Epoch 3252 0.03515518053124348 8.448317766189575
Epoch 3253 0.03543362439506584 8.438825178146363
Epoch 3254 0.035231944690975875 8.443594884872436
Epoch 3255 0.0355472330417898 8.44919924736023
Epoch 3256 0.035183458175096244 8.438719177246094
Epoch 3257 0.035154405339724484 8.4376051902771
Epoch 3258 0.035126412121786015 8.443004274368286
Epoch 3259 0.035376689376102555 8.449206876754761
Epoch 3260 0.035443693337341146 8.456278562545776
Epoch 3261 0.035327299704982176 8.438163900375367
Epoch 3262 0.035286306734714244 8.446435260772706
Epoch 3263 0.034913062552611035 8.455647563934326
Epoch 3264 0.03501827472613917 8.442747974395752
Epoch 3265 0.03510732439657052 8.4476900100708
Epoch 3266 0.03521393771386809 8.444510364532471
Epoch 3267 0.0354917627448837 8.438886594772338
Epoch 3268 0.03537704975654681 8.436422395706177
Epoch 3269 0.03518711705174711 8.439616489410401
Epoch 3270 0.0347231385194593 8.4345871925354
Epoch 3271 0.03535297

Epoch 3394 0.0343845259398222 8.466054916381836
Epoch 3395 0.03442471716552973 8.464006185531616
Epoch 3396 0.03419158411108785 8.46919388771057
Epoch 3397 0.03454421659310659 8.464384937286377
Epoch 3398 0.03430006771037976 8.470265960693359
Epoch 3399 0.03468115675366587 8.470146179199219
Epoch 3400 0.03449479041414128 8.482157325744629
Time1:  0.2583376746624708
Time2:  4.100243708118796
Time3:  15.850954441353679
Time1:  0.028795059770345688
Time2:  0.43242539279162884
Time3:  1.7533062510192394

train {'F1@10': 0.20044921338558197, 'F1@30': 0.5062450170516968, 'F1@50': 0.6959705948829651, 'ndcg@10': 0.5148175468952193, 'ndcg@30': 0.6444090562548335, 'ndcg@50': 0.7449829925295044, 'ndcg@all': 0.7442559589367688}
valid {'F1@10': 0.11706504970788956, 'F1@30': 0.20292577147483826, 'F1@50': 0.2326587438583374, 'ndcg@10': 0.3606207027072991, 'ndcg@30': 0.3534839828848257, 'ndcg@50': 0.36796905820512466, 'ndcg@all': 0.5653216243237704}
Epoch 3401 0.034408250802920926 8.47313084602356
Epo

Epoch 3524 0.03382171889146169 8.489439392089844
Epoch 3525 0.03346931737744146 8.491001796722411
Epoch 3526 0.033842636437879665 8.47547812461853
Epoch 3527 0.033377928307486905 8.503305625915527
Epoch 3528 0.033425181669493514 8.493395566940308
Epoch 3529 0.03351460078524219 8.4822518825531
Epoch 3530 0.03357071460535129 8.489150524139404
Epoch 3531 0.033310733797649546 8.501464080810546
Epoch 3532 0.033541945761276615 8.490343618392945
Epoch 3533 0.03365327604115009 8.497290086746215
Epoch 3534 0.0337449110009604 8.48912010192871
Epoch 3535 0.033510204404592514 8.491716861724854
Epoch 3536 0.03353272517108255 8.496157836914062
Epoch 3537 0.03370495987021261 8.49480094909668
Epoch 3538 0.03379914671596554 8.487381172180175
Epoch 3539 0.03306470902429687 8.503437900543213
Epoch 3540 0.03393990002158615 8.494452857971192
Epoch 3541 0.033272679377761155 8.4967435836792
Epoch 3542 0.03366216808143589 8.500555944442748
Epoch 3543 0.03333664083232482 8.499409341812134
Epoch 3544 0.03348405

Epoch 3655 0.03260047650999493 8.516926574707032
Epoch 3656 0.032713050684995124 8.508572769165038
Epoch 3657 0.03260867109315263 8.5145339012146
Epoch 3658 0.03272403729044729 8.503111791610717
Epoch 3659 0.03315440643992689 8.507572746276855
Epoch 3660 0.03291873859448565 8.514738988876342
Epoch 3661 0.033081251507004104 8.51312689781189
Epoch 3662 0.03287842869758606 8.511976051330567
Epoch 3663 0.032727877609431744 8.52460594177246
Epoch 3664 0.03285763572073645 8.514109563827514
Epoch 3665 0.0328626517413391 8.517494678497314
Epoch 3666 0.032814743535386194 8.509931612014771
Epoch 3667 0.0327724470032586 8.511334419250488
Epoch 3668 0.0327420261171129 8.518805646896363
Epoch 3669 0.03269073596845071 8.511245155334473
Epoch 3670 0.03239493142399523 8.524007940292359
Epoch 3671 0.033040659750501314 8.52566409111023
Epoch 3672 0.03282905810823043 8.51654281616211
Epoch 3673 0.03281351890828874 8.510879421234131
Epoch 3674 0.032400057692494655 8.524649667739869
Epoch 3675 0.0327567916

Epoch 3798 0.03221568003710773 8.532896900177002
Epoch 3799 0.03196507710963488 8.540262699127197
Epoch 3800 0.03199270570443736 8.556234169006348
Time1:  0.2588244993239641
Time2:  3.958852520212531
Time3:  15.816060757264495
Time1:  0.028908811509609222
Time2:  0.42956038378179073
Time3:  1.7511030156165361

train {'F1@10': 0.20008507370948792, 'F1@30': 0.5073238611221313, 'F1@50': 0.6966419816017151, 'ndcg@10': 0.5116300474152019, 'ndcg@30': 0.6437551761006686, 'ndcg@50': 0.7442952172480548, 'ndcg@all': 0.7430692803790094}
valid {'F1@10': 0.11602514237165451, 'F1@30': 0.20219260454177856, 'F1@50': 0.23321805894374847, 'ndcg@10': 0.35687878253954547, 'ndcg@30': 0.3513423206566183, 'ndcg@50': 0.36684806897744043, 'ndcg@all': 0.563911886171655}
Epoch 3801 0.031786419720285466 8.538947868347169
Epoch 3802 0.031918244084550275 8.536634588241578
Epoch 3803 0.031956496751970716 8.543809747695922
Epoch 3804 0.03215208078424136 8.543517446517944
Epoch 3805 0.032029819592005675 8.549917125701

Epoch 3928 0.031206132223208744 8.569773101806641
Epoch 3929 0.03156344203485383 8.566259574890136
Epoch 3930 0.031338151233891644 8.561782169342042
Epoch 3931 0.03136343904253509 8.563489246368409
Epoch 3932 0.031664244934088655 8.555794477462769
Epoch 3933 0.03116191489001115 8.566700267791749
Epoch 3934 0.03117789886891842 8.5615629196167
Epoch 3935 0.0312536205475529 8.56364164352417
Epoch 3936 0.031715213155580894 8.560491275787353
Epoch 3937 0.03127106978661484 8.55902681350708
Epoch 3938 0.03117505752791961 8.570997667312621
Epoch 3939 0.03133691706591182 8.555479574203492
Epoch 3940 0.0312006798469358 8.572279214859009
Epoch 3941 0.031248139527936775 8.560316181182861
Epoch 3942 0.031367909991078906 8.565645503997803
Epoch 3943 0.03154155713402563 8.564314937591552
Epoch 3944 0.03162404029733605 8.573010540008545
Epoch 3945 0.03145084970941146 8.565627670288086
Epoch 3946 0.03129514668964677 8.563296604156495
Epoch 3947 0.03146428999801477 8.565271759033203
Epoch 3948 0.0316655

Epoch 4058 0.030762385266522568 8.595937538146973
Epoch 4059 0.030733595188293193 8.592056894302369
Epoch 4060 0.030910989145437877 8.586308574676513
Epoch 4061 0.030487562405566375 8.592395877838134
Epoch 4062 0.030673717748787667 8.591708421707153
Epoch 4063 0.030746397686501344 8.583958148956299
Epoch 4064 0.030487447708017295 8.579886054992675
Epoch 4065 0.03076930640058385 8.59314022064209
Epoch 4066 0.030623730127182273 8.581980323791504
Epoch 4067 0.030782048880226082 8.587394523620606
Epoch 4068 0.030680980160832406 8.594218206405639
Epoch 4069 0.030666117887530063 8.582756423950196
Epoch 4070 0.030720357514090007 8.589304685592651
Epoch 4071 0.03088891798009475 8.583889436721801
Epoch 4072 0.030918517377641466 8.597157430648803
Epoch 4073 0.030745238272680178 8.590699481964112
Epoch 4074 0.03084465784745084 8.584955787658691
Epoch 4075 0.03066112036920256 8.580272102355957
Epoch 4076 0.030726483940250345 8.58911199569702
Epoch 4077 0.030629398549596468 8.589999961853028
Epoch 

Epoch 4199 0.030201663739151424 8.602332973480225
Epoch 4200 0.030251958614422217 8.60368766784668
Time1:  0.2581804059445858
Time2:  3.951729716733098
Time3:  15.749070849269629
Time1:  0.028965070843696594
Time2:  0.43275645188987255
Time3:  1.7461068090051413

train {'F1@10': 0.19188688695430756, 'F1@30': 0.4976481795310974, 'F1@50': 0.6968087553977966, 'ndcg@10': 0.4904156559136581, 'ndcg@30': 0.6286639450122089, 'ndcg@50': 0.7349337956639166, 'ndcg@all': 0.7334633661804926}
valid {'F1@10': 0.11269740015268326, 'F1@30': 0.19880172610282898, 'F1@50': 0.23114857077598572, 'ndcg@10': 0.348894155735301, 'ndcg@30': 0.3447589591729107, 'ndcg@50': 0.36135055409377076, 'ndcg@all': 0.5592978796711913}
Epoch 4201 0.030241620499226782 8.61372241973877
Epoch 4202 0.030024382265077698 8.620417404174805
Epoch 4203 0.030255760749181113 8.609473705291748
Epoch 4204 0.030093072023656634 8.613524436950684
Epoch 4205 0.030013196356594562 8.614064121246338
Epoch 4206 0.03020143181913429 8.612858533859

Epoch 4328 0.029871476689974467 8.63325548171997
Epoch 4329 0.029457109545667968 8.615801239013672
Epoch 4330 0.029346745419833394 8.62168092727661
Epoch 4331 0.029275527141160437 8.62207908630371
Epoch 4332 0.02937869005319145 8.623494720458984
Epoch 4333 0.029574912496738964 8.636596155166625
Epoch 4334 0.02941345243404309 8.632451820373536
Epoch 4335 0.02934232217570146 8.627283477783203
Epoch 4336 0.02968860591451327 8.623214960098267
Epoch 4337 0.029617900815274982 8.628540515899658
Epoch 4338 0.029305137486921415 8.633312034606934
Epoch 4339 0.029372248198423122 8.636498355865479
Epoch 4340 0.02934088415155808 8.636508464813232
Epoch 4341 0.02949589751660824 8.6330397605896
Epoch 4342 0.02950546582125955 8.63449649810791
Epoch 4343 0.029540154006746082 8.648384380340577
Epoch 4344 0.02990237068798807 8.639788818359374
Epoch 4345 0.02937744804140594 8.635269832611083
Epoch 4346 0.029723032067219417 8.6395854473114
Epoch 4347 0.029457231611013414 8.640587520599365
Epoch 4348 0.0296

Epoch 4458 0.02901214085933235 8.650586318969726
Epoch 4459 0.028946352005004884 8.64810562133789
Epoch 4460 0.02885147819502486 8.6474365234375
Epoch 4461 0.02902298480686214 8.642538928985596
Epoch 4462 0.028887384798791674 8.656278800964355
Epoch 4463 0.029044922958645556 8.652577590942382
Epoch 4464 0.02918148955537213 8.650642681121827
Epoch 4465 0.029065911223491035 8.645921611785889
Epoch 4466 0.02900791441400846 8.657538604736327
Epoch 4467 0.028979530785646704 8.656515502929688
Epoch 4468 0.028901776795585952 8.650046014785767
Epoch 4469 0.02904153244776858 8.658582735061646
Epoch 4470 0.028727555688884524 8.656681728363036
Epoch 4471 0.028704367246892716 8.657898664474487
Epoch 4472 0.028653240348729823 8.666394329071045
Epoch 4473 0.029104994299511116 8.654007053375244
Epoch 4474 0.028739392571151257 8.655403184890748
Epoch 4475 0.02891493828760253 8.650689125061035
Epoch 4476 0.028717754843334355 8.650680875778198
Epoch 4477 0.029239445303877194 8.654525709152221
Epoch 4478

Epoch 4600 0.028524549740056197 8.664960670471192
Time1:  0.2578737922012806
Time2:  3.9535233974456787
Time3:  15.82808637432754
Time1:  0.028924720361828804
Time2:  0.42919897101819515
Time3:  1.7489221654832363

train {'F1@10': 0.19237352907657623, 'F1@30': 0.4975413382053375, 'F1@50': 0.6970452666282654, 'ndcg@10': 0.4912356136565628, 'ndcg@30': 0.6287114135181331, 'ndcg@50': 0.7353134398482267, 'ndcg@all': 0.7335930843589966}
valid {'F1@10': 0.11290538311004639, 'F1@30': 0.1986413598060608, 'F1@50': 0.23032821714878082, 'ndcg@10': 0.34847512748465315, 'ndcg@30': 0.3444881222907219, 'ndcg@50': 0.36071296673998965, 'ndcg@all': 0.5591648723091092}
Epoch 4601 0.02862893758962552 8.671133947372436
Epoch 4602 0.028553549986746578 8.675738096237183
Epoch 4603 0.0286311874166131 8.670692014694215
Epoch 4604 0.028409454226493835 8.674978399276734
Epoch 4605 0.028557978239324357 8.67033805847168
Epoch 4606 0.028608204196724628 8.664687633514404
Epoch 4607 0.028582812514570025 8.680933475494

Epoch 4730 0.028026378610067897 8.694759082794189
Epoch 4731 0.02777672770122687 8.690771484375
Epoch 4732 0.02786417535195748 8.685422325134278
Epoch 4733 0.027732363177670374 8.694517040252686
Epoch 4734 0.02757557239383459 8.696991395950317
Epoch 4735 0.028157026258607706 8.69846978187561
Epoch 4736 0.027807343213094606 8.703439474105835
Epoch 4737 0.028072274724642434 8.69614224433899
Epoch 4738 0.027929879642195173 8.700318145751954
Epoch 4739 0.027998955527113543 8.682997512817384
Epoch 4740 0.027899964981608922 8.699979972839355
Epoch 4741 0.027925808500084612 8.685987281799317
Epoch 4742 0.027911795862019063 8.690841484069825
Epoch 4743 0.02792750605278545 8.69177770614624
Epoch 4744 0.02808500298609336 8.689690685272216
Epoch 4745 0.027796277797056568 8.703771591186523
Epoch 4746 0.02799513971226083 8.699297142028808
Epoch 4747 0.027834707270893787 8.701103210449219
Epoch 4748 0.02792649625076188 8.706213331222534
Epoch 4749 0.027984517916209168 8.695131015777587
Epoch 4750 0.

Epoch 4859 0.027414239301449723 8.719456911087036
Epoch 4860 0.027584816743102338 8.715330600738525
Epoch 4861 0.027480417241652805 8.729506683349609
Epoch 4862 0.027578793300522698 8.723277568817139
Epoch 4863 0.027550171626110873 8.717992305755615
Epoch 4864 0.02758892306851016 8.70523681640625
Epoch 4865 0.027740880598624548 8.710594749450683
Epoch 4866 0.027464465859035652 8.70513849258423
Epoch 4867 0.02743535985549291 8.708373975753783
Epoch 4868 0.02754152226779196 8.714540576934814
Epoch 4869 0.027433682212399113 8.714008045196532
Epoch 4870 0.02766302964753575 8.718526840209961
Epoch 4871 0.027396050799224112 8.720158433914184
Epoch 4872 0.027252088776893085 8.711903381347657
Epoch 4873 0.02737470469954941 8.725405883789062
Epoch 4874 0.027288570441305636 8.719518804550171
Epoch 4875 0.027437323269744716 8.714760494232177
Epoch 4876 0.027644457730154195 8.712922048568725
Epoch 4877 0.027287240905894173 8.721023988723754
Epoch 4878 0.02751271403912041 8.705244255065917
Epoch 48

In [55]:
def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
    """
    ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: loss value, a torch.Tensor
    """
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    mask = y_true == padded_value_indicator
    y_pred[mask] = float('-inf')
    y_true[mask] = float('-inf')

    preds_smax = F.softmax(y_pred, dim=1)
    true_smax = F.softmax(y_true, dim=1)

    preds_smax = preds_smax + eps
    preds_log = torch.log(preds_smax)

    return torch.mean(-torch.sum(true_smax * preds_log, dim=1))

In [56]:
# setting
lr = 0.05
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 5000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
#         loss = criterion(pred, target)
#         loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        loss = listNet(pred, target)
    
        # Model backwarding
        model.zero_grad()
        loss.backward()
#         torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        
#         loss = criterion(pred, target)
#         loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        loss = listNet(pred, target)
    
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)
            


  0%|          | 0/5000 [00:00<?, ?it/s]

Epoch 0 8.282003551059299 7.740373373031616
Time1:  0.25538268126547337
Time2:  3.970807919278741
Time3:  15.84010280482471
Time1:  0.028463536873459816
Time2:  0.4305055998265743
Time3:  1.750818321481347

train {'F1@10': 0.04074257239699364, 'F1@30': 0.05158373713493347, 'F1@50': 0.05294894054532051, 'ndcg@10': 0.14633899610564713, 'ndcg@30': 0.1295833511801967, 'ndcg@50': 0.12889668041308233, 'ndcg@all': 0.35752222899500863}
valid {'F1@10': 0.038714658468961716, 'F1@30': 0.049900904297828674, 'F1@50': 0.05214732512831688, 'ndcg@10': 0.13092860741703002, 'ndcg@30': 0.11550609224724832, 'ndcg@50': 0.11484438787867969, 'ndcg@all': 0.3468204210172067}
Epoch 1 6.985731659995185 7.194931888580323
Epoch 2 6.092027176751031 6.780399131774902
Epoch 3 5.346906359990438 6.456426000595092
Epoch 4 4.7333599037594265 6.172033596038818
Epoch 5 4.207616909344991 5.953709936141967
Epoch 6 3.7671866178512574 5.7682637691497805
Epoch 7 3.3899072064293754 5.619031810760498
Epoch 8 3.0724344041612413 5.

Epoch 139 1.263231109910541 4.2111776113510135
Epoch 140 1.2632724603017171 4.2101717710495
Epoch 141 1.2623099015818702 4.2047199010849
Epoch 142 1.2624440020985073 4.201068496704101
Epoch 143 1.2615688337220086 4.201591038703919
Epoch 144 1.2610105521149106 4.203146052360535
Epoch 145 1.2602436754438613 4.199316692352295
Epoch 146 1.260306899415122 4.195049381256103
Epoch 147 1.259492959578832 4.19432532787323
Epoch 148 1.2589012689060635 4.19471378326416
Epoch 149 1.2583658854166666 4.193408131599426
Epoch 150 1.2579590035809411 4.194846963882446
Time1:  0.2554880306124687
Time2:  3.967317884787917
Time3:  15.88158598728478
Time1:  0.028435220941901207
Time2:  0.43453568033874035
Time3:  1.756269283592701

train {'F1@10': 0.13561570644378662, 'F1@30': 0.15707826614379883, 'F1@50': 0.15281274914741516, 'ndcg@10': 0.6535564544990874, 'ndcg@30': 0.5359639566180372, 'ndcg@50': 0.5172786758932906, 'ndcg@all': 0.6923878856865198}
valid {'F1@10': 0.11341048032045364, 'F1@30': 0.14147754013

Epoch 276 1.2259035262796614 4.0658392906188965
Epoch 277 1.2257415937052833 4.068649816513061
Epoch 278 1.2255081944995456 4.069665169715881
Epoch 279 1.2253522766960991 4.068708968162537
Epoch 280 1.2256091521845924 4.0657330513000485
Epoch 281 1.2250961131519742 4.064980101585388
Epoch 282 1.224692157904307 4.06555917263031
Epoch 283 1.224692157904307 4.062156391143799
Epoch 284 1.2246405495537651 4.061219263076782
Epoch 285 1.2247793310218387 4.062985777854919
Epoch 286 1.2244803845882415 4.062053465843201
Epoch 287 1.2240287376774681 4.061098313331604
Epoch 288 1.2242251879639097 4.062237739562988
Epoch 289 1.2239112436771393 4.062393832206726
Epoch 290 1.2244316624270546 4.05859341621399
Epoch 291 1.2241053091155158 4.059314131736755
Epoch 292 1.223591085937288 4.059266829490662
Epoch 293 1.2239060408539242 4.058769297599793
Epoch 294 1.2229907121923236 4.058314251899719
Epoch 295 1.2233721243010627 4.057916903495789
Epoch 296 1.222854165898429 4.058439946174621
Epoch 297 1.22348

Epoch 413 1.2119596653514437 4.007680416107178
Epoch 414 1.211221135324902 4.0049227476119995
Epoch 415 1.211825070116255 4.0024110078811646
Epoch 416 1.2121552871333228 4.003641796112061
Epoch 417 1.2105319877465566 4.000801849365234
Epoch 418 1.2110837890042199 4.00432984828949
Epoch 419 1.211373828517066 4.001468443870545
Epoch 420 1.2111894508202872 4.004180073738098
Epoch 421 1.2109325302971734 4.0048997640609745
Epoch 422 1.211112129688263 4.0002758502960205
Epoch 423 1.2112083103921678 4.001972627639771
Epoch 424 1.2104515486293368 3.9997998237609864
Epoch 425 1.2113604254192776 3.9991337776184084
Epoch 426 1.2104715764522553 4.00017204284668
Epoch 427 1.2112419989373948 3.997720813751221
Epoch 428 1.2104423112339444 4.002051377296448
Epoch 429 1.2106835775905185 3.9994035720825196
Epoch 430 1.2108005550172594 3.998451328277588
Epoch 431 1.2106362018320296 3.9969583988189696
Epoch 432 1.210180424981647 4.000105857849121
Epoch 433 1.2094403598043653 3.9973737478256224
Epoch 434 1

Epoch 551 1.2038968278302087 3.9706897020339964
Epoch 552 1.2038602570692698 3.965231442451477
Epoch 553 1.2041191187169817 3.962864303588867
Epoch 554 1.2033239920934042 3.9667211532592774
Epoch 555 1.2032449854744804 3.96760516166687
Epoch 556 1.2032928870783912 3.963817811012268
Epoch 557 1.2029109729660883 3.9645206928253174
Epoch 558 1.2036041286256578 3.965013837814331
Epoch 559 1.2033789303567675 3.9647298097610473
Epoch 560 1.2038362251387702 3.967046856880188
Epoch 561 1.2033545003996955 3.967442011833191
Epoch 562 1.2035255451997122 3.9696921825408937
Epoch 563 1.2040499091148376 3.9647473812103273
Epoch 564 1.203703741894828 3.9690941333770753
Epoch 565 1.202871748473909 3.9622127056121825
Epoch 566 1.2029552161693573 3.963101530075073
Epoch 567 1.2036874877081978 3.9652790307998655
Epoch 568 1.2026923841900297 3.960527181625366
Epoch 569 1.2025277482138739 3.9644999265670777
Epoch 570 1.2026540332370335 3.9659228324890137
Epoch 571 1.2028726140658061 3.9626482725143433
Epoc

Epoch 699 1.1978276358710396 3.9413910150527953
Epoch 700 1.198935106727812 3.9405080795288088
Time1:  0.25616661086678505
Time2:  4.056248344480991
Time3:  15.892941545695066
Time1:  0.02858608029782772
Time2:  0.4304351359605789
Time3:  1.7509670462459326

train {'F1@10': 0.16074462234973907, 'F1@30': 0.18747344613075256, 'F1@50': 0.1805022805929184, 'ndcg@10': 0.7200613737984812, 'ndcg@30': 0.5942041225656784, 'ndcg@50': 0.5738216499001352, 'ndcg@all': 0.7360389612391343}
valid {'F1@10': 0.12395822256803513, 'F1@30': 0.15669068694114685, 'F1@50': 0.1576538383960724, 'ndcg@10': 0.5156586523409243, 'ndcg@30': 0.4429191573055925, 'ndcg@50': 0.43276759325428765, 'ndcg@all': 0.6196654702241237}
Epoch 701 1.1977938771247865 3.9412967920303346
Epoch 702 1.1984469864103529 3.9405803203582765
Epoch 703 1.1990669978989494 3.941893768310547
Epoch 704 1.1982420682907104 3.9415714740753174
Epoch 705 1.1981464750236934 3.9422979593276977
Epoch 706 1.198199945025974 3.9432605504989624
Epoch 707 1.

Epoch 834 1.1941354294617972 3.9242907762527466
Epoch 835 1.194254520866606 3.928330159187317
Epoch 836 1.1939412461386787 3.9232163429260254
Epoch 837 1.1947961390018462 3.9224135637283326
Epoch 838 1.194669172498915 3.9298900842666624
Epoch 839 1.1946702533298068 3.9282042026519775
Epoch 840 1.194083360168669 3.927843379974365
Epoch 841 1.1946452935536702 3.92181510925293
Epoch 842 1.1945864127741919 3.925040102005005
Epoch 843 1.194380572769377 3.9253304481506346
Epoch 844 1.1948889129691653 3.9260945081710816
Epoch 845 1.1942339208390977 3.9266061067581175
Epoch 846 1.1943813880284628 3.924824523925781
Epoch 847 1.1944239675998687 3.9240992307662963
Epoch 848 1.1951680143674215 3.923406171798706
Epoch 849 1.1937725649939643 3.9209569692611694
Epoch 850 1.1947918706470066 3.92563853263855
Time1:  0.2550753690302372
Time2:  4.02676036208868
Time3:  15.88817179389298
Time1:  0.02836534008383751
Time2:  0.43076448142528534
Time3:  1.7548394277691841

train {'F1@10': 0.16377681493759155

Epoch 969 1.1918552974859873 3.9161741971969604
Epoch 970 1.1925813542471992 3.91773316860199
Epoch 971 1.191544767220815 3.9179988861083985
Epoch 972 1.1923274702495998 3.9187676668167115
Epoch 973 1.1921399692694346 3.9163455724716187
Epoch 974 1.1919195301002927 3.9177692890167237
Epoch 975 1.1916112853421106 3.9161560773849486
Epoch 976 1.1919051819377475 3.9177542686462403
Epoch 977 1.1919949889183044 3.915113615989685
Epoch 978 1.19121928413709 3.9169204473495483
Epoch 979 1.1919421368175083 3.9165480613708494
Epoch 980 1.1913817630873786 3.91429545879364
Epoch 981 1.191839290327496 3.9148746728897095
Epoch 982 1.1925164931350285 3.9127419710159304
Epoch 983 1.1912710090478262 3.91385395526886
Epoch 984 1.1921346174346077 3.9138432264328005
Epoch 985 1.191401368379593 3.914120602607727
Epoch 986 1.1917218691772884 3.915145492553711
Epoch 987 1.191122798787223 3.915494513511658
Epoch 988 1.1911691261662378 3.9145288467407227
Epoch 989 1.191197376118766 3.914929175376892
Epoch 990 

Epoch 1102 1.1891640371746488 3.903961682319641
Epoch 1103 1.1892294075753953 3.9083755254745483
Epoch 1104 1.189440224568049 3.9093725204467775
Epoch 1105 1.1895102898279826 3.91271767616272
Epoch 1106 1.1895240147908528 3.9082591056823732
Epoch 1107 1.189380266269048 3.9106884002685547
Epoch 1108 1.1888952652613323 3.910168242454529
Epoch 1109 1.1892339401774936 3.9046444416046144
Epoch 1110 1.1891038881407843 3.9030915975570677
Epoch 1111 1.1892699029710558 3.9052152156829836
Epoch 1112 1.1890909334023794 3.906708097457886
Epoch 1113 1.1893894288274978 3.9046161651611326
Epoch 1114 1.1887665980392033 3.9078672170639037
Epoch 1115 1.1897568225860595 3.906430411338806
Epoch 1116 1.1895737879806094 3.9075716018676756
Epoch 1117 1.1892045054170821 3.9089657068252563
Epoch 1118 1.1894526057773167 3.906771183013916
Epoch 1119 1.1886909928586749 3.9083744764328
Epoch 1120 1.1882022672229342 3.912795639038086
Epoch 1121 1.1890515625476836 3.907426285743713
Epoch 1122 1.188876051372952 3.908

Epoch 1247 1.1870057033167944 3.9036655187606812
Epoch 1248 1.1874682479434544 3.905927586555481
Epoch 1249 1.18713997470008 3.902501606941223
Epoch 1250 1.187324854400423 3.9035677909851074
Time1:  0.2557370513677597
Time2:  3.968011947348714
Time3:  15.77169855311513
Time1:  0.02842358872294426
Time2:  0.43063221871852875
Time3:  1.7515177950263023

train {'F1@10': 0.1695690006017685, 'F1@30': 0.199637770652771, 'F1@50': 0.19213123619556427, 'ndcg@10': 0.7402905429994171, 'ndcg@30': 0.6134897068855627, 'ndcg@50': 0.5928819690218757, 'ndcg@all': 0.7496349171815114}
valid {'F1@10': 0.1280287653207779, 'F1@30': 0.16228105127811432, 'F1@50': 0.16361993551254272, 'ndcg@10': 0.5254907986248011, 'ndcg@30': 0.4526841640211381, 'ndcg@50': 0.44291109909994286, 'ndcg@all': 0.6274155239323834}
Epoch 1251 1.187650348742803 3.9052348136901855
Epoch 1252 1.1874385277430217 3.9070218801498413
Epoch 1253 1.187076434161928 3.9076489210128784
Epoch 1254 1.187520173523161 3.901311993598938
Epoch 1255 1.

Epoch 1379 1.1859911415312026 3.8994593143463137
Epoch 1380 1.185479427046246 3.8983542680740357
Epoch 1381 1.1858552714188895 3.9019304275512696
Epoch 1382 1.1862084693378872 3.9041942358016968
Epoch 1383 1.185974219110277 3.900104832649231
Epoch 1384 1.185205700662401 3.9011902093887327
Epoch 1385 1.1857993874284956 3.899787735939026
Epoch 1386 1.1858872870604198 3.9026143550872803
Epoch 1387 1.1848759909470876 3.8987674951553344
Epoch 1388 1.1855257723066541 3.901110577583313
Epoch 1389 1.1853540254963768 3.900988507270813
Epoch 1390 1.185389338599311 3.901816701889038
Epoch 1391 1.1850898610221015 3.9061426639556887
Epoch 1392 1.1853585051165687 3.9011024951934816
Epoch 1393 1.1849302695857153 3.9018874883651735
Epoch 1394 1.184990413321389 3.8961390256881714
Epoch 1395 1.1855114122231802 3.9011223316192627
Epoch 1396 1.1857401940557692 3.8979740381240844
Epoch 1397 1.1850109186437394 3.9009191751480103
Epoch 1398 1.1850233197212219 3.8980042219161986
Epoch 1399 1.1847055832544962 

Epoch 1511 1.1840470280912188 3.8978271484375
Epoch 1512 1.183569037914276 3.8982808351516725
Epoch 1513 1.1838595887025198 3.898422288894653
Epoch 1514 1.1846875429153443 3.8971019983291626
Epoch 1515 1.1844122260808945 3.8918957233428957
Epoch 1516 1.1838289631737604 3.8959859132766725
Epoch 1517 1.183594693077935 3.8989710569381715
Epoch 1518 1.1839304327964784 3.8978424072265625
Epoch 1519 1.1842424432436625 3.8985063076019286
Epoch 1520 1.1837947030862173 3.8942667961120607
Epoch 1521 1.1836195601357353 3.8967997550964357
Epoch 1522 1.1840783529811436 3.8927413702011107
Epoch 1523 1.1845911191569434 3.8971287488937376
Epoch 1524 1.1837929666042328 3.8977840900421143
Epoch 1525 1.1842404537730746 3.8984957933425903
Epoch 1526 1.1841218517886267 3.897414493560791
Epoch 1527 1.184147126144833 3.899429941177368
Epoch 1528 1.1836267623636458 3.8984952211380004
Epoch 1529 1.1836387515068054 3.9006134986877443
Epoch 1530 1.1834696253140768 3.8986513137817385
Epoch 1531 1.1838232060273488

Epoch 1651 1.1824838234318626 3.8968693733215334
Epoch 1652 1.1827681726879544 3.90003981590271
Epoch 1653 1.1831182883845435 3.898288369178772
Epoch 1654 1.182696164978875 3.8996184825897218
Epoch 1655 1.182006652487649 3.8963422060012816
Epoch 1656 1.182882848713133 3.8985244750976564
Epoch 1657 1.1828803188270993 3.8986581563949585
Epoch 1658 1.1824163377285004 3.899065852165222
Epoch 1659 1.182603015502294 3.9019115447998045
Epoch 1660 1.1828278217050765 3.9004045009613035
Epoch 1661 1.182341637214025 3.9020429849624634
Epoch 1662 1.1821234934859806 3.8967950344085693
Epoch 1663 1.1827607035636902 3.8985056638717652
Epoch 1664 1.1825838069121042 3.8962980270385743
Epoch 1665 1.1827548682689666 3.8977756023406984
Epoch 1666 1.1825574179490408 3.8936012029647826
Epoch 1667 1.1824521283308664 3.89535448551178
Epoch 1668 1.1829650587505764 3.898763394355774
Epoch 1669 1.182603630092409 3.9003991365432737
Epoch 1670 1.1827739708953433 3.898916482925415
Epoch 1671 1.1826125542322794 3.89

Epoch 1797 1.181849608156416 3.8968082904815673
Epoch 1798 1.1816318485471937 3.894526982307434
Epoch 1799 1.1813243236806659 3.8971397161483763
Epoch 1800 1.1811382128132715 3.89496591091156
Time1:  0.2555961776524782
Time2:  3.9643267169594765
Time3:  15.86089632473886
Time1:  0.02856377884745598
Time2:  0.43206309527158737
Time3:  1.7508957032114267

train {'F1@10': 0.17456482350826263, 'F1@30': 0.20742680132389069, 'F1@50': 0.19951018691062927, 'ndcg@10': 0.7514575584512929, 'ndcg@30': 0.6250768726308301, 'ndcg@50': 0.6041697868383236, 'ndcg@all': 0.7576424037190247}
valid {'F1@10': 0.129752054810524, 'F1@30': 0.1656719297170639, 'F1@50': 0.16695721447467804, 'ndcg@10': 0.5293143173219176, 'ndcg@30': 0.45671176680133446, 'ndcg@50': 0.44720143087961534, 'ndcg@all': 0.630718444724818}
Epoch 1801 1.1813367625077567 3.895174241065979
Epoch 1802 1.1813427395290799 3.896507430076599
Epoch 1803 1.181484764814377 3.896136164665222
Epoch 1804 1.1817438781261445 3.8977041721343992
Epoch 1805

Epoch 1929 1.1804130176703136 3.8970319509506224
Epoch 1930 1.1802561965253617 3.896972393989563
Epoch 1931 1.1812631713019477 3.8985816478729247
Epoch 1932 1.1799589647187128 3.8985713958740233
Epoch 1933 1.1801705327298906 3.901512384414673
Epoch 1934 1.1804188046190474 3.901765513420105
Epoch 1935 1.1801072537899018 3.9015398025512695
Epoch 1936 1.1801624092790814 3.900637316703796
Epoch 1937 1.1805218855539958 3.8966240167617796
Epoch 1938 1.1799009581406912 3.901553821563721
Epoch 1939 1.1806571675671471 3.8986969232559203
Epoch 1940 1.1798994282881419 3.9016797304153443
Epoch 1941 1.18038646446334 3.8934589624404907
Epoch 1942 1.179795061217414 3.900215172767639
Epoch 1943 1.180106243160036 3.897476410865784
Epoch 1944 1.1805298573440977 3.8958166599273683
Epoch 1945 1.1798599792851343 3.8981491565704345
Epoch 1946 1.1800871200031704 3.8965900182724
Epoch 1947 1.1799330625269149 3.9010093450546264
Epoch 1948 1.1803642941845789 3.8990558385849
Epoch 1949 1.1805542124642265 3.89752

Epoch 2062 1.1794688979784647 3.898500895500183
Epoch 2063 1.1793106410238479 3.89774374961853
Epoch 2064 1.1795527570777469 3.901267433166504
Epoch 2065 1.1792002419630687 3.89868278503418
Epoch 2066 1.179381204975976 3.8981569051742553
Epoch 2067 1.1790928271081713 3.900921106338501
Epoch 2068 1.1797147002485064 3.899459433555603
Epoch 2069 1.1790245572725933 3.897066855430603
Epoch 2070 1.179219541284773 3.89889919757843
Epoch 2071 1.1788820876015558 3.900363397598267
Epoch 2072 1.1792808320787218 3.898343634605408
Epoch 2073 1.1791467686494193 3.9017058610916138
Epoch 2074 1.178932398557663 3.8993619680404663
Epoch 2075 1.1789670500490401 3.898793649673462
Epoch 2076 1.17943245238728 3.9008362770080565
Epoch 2077 1.1791162603431278 3.899111533164978
Epoch 2078 1.1794831931591034 3.900964879989624
Epoch 2079 1.1797160936726465 3.898269772529602
Epoch 2080 1.1790972179836696 3.899533414840698
Epoch 2081 1.1795684152179293 3.8999401807785032
Epoch 2082 1.1794416348139445 3.89604446887

Epoch 2201 1.1790927906831106 3.9009015798568725
Epoch 2202 1.1782842437426249 3.902545619010925
Epoch 2203 1.1784537937906054 3.8983985662460325
Epoch 2204 1.1782335539658864 3.901146388053894
Epoch 2205 1.179102372460895 3.9036539793014526
Epoch 2206 1.1785723659727307 3.9029767513275146
Epoch 2207 1.1778542796770732 3.9023578643798826
Epoch 2208 1.1781459987163543 3.906861972808838
Epoch 2209 1.1776505185498132 3.9036605834960936
Epoch 2210 1.178396643532647 3.9022266149520872
Epoch 2211 1.1779977473947736 3.9007471323013307
Epoch 2212 1.1786823358800675 3.902396869659424
Epoch 2213 1.1783398575252957 3.901318073272705
Epoch 2214 1.1782540029949613 3.9013506889343263
Epoch 2215 1.177741277217865 3.9009265184402464
Epoch 2216 1.1779999494552613 3.897782802581787
Epoch 2217 1.1783175422085657 3.9025309324264525
Epoch 2218 1.1784692380163404 3.900354290008545
Epoch 2219 1.1785565495491028 3.901191306114197
Epoch 2220 1.1782268027464549 3.901654505729675
Epoch 2221 1.1775313602553474 3.

Epoch 2346 1.1775655388832091 3.904308533668518
Epoch 2347 1.1773126357131534 3.9070533514022827
Epoch 2348 1.17745821012391 3.9060434818267824
Epoch 2349 1.1771640963024563 3.90427942276001
Epoch 2350 1.1772257930702634 3.9080851793289186
Time1:  0.25552340410649776
Time2:  3.9614133033901453
Time3:  15.770186649635434
Time1:  0.028735708445310593
Time2:  0.43720840103924274
Time3:  1.7490929514169693

train {'F1@10': 0.17808368802070618, 'F1@30': 0.21310244500637054, 'F1@50': 0.20472082495689392, 'ndcg@10': 0.758971543963149, 'ndcg@30': 0.6329892367688117, 'ndcg@50': 0.611941168239305, 'ndcg@all': 0.7630252927707485}
valid {'F1@10': 0.13088110089302063, 'F1@30': 0.1675964742898941, 'F1@50': 0.16861651837825775, 'ndcg@10': 0.5317799298090858, 'ndcg@30': 0.4593624566831776, 'ndcg@50': 0.44974292343197103, 'ndcg@all': 0.633027443790937}
Epoch 2351 1.1770981093247732 3.9060108423233033
Epoch 2352 1.177762954764896 3.9029460668563845
Epoch 2353 1.177433051665624 3.9013632774353026
Epoch 2

Epoch 2478 1.1765183177259233 3.9069068908691404
Epoch 2479 1.1766289432843526 3.9109716415405273
Epoch 2480 1.1766422702206505 3.9062843561172484
Epoch 2481 1.1763347095913357 3.90921688079834
Epoch 2482 1.1763913558589087 3.904183864593506
Epoch 2483 1.1760637515121035 3.9054454803466796
Epoch 2484 1.1768243941995833 3.909564185142517
Epoch 2485 1.1759487284554375 3.9085453033447264
Epoch 2486 1.176282861497667 3.907271933555603
Epoch 2487 1.1765658802456327 3.9068273544311523
Epoch 2488 1.1760697894626193 3.902876687049866
Epoch 2489 1.1763780176639558 3.907194638252258
Epoch 2490 1.1768368416362338 3.908922028541565
Epoch 2491 1.1764712724420758 3.9082860946655273
Epoch 2492 1.176334187719557 3.9082827091217043
Epoch 2493 1.1769742654429542 3.9047865152359007
Epoch 2494 1.1760910941494835 3.9065433979034423
Epoch 2495 1.176883072985543 3.9084484577178955
Epoch 2496 1.1768744501802657 3.908168363571167
Epoch 2497 1.1766973237196605 3.9065330028533936
Epoch 2498 1.1761245403024885 3.

Epoch 2610 1.1762417640950944 3.90970356464386
Epoch 2611 1.1762112081050873 3.905740833282471
Epoch 2612 1.1759745849503411 3.909955096244812
Epoch 2613 1.175638672378328 3.9075211763381956
Epoch 2614 1.1761260198222265 3.9107038497924806
Epoch 2615 1.1763628462950388 3.910283637046814
Epoch 2616 1.1757977883021036 3.9051475048065187
Epoch 2617 1.175580581029256 3.9095849990844727
Epoch 2618 1.1758378095097013 3.9071627378463747
Epoch 2619 1.175754475593567 3.911455583572388
Epoch 2620 1.1764816761016845 3.910498332977295
Epoch 2621 1.1756452765729692 3.9099003791809084
Epoch 2622 1.1756155987580617 3.9091978311538695
Epoch 2623 1.1759856230682797 3.9112216949462892
Epoch 2624 1.1755887687206268 3.9081504344940186
Epoch 2625 1.176070746448305 3.9068254709243773
Epoch 2626 1.1760223633713192 3.9095182180404664
Epoch 2627 1.1757308423519135 3.909983515739441
Epoch 2628 1.1755505363146463 3.9078661441802978
Epoch 2629 1.1762640900082058 3.9108503103256225
Epoch 2630 1.175439676311281 3.9

Epoch 2751 1.1749144434928893 3.9153462409973145
Epoch 2752 1.1748267339335547 3.911493754386902
Epoch 2753 1.1749603595998552 3.9150699853897093
Epoch 2754 1.175449964735243 3.915991997718811
Epoch 2755 1.1751480089293587 3.9165707349777223
Epoch 2756 1.175140814648734 3.9182755947113037
Epoch 2757 1.175294437011083 3.9125524282455446
Epoch 2758 1.1752859029504987 3.915825533866882
Epoch 2759 1.1749147163497078 3.9113507747650145
Epoch 2760 1.1751490943961673 3.913152551651001
Epoch 2761 1.174758337603675 3.9158138036727905
Epoch 2762 1.1748056206438275 3.9161578178405763
Epoch 2763 1.174533690346612 3.9144811391830445
Epoch 2764 1.1748816668987274 3.9164103984832765
Epoch 2765 1.175060170226627 3.9158015727996824
Epoch 2766 1.174611477719413 3.9144341945648193
Epoch 2767 1.1752642419603136 3.9124483108520507
Epoch 2768 1.175445826186074 3.9127735614776613
Epoch 2769 1.1748632033665976 3.914789319038391
Epoch 2770 1.1747426238324907 3.9134753465652468
Epoch 2771 1.1760353598329756 3.9

Epoch 2896 1.1741236799293093 3.9144269466400146
Epoch 2897 1.1743208666642506 3.9193315267562867
Epoch 2898 1.1737302660942077 3.918328356742859
Epoch 2899 1.1750509295198652 3.91927056312561
Epoch 2900 1.1741589883963266 3.915877866744995
Time1:  0.25508348271250725
Time2:  3.9532958082854748
Time3:  15.751880768686533
Time1:  0.02858113870024681
Time2:  0.4315254557877779
Time3:  1.74780697748065

train {'F1@10': 0.18069389462471008, 'F1@30': 0.21816831827163696, 'F1@50': 0.20931710302829742, 'ndcg@10': 0.7642042423680309, 'ndcg@30': 0.6391865503569201, 'ndcg@50': 0.6180732940241811, 'ndcg@all': 0.7673394730776648}
valid {'F1@10': 0.1321290135383606, 'F1@30': 0.17018546164035797, 'F1@50': 0.17089110612869263, 'ndcg@10': 0.5338462794980815, 'ndcg@30': 0.4617124519630655, 'ndcg@50': 0.45195307580565847, 'ndcg@all': 0.6348229805837283}
Epoch 2901 1.1743052807119159 3.9184261560440063
Epoch 2902 1.1746031873755984 3.9197145223617555
Epoch 2903 1.174193540546629 3.919498920440674
Epoch 2

Epoch 3029 1.17410230173005 3.914847183227539
Epoch 3030 1.173498241106669 3.9167039155960084
Epoch 3031 1.1733445160918765 3.9188063621520994
Epoch 3032 1.1739995380242665 3.920490336418152
Epoch 3033 1.1742723928557501 3.921822953224182
Epoch 3034 1.173516880803638 3.917348289489746
Epoch 3035 1.1740227904584672 3.9218995332717896
Epoch 3036 1.1737652129597134 3.9230612754821776
Epoch 3037 1.1733337687121497 3.9200758934020996
Epoch 3038 1.1733501977390712 3.9194408416748048
Epoch 3039 1.1737396895885468 3.920808029174805
Epoch 3040 1.1735771569940778 3.921548104286194
Epoch 3041 1.1738986406061385 3.918537712097168
Epoch 3042 1.1743369294537438 3.917910075187683
Epoch 3043 1.1738717310958438 3.9166011571884156
Epoch 3044 1.1734141204092237 3.9165947675704955
Epoch 3045 1.1736429658201006 3.916804575920105
Epoch 3046 1.1733000854651132 3.920549917221069
Epoch 3047 1.174060153298908 3.920826721191406
Epoch 3048 1.174012107981576 3.9200560808181764
Epoch 3049 1.1734658128685422 3.91823

Epoch 3161 1.1728984190358056 3.925794076919556
Epoch 3162 1.173040227095286 3.9206278800964354
Epoch 3163 1.1735449161794451 3.9218024969100953
Epoch 3164 1.173585318856769 3.9230504512786863
Epoch 3165 1.1732178025775486 3.9230223178863524
Epoch 3166 1.1733005868064033 3.9221843242645265
Epoch 3167 1.173357037703196 3.924805188179016
Epoch 3168 1.173175538248486 3.9246280908584597
Epoch 3169 1.1728141029675803 3.923851823806763
Epoch 3170 1.1735962152481079 3.9237307786941527
Epoch 3171 1.1730669756730399 3.9232542514801025
Epoch 3172 1.1731104294459025 3.925450563430786
Epoch 3173 1.1727220820056068 3.9231401443481446
Epoch 3174 1.173212210337321 3.921018934249878
Epoch 3175 1.1731285744243198 3.9230398416519163
Epoch 3176 1.1728159719043307 3.9263059377670286
Epoch 3177 1.173073285155826 3.9244366884231567
Epoch 3178 1.1740060342682732 3.925092911720276
Epoch 3179 1.1723508384492662 3.921573114395142
Epoch 3180 1.1735473189089034 3.9225061893463136
Epoch 3181 1.1727623111671872 3.9

Epoch 3301 1.1729188091225093 3.925978899002075
Epoch 3302 1.1728021681308747 3.9283469200134276
Epoch 3303 1.1724189288086362 3.926156830787659
Epoch 3304 1.1725453562206691 3.9243528842926025
Epoch 3305 1.1725960751374562 3.927945303916931
Epoch 3306 1.172533435291714 3.922365593910217
Epoch 3307 1.1726425256994035 3.927656126022339
Epoch 3308 1.172515290313297 3.9294564723968506
Epoch 3309 1.1728831542862785 3.9292616128921507
Epoch 3310 1.1723703523476918 3.9266469717025756
Epoch 3311 1.1727025051911673 3.931173872947693
Epoch 3312 1.172313665681415 3.9253559827804567
Epoch 3313 1.1730528785122765 3.926344966888428
Epoch 3314 1.1727697749932606 3.9292980194091798
Epoch 3315 1.172641834947798 3.926156187057495
Epoch 3316 1.1729825297991434 3.926886534690857
Epoch 3317 1.1720827798048654 3.929874229431152
Epoch 3318 1.173341041141086 3.925369381904602
Epoch 3319 1.1723244362407261 3.927344012260437
Epoch 3320 1.172683201233546 3.9334069967269896
Epoch 3321 1.1723666661315495 3.929079

Epoch 3446 1.1721286078294118 3.9355419397354128
Epoch 3447 1.1716734462314182 3.935372996330261
Epoch 3448 1.1716981212298074 3.9337430953979493
Epoch 3449 1.1719128217962054 3.9297673225402834
Epoch 3450 1.1720836751990849 3.9296764135360718
Time1:  0.2551628090441227
Time2:  3.9787470754235983
Time3:  15.897585939615965
Time1:  0.028688376769423485
Time2:  0.4348601680248976
Time3:  1.7490069046616554

train {'F1@10': 0.18293659389019012, 'F1@30': 0.22189736366271973, 'F1@50': 0.2132166475057602, 'ndcg@10': 0.7686722733055524, 'ndcg@30': 0.6441944258563113, 'ndcg@50': 0.6231298159604873, 'ndcg@all': 0.7707097768949321}
valid {'F1@10': 0.1319507360458374, 'F1@30': 0.17188091576099396, 'F1@50': 0.17251311242580414, 'ndcg@10': 0.5340208831980631, 'ndcg@30': 0.46365453325420825, 'ndcg@50': 0.4538904119122613, 'ndcg@all': 0.6362064261067737}
Epoch 3451 1.1720712496174706 3.930811285972595
Epoch 3452 1.1714386926756966 3.930591344833374
Epoch 3453 1.1719862699508667 3.933813691139221
Epoc

Epoch 3579 1.1713713043265872 3.9356902599334718
Epoch 3580 1.1713512433899773 3.9368056774139406
Epoch 3581 1.1715740058157178 3.933602476119995
Epoch 3582 1.1714152766598596 3.9371817111968994
Epoch 3583 1.171153614918391 3.93344509601593
Epoch 3584 1.1720756391684215 3.9371954917907717
Epoch 3585 1.1717088811927372 3.9363996505737306
Epoch 3586 1.1720433361000484 3.931976866722107
Epoch 3587 1.171216317680147 3.9375588417053224
Epoch 3588 1.1712725930743748 3.9336116790771483
Epoch 3589 1.1722652282979753 3.934677505493164
Epoch 3590 1.1716490434275733 3.9369903802871704
Epoch 3591 1.1714513699213664 3.936224675178528
Epoch 3592 1.1713979336950513 3.9344537019729615
Epoch 3593 1.1714652650886113 3.9358620405197144
Epoch 3594 1.1716912501388126 3.933594822883606
Epoch 3595 1.1714151137404971 3.9348700523376463
Epoch 3596 1.1708291437890794 3.9373247385025025
Epoch 3597 1.170792622698678 3.9380990505218505
Epoch 3598 1.1714564873112572 3.934998369216919
Epoch 3599 1.1719324794080523 3

Epoch 3711 1.170799312988917 3.9351577758789062
Epoch 3712 1.170791197485394 3.9373724460601807
Epoch 3713 1.171377052201165 3.9394708395004274
Epoch 3714 1.1707883298397064 3.937738060951233
Epoch 3715 1.1712168527974023 3.937634539604187
Epoch 3716 1.1708874358071222 3.9394202709197996
Epoch 3717 1.1707780639330545 3.9370587348937987
Epoch 3718 1.1711635291576385 3.9377561807632446
Epoch 3719 1.1712635762161678 3.9384808778762816
Epoch 3720 1.1709952460394966 3.9376105070114136
Epoch 3721 1.1714172475867801 3.936022162437439
Epoch 3722 1.1711883280012343 3.941288638114929
Epoch 3723 1.1713600330882603 3.938583493232727
Epoch 3724 1.1705285436577266 3.934883379936218
Epoch 3725 1.170735056532754 3.9385719537734984
Epoch 3726 1.1712651716338263 3.9371596336364747
Epoch 3727 1.170933304230372 3.9398845195770265
Epoch 3728 1.1711217767662472 3.9371890783309937
Epoch 3729 1.1710255927509732 3.937090849876404
Epoch 3730 1.1707106113433838 3.9416948318481446
Epoch 3731 1.1715963390138415 3.

Epoch 3851 1.1706590374310812 3.9418090105056764
Epoch 3852 1.1710540943675571 3.938898491859436
Epoch 3853 1.1704183240731558 3.9412240743637086
Epoch 3854 1.1703051997555627 3.9402428150177
Epoch 3855 1.170098629924986 3.9377074003219605
Epoch 3856 1.170573651128345 3.944311332702637
Epoch 3857 1.1703990214400821 3.946201753616333
Epoch 3858 1.1702956895033518 3.9418755769729614
Epoch 3859 1.1708469443851046 3.9440001010894776
Epoch 3860 1.1705276230971018 3.937233304977417
Epoch 3861 1.1705179082022772 3.9441550970077515
Epoch 3862 1.170354891485638 3.9400760889053346
Epoch 3863 1.1699559383922153 3.9431238412857055
Epoch 3864 1.1704807877540588 3.946762537956238
Epoch 3865 1.170867727862464 3.9445263862609865
Epoch 3866 1.1701003730297088 3.9465267658233643
Epoch 3867 1.1707608507739173 3.943842148780823
Epoch 3868 1.1697623146904839 3.941539239883423
Epoch 3869 1.1705615984068976 3.939166235923767
Epoch 3870 1.1708102742830913 3.943260598182678
Epoch 3871 1.1707281370957692 3.9443

Epoch 3996 1.1697340799702538 3.9476075649261473
Epoch 3997 1.1701389564408196 3.9533324241638184
Epoch 3998 1.1697228206528558 3.9462578773498533
Epoch 3999 1.1693850100040435 3.947212743759155
Epoch 4000 1.170466320382224 3.9479496240615846
Time1:  0.25508829578757286
Time2:  3.9594200905412436
Time3:  15.760886088013649
Time1:  0.028712304309010506
Time2:  0.43040965497493744
Time3:  1.74822904355824

train {'F1@10': 0.1843693107366562, 'F1@30': 0.22531889379024506, 'F1@50': 0.21615557372570038, 'ndcg@10': 0.771838407041101, 'ndcg@30': 0.648437612045461, 'ndcg@50': 0.6271440586272018, 'ndcg@all': 0.7735186260819302}
valid {'F1@10': 0.13307979702949524, 'F1@30': 0.17268279194831848, 'F1@50': 0.17286737263202667, 'ndcg@10': 0.5356736472557048, 'ndcg@30': 0.4647422625279582, 'ndcg@50': 0.45490932645643495, 'ndcg@all': 0.6373665056995921}
Epoch 4001 1.1696135573916966 3.9463621616363525
Epoch 4002 1.170462558666865 3.9487936019897463
Epoch 4003 1.1702236692110697 3.9464893102645875
Epoc

Epoch 4129 1.1695596953233083 3.953209948539734
Epoch 4130 1.1699311196804048 3.9505628824234007
Epoch 4131 1.1695387621720632 3.9498183965682983
Epoch 4132 1.1695887042416466 3.950254464149475
Epoch 4133 1.169493126206928 3.954115462303162
Epoch 4134 1.1696403668986426 3.9527190208435057
Epoch 4135 1.1693369408448537 3.9561453580856325
Epoch 4136 1.1692424601978726 3.952668333053589
Epoch 4137 1.1698264141877492 3.9539751052856444
Epoch 4138 1.1699803312619528 3.95240638256073
Epoch 4139 1.1696884241369037 3.956814432144165
Epoch 4140 1.1693964322408041 3.9558383226394653
Epoch 4141 1.1701561603281232 3.953968644142151
Epoch 4142 1.1695374223921033 3.952404022216797
Epoch 4143 1.1693168580532074 3.9541922569274903
Epoch 4144 1.1691719737317827 3.9566041946411135
Epoch 4145 1.1693932043181525 3.953344392776489
Epoch 4146 1.1692341069380443 3.9545814752578736
Epoch 4147 1.169851800468233 3.9514056921005247
Epoch 4148 1.1703344490793017 3.9549196004867553
Epoch 4149 1.1696550203694238 3.

Epoch 4261 1.1691081172890134 3.9548882484436034
Epoch 4262 1.169494268629286 3.9546003103256226
Epoch 4263 1.1691203090879652 3.9574207067489624
Epoch 4264 1.1691329836845399 3.9587443351745604
Epoch 4265 1.1689785950713687 3.9561693906784057
Epoch 4266 1.1692658536963993 3.953891396522522
Epoch 4267 1.1691869497299194 3.9566968202590944
Epoch 4268 1.1693344626161788 3.956840252876282
Epoch 4269 1.1692057331403096 3.9550431966781616
Epoch 4270 1.1694188051753573 3.9584787607192995
Epoch 4271 1.1690647668308682 3.95813467502594
Epoch 4272 1.1693235496679941 3.95920307636261
Epoch 4273 1.169152256515291 3.9560474157333374
Epoch 4274 1.1690120783117082 3.9573192596435547
Epoch 4275 1.169351683060328 3.9535213232040407
Epoch 4276 1.1688809136549632 3.955638790130615
Epoch 4277 1.1690478715631696 3.952527117729187
Epoch 4278 1.1692302491929796 3.9566576719284057
Epoch 4279 1.169625296857622 3.9555699110031126
Epoch 4280 1.16950467493799 3.953077960014343
Epoch 4281 1.1689817084206475 3.956

Epoch 4401 1.1691047569115958 3.96573326587677
Epoch 4402 1.16841083433893 3.961131954193115
Epoch 4403 1.1684832784864638 3.9566322803497314
Epoch 4404 1.1687785135375128 3.9646331071853638
Epoch 4405 1.1690763884120516 3.9626843214035032
Epoch 4406 1.1682941986454858 3.9619353771209718
Epoch 4407 1.1688477026091681 3.9579038619995117
Epoch 4408 1.1688388956917657 3.9613857746124266
Epoch 4409 1.1686950471666124 3.9578627824783323
Epoch 4410 1.1684579968452453 3.961476182937622
Epoch 4411 1.168502637412813 3.9594822406768797
Epoch 4412 1.1684431778060065 3.95615496635437
Epoch 4413 1.1689589947462082 3.9591180801391603
Epoch 4414 1.168913949198193 3.955283784866333
Epoch 4415 1.1681882341702778 3.9614351272583006
Epoch 4416 1.168787278731664 3.9598438262939455
Epoch 4417 1.1685082342889574 3.9612508535385134
Epoch 4418 1.1690841999318864 3.9630778789520265
Epoch 4419 1.1682618313365511 3.9595579147338866
Epoch 4420 1.16870137585534 3.9590271949768066
Epoch 4421 1.1680792027049594 3.96

Epoch 4546 1.168256500032213 3.966037940979004
Epoch 4547 1.1691234469413758 3.9696967601776123
Epoch 4548 1.168587139579985 3.9658596754074096
Epoch 4549 1.1688888768355052 3.965130829811096
Epoch 4550 1.1682601597574023 3.967977738380432
Time1:  0.25571646727621555
Time2:  3.9625871665775776
Time3:  15.863826112821698
Time1:  0.02847956493496895
Time2:  0.43031965009868145
Time3:  1.7510002367198467

train {'F1@10': 0.185836061835289, 'F1@30': 0.2282114326953888, 'F1@50': 0.21890656650066376, 'ndcg@10': 0.7749225826165302, 'ndcg@30': 0.6519925903425333, 'ndcg@50': 0.630718545514564, 'ndcg@all': 0.7760015363654783}
valid {'F1@10': 0.13370373845100403, 'F1@30': 0.17341597378253937, 'F1@50': 0.17407920956611633, 'ndcg@10': 0.5357104120226027, 'ndcg@30': 0.46488717801727214, 'ndcg@50': 0.455400214173474, 'ndcg@all': 0.6375818651866865}
Epoch 4551 1.168306589126587 3.965852379798889
Epoch 4552 1.1684383816189237 3.96849410533905
Epoch 4553 1.1684189882543352 3.963755464553833
Epoch 4554 1

Epoch 4679 1.1674600395891401 3.9719737768173218
Epoch 4680 1.168486777941386 3.9743817806243897
Epoch 4681 1.1676471418804593 3.9703279972076415
Epoch 4682 1.1671780678961012 3.972154426574707
Epoch 4683 1.1674615813626184 3.9677141904830933
Epoch 4684 1.1677422953976526 3.965385580062866
Epoch 4685 1.16770840883255 3.9724135637283324
Epoch 4686 1.1680107911427815 3.9697240591049194
Epoch 4687 1.1680881261825562 3.9695950984954833
Epoch 4688 1.1680275354120466 3.9665042400360107
Epoch 4689 1.1681327905919816 3.9671592235565187
Epoch 4690 1.1676384978824192 3.9654650688171387
Epoch 4691 1.167839056915707 3.9671314001083373
Epoch 4692 1.1682360576258766 3.9685758352279663
Epoch 4693 1.1679765204588572 3.9708054304122924
Epoch 4694 1.168201955821779 3.969405937194824
Epoch 4695 1.1684908833768632 3.9696892499923706
Epoch 4696 1.1680777960353428 3.970832920074463
Epoch 4697 1.168077954981062 3.9657038927078245
Epoch 4698 1.1679920607142977 3.9653351306915283
Epoch 4699 1.1679037272930146 

Epoch 4812 1.1672775977187686 3.9747653007507324
Epoch 4813 1.1672821323076883 3.974014711380005
Epoch 4814 1.1676717930369906 3.9737016439437864
Epoch 4815 1.167243053515752 3.9738086462020874
Epoch 4816 1.1675302777025434 3.974694561958313
Epoch 4817 1.1680070320765177 3.972719407081604
Epoch 4818 1.1677178064982097 3.9736797332763674
Epoch 4819 1.1675103578302595 3.9727112293243407
Epoch 4820 1.1678763965765635 3.973256325721741
Epoch 4821 1.167923284901513 3.97327606678009
Epoch 4822 1.1677608847618104 3.9712779045104982
Epoch 4823 1.1672236601511636 3.973597836494446
Epoch 4824 1.1675588952170477 3.9737276077270507
Epoch 4825 1.1675576355722215 3.970454382896423
Epoch 4826 1.1679449101289114 3.971514105796814
Epoch 4827 1.1671906027528975 3.9711604833602907
Epoch 4828 1.1674522373411391 3.976252579689026
Epoch 4829 1.1671211196316613 3.9772651195526123
Epoch 4830 1.1671992242336273 3.9733057022094727
Epoch 4831 1.167876101202435 3.975286865234375
Epoch 4832 1.1674926857153574 3.97

Epoch 4951 1.1668488535616133 3.9790252447128296
Epoch 4952 1.1670742372671763 3.9761707067489622
Epoch 4953 1.1677469359503851 3.9764343976974486
Epoch 4954 1.1672996587223476 3.978725552558899
Epoch 4955 1.167238908343845 3.9816753387451174
Epoch 4956 1.166672670841217 3.979105520248413
Epoch 4957 1.1674071106645796 3.9818819761276245
Epoch 4958 1.1673452231619093 3.9802659511566163
Epoch 4959 1.1673103166951073 3.9779738187789917
Epoch 4960 1.1670275469621023 3.980030560493469
Epoch 4961 1.1675269895129734 3.977017068862915
Epoch 4962 1.1673874759011797 3.977394437789917
Epoch 4963 1.1672039555178748 3.9769142866134644
Epoch 4964 1.167130868964725 3.976837229728699
Epoch 4965 1.1671762062443627 3.979563069343567
Epoch 4966 1.1672773917516073 3.977637457847595
Epoch 4967 1.1677186787128448 3.976111364364624
Epoch 4968 1.1666782206959194 3.9795403242111207
Epoch 4969 1.166677176952362 3.979121994972229
Epoch 4970 1.166628101799223 3.977053904533386
Epoch 4971 1.167515348063575 3.98111

In [None]:
def deterministic_neural_sort(s, tau, mask):
    """
    Deterministic neural sort.
    Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
    Minor modifications applied to the original code (masking).
    :param s: values to sort, shape [batch_size, slate_length]
    :param tau: temperature for the final softmax function
    :param mask: mask indicating padded elements
    :return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
    """
    dev = get_torch_device()

    n = s.size()[1]
    one = torch.ones((n, 1), dtype=torch.float32, device=dev)
    s = s.masked_fill(mask[:, :, None], -1e8)
    A_s = torch.abs(s - s.permute(0, 2, 1))
    A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)

    B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))

    temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
    temp = [t.type(torch.float32) for t in temp]
    temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
    scaling = torch.stack(temp).type(torch.float32).to(dev)  # type: ignore

    s = s.masked_fill(mask[:, :, None], 0.0)
    C = torch.matmul(s, scaling.unsqueeze(-2))

    P_max = (C - B).permute(0, 2, 1)
    P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
    P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
    sm = torch.nn.Softmax(-1)
    P_hat = sm(P_max / tau)
    return P_hat

def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
    """
    Sinkhorn scaling procedure.
    :param mat: a tensor of square matrices of shape N x M x M, where N is batch size
    :param mask: a tensor of masks of shape N x M
    :param tol: Sinkhorn scaling tolerance
    :param max_iter: maximum number of iterations of the Sinkhorn scaling
    :return: a tensor of (approximately) doubly stochastic matrices
    """
    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
        mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)

    for _ in range(max_iter):
        mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
        mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)

        if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
            break

    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)

    return mat

def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10):
    """
    Stochastic neural sort. Please note that memory complexity grows by factor n_samples.
    Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
    Minor modifications applied to the original code (masking).
    :param s: values to sort, shape [batch_size, slate_length]
    :param n_samples: number of samples (approximations) for each permutation matrix
    :param tau: temperature for the final softmax function
    :param mask: mask indicating padded elements
    :param beta: scale parameter for the Gumbel distribution
    :param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation
    :param eps: epsilon for the logarithm function
    :return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
    """
    dev = get_torch_device()

    batch_size = s.size()[0]
    n = s.size()[1]
    s_positive = s + torch.abs(s.min())
    samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev)
    if log_scores:
        s_positive = torch.log(s_positive + eps)

    s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1)
    mask_repeated = mask.repeat_interleave(n_samples, dim=0)

    P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated)
    P_hat = P_hat.view(n_samples, batch_size, n, n)
    return P_hat

def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE):
    """
    Discounted Cumulative Gain at k.
    Compute DCG at ranks given by ats or at the maximum rank if ats is None.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
    :param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
    :param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
    """
    y_true = y_true.clone()
    y_pred = y_pred.clone()

    actual_length = y_true.shape[1]

    if ats is None:
        ats = [actual_length]
    ats = [min(at, actual_length) for at in ats]

    true_sorted_by_preds = __apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator)

    discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
        device=true_sorted_by_preds.device)

    gains = gain_function(true_sorted_by_preds)

    discounted_gains = (gains * discounts)[:, :np.max(ats)]

    cum_dcg = torch.cumsum(discounted_gains, dim=1)

    ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)

    dcg = cum_dcg[:, ats_tensor]

    return dcg

def get_torch_device():
    """
    Getter for an available pyTorch device.
    :return: CUDA-capable GPU if available, CPU otherwise
    """
    return "cpu"
    return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
               stochastic=False, n_samples=32, beta=0.1, log_scores=True):
    """
    NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
    Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param temperature: temperature for the NeuralSort algorithm
    :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
    :param k: rank at which the loss is truncated
    :param stochastic: whether to calculate the stochastic variant
    :param n_samples: how many stochastic samples are taken, used if stochastic == True
    :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
    :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
    :return: loss value, a torch.Tensor
    """
    dev = get_torch_device()

    if k is None:
        k = y_true.shape[1]

    mask = (y_true == padded_value_indicator)
    # Choose the deterministic/stochastic variant
    if stochastic:
        P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
                                       beta=beta, log_scores=log_scores)
    else:
        P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)

    # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
    P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
                             mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
    P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])

    # Mask P_hat and apply to true labels, ie approximately sort them
    P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
    y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
    if powered_relevancies:
        y_true_masked = torch.pow(2., y_true_masked) - 1.

    ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
    discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
    discounted_gains = ground_truth * discounts

    if powered_relevancies:
        idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0)
    else:
        idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)

    discounted_gains = discounted_gains[:, :, :k]
    ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
    idcg_mask = idcg == 0.
    ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)

    assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
    if idcg_mask.all():
        return torch.tensor(0.)

    mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0])  # type: ignore
    return -1. * mean_ndcg  # -1 cause we want to maximize NDCG


def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1.,
                          powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True,
                          max_iter=50, tol=1e-6):
    """
    NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
    Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param temperature: temperature for the NeuralSort algorithm
    :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
    :param k: rank at which the loss is truncated
    :param stochastic: whether to calculate the stochastic variant
    :param n_samples: how many stochastic samples are taken, used if stochastic == True
    :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
    :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
    :param max_iter: maximum iteration count for Sinkhorn scaling
    :param tol: tolerance for Sinkhorn scaling
    :return: loss value, a torch.Tensor
    """
    dev = get_torch_device()

    if k is None:
        k = y_true.shape[1]

    mask = (y_true == padded_value_indicator)

    if stochastic:
        P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
                                       beta=beta, log_scores=log_scores)
    else:
        P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)

    # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
    P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
                                    mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
    P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
    discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)

    # This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
    discounts[k:] = 0.
    discounts = discounts[None, None, :, None]

    # Here the discounts become expected discounts
    discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1)
    if powered_relevancies:
        gains = torch.pow(2., y_true) - 1
        discounted_gains = gains.unsqueeze(0) * discounts
        idcg = dcg(y_true, y_true, ats=[k]).squeeze()
    else:
        gains = y_true
        discounted_gains = gains.unsqueeze(0) * discounts
        idcg = dcg(y_true, y_true, ats=[k]).squeeze()

    ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS)
    idcg_mask = idcg == 0.
    ndcg = ndcg.masked_fill(idcg_mask, 0.)

    assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
    if idcg_mask.all():
        return torch.tensor(0.)

    mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0])  # type: ignore
    return -1. * mean_ndcg  # -1 cause we want to maximize NDCG

In [None]:
def lambdaLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, weighing_scheme=None, k=None, sigma=1., mu=10.,
               reduction="sum", reduction_log="binary"):
    """
    LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
    Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param weighing_scheme: a string corresponding to a name of one of the weighing schemes
    :param k: rank at which the loss is truncated
    :param sigma: score difference weight used in the sigmoid function
    :param mu: optional weight used in NDCGLoss2++ weighing scheme
    :param reduction: losses reduction method, could be either a sum or a mean
    :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
    :return: loss value, a torch.Tensor
    """
    device = y_pred.device
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    padded_mask = y_true == padded_value_indicator
    y_pred[padded_mask] = float("-inf")
    y_true[padded_mask] = float("-inf")

    # Here we sort the true and predicted relevancy scores.
    y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
    y_true_sorted, _ = y_true.sort(descending=True, dim=-1)

    # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
    true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
    true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
    print(true_sorted_by_preds.shape)
    print(true_diffs.shape)
    padded_pairs_mask = torch.isfinite(true_diffs)

    if weighing_scheme != "ndcgLoss1_scheme":
        padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)

    ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
    ndcg_at_k_mask[:k, :k] = 1

    # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
    true_sorted_by_preds.clamp_(min=0.)
    y_true_sorted.clamp_(min=0.)

    # Here we find the gains, discounts and ideal DCGs per slate.
    pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
    D = torch.log2(1. + pos_idxs.float())[None, :]
    maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
    G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]

    # Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
    if weighing_scheme is None:
        weights = 1.
    else:
        weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds)  # type: ignore

    # We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
    scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
    scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
    weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
    if reduction_log == "natural":
        losses = torch.log(weighted_probas)
    elif reduction_log == "binary":
        losses = torch.log2(weighted_probas)
    else:
        raise ValueError("Reduction logarithm base can be either natural or binary")

    if reduction == "sum":
        loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
    elif reduction == "mean":
        loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
    else:
        raise ValueError("Reduction method can be either sum or mean")

    return loss


def ndcgLoss1_scheme(G, D, *args):
    return (G / D)[:, :, None]


def ndcgLoss2_scheme(G, D, *args):
    pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device)
    delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :])
    deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.))
    deltas.diagonal().zero_()

    return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :])


def lambdaRank_scheme(G, D, *args):
    return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs(G[:, :, None] - G[:, None, :])


def ndcgLoss2PP_scheme(G, D, *args):
    return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D)


def rankNet_scheme(G, D, *args):
    return 1.


def rankNetWeightedByGTDiff_scheme(G, D, *args):
    return torch.abs(args[1][:, :, None] - args[1][:, None, :])


def rankNetWeightedByGTDiffPowed_scheme(G, D, *args):
    return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2))

In [None]:
# from torchmetrics.functional import retrieval_normalized_dcg
# preds = torch.tensor([[.1, .2, .3, 4, 70], [.1, .2, .3, 4, 2]])
# target = torch.tensor([[10, 0, 8, 1, 5], [10, 0, 0, 1, 5]])
# # preds = torch.tensor([[.1, .2, .3, 4, 70]])
# # target = torch.tensor([[10, 0, 0, 1, 6]])

# print(retrieval_normalized_dcg(preds[0], target[0], k=3))
# print(ndcg_score(target[0].reshape(1,-1), preds[0].reshape(1,-1), k=3))
# print(retrieval_normalized_dcg(preds[1], target[1], k=3))
# print(ndcg_score(target[1].reshape(1,-1), preds[1].reshape(1,-1), k=3))
# print(retrieval_normalized_dcg(preds, target, k=3))
# print(ndcg_score(target, preds, k=3))

# ndcg = RetrievalNormalizedDCG(k=3)
# ndcg(preds, target, indexes=torch.vstack([torch.arange(preds.shape[0]) for _ in range(preds.shape[1])]).T)