## Load

In [1]:
import os
from collections import defaultdict
import math
import numpy as np 
import random
import re
import torch
import torch.nn as nn
from itertools import cycle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from tqdm.auto import tqdm

# Used to get the data
from sklearn.metrics import ndcg_score

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import nltk
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
nltk.download('stopwords')

import matplotlib.pyplot as plt 
import matplotlib
matplotlib.use('Agg')

seed = 33
import pandas as pd

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/chrisliu/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


## Preprocess config

In [2]:
config = {}

config["dataset"] = "CNN" # "IMDB" "CNN", "PubMed"
config["n_document"] = 10000000
config["normalize_word_embedding"] = True
config["min_word_freq_threshold"] = 20
config["topk_word_freq_threshold"] = 100
config["document_vector_agg_weight"] = 'IDF' # ['mean', 'IDF', 'uniform', 'gaussian', 'exponential', 'pmi']
config["document_vector_weight_normalize"] = False # weighted sum or mean, True for sum, False for mean 
config["select_topk_TFIDF"] = None # ignore
config["embedding_file"] = "../data/glove.6B.100d.txt"
config["topk"] = [10, 30, 50]


In [3]:
def in_notebook():
    try:
        from IPython import get_ipython
        if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
            return False
    except ImportError:
        return False
    return True

In [4]:
def load_word2emb(embedding_file):
    word2embedding = dict()
    word_dim = int(re.findall(r".(\d+)d", embedding_file)[0])

    with open(embedding_file, "r") as f:
        for line in tqdm(f):
            line = line.strip().split()
            word = line[0]
            embedding = list(map(float, line[1:]))
            word2embedding[word] = np.array(embedding)

    print("Number of words:%d" % len(word2embedding))

    return word2embedding

word2embedding = load_word2emb(config["embedding_file"])

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Number of words:400000


In [5]:
def normalize_wordemb(word2embedding):
    # Every word emb should have norm 1
    
    word_emb = []
    word_list = []
    for word, emb in word2embedding.items():
        word_list.append(word)
        word_emb.append(emb)

    word_emb = np.array(word_emb)

    for i in range(len(word_emb)):
        norm = np.linalg.norm(word_emb[i])
        word_emb[i] = word_emb[i] / norm

    for word, emb in tqdm(zip(word_list, word_emb)):
        word2embedding[word] = emb
    return word2embedding

if config["normalize_word_embedding"]:
    normalize_wordemb(word2embedding)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [6]:
class Vocabulary:
    def __init__(self, word2embedding, config):
        # The low frequency words will be assigned as <UNK> token
        self.itos = {0: "<UNK>"}
        self.stoi = {"<UNK>": 0}
        
        self.word2embedding = word2embedding
        self.config = config

        self.word_freq_in_corpus = defaultdict(int)
        self.IDF = {}
        self.ps = PorterStemmer()
        self.stop_words = set(stopwords.words('english'))
        
        self.word_dim = len(word2embedding['the'])
    def __len__(self):
        return len(self.itos)

    def tokenizer_eng(self, text):
        text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
        text = text.strip().split()
        
        return [self.ps.stem(w) for w in text if w.lower() not in self.stop_words]
    
    def read_raw(self):        
        if self.config["dataset"] == 'IMDB':
            data_file_path = '../data/IMDB.txt'
        elif self.config["dataset"] == 'CNN':
            data_file_path = '../data/CNN.txt'
        elif self.config["dataset"] == 'PubMed':
            data_file_path = '../data/PubMed.txt'
        
        # raw documents
        self.raw_documents = []
        with open(data_file_path,'r',encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading documents"):
                self.raw_documents.append(line.strip("\n"))
                
        return self.raw_documents
    
    def build_vocabulary(self):
        sentence_list = self.raw_documents
        
        self.doc_freq = defaultdict(int) # # of document a word appear
        self.document_num = len(sentence_list)
        self.word_vectors = [[0]*self.word_dim] # unknown word emb
        
        for sentence in tqdm(sentence_list, desc="Preprocessing documents"):
            # for doc_freq
            document_words = set()
            
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.word2embedding:
                    continue
                    
                # calculate word freq
                self.word_freq_in_corpus[word] += 1
                document_words.add(word)
                
            for word in document_words:
                self.doc_freq[word] += 1
        
        # calculate IDF
        print('doc num', self.document_num)
        for word, freq in self.doc_freq.items():
            self.IDF[word] = math.log(self.document_num / (freq+1))
        
        # delete less freq words:
        delete_words = []
        for word, v in self.word_freq_in_corpus.items():
            if v < self.config["min_word_freq_threshold"]:
                delete_words.append(word)     
        for word in delete_words:
            del self.IDF[word]    
            del self.word_freq_in_corpus[word]    
        
        # delete too freq words
        print('eliminate freq words')
        IDF = [(word, freq) for word, freq in self.IDF.items()]
        IDF.sort(key=lambda x: x[1])

        for i in range(self.config["topk_word_freq_threshold"]):
            print(word)
            word = IDF[i][0]
            del self.IDF[word]
            del self.word_freq_in_corpus[word]
        
        # construct word_vectors
        idx = 1
        for word in self.word_freq_in_corpus:
            self.word_vectors.append(self.word2embedding[word])
            self.stoi[word] = idx
            self.itos[idx] = word
            idx += 1
            
    def init_word_weight(self,sentence_list, agg):
        if agg == 'mean':
            self.word_weight = {word: 1 for word in self.IDF.keys()}
        elif agg == 'IDF':
            self.word_weight = self.IDF
        elif agg == 'uniform':
            self.word_weight = {word: np.random.uniform(low=0.0, high=1.0) for word in self.IDF.keys()}
        elif agg == 'gaussian':
            mu, sigma = 10, 1 # mean and standard deviation
            self.word_weight = {word: np.random.normal(mu, sigma) for word in self.IDF.keys()}
        elif agg == 'exponential':
            self.word_weight = {word: np.random.exponential(scale=1.0) for word in self.IDF.keys()}
        elif agg == 'pmi':
            trigram_measures = BigramAssocMeasures()
            self.word_weight = defaultdict(int)
            corpus = []

            for text in tqdm(sentence_list):
                corpus.extend(text.split())

            finder = BigramCollocationFinder.from_words(corpus)
            for pmi_score in finder.score_ngrams(trigram_measures.pmi):
                pair, score = pmi_score
                self.word_weight[pair[0]] += score
                self.word_weight[pair[1]] += score
                
    def calculate_document_vector(self):
        # Return
        # document_vectors: weighted sum of word emb
        # document_answers_idx: doc to word index list
        # document_answers_wsum: word weight summation, e.g. total TFIDF score of a doc
        
        document_vectors = [] 
        document_answers = []
        document_answers_wsum = []
        
        sentence_list = self.raw_documents
        agg = self.config["document_vector_agg_weight"]
        n_document = self.config["n_document"]
        select_topk_TFIDF = self.config["select_topk_TFIDF"]
        
        self.init_word_weight(sentence_list, agg)
        for sentence in tqdm(sentence_list[:min(n_document, len(sentence_list))], desc="calculate document vectors"):
            document_vector = np.zeros(len(self.word_vectors[0]))
            select_words = []
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.stoi:
                    continue
                else:
                    select_words.append(word)

            # select topk TDIDF
            if select_topk_TFIDF is not None:
                doc_TFIDF = defaultdict(float)
                for word in select_words:    
                    doc_TFIDF[word] += self.IDF[word]

                doc_TFIDF_l = [(word, TFIDF) for word, TFIDF in doc_TFIDF.items()]
                doc_TFIDF_l.sort(key=lambda x:x[1], reverse=True)
                
                select_topk_words = set(list(map(lambda x:x[0], doc_TFIDF_l[:select_topk_TFIDF])))
                select_words = [word for word in select_words if word in select_topk_words]
            else:
                pass
            
            total_weight = 0
            # aggregate to doc vectors
            for word in select_words:
                document_vector += np.array(self.word2embedding[word]) * self.word_weight[word]
                total_weight += self.word_weight[word]
                
            if len(select_words) == 0:
                print('error', sentence)
                continue
            else:
                if self.config["document_vector_weight_normalize"]:
                    document_vector /= total_weight
            
            document_vectors.append(document_vector)
            document_answers.append(select_words)
            document_answers_wsum.append(total_weight)
        
        # get answers
        document_answers_idx = []    
        for ans in document_answers:
            ans_idx = []
            for token in ans:
                if token in self.stoi:
                    ans_idx.append(self.stoi[token])                    
            document_answers_idx.append(ans_idx)
        
        self.document_vectors = document_vectors
        self.document_answers_idx = document_answers_idx
        self.document_answers_wsum = document_answers_wsum
        
        return document_vectors, document_answers_idx, document_answers_wsum
        
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
    
    def check_docemb(self):
        word_vectors = np.array(self.word_vectors)
        pred = np.zeros(word_vectors.shape[1])
        cnt = 0

        for word_idx in self.document_answers_idx[0]:
            pred += word_vectors[word_idx] * self.word_weight[self.itos[word_idx]]
            cnt += self.word_weight[self.itos[word_idx]]
        
        if self.config["document_vector_weight_normalize"]:
            pred /= cnt
        assert np.sum(self.document_vectors[0]) - np.sum(pred) == 0

In [7]:
def build_vocab(config, word2embedding):
    # build vocabulary
    vocab = Vocabulary(word2embedding, config)
    vocab.read_raw()
    vocab.build_vocabulary()
    vocab_size = len(vocab)
    # get doc emb
    vocab.calculate_document_vector()
    vocab.check_docemb()
    
    return vocab

vocab = build_vocab(config, word2embedding)

HBox(children=(IntProgress(value=1, bar_style='info', description='Loading documents', max=1, style=ProgressSt…




HBox(children=(IntProgress(value=0, description='Preprocessing documents', max=19026, style=ProgressStyle(desc…


doc num 19026
eliminate freq words
paroxysm
subject
line
organ
write
univers
one
would
use
like
get
know
dont
think
time
make
also
say
go
im
could
want
new
work
good
well
way
need
look
even
anyon
thing
see
tri
thank
much
year
world
system
right
problem
may
take
mani
two
first
seem
question
pleas
1
state
us
come
2
post
help
call
usa
point
sinc
find
read
still
back
mean
ive
give
email
sure
differ
might
run
cant
reason
last
day
interest
case
let
person
said
never
start
doesnt
tell
better
ask
got
without
follow
part
lot
3
number
put
fact
gener
inform
actual
that


HBox(children=(IntProgress(value=0, description='calculate document vectors', max=19026, style=ProgressStyle(d…

error |> 
error 
error 
error Mikael Fredriksson
error 
error -------------------------------------------------
error email: mikael_fredriksson@macexchange.se
error 
error FIDO 2:203/211
error  
error  
error   
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error    
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error  
error    
error  
error  
error  
error  
error  
error  
error  
error  



In [8]:
print("Finish building dataset!")
print(f"Number of documents:{len(vocab.raw_documents)}")
print(f"Number of words:{len(vocab)}")

l = list(map(len, vocab.document_answers_idx))
print("Average length of document:", np.mean(l))

Finish building dataset!
Number of documents:19026
Number of words:7602
Average length of document: 84.86410175216382


In [9]:
word_vectors = np.array(vocab.word_vectors)
print(word_vectors.shape)

document_vectors = np.array(vocab.document_vectors)
print(document_vectors.shape)

document_answers_wsum = np.array(vocab.document_answers_wsum).reshape(-1, 1)
print(document_answers_wsum.shape)

# create weight_ans
document_answers_idx = vocab.document_answers_idx

# random shuffle
shuffle_idx = list(range(len(document_vectors)))
random.Random(seed).shuffle(shuffle_idx)

document_vectors = document_vectors[shuffle_idx]
document_answers_wsum = document_answers_wsum[shuffle_idx]
document_answers_idx = [document_answers_idx[idx] for idx in shuffle_idx]

(7602, 100)
(18948, 100)
(18948, 1)


In [10]:
# onthot_ans: word freq matrix
# weight_ans: TFIDF matrix

onehot_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
weight_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
print(weight_ans.shape)

for i in tqdm(range(len(document_answers_idx))):
    for word_idx in document_answers_idx[i]:
        weight_ans[i, word_idx] += vocab.word_weight[vocab.itos[word_idx]]
        onehot_ans[i, word_idx] += 1

(18948, 7602)


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))




## Results

In [11]:
final_results = []
select_columns = ['model']
for topk in config["topk"]:
    select_columns.append('F1@{}'.format(topk))
for topk in config["topk"]:
    select_columns.append('ndcg@{}'.format(topk))
select_columns.append('ndcg@all')
select_columns

['model',
 'F1@10',
 'F1@30',
 'F1@50',
 'ndcg@10',
 'ndcg@30',
 'ndcg@50',
 'ndcg@all']

## setting training size

In [12]:
train_size_ratio = 1
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

18948

## Top K freq word

In [13]:
topk_results = {}

In [14]:
test_ans = document_answers_idx[:train_size]

In [15]:
word_freq = [(word, freq) for word, freq in vocab.word_freq_in_corpus.items()]
word_freq.sort(key=lambda x:x[1], reverse=True)
word_freq[:10]

[('x', 6539),
 ('god', 5208),
 ('file', 4918),
 ('0', 4520),
 ('window', 4444),
 ('program', 4201),
 ('drive', 3633),
 ('4', 3528),
 ('game', 3474),
 ('govern', 3268)]

In [16]:
def topk_word_evaluation(k=50):
    topk_word = [word for (word, freq) in word_freq[:k]]

    pr, re = [], []
    for ans in tqdm(test_ans):
        ans = set(ans)
        ans = [vocab.itos[a] for a in ans]

        hit = []
        for word in ans:
            if word in topk_word:
                hit.append(word)

        precision = len(hit) / k
        recall = len(hit) / len(ans)
        pr.append(precision)
        re.append(recall)

    pr = np.mean(pr)
    re = np.mean(re)
    f1 = 2 * pr * re / (pr + re) if (pr + re) != 0 else 0
    print('top {} word'.format(k))
    print('percision', np.mean(pr))
    print('recall', np.mean(re))
    print('F1', f1)
    return f1


for topk in config['topk']:
    topk_results["F1@{}".format(topk)] = topk_word_evaluation(k=topk)


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 10 word
percision 0.06950073886426009
recall 0.016353539716943234
F1 0.026477028568787433


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 30 word
percision 0.07639680529167546
recall 0.04850870113830837
F1 0.059339414277828975


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 50 word
percision 0.07682921680388433
recall 0.07978013464427035
F1 0.07827687433156728


In [17]:
def topk_word_evaluation_NDCG(k=50):
    freq_word =[word for (word, freq) in word_freq]
    freq_word_idx = [vocab.stoi[word] for word in freq_word if word in vocab.stoi]
    
    scores = np.zeros(len(vocab.word_vectors))
    for rank, idx in enumerate(freq_word_idx):
        scores[idx] = len(vocab.word_vectors) - rank
    
    NDCGs = []
    
    for ans in tqdm(test_ans):
        weight_ans = np.zeros(len(vocab.word_vectors))
        
        for word_idx in ans:
            if word_idx == 0:
                continue
            word = vocab.itos[word_idx]
            weight_ans[word_idx] += vocab.IDF[word]

        NDCG_score = ndcg_score(weight_ans.reshape(1,-1), scores.reshape(1,-1), k=k)
        NDCGs.append(NDCG_score)

    print('top {} NDCG:{}'.format(k, np.mean(NDCGs)))
    
    return np.mean(NDCGs)


for topk in config['topk']:
    topk_results["ndcg@{}".format(topk)] = topk_word_evaluation_NDCG(k=topk)
    
topk_results["ndcg@all"] = topk_word_evaluation_NDCG(k=None)


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 10 NDCG:0.028211750096691315


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 30 NDCG:0.03761860810723896


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top 50 NDCG:0.04587430885683149


HBox(children=(IntProgress(value=0, max=18948), HTML(value='')))


top None NDCG:0.31437606429732357


In [18]:
topk_results["model"] = "topk"
final_results.append(pd.Series(topk_results))

## MLP Decoder

In [19]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [20]:
class MLPDecoderDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans,
                 topk=50):
        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.weight_ans_s[:, topk:] = -1
        
        assert len(doc_vectors) == len(weight_ans)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [21]:
batch_size = 100

train_size_ratio = 0.9
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

print('train size', train_size)
print('valid size', len(document_vectors) - train_size)

train_dataset = MLPDecoderDataset(document_vectors[:train_size], weight_ans[:train_size], topk=50)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

valid_dataset = MLPDecoderDataset(document_vectors[train_size:], weight_ans[train_size:], topk=50)
valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

train size 17053
valid size 1895


In [22]:
class MLPDecoder(nn.Module):
    def __init__(self, doc_emb_dim, num_words, h_dim=300):
        super().__init__()
        
        self.fc1 = nn.Linear(doc_emb_dim, h_dim) 
#         self.fc2 = nn.Linear(h_dim, h_dim)
#         self.fc3 = nn.Linear(h_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, num_words)
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = F.tanh(self.fc1(x))
#         x = F.tanh(self.fc2(x))
#         x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.fc4(x)
        
        return x

In [23]:
def _dcg(target):
    """Computes Discounted Cumulative Gain for input tensor."""
    denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)
    return (target / denom).sum(dim=-1)

def retrieval_normalized_dcg_all(preds, target, k=None):
    """Computes `Normalized Discounted Cumulative Gain`_ (for information retrieval).
    ``preds`` and ``target`` should be of the same shape and live on the same device.
    ``target`` must be either `bool` or `integers` and ``preds`` must be `float`,
    otherwise an error is raised.
    Args:
        preds: estimated probabilities of each document to be relevant.
        target: ground truth about each document relevance.
        k: consider only the top k elements (default: None, which considers them all)
    Return:
        a single-value tensor with the nDCG of the predictions ``preds`` w.r.t. the labels ``target``.
    Raises:
        ValueError:
            If ``k`` parameter is not `None` or an integer larger than 0
    Example:
        >>> from torchmetrics.functional import retrieval_normalized_dcg
        >>> preds = torch.tensor([.1, .2, .3, 4, 70])
        >>> target = torch.tensor([10, 0, 0, 1, 5])
        >>> retrieval_normalized_dcg(preds, target)
        tensor(0.6957)
    """
    k = [preds.shape[-1]] if k is None else k + [preds.shape[-1]]
    
    assert preds.shape == target.shape and max(k) <= preds.shape[-1]
    
    if not isinstance(k, list):
        raise ValueError("`k` has to be a list of positive integer or None")

    sorted_target = target.gather(1, torch.argsort(preds, dim=-1, descending=True))
    ideal_target = torch.sort(target, descending=True)[0]
    
    ndcg_scores = {}
    for topk in k:
        sorted_target_k = sorted_target[:,:topk]
        ideal_target_k = ideal_target[:,:topk]
        
        ideal_dcg_k = _dcg(ideal_target_k)
        target_dcg_k = _dcg(sorted_target_k)
        
        # filter undefined scores
        target_dcg_k /= ideal_dcg_k
        
        if topk == preds.shape[-1]:
            topk = 'all'
        ndcg_scores[topk] = target_dcg_k.mean().item()
        
    return ndcg_scores

def retrieval_precision_all(preds, target, k = [10]):
        
    assert preds.shape == target.shape and max(k) <= preds.shape[-1]
    
    if not isinstance(k, list):
        raise ValueError("`k` has to be a list of positive integer")
        
    precision_scores = {}
    target_onehot = target > 0
    
    for topk in k:
        relevant = target_onehot.gather(1, preds.topk(topk, dim=-1)[1])
        relevant = relevant.sum(axis=1).float()
        relevant /= topk    
        precision_scores[topk] = relevant.mean().item()
    
    return precision_scores

In [24]:
import timeit
from torchmetrics.functional import retrieval_normalized_dcg, retrieval_precision
from sklearn.metrics import ndcg_score

def evaluate_MLPDecoder(model, data_loader):
    results = {}
    model.eval()
    
    pred_all = []
    target_all = []
    
    # predict all data
    start = timeit.default_timer()
    for data in data_loader:
        doc_embs, target, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
        pred_all.append(pred)
        target_all.append(target)
        
    pred_all = torch.cat(pred_all, dim=0)
    target_all = torch.cat(target_all, dim=0)
    stop1 = timeit.default_timer()
    print('Time1: ', stop1 - start)  
    
    # Precision
    precision_scores = retrieval_precision_all(pred_all, target_all, k=config["topk"])
    for k, v in precision_scores.items():
        results['precision@{}'.format(k)] = v
    
    stop2 = timeit.default_timer()
    print('Time2: ', stop2 - stop1)  
    
    # NDCG
    ndcg_scores = retrieval_normalized_dcg_all(pred_all, target_all, k=config["topk"])
    for k, v in ndcg_scores.items():
        results['ndcg@{}'.format(k)] = v
        
    stop3 = timeit.default_timer()
    print('Time3: ', stop3 - stop2)      
    
    return results

In [25]:
# setting
lr = 0.05
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 5000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        # Model backwarding
        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)
            


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))



Epoch 0 13.793183956927026 9.811852103785464
Time1:  0.5027011595666409
Time2:  0.05020587891340256
Time3:  0.16641565039753914
Time1:  0.0545343067497015
Time2:  0.00821242667734623
Time3:  0.018884481862187386

train {'precision@10': 0.12872809171676636, 'precision@30': 0.1125139370560646, 'precision@50': 0.11361167579889297, 'ndcg@10': 0.0913453921675682, 'ndcg@30': 0.10261578112840652, 'ndcg@50': 0.12063910067081451, 'ndcg@all': 0.3932613730430603}
valid {'precision@10': 0.12844327092170715, 'precision@30': 0.11134565621614456, 'precision@50': 0.1125699058175087, 'ndcg@10': 0.08903079479932785, 'ndcg@30': 0.09877375513315201, 'ndcg@50': 0.11589748412370682, 'ndcg@all': 0.38783103227615356}
Epoch 1 7.630794572551348 8.119166223626388
Epoch 2 5.741322891056886 7.372438832333214
Epoch 3 4.70718807226036 7.003115578701622
Epoch 4 4.040690806874058 6.7176112626728255
Epoch 5 3.56652854060569 6.533565847497237
Epoch 6 3.208136084484078 6.433688464917634
Epoch 7 2.9337398266931722 6.33289

Epoch 136 0.4292029172007801 5.65837456050672
Epoch 137 0.4282070572613276 5.65800696925113
Epoch 138 0.42720581059567414 5.6559469574376156
Epoch 139 0.42479992121980903 5.6543067630968595
Epoch 140 0.4231208271799032 5.653824831310072
Epoch 141 0.4224900303528323 5.66227388381958
Epoch 142 0.41994188874088534 5.658804316269724
Epoch 143 0.4174054937753064 5.663816100672672
Epoch 144 0.41673597432019416 5.66397732182553
Epoch 145 0.41485506958431667 5.662386392292223
Epoch 146 0.41262536456710414 5.656373827080977
Epoch 147 0.4106323590055544 5.669774457028038
Epoch 148 0.40890299454767104 5.660621542679636
Epoch 149 0.40782701742579364 5.664152998673289
Epoch 150 0.4057483878749156 5.659847083844636
Time1:  0.48781407438218594
Time2:  0.05162840895354748
Time3:  0.16418270207941532
Time1:  0.05432446487247944
Time2:  0.008219212293624878
Time3:  0.018908094614744186

train {'precision@10': 0.549035370349884, 'precision@30': 0.5579506754875183, 'precision@50': 0.5615457892417908, 'ndc

Epoch 268 0.2918160395664081 5.740734777952495
Epoch 269 0.29151924625474807 5.739767099681654
Epoch 270 0.29027844956743787 5.739340430811832
Epoch 271 0.2897665636581287 5.7408826476649235
Epoch 272 0.28847124632339033 5.739184931704872
Epoch 273 0.2889929445166337 5.741115319101434
Epoch 274 0.28884462219232704 5.743186298169587
Epoch 275 0.2881839871406555 5.742859714909604
Epoch 276 0.2868552479827613 5.739415871469598
Epoch 277 0.2861120071676042 5.740544243862755
Epoch 278 0.2865508019227033 5.742963088186164
Epoch 279 0.28664050551882964 5.741399413660953
Epoch 280 0.2853217271336338 5.7383442176015755
Epoch 281 0.2838943107434881 5.750457964445415
Epoch 282 0.2830490241622367 5.748042307401958
Epoch 283 0.2840631930103079 5.747386681406121
Epoch 284 0.28254524770884487 5.7449340318378646
Epoch 285 0.2829003910920773 5.747419633363423
Epoch 286 0.2825185234783686 5.75143605784366
Epoch 287 0.2806964277001152 5.744705576645701
Epoch 288 0.28042446190153647 5.749707046308015
Epoc

Epoch 401 0.23550660132664686 5.820138881081029
Epoch 402 0.23481796464027718 5.819893887168483
Epoch 403 0.2346050953655912 5.8203007296512
Epoch 404 0.23460334757266685 5.812806606292725
Epoch 405 0.23422171044767948 5.820287729564466
Epoch 406 0.23385358353455862 5.823027962132504
Epoch 407 0.2333930906845115 5.825497250807913
Epoch 408 0.23288103240972374 5.825371265411377
Epoch 409 0.2341129078328261 5.8256534275255705
Epoch 410 0.23251494242433915 5.825591012051231
Epoch 411 0.23350451912796288 5.832686850899144
Epoch 412 0.23222543281769892 5.827295629601729
Epoch 413 0.231398525729514 5.82989660062288
Epoch 414 0.23154000606801775 5.826279138263903
Epoch 415 0.23120858559482976 5.832578809637773
Epoch 416 0.23145011529239298 5.8320183753967285
Epoch 417 0.23001703794239559 5.827104869641755
Epoch 418 0.2305875200974314 5.830052175019917
Epoch 419 0.23068695563321923 5.829786928076493
Epoch 420 0.2300387060607386 5.828658329813104
Epoch 421 0.22938153862256055 5.828743231923957


Epoch 546 0.2011494898830938 5.909132982555189
Epoch 547 0.2010974900590049 5.903503342678673
Epoch 548 0.20051061505811257 5.905333368401778
Epoch 549 0.20029538111728534 5.913570178182502
Epoch 550 0.20009820877808576 5.913835425125925
Time1:  0.4871331490576267
Time2:  0.05110298655927181
Time3:  0.1644585896283388
Time1:  0.054303379729390144
Time2:  0.00816640816628933
Time3:  0.01890675164759159

train {'precision@10': 0.6337829232215881, 'precision@30': 0.6621885299682617, 'precision@50': 0.6629484295845032, 'ndcg@10': 0.5297787189483643, 'ndcg@30': 0.6213610768318176, 'ndcg@50': 0.7080297470092773, 'ndcg@all': 0.7459125518798828}
valid {'precision@10': 0.4357256293296814, 'precision@30': 0.32636764645576477, 'precision@50': 0.2756516933441162, 'ndcg@10': 0.403056263923645, 'ndcg@30': 0.39115017652511597, 'ndcg@50': 0.40438133478164673, 'ndcg@all': 0.5967994928359985}
Epoch 551 0.19962088092725877 5.901751493152819
Epoch 552 0.19979362590619695 5.906456395199425
Epoch 553 0.1993

Epoch 677 0.18059594502225954 5.97379920357152
Epoch 678 0.18052472296165445 5.966741110149183
Epoch 679 0.18044614722157082 5.968415410895097
Epoch 680 0.1798765083328325 5.970431227433054
Epoch 681 0.17953013972929346 5.96610719279239
Epoch 682 0.17930160549997587 5.97498715551276
Epoch 683 0.18021289655688214 5.970082709663792
Epoch 684 0.17967763987549565 5.97006873080605
Epoch 685 0.18015936095463603 5.9750689707304305
Epoch 686 0.1788641583849812 5.980073878639622
Epoch 687 0.1787409807680643 5.977621078491211
Epoch 688 0.17865454235620665 5.973803369622481
Epoch 689 0.17897522702203159 5.973914397390265
Epoch 690 0.177978864451598 5.97963285446167
Epoch 691 0.17954197316838985 5.97973098252949
Epoch 692 0.17822053235525276 5.982519551327354
Epoch 693 0.17834272832549802 5.973575165397243
Epoch 694 0.17901882407260916 5.978704979545192
Epoch 695 0.17792882548089614 5.976785559403269
Epoch 696 0.17790056965504472 5.984351308722245
Epoch 697 0.17830472287030247 5.977495268771523
Ep

Epoch 808 0.1652661850403624 6.042579650878906
Epoch 809 0.16535809139410654 6.0284546550951505
Epoch 810 0.1651827851582689 6.03145328320955
Epoch 811 0.16528210558040798 6.030243045405338
Epoch 812 0.16493761870596144 6.03867214604428
Epoch 813 0.16537347697375113 6.032085167734246
Epoch 814 0.1653787313323272 6.033125249963057
Epoch 815 0.16487865507254126 6.0359549271432975
Epoch 816 0.1649005331142604 6.040178725593968
Epoch 817 0.16432141961410032 6.033740219316985
Epoch 818 0.16445980826665085 6.033116039476897
Epoch 819 0.16445465316200814 6.0385592611212475
Epoch 820 0.16403356081212472 6.0327999215377
Epoch 821 0.1637124114217814 6.038586315355803
Epoch 822 0.16372724008141903 6.040528949938323
Epoch 823 0.16425691065732498 6.04152963035985
Epoch 824 0.16356331406281008 6.04493502566689
Epoch 825 0.16446924654015324 6.0443651550694515
Epoch 826 0.1630555357326541 6.044811298972682
Epoch 827 0.16391851269362265 6.04040893755461
Epoch 828 0.16369886164776762 6.060615087810316
E

Epoch 951 0.15278294063799563 6.104398426256682
Epoch 952 0.1523261931207445 6.098376625462582
Epoch 953 0.1527002981008842 6.0976865166111995
Epoch 954 0.15311455020779058 6.087757110595703
Epoch 955 0.15239752349797744 6.097585301650198
Epoch 956 0.15260893847161566 6.097335991106536
Epoch 957 0.1521275902700703 6.091224193572998
Epoch 958 0.15266523662714931 6.095336236451802
Epoch 959 0.15264774214105997 6.0948815847698015
Epoch 960 0.15295581624173282 6.0903244269521615
Epoch 961 0.1528252280420727 6.1158725086011385
Epoch 962 0.15224078602609578 6.099996918126156
Epoch 963 0.15195065518917397 6.100025578549034
Epoch 964 0.15143061485904002 6.101496721568861
Epoch 965 0.15210567597757307 6.1069999494050675
Epoch 966 0.1515413454750128 6.101678195752595
Epoch 967 0.15144425273290155 6.1024876142802995
Epoch 968 0.1513831293896625 6.10139472861039
Epoch 969 0.1515801840009745 6.097827961570339
Epoch 970 0.15125836109557347 6.1057104562458235
Epoch 971 0.151187673298239 6.09949656536

Epoch 1094 0.14324159397367844 6.155201786442807
Epoch 1095 0.1433908949382821 6.164302047930266
Epoch 1096 0.14312918013647982 6.154518152538099
Epoch 1097 0.14321727793641956 6.150161165940134
Epoch 1098 0.14263663558583511 6.154587544892964
Epoch 1099 0.1426046696322703 6.152897634004292
Epoch 1100 0.14311536763146607 6.149674189718146
Time1:  0.48792089708149433
Time2:  0.05090300925076008
Time3:  0.16409999504685402
Time1:  0.05438375286757946
Time2:  0.008204026147723198
Time3:  0.018886035308241844

train {'precision@10': 0.617169976234436, 'precision@30': 0.6679333448410034, 'precision@50': 0.6810320615768433, 'ndcg@10': 0.5149565935134888, 'ndcg@30': 0.6214480996131897, 'ndcg@50': 0.7148897647857666, 'ndcg@all': 0.7416992783546448}
valid {'precision@10': 0.4184168875217438, 'precision@30': 0.3200879395008087, 'precision@50': 0.27382582426071167, 'ndcg@10': 0.38814017176628113, 'ndcg@30': 0.38147884607315063, 'ndcg@50': 0.39694786071777344, 'ndcg@all': 0.5894569754600525}
Epoch

Epoch 1222 0.13582512379041192 6.199155656914962
Epoch 1223 0.13516819359440552 6.205816871241519
Epoch 1224 0.13555247746190133 6.208217169109144
Epoch 1225 0.13575042008656507 6.21362859324405
Epoch 1226 0.1358044117427709 6.208360295546682
Epoch 1227 0.1354921579970951 6.202748072774787
Epoch 1228 0.13532188835374095 6.201494969819722
Epoch 1229 0.13567149246993818 6.200528646770277
Epoch 1230 0.13503261773209824 6.207676059321353
Epoch 1231 0.13485549638668695 6.2032656167682845
Epoch 1232 0.13487081844032855 6.208439525805022
Epoch 1233 0.13540128723048328 6.2076804261458545
Epoch 1234 0.1358253321888154 6.211378724951493
Epoch 1235 0.13528634490151153 6.2011593015570385
Epoch 1236 0.13481312876904916 6.2060371449119165
Epoch 1237 0.13499177857274897 6.212559649818822
Epoch 1238 0.1351552487988221 6.2023399503607495
Epoch 1239 0.13450359635882908 6.208979732111881
Epoch 1240 0.1346157905128267 6.209242394095973
Epoch 1241 0.13459848987254483 6.207298554872212
Epoch 1242 0.13504822

Epoch 1351 0.13003203695454793 6.2456204263787525
Epoch 1352 0.12958424111381608 6.2547023421839665
Epoch 1353 0.12906627574859306 6.257042558569657
Epoch 1354 0.12901664464271556 6.255743252603631
Epoch 1355 0.12936417205117598 6.251935557315224
Epoch 1356 0.12996215133639108 6.2564944217079566
Epoch 1357 0.1292763904869905 6.252857057671798
Epoch 1358 0.12838760706764912 6.255588983234606
Epoch 1359 0.12957214877793663 6.252112438804225
Epoch 1360 0.12899645714209093 6.252381073801141
Epoch 1361 0.12894736539724974 6.251436584874203
Epoch 1362 0.12921952996511907 6.261046585283782
Epoch 1363 0.12820428213355137 6.253853346172132
Epoch 1364 0.12873141297645735 6.252940102627403
Epoch 1365 0.129333845661049 6.254155761317203
Epoch 1366 0.1289472540306766 6.255042653334768
Epoch 1367 0.12915345067866366 6.261382755480315
Epoch 1368 0.12976264844686664 6.290702970404374
Epoch 1369 0.12916708624328088 6.256226941158897
Epoch 1370 0.12978161623080572 6.252899847532573
Epoch 1371 0.12892570

Epoch 1493 0.12396980717516783 6.299332844583612
Epoch 1494 0.12355025949185355 6.308816985080116
Epoch 1495 0.12341050633735824 6.296660573858964
Epoch 1496 0.1231989205667847 6.305860996246338
Epoch 1497 0.12358520624407551 6.307034693266216
Epoch 1498 0.12354286899517851 6.300889291261372
Epoch 1499 0.12398566148783031 6.308823183963173
Epoch 1500 0.12341251516202736 6.30682006635164
Time1:  0.48806686513125896
Time2:  0.05068837106227875
Time3:  0.16421656496822834
Time1:  0.05450800806283951
Time2:  0.008211132138967514
Time3:  0.018962277099490166

train {'precision@10': 0.6478977203369141, 'precision@30': 0.69218909740448, 'precision@50': 0.6942262053489685, 'ndcg@10': 0.5367221236228943, 'ndcg@30': 0.6439714431762695, 'ndcg@50': 0.7333750128746033, 'ndcg@all': 0.7531360983848572}
valid {'precision@10': 0.42986807227134705, 'precision@30': 0.3272647559642792, 'precision@50': 0.2773192822933197, 'ndcg@10': 0.39738035202026367, 'ndcg@30': 0.38996976613998413, 'ndcg@50': 0.40425384

Epoch 1621 0.11902409495666014 6.348368117683812
Epoch 1622 0.11948833279093804 6.35045239799901
Epoch 1623 0.11919334543901577 6.345412128850033
Epoch 1624 0.11940842605473702 6.344491632361161
Epoch 1625 0.11883658314483207 6.349827540548224
Epoch 1626 0.11951891729357647 6.346530914306641
Epoch 1627 0.11839026256262908 6.339624279423764
Epoch 1628 0.11960166400810432 6.34820042158428
Epoch 1629 0.11930335487364328 6.347242807087145
Epoch 1630 0.11922640248871687 6.3524700214988306
Epoch 1631 0.11915918931975002 6.353889766492341
Epoch 1632 0.11927374783489439 6.349397558914988
Epoch 1633 0.11822700644271415 6.359099839863024
Epoch 1634 0.11838640337973311 6.360484725550601
Epoch 1635 0.11872171392740562 6.355201670998021
Epoch 1636 0.1188894140354374 6.351198346991288
Epoch 1637 0.11883284347622018 6.350670663934005
Epoch 1638 0.11789902931416942 6.350819813577752
Epoch 1639 0.11820550201929103 6.346531064886796
Epoch 1640 0.11875767326145842 6.3505323058680485
Epoch 1641 0.11837321

Epoch 1751 0.11508784552066646 6.389854983279579
Epoch 1752 0.11546003866439675 6.390398000416003
Epoch 1753 0.11461744775549013 6.387274164902537
Epoch 1754 0.11547226376003689 6.390908090691817
Epoch 1755 0.11510113321723994 6.388809028424714
Epoch 1756 0.11506164749289116 6.387948412644236
Epoch 1757 0.11481057957076189 6.388876287560714
Epoch 1758 0.115208349991263 6.389404698422081
Epoch 1759 0.11439811643104107 6.392311497738487
Epoch 1760 0.11497514233080267 6.3870137365240796
Epoch 1761 0.11493113323261864 6.391644553134316
Epoch 1762 0.11447551140659734 6.390206813812256
Epoch 1763 0.1151251832121297 6.39363936374062
Epoch 1764 0.11441858099740848 6.391613307752107
Epoch 1765 0.11474305962087118 6.386793462853682
Epoch 1766 0.1143752964355095 6.393971493369655
Epoch 1767 0.11493094823165247 6.392254076505962
Epoch 1768 0.11490220511168764 6.3922161303068465
Epoch 1769 0.11452185715499677 6.3931428005820825
Epoch 1770 0.11475044109842233 6.395209262245579
Epoch 1771 0.114931212

Epoch 1892 0.11153411107105121 6.427857148019891
Epoch 1893 0.11156244715403395 6.439224694904528
Epoch 1894 0.11104254337430697 6.4291074150486995
Epoch 1895 0.11152146695650111 6.429888072766755
Epoch 1896 0.11112736387733828 6.428094964278372
Epoch 1897 0.11171371881899081 6.436765570389597
Epoch 1898 0.11143266610060519 6.432217748541581
Epoch 1899 0.11085939934553458 6.434158827129163
Epoch 1900 0.11125452903627653 6.432562326130114
Time1:  0.48789888247847557
Time2:  0.050532422959804535
Time3:  0.16389675997197628
Time1:  0.05458090826869011
Time2:  0.00821535475552082
Time3:  0.018925847485661507

train {'precision@10': 0.6226177215576172, 'precision@30': 0.6800640821456909, 'precision@50': 0.6932633519172668, 'ndcg@10': 0.5167406797409058, 'ndcg@30': 0.6309431791305542, 'ndcg@50': 0.7249358296394348, 'ndcg@all': 0.7444005608558655}
valid {'precision@10': 0.41408970952033997, 'precision@30': 0.3192436695098877, 'precision@50': 0.27285486459732056, 'ndcg@10': 0.3833693265914917,

Epoch 2020 0.10872115504149107 6.47960100675884
Epoch 2021 0.10800563544034958 6.4711572496514576
Epoch 2022 0.10813809321289174 6.474529216164036
Epoch 2023 0.10806005683384444 6.478775626734683
Epoch 2024 0.10804364446833817 6.479060424001593
Epoch 2025 0.10761960791914087 6.476909010033858
Epoch 2026 0.1082728641510707 6.473943559746993
Epoch 2027 0.1084984237251923 6.474972373560855
Epoch 2028 0.10817600468620223 6.47109664113898
Epoch 2029 0.10811825636883228 6.471841385490016
Epoch 2030 0.10733515939168763 6.469636917114258
Epoch 2031 0.10803347272656814 6.476727084109657
Epoch 2032 0.10777952151688916 6.472081962384675
Epoch 2033 0.10798479281148018 6.476990022157368
Epoch 2034 0.10740975395106432 6.4788634902552555
Epoch 2035 0.10765659090196877 6.484890235097785
Epoch 2036 0.10794977717406569 6.474394271248265
Epoch 2037 0.10784450129807344 6.477341325659501
Epoch 2038 0.10768277259075154 6.4767046978599145
Epoch 2039 0.10744971433403896 6.475915130816008
Epoch 2040 0.10768049

Epoch 2151 0.10501774836179109 6.512912122826827
Epoch 2152 0.10533012188317482 6.508501303823371
Epoch 2153 0.1057466203619165 6.507728526466771
Epoch 2154 0.10475125102794658 6.516601989143773
Epoch 2155 0.10539067706518006 6.509318502325761
Epoch 2156 0.10570492887357522 6.509448553386488
Epoch 2157 0.10525595161475633 6.517148871170847
Epoch 2158 0.1052617021994284 6.510915480161968
Epoch 2159 0.10558015038395485 6.509504017076995
Epoch 2160 0.10539303924779446 6.5148904449061344
Epoch 2161 0.10440539847514783 6.513654633572227
Epoch 2162 0.10509634832715431 6.513159199764854
Epoch 2163 0.1050263299056661 6.522997630269904
Epoch 2164 0.10498623005305117 6.51343930395026
Epoch 2165 0.10475128810656698 6.512347146084434
Epoch 2166 0.10488699119516283 6.508367764322381
Epoch 2167 0.10495465967738837 6.516734374196906
Epoch 2168 0.10482913817752872 6.517443506341231
Epoch 2169 0.10467544546601368 6.509822770168907
Epoch 2170 0.10485728786528459 6.518064498901367
Epoch 2171 0.1051861072

Epoch 2292 0.10225447424148258 6.5491393239874585
Epoch 2293 0.10294382640144281 6.5514812218515495
Epoch 2294 0.10267514513249983 6.546850631111546
Epoch 2295 0.10226684041887696 6.548980336440237
Epoch 2296 0.10301932413675631 6.545513228366249
Epoch 2297 0.10227363151416444 6.550157572093763
Epoch 2298 0.1016057665236512 6.5471119378742415
Epoch 2299 0.10218699664226052 6.548255092219303
Epoch 2300 0.10241636764584926 6.546947454151354
Time1:  0.4882304146885872
Time2:  0.05045063979923725
Time3:  0.16395986825227737
Time1:  0.054344331845641136
Time2:  0.00812956877052784
Time3:  0.018926048651337624

train {'precision@10': 0.6344044804573059, 'precision@30': 0.6894583702087402, 'precision@50': 0.6989526748657227, 'ndcg@10': 0.5238460302352905, 'ndcg@30': 0.6386987566947937, 'ndcg@50': 0.7314682602882385, 'ndcg@all': 0.7480750679969788}
valid {'precision@10': 0.41883906722068787, 'precision@30': 0.32209324836730957, 'precision@50': 0.2744590938091278, 'ndcg@10': 0.38707035779953003

Epoch 2420 0.09965777070375911 6.583598362772088
Epoch 2421 0.09978620063143166 6.585091339914422
Epoch 2422 0.09951265182411462 6.601437844728169
Epoch 2423 0.09988889438018464 6.587978789680882
Epoch 2424 0.09966542136076598 6.588800104040849
Epoch 2425 0.09986182180238747 6.585944928620991
Epoch 2426 0.10017163341331203 6.583787667123895
Epoch 2427 0.09928699129680443 6.5883253750048185
Epoch 2428 0.09943556763798173 6.589062013124165
Epoch 2429 0.10009834688832188 6.594801024386757
Epoch 2430 0.09945310839609793 6.585454288281892
Epoch 2431 0.09967910001675288 6.588284166235673
Epoch 2432 0.09958096262481478 6.587717859368575
Epoch 2433 0.0993409915799983 6.587257711510909
Epoch 2434 0.09993942909770542 6.59045211892379
Epoch 2435 0.0996229310481869 6.583856331674676
Epoch 2436 0.09946925646206092 6.590099184136641
Epoch 2437 0.09954470282758189 6.589400416926334
Epoch 2438 0.09973259413974327 6.589399237381785
Epoch 2439 0.09942417247602117 6.592803879788048
Epoch 2440 0.099743524

Epoch 2551 0.09764389584810414 6.623587106403551
Epoch 2552 0.09734531008360679 6.629624291470177
Epoch 2553 0.09767779302701615 6.618170261383057
Epoch 2554 0.09734793704504158 6.622844118820994
Epoch 2555 0.0980332081603725 6.621036153090627
Epoch 2556 0.0974504186047448 6.620049526816921
Epoch 2557 0.09689086039512478 6.619750248758416
Epoch 2558 0.09772472013855538 6.619012230320981
Epoch 2559 0.09752323652742899 6.624424256776509
Epoch 2560 0.09744817367074085 6.623631577742727
Epoch 2561 0.09757858362288503 6.6276468728718
Epoch 2562 0.097624183790377 6.630044811650326
Epoch 2563 0.09742165191306008 6.628704522785387
Epoch 2564 0.0973533435087455 6.6214032424123666
Epoch 2565 0.09742873841733263 6.622744083404541
Epoch 2566 0.09698694757027933 6.630314099161248
Epoch 2567 0.09732552160296523 6.622450853648939
Epoch 2568 0.09683621195498963 6.642610876183761
Epoch 2569 0.09750575912103318 6.6257571170204566
Epoch 2570 0.09725755792960786 6.626663458974738
Epoch 2571 0.097180525026

Epoch 2693 0.0949769389559651 6.6592996245936344
Epoch 2694 0.09557288976614936 6.663779936338726
Epoch 2695 0.09559735254934656 6.658239715977719
Epoch 2696 0.09520526517901504 6.657645727458753
Epoch 2697 0.09442350682285097 6.6679408675745915
Epoch 2698 0.09527621779874054 6.658059722498844
Epoch 2699 0.09531153525002518 6.659031190370259
Epoch 2700 0.09500599626386375 6.659852153376529
Time1:  0.48800705187022686
Time2:  0.050308508798480034
Time3:  0.16407573968172073
Time1:  0.05439259670674801
Time2:  0.00811847485601902
Time3:  0.018742568790912628

train {'precision@10': 0.6483492255210876, 'precision@30': 0.6988428235054016, 'precision@50': 0.7037647366523743, 'ndcg@10': 0.533232569694519, 'ndcg@30': 0.64768385887146, 'ndcg@50': 0.7388082146644592, 'ndcg@all': 0.7528384327888489}
valid {'precision@10': 0.42348283529281616, 'precision@30': 0.32452067732810974, 'precision@50': 0.27586281299591064, 'ndcg@10': 0.39023515582084656, 'ndcg@30': 0.38533148169517517, 'ndcg@50': 0.4002

Epoch 2821 0.09360033718117497 6.690753434833727
Epoch 2822 0.09357914036651802 6.687015985187731
Epoch 2823 0.09359575511767851 6.689244220131322
Epoch 2824 0.09331140566987602 6.692051962802284
Epoch 2825 0.09311907716661866 6.688623001700954
Epoch 2826 0.09277582225220943 6.687707850807591
Epoch 2827 0.09354842685119451 6.687484063600239
Epoch 2828 0.09309630470666272 6.690291053370426
Epoch 2829 0.09310007666112387 6.688182655133699
Epoch 2830 0.0933548591551725 6.692067196494655
Epoch 2831 0.09274948571334805 6.689076022097939
Epoch 2832 0.09321822548470302 6.689942234440854
Epoch 2833 0.09338867002063328 6.685310915896767
Epoch 2834 0.09376933495377937 6.68724270870811
Epoch 2835 0.0934571361071185 6.6879878295095345
Epoch 2836 0.09361565927838722 6.692419077220716
Epoch 2837 0.09327445719500034 6.6890319021124585
Epoch 2838 0.09321193845822798 6.684301326149388
Epoch 2839 0.0929391077823109 6.688423081448204
Epoch 2840 0.09336184710264206 6.69474122398778
Epoch 2841 0.0927633298

Epoch 2951 0.09216449483793382 6.717920453924882
Epoch 2952 0.09162807190104534 6.719230526372006
Epoch 2953 0.09141023744616592 6.724516793301231
Epoch 2954 0.09173147049215105 6.722435298718904
Epoch 2955 0.09112702218586938 6.722242731797068
Epoch 2956 0.09140453269781425 6.726499181044729
Epoch 2957 0.09145224395028331 6.723903405038934
Epoch 2958 0.09132481453537244 6.727524958158794
Epoch 2959 0.09185921949775595 6.721400060151753
Epoch 2960 0.09145273198509773 6.731787003968892
Epoch 2961 0.09201879799365997 6.724028687728079
Epoch 2962 0.09138128671206926 6.725037775541606
Epoch 2963 0.09117027528976139 6.72184796082346
Epoch 2964 0.09174917206952446 6.725612314123857
Epoch 2965 0.09139918675373869 6.723749461926912
Epoch 2966 0.09165192312664455 6.723313908827932
Epoch 2967 0.09177711710595247 6.72242367895026
Epoch 2968 0.0914426606736685 6.727186228099622
Epoch 2969 0.09175926359773379 6.722978366048713
Epoch 2970 0.0914252216530125 6.726050326698704
Epoch 2971 0.09136060616

Epoch 3092 0.09013957190409042 6.752833291103966
Epoch 3093 0.08993043698239744 6.754698301616468
Epoch 3094 0.08977453284270583 6.754577134784899
Epoch 3095 0.08966787974213997 6.755701742674175
Epoch 3096 0.08983551933054339 6.75731867238095
Epoch 3097 0.09029294851056316 6.752124936957109
Epoch 3098 0.09016426569885677 6.755489048204924
Epoch 3099 0.09003304541982406 6.749021530151367
Epoch 3100 0.0896259172443758 6.7577461443449325
Time1:  0.48790457285940647
Time2:  0.050270143896341324
Time3:  0.16383101604878902
Time1:  0.05459956265985966
Time2:  0.008119093254208565
Time3:  0.01880658231675625

train {'precision@10': 0.629842221736908, 'precision@30': 0.6903496980667114, 'precision@50': 0.7025379538536072, 'ndcg@10': 0.5196765661239624, 'ndcg@30': 0.6392518877983093, 'ndcg@50': 0.7330239415168762, 'ndcg@all': 0.7469692230224609}
valid {'precision@10': 0.40891823172569275, 'precision@30': 0.3179771304130554, 'precision@50': 0.272664874792099, 'ndcg@10': 0.38071033358573914, 'nd

Epoch 3220 0.08819873256292957 6.786601342652974
Epoch 3221 0.08848760651740414 6.7828114158228825
Epoch 3222 0.0880533472487801 6.78559486489547
Epoch 3223 0.08821427382049504 6.78239295357152
Epoch 3224 0.08853906881042391 6.783917477256374
Epoch 3225 0.08835018792173319 6.793031391344573
Epoch 3226 0.08767956878706726 6.786232697336297
Epoch 3227 0.08857946646840949 6.784926188619513
Epoch 3228 0.08788016630195038 6.7887239707143685
Epoch 3229 0.08824891257181502 6.786152413016872
Epoch 3230 0.0880074745468926 6.783759393190083
Epoch 3231 0.08812260270467279 6.785074133622019
Epoch 3232 0.08869567243327871 6.793946291271009
Epoch 3233 0.08836430045422057 6.791840779153924
Epoch 3234 0.08841833236970399 6.784364875994231
Epoch 3235 0.08800920765650899 6.786920422001889
Epoch 3236 0.08801431849337461 6.791594580600136
Epoch 3237 0.08775157324577633 6.79121218229595
Epoch 3238 0.08843261765980581 6.787802043714021
Epoch 3239 0.08815871976446688 6.792820629320647
Epoch 3240 0.0887141986

Epoch 3351 0.08708041322510145 6.81220165051912
Epoch 3352 0.08694471985275982 6.815454407742149
Epoch 3353 0.0867005648494464 6.812865056489644
Epoch 3354 0.08707682679445423 6.81275601136057
Epoch 3355 0.08678510885315331 6.815217921608372
Epoch 3356 0.08652366340508935 6.815952426508853
Epoch 3357 0.0869195492644059 6.8131529406497355
Epoch 3358 0.08681543262904151 6.813116048511706
Epoch 3359 0.08702643860501853 6.814633620412726
Epoch 3360 0.08643068704340193 6.819416573173122
Epoch 3361 0.08680912556006895 6.822060233668277
Epoch 3362 0.0868631246232847 6.82215512426276
Epoch 3363 0.08672585730489932 6.825013737929495
Epoch 3364 0.08686845326981349 6.812470360806114
Epoch 3365 0.08672786103180277 6.816369659022281
Epoch 3366 0.08657262801078328 6.816603786066959
Epoch 3367 0.08699511662561293 6.817348931965075
Epoch 3368 0.08653421072583449 6.821032448818809
Epoch 3369 0.08669334122834847 6.814783146506862
Epoch 3370 0.08634496741650398 6.818107303820159
Epoch 3371 0.086415296365

Epoch 3492 0.08556800109078312 6.842887752934506
Epoch 3493 0.0851040982712082 6.848600061316239
Epoch 3494 0.08491895644114031 6.846276157780697
Epoch 3495 0.08560481310239312 6.844491632361161
Epoch 3496 0.08562280093891579 6.848322441703395
Epoch 3497 0.08443178707047512 6.841982841491699
Epoch 3498 0.08517523057628096 6.847090068616365
Epoch 3499 0.0853563936220275 6.846368212448923
Epoch 3500 0.0853277848739373 6.846039671646921
Time1:  0.48812262900173664
Time2:  0.050142738968133926
Time3:  0.16389093548059464
Time1:  0.054349806159734726
Time2:  0.00810421071946621
Time3:  0.01887904480099678

train {'precision@10': 0.6517679691314697, 'precision@30': 0.7028988599777222, 'precision@50': 0.7077922224998474, 'ndcg@10': 0.5353508591651917, 'ndcg@30': 0.6520482301712036, 'ndcg@50': 0.7428076863288879, 'ndcg@all': 0.7543975710868835}
valid {'precision@10': 0.4198944568634033, 'precision@30': 0.3236059844493866, 'precision@50': 0.2746807336807251, 'ndcg@10': 0.3883833587169647, 'ndcg

Epoch 3620 0.08408646978307188 6.874548083857486
Epoch 3621 0.08480047129573878 6.870610839442203
Epoch 3622 0.08440252240986852 6.869061796288741
Epoch 3623 0.08383715640731722 6.882628917694092
Epoch 3624 0.08367898302120075 6.871010027433696
Epoch 3625 0.08430641449508611 6.873443678805702
Epoch 3626 0.08360858331298271 6.871944552973697
Epoch 3627 0.08391924429009533 6.8781937548988745
Epoch 3628 0.08395718260292422 6.874500600915206
Epoch 3629 0.0835881860196939 6.872214869449013
Epoch 3630 0.08406038704322792 6.876219849837454
Epoch 3631 0.08443176559014627 6.874734025252493
Epoch 3632 0.08388503819529773 6.884509488155968
Epoch 3633 0.08387219430933222 6.8745193230478385
Epoch 3634 0.08404153790216 6.877157361883866
Epoch 3635 0.08361501760824382 6.879459381103516
Epoch 3636 0.084123579295058 6.884507430227179
Epoch 3637 0.08393015690714295 6.8780395859166195
Epoch 3638 0.08379730965658935 6.880899755578292
Epoch 3639 0.08387497355017745 6.8719000816345215
Epoch 3640 0.083955571

Epoch 3751 0.08270448903765595 6.90050714894345
Epoch 3752 0.08275305501550262 6.898682418622468
Epoch 3753 0.0823953506803652 6.903947002009342
Epoch 3754 0.08278298190636942 6.900190353393555
Epoch 3755 0.08287839378006974 6.902396628731175
Epoch 3756 0.08255856492888858 6.899849841469212
Epoch 3757 0.08259248114817323 6.9056800541124845
Epoch 3758 0.08258497349002905 6.896706079181872
Epoch 3759 0.08229857744180669 6.901628117812307
Epoch 3760 0.08265574646797794 6.9107923005756575
Epoch 3761 0.08229119844778239 6.898821629975972
Epoch 3762 0.08304799705395224 6.90788459777832
Epoch 3763 0.08249156974386751 6.907181539033589
Epoch 3764 0.08276501402520296 6.902996364392732
Epoch 3765 0.08290847813525395 6.906750779402883
Epoch 3766 0.08255578586232593 6.910129496925755
Epoch 3767 0.08291342052799916 6.908054778450413
Epoch 3768 0.08269821787089632 6.905374225817229
Epoch 3769 0.08267158239382749 6.9001263066342
Epoch 3770 0.08278504484578182 6.90454794231214
Epoch 3771 0.08303770148

Epoch 3892 0.0818961154647738 6.927345602135909
Epoch 3893 0.08144739871485192 6.936786300257633
Epoch 3894 0.081631743202084 6.928255457627146
Epoch 3895 0.08146745951203575 6.929188703235827
Epoch 3896 0.08157749164697023 6.931223994807193
Epoch 3897 0.0814545306197384 6.934426458258378
Epoch 3898 0.08126913660276704 6.93889025637978
Epoch 3899 0.0811982799256057 6.931828448646947
Epoch 3900 0.08101611763064624 6.933616713473671
Time1:  0.4875898212194443
Time2:  0.05005566589534283
Time3:  0.16378521732985973
Time1:  0.0544129665941
Time2:  0.008078902959823608
Time3:  0.018848920240998268

train {'precision@10': 0.6370022892951965, 'precision@30': 0.6970171332359314, 'precision@50': 0.7072609066963196, 'ndcg@10': 0.5232338309288025, 'ndcg@30': 0.6444600224494934, 'ndcg@50': 0.737514853477478, 'ndcg@all': 0.749002993106842}
valid {'precision@10': 0.410817950963974, 'precision@30': 0.31934916973114014, 'precision@50': 0.27278098464012146, 'ndcg@10': 0.3802909851074219, 'ndcg@30': 0.3

Epoch 4021 0.08061837648961977 6.956771674909089
Epoch 4022 0.0805058620042271 6.954074909812526
Epoch 4023 0.08043934897198314 6.958366193269429
Epoch 4024 0.07996619096276356 6.9614389570135815
Epoch 4025 0.0804489996523885 6.958659046574643
Epoch 4026 0.08037251473693123 6.963343143463135
Epoch 4027 0.08062632037707937 6.958450994993511
Epoch 4028 0.08015203484666278 6.957799459758558
Epoch 4029 0.08043883919541599 6.959434358697188
Epoch 4030 0.0807022023619267 6.9603267970838045
Epoch 4031 0.080370966628281 6.959244627701609
Epoch 4032 0.08099792052430717 6.962058318288703
Epoch 4033 0.08081790913789592 6.958752632141113
Epoch 4034 0.08007319149566673 6.958911117754485
Epoch 4035 0.08019698040875775 6.958211070612857
Epoch 4036 0.0802141407079864 6.963475327742727
Epoch 4037 0.08078062756542574 6.9656635083650285
Epoch 4038 0.08041360643174914 6.957892242230867
Epoch 4039 0.08039418276813295 6.961696574562474
Epoch 4040 0.08006257228335442 6.9580361968592594
Epoch 4041 0.080185429

Epoch 4151 0.07956293804777993 6.981447596299021
Epoch 4152 0.0795951553604059 6.984243970168264
Epoch 4153 0.07907199371627897 6.981300077940288
Epoch 4154 0.07923560567766602 6.983444891477886
Epoch 4155 0.07977421599173407 6.986751581493177
Epoch 4156 0.07898403333815915 6.980314882178056
Epoch 4157 0.07927252604947453 6.99155835101479
Epoch 4158 0.07913697441244683 6.9949744877062345
Epoch 4159 0.0791380476881886 6.98638885899594
Epoch 4160 0.07936054821077146 6.984373745165374
Epoch 4161 0.07928465868820224 6.98613957354897
Epoch 4162 0.0789729162004956 6.987059944554379
Epoch 4163 0.07930749001210197 7.001666922318308
Epoch 4164 0.0792721445013208 6.98660441448814
Epoch 4165 0.07943415659212927 6.989505541952033
Epoch 4166 0.07910110335252439 6.988113127256694
Epoch 4167 0.07925379210919664 6.986603209846898
Epoch 4168 0.0793703431971589 6.988864848488255
Epoch 4169 0.07943022547409548 6.986473911686947
Epoch 4170 0.07931033256109694 6.989578171780235
Epoch 4171 0.079405852838566

Epoch 4292 0.07825596552146108 7.009679719021446
Epoch 4293 0.07786440701164 7.019050949498227
Epoch 4294 0.07873949839880592 7.015943928768761
Epoch 4295 0.07841734440006011 7.012344837188721
Epoch 4296 0.07831513389335042 7.0159707822297745
Epoch 4297 0.07799086645681258 7.017945515482049
Epoch 4298 0.07822060319241027 7.008882798646626
Epoch 4299 0.07799929219205477 7.020005728069105
Epoch 4300 0.07843410180151811 7.02079607311048
Time1:  0.4876679703593254
Time2:  0.04990185424685478
Time3:  0.1638012994080782
Time1:  0.05435021221637726
Time2:  0.008067192509770393
Time3:  0.018849261105060577

train {'precision@10': 0.6696827411651611, 'precision@30': 0.7162708044052124, 'precision@50': 0.7126511335372925, 'ndcg@10': 0.547155499458313, 'ndcg@30': 0.6635112762451172, 'ndcg@50': 0.751183807849884, 'ndcg@all': 0.7602506279945374}
valid {'precision@10': 0.4296042323112488, 'precision@30': 0.3278804123401642, 'precision@50': 0.27700263261795044, 'ndcg@10': 0.39472970366477966, 'ndcg@3

Epoch 4420 0.07741389427965845 7.037549018859863
Epoch 4421 0.0774901384237217 7.0338822164033585
Epoch 4422 0.07713585843642552 7.031951979586952
Epoch 4423 0.07724913527742464 7.041886254360802
Epoch 4424 0.07683403441431927 7.0361193606728
Epoch 4425 0.07754857276092496 7.033593629535876
Epoch 4426 0.07757229157532865 7.036704866509688
Epoch 4427 0.0774683110110941 7.036513780292712
Epoch 4428 0.07718088883056975 7.041012186753123
Epoch 4429 0.07785510703137047 7.038277801714446
Epoch 4430 0.07746072039443846 7.043885732951917
Epoch 4431 0.07712128745359287 7.048231375844855
Epoch 4432 0.07702300980774282 7.041406004052413
Epoch 4433 0.07752352347325163 7.040489096390574
Epoch 4434 0.0770272777268761 7.04215649554604
Epoch 4435 0.07751737426073231 7.0435934568706315
Epoch 4436 0.07699117552467256 7.041855485815751
Epoch 4437 0.07779535298284732 7.042866832331607
Epoch 4438 0.07754639431572798 7.043884804374294
Epoch 4439 0.07737019209311022 7.041373177578575
Epoch 4440 0.07727654766

Epoch 4551 0.07666044712763781 7.06122360731426
Epoch 4552 0.07610992590586345 7.069509832482589
Epoch 4553 0.0767347830516553 7.063116148898476
Epoch 4554 0.07644956959792745 7.065639420559532
Epoch 4555 0.07695401024225859 7.065969893806859
Epoch 4556 0.07632868933050256 7.066195989909925
Epoch 4557 0.07618135788984466 7.0661822871158
Epoch 4558 0.07646072830198801 7.066916139502275
Epoch 4559 0.07633161279018859 7.064556272406327
Epoch 4560 0.07679096077792129 7.070150877300062
Epoch 4561 0.07689602002065782 7.069032167133532
Epoch 4562 0.07610240304156353 7.0676799824363306
Epoch 4563 0.07632450310633196 7.066681736393979
Epoch 4564 0.07643127789971424 7.068909544693796
Epoch 4565 0.07670640710153077 7.065959704549689
Epoch 4566 0.07655732843436693 7.066118265453138
Epoch 4567 0.07589931903701079 7.070708525808234
Epoch 4568 0.07661366941984633 7.067647256349263
Epoch 4569 0.07630262332178696 7.0649818119249845
Epoch 4570 0.07664957951906828 7.072997971584923
Epoch 4571 0.076329694

Epoch 4693 0.07549281466251229 7.08869941611039
Epoch 4694 0.07584924782403031 7.092229391399183
Epoch 4695 0.07536059652852733 7.100721610219855
Epoch 4696 0.07551679254798165 7.090002586967067
Epoch 4697 0.07546061586741118 7.089119610033538
Epoch 4698 0.07566381446276492 7.092567895588122
Epoch 4699 0.07601361665112233 7.090487204099956
Epoch 4700 0.0752817799299084 7.088813580964741
Time1:  0.4886180441826582
Time2:  0.04987652786076069
Time3:  0.1637501996010542
Time1:  0.05440474860370159
Time2:  0.008071228861808777
Time3:  0.018833201378583908

train {'precision@10': 0.6484020352363586, 'precision@30': 0.7053499817848206, 'precision@50': 0.7107007503509521, 'ndcg@10': 0.5307245850563049, 'ndcg@30': 0.6524545550346375, 'ndcg@50': 0.7431520819664001, 'ndcg@all': 0.7528306841850281}
valid {'precision@10': 0.4177308976650238, 'precision@30': 0.3217414617538452, 'precision@50': 0.27331921458244324, 'ndcg@10': 0.3849017322063446, 'ndcg@30': 0.3813273310661316, 'ndcg@50': 0.3960969150

Epoch 4821 0.07471378744519942 7.109301968624718
Epoch 4822 0.07450176588101694 7.105955224288137
Epoch 4823 0.07489262978758729 7.113927715703061
Epoch 4824 0.07533428865915154 7.110732530292712
Epoch 4825 0.07450341624997513 7.113054426092851
Epoch 4826 0.07505470004520919 7.107476058759187
Epoch 4827 0.07467740708798692 7.120363084893477
Epoch 4828 0.07512135503062031 7.114572299154181
Epoch 4829 0.07452628208182709 7.117935306147525
Epoch 4830 0.07452633802653753 7.111491705241956
Epoch 4831 0.07465025669301463 7.111699631339626
Epoch 4832 0.0750370801192278 7.109407851570531
Epoch 4833 0.07489248522018131 7.110749319980019
Epoch 4834 0.07459740859200382 7.117331454628392
Epoch 4835 0.07449727867081848 7.11566360373246
Epoch 4836 0.07445408369016926 7.11579694245991
Epoch 4837 0.07453008327219221 7.113644499527781
Epoch 4838 0.0746192558292757 7.116317096509431
Epoch 4839 0.0746051164572699 7.112031409614964
Epoch 4840 0.07488164793678194 7.119180051903975
Epoch 4841 0.074169893405

Epoch 4951 0.07397593048058058 7.137382457130833
Epoch 4952 0.07337395999340983 7.136935560326827
Epoch 4953 0.07397205770364282 7.139230050538716
Epoch 4954 0.07424376595612855 7.135972750814338
Epoch 4955 0.07418042133774674 7.139362259915001
Epoch 4956 0.0742234420358089 7.131974345759342
Epoch 4957 0.07431088592748196 7.13766958839015
Epoch 4958 0.07391684759430021 7.135665517104299
Epoch 4959 0.07409835968449799 7.139606902473851
Epoch 4960 0.07338012839879907 7.148129111842105
Epoch 4961 0.07464625346556045 7.140285140589664
Epoch 4962 0.07385740555517854 7.139320197858308
Epoch 4963 0.07387916602150739 7.1337527224892066
Epoch 4964 0.07391641916413057 7.137392596194618
Epoch 4965 0.07407160552098738 7.139000340511925
Epoch 4966 0.0736810290064031 7.136038479052092
Epoch 4967 0.0740282325542461 7.141219866903205
Epoch 4968 0.07381757701698102 7.1382036209106445
Epoch 4969 0.07420801841898968 7.1379829456931665
Epoch 4970 0.07367140115701665 7.140657399830065
Epoch 4971 0.07393752

In [26]:
def listNet(y_pred, y_true, eps=1e-10, padded_value_indicator=-1):
    """
    ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: loss value, a torch.Tensor
    """
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    mask = y_true == padded_value_indicator
    y_pred[mask] = float('-inf')
    y_true[mask] = float('-inf')

    preds_smax = F.softmax(y_pred, dim=1)
    true_smax = F.softmax(y_true, dim=1)

    preds_smax = preds_smax + eps
    preds_log = torch.log(preds_smax)

    return torch.mean(-torch.sum(true_smax * preds_log, dim=1))

In [27]:
# setting
lr = 0.1
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 5000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
# criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        pred = model(doc_embs)     
        loss = listNet(pred, target)
    
        # Model backwarding
        model.zero_grad()
        loss.backward()
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        loss = listNet(pred, target)
    
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)


HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

Epoch 0 7.558826981929311 6.642893239071495
Time1:  0.48111855424940586
Time2:  0.050766244530677795
Time3:  0.1644516959786415
Time1:  0.05355365201830864
Time2:  0.008210383355617523
Time3:  0.018949346616864204

train {'precision@10': 0.19491584599018097, 'precision@30': 0.11187867075204849, 'precision@50': 0.08433237671852112, 'ndcg@10': 0.24771833419799805, 'ndcg@30': 0.21983011066913605, 'ndcg@50': 0.21720483899116516, 'ndcg@all': 0.4307364821434021}
valid {'precision@10': 0.18659630417823792, 'precision@30': 0.10979772359132767, 'precision@50': 0.08303956687450409, 'ndcg@10': 0.21357251703739166, 'ndcg@30': 0.1910678893327713, 'ndcg@50': 0.18884526193141937, 'ndcg@all': 0.40859586000442505}
Epoch 1 5.386427056719685 5.65228974191766
Epoch 2 4.122493099748042 5.1396158870897795
Epoch 3 3.3047335440652414 4.75669695201673
Epoch 4 2.7447768749549373 4.531344840401097
Epoch 5 2.3409588950419287 4.370237086948595
Epoch 6 2.0675212883809855 4.255293871227064
Epoch 7 1.8832298001350716

Epoch 135 1.2597397468243425 3.406078526848241
Epoch 136 1.2599005632930331 3.4035764367956864
Epoch 137 1.2579292058944702 3.406371493088572
Epoch 138 1.258521216654638 3.3975021337207996
Epoch 139 1.2588074437358923 3.4017832028238395
Epoch 140 1.2572579042256227 3.404090680574116
Epoch 141 1.2588042048682944 3.3992137407001697
Epoch 142 1.2595475047652485 3.402209294469733
Epoch 143 1.257382897954238 3.3968052738591243
Epoch 144 1.2562592008657623 3.388517166438856
Epoch 145 1.2570915612560964 3.4015126102849056
Epoch 146 1.2568077214977198 3.39712875767758
Epoch 147 1.2569824631451165 3.390445960195441
Epoch 148 1.256877797388891 3.398538087543688
Epoch 149 1.2571047331854615 3.3956476010774312
Epoch 150 1.2566403635064063 3.3917504988218607
Time1:  0.47929034382104874
Time2:  0.05126351863145828
Time3:  0.1642293818295002
Time1:  0.05348225310444832
Time2:  0.008181614801287651
Time3:  0.018965011462569237

train {'precision@10': 0.5086261034011841, 'precision@30': 0.2592564523220

Epoch 268 1.2403727449171724 3.3510616327586926
Epoch 269 1.2396841028280425 3.3480429900319955
Epoch 270 1.2399257832800437 3.3476637037176835
Epoch 271 1.2394402009701868 3.3426146256296256
Epoch 272 1.2402539016210545 3.348825504905299
Epoch 273 1.2388543796818159 3.353108569195396
Epoch 274 1.2400201036219012 3.3402598908073022
Epoch 275 1.2410284183178728 3.346604447615774
Epoch 276 1.2411729707355388 3.350219312467073
Epoch 277 1.2371617918126068 3.3472148619200053
Epoch 278 1.2391816178957622 3.3592327519467005
Epoch 279 1.2391265205472533 3.34495600901152
Epoch 280 1.2390639088307207 3.3446347964437386
Epoch 281 1.2383231388198004 3.343806643235056
Epoch 282 1.2395718156942848 3.3396058459030953
Epoch 283 1.2391244985206782 3.343995257427818
Epoch 284 1.238725067230693 3.342506270659597
Epoch 285 1.2382654662717854 3.348313281410619
Epoch 286 1.2413311304404722 3.3434331668050667
Epoch 287 1.2381334757944296 3.3505653080187345
Epoch 288 1.23868955924497 3.347330883929604
Epoch 

Epoch 401 1.2319392032093472 3.340606451034546
Epoch 402 1.2317541792378788 3.333127762141981
Epoch 403 1.230491468432354 3.338856421018902
Epoch 404 1.2304626197842827 3.3317849008660567
Epoch 405 1.232363504275941 3.3356052448875024
Epoch 406 1.2315258958883453 3.335697186620612
Epoch 407 1.2306940830938997 3.339101276899639
Epoch 408 1.2328254270274737 3.329742619865819
Epoch 409 1.232120604891526 3.3404737146277177
Epoch 410 1.231837414858634 3.3304218994943717
Epoch 411 1.2308993987869798 3.3323654375578227
Epoch 412 1.2301074364031965 3.334660028156481
Epoch 413 1.2300295157042163 3.3296350428932593
Epoch 414 1.2319474680381908 3.334347737462897
Epoch 415 1.2293350661707203 3.3356498668068335
Epoch 416 1.231697606761553 3.327612048701236
Epoch 417 1.2311688675517924 3.329871993315847
Epoch 418 1.2305808520456503 3.3337048856835616
Epoch 419 1.2297281120952808 3.3285773804313257
Epoch 420 1.230550779933818 3.3309754321449683
Epoch 421 1.2288366580567165 3.3297368476265357
Epoch 42

Epoch 547 1.2256234831977308 3.3380564513959383
Epoch 548 1.227158390639121 3.3394918065322075
Epoch 549 1.2255926320427342 3.3335102231878984
Epoch 550 1.2264449690517627 3.3332309095483077
Time1:  0.4793894086033106
Time2:  0.05086854100227356
Time3:  0.16421302035450935
Time1:  0.05353509820997715
Time2:  0.008177397772669792
Time3:  0.01894674450159073

train {'precision@10': 0.5663167834281921, 'precision@30': 0.2923806309700012, 'precision@50': 0.20947866141796112, 'ndcg@10': 0.7500427961349487, 'ndcg@30': 0.6242126822471619, 'ndcg@50': 0.6043657064437866, 'ndcg@all': 0.7580590844154358}
valid {'precision@10': 0.467387855052948, 'precision@30': 0.2571679949760437, 'precision@50': 0.18930870294570923, 'ndcg@10': 0.5848190188407898, 'ndcg@30': 0.5023055672645569, 'ndcg@50': 0.49015316367149353, 'ndcg@all': 0.6654290556907654}
Epoch 551 1.2242247525014376 3.338511667753521
Epoch 552 1.224712240068536 3.3363934441616663
Epoch 553 1.2255599129269694 3.3311534555334794
Epoch 554 1.2261

Epoch 680 1.2207015052176358 3.340752124786377
Epoch 681 1.219726638835773 3.34030903013129
Epoch 682 1.219618746760296 3.34334565463819
Epoch 683 1.2193106252547594 3.3462039922413074
Epoch 684 1.221551333254541 3.338478627957796
Epoch 685 1.2211290242379171 3.343057770478098
Epoch 686 1.2213567648714745 3.3350486755371094
Epoch 687 1.2205200749531127 3.348749035283139
Epoch 688 1.2204923159197758 3.337512204521581
Epoch 689 1.2207919763542756 3.3415659477836206
Epoch 690 1.2214152077485245 3.347462666662116
Epoch 691 1.2217217105173925 3.3418641968777307
Epoch 692 1.2218149973635088 3.3420703913036145
Epoch 693 1.2196414582213464 3.347943845548128
Epoch 694 1.221260744925828 3.3436700168408846
Epoch 695 1.2208167471383746 3.340518926319323
Epoch 696 1.2205689106768334 3.342003947810123
Epoch 697 1.2208242078273617 3.344422465876529
Epoch 698 1.219531480331867 3.344138132898431
Epoch 699 1.2199749838539033 3.3391949126594946
Epoch 700 1.2203128358774018 3.3408012641103646
Time1:  0.47

Epoch 813 1.218286817533928 3.351712051190828
Epoch 814 1.2171335025140417 3.3522389060572575
Epoch 815 1.2183801458593 3.349456322820563
Epoch 816 1.2169683227065013 3.3528333839617277
Epoch 817 1.2175859767093993 3.353894672895733
Epoch 818 1.2176650549933228 3.352802414643137
Epoch 819 1.2189851159240768 3.3530053464989913
Epoch 820 1.2176883489067791 3.3447943110215035
Epoch 821 1.218661512199201 3.359586088280929
Epoch 822 1.2164964829271996 3.3568719186280904
Epoch 823 1.2160908332345082 3.345466325157567
Epoch 824 1.2175654322780363 3.353211202119526
Epoch 825 1.2171973042320787 3.3554758398156417
Epoch 826 1.2168255476226584 3.3473233925668815
Epoch 827 1.2164568437470331 3.353098254454763
Epoch 828 1.2186609384609244 3.3573696989762154
Epoch 829 1.2158850741665266 3.3556306236668636
Epoch 830 1.2167660092052661 3.3498891780250952
Epoch 831 1.2169781177364596 3.3495521921860543
Epoch 832 1.217406938299101 3.3568158526169625
Epoch 833 1.2173213693830702 3.3463985418018543
Epoch 

Epoch 951 1.214418261720423 3.372355686990838
Epoch 952 1.2150431969012434 3.3595125549717952
Epoch 953 1.2145565413592154 3.3604939234884164
Epoch 954 1.2166004885009856 3.3693934490806177
Epoch 955 1.213540826624597 3.3684566773866353
Epoch 956 1.2140282368102269 3.366467337859304
Epoch 957 1.2149322726573164 3.3701779967860173
Epoch 958 1.2152279828026977 3.3714835392801383
Epoch 959 1.2136038381453844 3.3660215578581156
Epoch 960 1.2151407749332181 3.366242898137946
Epoch 961 1.2141861971358807 3.366050544537996
Epoch 962 1.2127235615462588 3.3708241613287675
Epoch 963 1.2129095705629092 3.3645049521797583
Epoch 964 1.2161064695196542 3.364478726136057
Epoch 965 1.2135351627890827 3.363938833537855
Epoch 966 1.214097921959838 3.370548223194323
Epoch 967 1.214037847449208 3.3630147858669885
Epoch 968 1.216049343521832 3.3719073094819723
Epoch 969 1.2153030079707765 3.3654888303656327
Epoch 970 1.2141502496094732 3.36854130343387
Epoch 971 1.214844122267606 3.3681636233078804
Epoch 9

Epoch 1095 1.211554091227682 3.3798355930729915
Epoch 1096 1.212605850041261 3.386644764950401
Epoch 1097 1.2113511715716088 3.379405975341797
Epoch 1098 1.213872222175375 3.379949030123259
Epoch 1099 1.2130097541195608 3.385654462011237
Epoch 1100 1.2122933902935675 3.381299709018908
Time1:  0.4792938884347677
Time2:  0.05055186338722706
Time3:  0.16432500444352627
Time1:  0.057885363698005676
Time2:  0.008059408515691757
Time3:  0.019036423414945602

train {'precision@10': 0.5920600891113281, 'precision@30': 0.31079769134521484, 'precision@50': 0.22236555814743042, 'ndcg@10': 0.7668549418449402, 'ndcg@30': 0.6420891284942627, 'ndcg@50': 0.6219351887702942, 'ndcg@all': 0.7705984115600586}
valid {'precision@10': 0.4821636378765106, 'precision@30': 0.2662796974182129, 'precision@50': 0.19602110981941223, 'ndcg@10': 0.5905939936637878, 'ndcg@30': 0.5093238949775696, 'ndcg@50': 0.4973691403865814, 'ndcg@all': 0.6710137128829956}
Epoch 1101 1.2126375213701126 3.3904004975369104
Epoch 1102 

Epoch 1224 1.2096858819325764 3.394897046842073
Epoch 1225 1.2110123658737941 3.3984400975076774
Epoch 1226 1.2123095992712947 3.3972891631879305
Epoch 1227 1.21021031322535 3.3966882605301705
Epoch 1228 1.2094168537541439 3.3933592846519067
Epoch 1229 1.210005912167287 3.398190586190475
Epoch 1230 1.210721799156122 3.3949857511018453
Epoch 1231 1.208321800357417 3.391814482839484
Epoch 1232 1.2102354374545359 3.395700580195377
Epoch 1233 1.2101755288609288 3.400462313702232
Epoch 1234 1.2095225178010283 3.391105237760042
Epoch 1235 1.2090328283477247 3.3951198176333777
Epoch 1236 1.209720126020978 3.3938869928058826
Epoch 1237 1.209962264836183 3.403132350821244
Epoch 1238 1.2106139196289911 3.398274120531584
Epoch 1239 1.209111946368078 3.3976425120705054
Epoch 1240 1.2109813672757288 3.398077023656745
Epoch 1241 1.2099351140490748 3.4000008733649003
Epoch 1242 1.2094267575364364 3.3986426529131437
Epoch 1243 1.209442683130677 3.397110888832494
Epoch 1244 1.210582064954858 3.39637032

Epoch 1354 1.2100491300660965 3.4032394886016846
Epoch 1355 1.2075292283331442 3.414962040750604
Epoch 1356 1.208238568919444 3.4150462401540658
Epoch 1357 1.2098409836752373 3.4123940216867545
Epoch 1358 1.2091635514421073 3.4130529604460063
Epoch 1359 1.2077913803663867 3.415101176814029
Epoch 1360 1.208432524176369 3.4120544257916903
Epoch 1361 1.207256205249251 3.4108935657300448
Epoch 1362 1.209432322030876 3.4095913359993384
Epoch 1363 1.2088486025905052 3.4105458008615592
Epoch 1364 1.2076799196806567 3.4112378421582674
Epoch 1365 1.2095932859426353 3.410399637724224
Epoch 1366 1.2075765303700987 3.411956473400718
Epoch 1367 1.2072893907452187 3.4101760387420654
Epoch 1368 1.2082275302089445 3.4143740252444617
Epoch 1369 1.2088687151496174 3.415614492014835
Epoch 1370 1.208438689945734 3.4066926303662752
Epoch 1371 1.20792287553263 3.4117425240968404
Epoch 1372 1.209290567545863 3.410003850334569
Epoch 1373 1.2086736308900934 3.4172191996323433
Epoch 1374 1.2081225420299329 3.42

Epoch 1497 1.2055162703781797 3.4251376453198885
Epoch 1498 1.2060677437057272 3.429544737464503
Epoch 1499 1.2069985249586273 3.4337553852482845
Epoch 1500 1.2067718962479752 3.428237224880018
Time1:  0.4792105834931135
Time2:  0.05033884383738041
Time3:  0.16427990794181824
Time1:  0.05358118750154972
Time2:  0.008145757019519806
Time3:  0.018839459866285324

train {'precision@10': 0.6037471890449524, 'precision@30': 0.3191579580307007, 'precision@50': 0.22855448722839355, 'ndcg@10': 0.7738591432571411, 'ndcg@30': 0.6496145129203796, 'ndcg@50': 0.6296790242195129, 'ndcg@all': 0.775989830493927}
valid {'precision@10': 0.48691293597221375, 'precision@30': 0.269850492477417, 'precision@50': 0.1991451233625412, 'ndcg@10': 0.5911972522735596, 'ndcg@30': 0.5104780793190002, 'ndcg@50': 0.4987637400627136, 'ndcg@all': 0.6719908714294434}
Epoch 1501 1.2057444592665512 3.4313853790885522
Epoch 1502 1.2065936145726701 3.43232892688952
Epoch 1503 1.2086348620771665 3.4275822764948796
Epoch 1504 

Epoch 1627 1.2057498247302763 3.4413114346955953
Epoch 1628 1.2053959519542448 3.450330709156237
Epoch 1629 1.2037252379439727 3.4488906860351562
Epoch 1630 1.204418351078591 3.4389016377298454
Epoch 1631 1.2061238187795493 3.45144910561411
Epoch 1632 1.2055522663551463 3.4460593901182475
Epoch 1633 1.2051014199591519 3.448011134800158
Epoch 1634 1.2051950135426215 3.4526280101976896
Epoch 1635 1.2063747857049194 3.446348980853432
Epoch 1636 1.205016159174735 3.447292214945743
Epoch 1637 1.2049059989856699 3.4542048102930973
Epoch 1638 1.205443819712477 3.4490345026317395
Epoch 1639 1.2033851516177083 3.448683362258108
Epoch 1640 1.2041591433753744 3.450129170166819
Epoch 1641 1.2046767936812506 3.446876965071026
Epoch 1642 1.2051493180425543 3.449544116070396
Epoch 1643 1.2048196335982162 3.448498512569227
Epoch 1644 1.2052023494452762 3.4410678712945235
Epoch 1645 1.2044963216223912 3.4405195713043213
Epoch 1646 1.2042187190892404 3.4500959672425924
Epoch 1647 1.205858374199672 3.445

Epoch 1757 1.2036582994879337 3.466364622116089
Epoch 1758 1.2036690705003794 3.4630305641575863
Epoch 1759 1.2037152839682952 3.456662341168052
Epoch 1760 1.203321569844296 3.4638533717707585
Epoch 1761 1.2039809624354045 3.4644809271159924
Epoch 1762 1.2027266029028865 3.4603490954951237
Epoch 1763 1.2027112165389702 3.4682789727261194
Epoch 1764 1.2024456705266273 3.462083841625013
Epoch 1765 1.2036987421108267 3.4732921374471566
Epoch 1766 1.2036745736473484 3.4688071828139457
Epoch 1767 1.2026751017012791 3.4663060715324
Epoch 1768 1.2033417444480092 3.460426041954442
Epoch 1769 1.2054254541620177 3.456512877815648
Epoch 1770 1.2033488987482082 3.4549830963737085
Epoch 1771 1.2030062455880015 3.4581782190423263
Epoch 1772 1.2041899600224188 3.4624351200304533
Epoch 1773 1.2037053031530993 3.4616852559541402
Epoch 1774 1.2023692883943256 3.46679959799114
Epoch 1775 1.202604399787055 3.4667076939030697
Epoch 1776 1.2036698216583297 3.461879040065565
Epoch 1777 1.205222997749061 3.46

Time1:  0.4787521604448557
Time2:  0.050198474898934364
Time3:  0.16426915675401688
Time1:  0.053563227877020836
Time2:  0.008146403357386589
Time3:  0.01891898736357689

train {'precision@10': 0.6122089624404907, 'precision@30': 0.32640397548675537, 'precision@50': 0.23390722274780273, 'ndcg@10': 0.7790075540542603, 'ndcg@30': 0.6558478474617004, 'ndcg@50': 0.6359550356864929, 'ndcg@all': 0.7804374098777771}
valid {'precision@10': 0.4905013144016266, 'precision@30': 0.2721196115016937, 'precision@50': 0.20101317763328552, 'ndcg@10': 0.5905339121818542, 'ndcg@30': 0.5113978385925293, 'ndcg@50': 0.49992573261260986, 'ndcg@all': 0.6731889843940735}
Epoch 1901 1.2021480045820538 3.4765089687548185
Epoch 1902 1.2023591200510662 3.4778839914422286
Epoch 1903 1.2012168953293247 3.478755047446803
Epoch 1904 1.2020652154732865 3.474665641784668
Epoch 1905 1.2028267634542364 3.4760097453468726
Epoch 1906 1.2026312835732398 3.4806040713661597
Epoch 1907 1.2037848945249592 3.482568389491031
Epoch

Epoch 2031 1.2014841283971107 3.4933570560656095
Epoch 2032 1.2008953160709805 3.498928007326628
Epoch 2033 1.2010618392487018 3.490784858402453
Epoch 2034 1.2011323606061657 3.4940308520668433
Epoch 2035 1.2024997134654842 3.4950621630016125
Epoch 2036 1.2019634253797475 3.4940928534457556
Epoch 2037 1.2000476460010685 3.490415460185001
Epoch 2038 1.2008431700238011 3.497340829748856
Epoch 2039 1.2009691592545537 3.4956611081173548
Epoch 2040 1.2017864257271527 3.4938400168167916
Epoch 2041 1.2009683226981358 3.5035538924367806
Epoch 2042 1.2012542820813363 3.4939331004494116
Epoch 2043 1.2004676311336764 3.500512386623182
Epoch 2044 1.201006957662036 3.497074139745612
Epoch 2045 1.2026541856994406 3.4969249524568258
Epoch 2046 1.2017223754821464 3.498152820687545
Epoch 2047 1.2021588791183562 3.496214590574566
Epoch 2048 1.2006222818330017 3.4910309942145097
Epoch 2049 1.201157485881047 3.4947624081059505
Epoch 2050 1.2021263050057038 3.4936293802763285
Time1:  0.478642912581563
Time

Epoch 2161 1.2000579530732673 3.510768438640394
Epoch 2162 1.1996978981453075 3.511096377121775
Epoch 2163 1.201222622952266 3.515484408328408
Epoch 2164 1.2013668846665768 3.509620327698557
Epoch 2165 1.199029656878689 3.5183974943662943
Epoch 2166 1.1993401343362373 3.510169694298192
Epoch 2167 1.2004169201293187 3.512089440697118
Epoch 2168 1.1998232669300504 3.513648321754054
Epoch 2169 1.200021719723417 3.509569858249865
Epoch 2170 1.1998926937928673 3.5073737345243754
Epoch 2171 1.1986614016064427 3.5135608221355237
Epoch 2172 1.1993501531450372 3.5104757735603735
Epoch 2173 1.2010140425977651 3.5187893541235673
Epoch 2174 1.2012750251948485 3.5108533282029
Epoch 2175 1.1999697521416068 3.508210357866789
Epoch 2176 1.1997241489371362 3.5131320702402213
Epoch 2177 1.199064208053009 3.510583626596551
Epoch 2178 1.1982089040572184 3.511428531847502
Epoch 2179 1.2010206127027323 3.5117837378853247
Epoch 2180 1.199990983943493 3.5134748659635844
Epoch 2181 1.2004666969790097 3.5099166

Epoch 2301 1.1995609158660934 3.5277850502415706
Epoch 2302 1.1994850220736006 3.5265850142428747
Epoch 2303 1.2006831301583185 3.5344987668489156
Epoch 2304 1.1980055061697263 3.5343215967479504
Epoch 2305 1.2003490140563564 3.5314599965748035
Epoch 2306 1.1993379380270752 3.526826180909809
Epoch 2307 1.2004676684301498 3.5325654682360197
Epoch 2308 1.1984572696406939 3.5301532619877865
Epoch 2309 1.1996631552601418 3.529097030037328
Epoch 2310 1.198632085881038 3.5236100146644995
Epoch 2311 1.199792694627193 3.5255059944955924
Epoch 2312 1.2000308552680656 3.5324539510827315
Epoch 2313 1.198457646439647 3.526915010653044
Epoch 2314 1.1981341633183218 3.536867304852134
Epoch 2315 1.1981325871066044 3.5329588588915373
Epoch 2316 1.1997225640112894 3.5257581535138582
Epoch 2317 1.2004496392450834 3.5325908786372135
Epoch 2318 1.198709894342032 3.534240245819092
Epoch 2319 1.1992658622432173 3.534687092429713
Epoch 2320 1.1985388471369158 3.52929415200886
Epoch 2321 1.1992456689912674 3.

Epoch 2444 1.197333408029456 3.5492015637849508
Epoch 2445 1.1987848236546879 3.543488251535516
Epoch 2446 1.1964919321369707 3.5456865335765637
Epoch 2447 1.1989728766575194 3.546447415100901
Epoch 2448 1.1987694902726782 3.5469293343393424
Epoch 2449 1.1989618473582797 3.554651473697863
Epoch 2450 1.1987905275751973 3.5465728860152397
Time1:  0.4791992921382189
Time2:  0.05001219920814037
Time3:  0.16422765515744686
Time1:  0.05335942842066288
Time2:  0.008114904165267944
Time3:  0.01898769661784172

train {'precision@10': 0.619498074054718, 'precision@30': 0.33387088775634766, 'precision@50': 0.23966455459594727, 'ndcg@10': 0.7830520272254944, 'ndcg@30': 0.6616485118865967, 'ndcg@50': 0.6421034932136536, 'ndcg@all': 0.7847689390182495}
valid {'precision@10': 0.49034300446510315, 'precision@30': 0.2749868333339691, 'precision@50': 0.2043377161026001, 'ndcg@10': 0.5885800719261169, 'ndcg@30': 0.5112393498420715, 'ndcg@50': 0.500589907169342, 'ndcg@all': 0.6733241081237793}
Epoch 2451 

Epoch 2574 1.1973131466330142 3.560375452041626
Epoch 2575 1.199193611131077 3.560296886845639
Epoch 2576 1.1982559303791203 3.555749955930208
Epoch 2577 1.1974811822350262 3.5591916159579626
Epoch 2578 1.1974780381771557 3.5594933158472966
Epoch 2579 1.1976591840124966 3.558589822367618
Epoch 2580 1.1982459418954905 3.562194510510093
Epoch 2581 1.1990107231669955 3.560219250227276
Epoch 2582 1.1975678756223087 3.5664812389173006
Epoch 2583 1.1968972045078612 3.5614986921611584
Epoch 2584 1.1978637583074514 3.5651742659117045
Epoch 2585 1.1981131859690126 3.564826124592831
Epoch 2586 1.1959639592477453 3.5652949810028076
Epoch 2587 1.1973649911713182 3.5598481328863847
Epoch 2588 1.1974034037506371 3.5644551076387105
Epoch 2589 1.1962471935484145 3.5619163638667057
Epoch 2590 1.1972545472502012 3.562049489272268
Epoch 2591 1.1980887710699561 3.5659550491132235
Epoch 2592 1.1969915123013726 3.5651158784565173
Epoch 2593 1.1985524268875345 3.5610429864180717
Epoch 2594 1.1979777690959952

Epoch 2703 1.1961715963848851 3.5803655825163188
Epoch 2704 1.1963647642330817 3.5737156240563643
Epoch 2705 1.1974344727588675 3.577881750307585
Epoch 2706 1.1968464467957702 3.5726383108841744
Epoch 2707 1.197266223486404 3.572600477620175
Epoch 2708 1.1966020457925852 3.5775443378247713
Epoch 2709 1.199269898105086 3.576879538987812
Epoch 2710 1.1963575736821046 3.5754912276017037
Epoch 2711 1.1975872258693852 3.5712815083955465
Epoch 2712 1.1974745675137168 3.578944055657638
Epoch 2713 1.19817930581974 3.575814686323467
Epoch 2714 1.195850079868272 3.576230789485731
Epoch 2715 1.19476648189171 3.5806714610049597
Epoch 2716 1.1986791714590195 3.5724850704795434
Epoch 2717 1.1973171582696034 3.5786533983130204
Epoch 2718 1.1973596438329819 3.578512819189774
Epoch 2719 1.197594085632012 3.5704162999203333
Epoch 2720 1.1964327633729455 3.5764311363822534
Epoch 2721 1.196359490093432 3.578477508143375
Epoch 2722 1.197463521831914 3.5779654226805033
Epoch 2723 1.1962486854073597 3.572828

Epoch 2846 1.1968056803558305 3.5849607116297673
Epoch 2847 1.1974542569695859 3.5866140942824516
Epoch 2848 1.1964901316235637 3.5884205291145728
Epoch 2849 1.1965217799471135 3.586797739330091
Epoch 2850 1.1956450834608914 3.589496637645521
Time1:  0.47954620234668255
Time2:  0.049966951832175255
Time3:  0.1643209494650364
Time1:  0.05349960923194885
Time2:  0.008093316107988358
Time3:  0.018878979608416557

train {'precision@10': 0.6235911250114441, 'precision@30': 0.33739128708839417, 'precision@50': 0.24232567846775055, 'ndcg@10': 0.7856317758560181, 'ndcg@30': 0.6649056673049927, 'ndcg@50': 0.6454285979270935, 'ndcg@all': 0.7869139909744263}
valid {'precision@10': 0.49039578437805176, 'precision@30': 0.27447670698165894, 'precision@50': 0.20414775609970093, 'ndcg@10': 0.5869216322898865, 'ndcg@30': 0.5102509260177612, 'ndcg@50': 0.499580979347229, 'ndcg@all': 0.6726781725883484}
Epoch 2851 1.195864905390823 3.5894656808752763
Epoch 2852 1.1972994386104114 3.5856180567490425
Epoch

Epoch 2976 1.1952200552873444 3.607753377211721
Epoch 2977 1.1951073882175467 3.603384595168264
Epoch 2978 1.1953078407293174 3.5988423447859916
Epoch 2979 1.1952413090488367 3.5963495907030607
Epoch 2980 1.196654498577118 3.599669205515008
Epoch 2981 1.1959519313092817 3.6002452498988102
Epoch 2982 1.1953133804756297 3.60679292678833
Epoch 2983 1.1961203151279025 3.6052537089899968
Epoch 2984 1.1952401713321084 3.6039272484026457
Epoch 2985 1.1971294255981668 3.6057890590868498
Epoch 2986 1.194202691142322 3.6021421081141423
Epoch 2987 1.1951120746763129 3.6036954929954126
Epoch 2988 1.1944000410754778 3.599889918377525
Epoch 2989 1.1950344241153428 3.5983962761728385
Epoch 2990 1.1968730569582933 3.604512252305683
Epoch 2991 1.196158776506346 3.6031641960144043
Epoch 2992 1.1951560134079025 3.6022889363138297
Epoch 2993 1.195469732173005 3.600785644430863
Epoch 2994 1.1953929544192308 3.6079015355361137
Epoch 2995 1.1947191814232987 3.5982863652078727
Epoch 2996 1.197334433159633 3.6

Epoch 3106 1.1958342647691915 3.6180613542857922
Epoch 3107 1.1952500228296246 3.611261957570126
Epoch 3108 1.1959918054223757 3.6162353189367997
Epoch 3109 1.19656542204974 3.617560599979601
Epoch 3110 1.1941126058673301 3.6089964163930794
Epoch 3111 1.1942543851004706 3.617464592582301
Epoch 3112 1.1947515094489383 3.6091266556790003
Epoch 3113 1.1940776095752828 3.613578884225143
Epoch 3114 1.1941392700574551 3.611830849396555
Epoch 3115 1.1940127646016796 3.6181067416542456
Epoch 3116 1.1949682444856877 3.611082064478021
Epoch 3117 1.1944105935375593 3.6160808487942346
Epoch 3118 1.1961452521775897 3.6175179356022884
Epoch 3119 1.1949840006772539 3.6118758854113127
Epoch 3120 1.1939755409781696 3.613909131602237
Epoch 3121 1.194513706087369 3.615985079815513
Epoch 3122 1.1948497344178763 3.620204160087987
Epoch 3123 1.1942755998226635 3.6199823178743062
Epoch 3124 1.1946333869856003 3.622109337856895
Epoch 3125 1.1950457908256709 3.616036716260408
Epoch 3126 1.1949688799200002 3.61

Epoch 3249 1.1933665523055004 3.6317365420492074
Epoch 3250 1.1941574836334987 3.632058469872726
Time1:  0.4794121701270342
Time2:  0.04982390254735947
Time3:  0.16437037102878094
Time1:  0.053525203838944435
Time2:  0.008100949227809906
Time3:  0.018966970965266228

train {'precision@10': 0.6269981861114502, 'precision@30': 0.3410876393318176, 'precision@50': 0.24527646601200104, 'ndcg@10': 0.7875974178314209, 'ndcg@30': 0.6678304076194763, 'ndcg@50': 0.6485500335693359, 'ndcg@all': 0.7890027761459351}
valid {'precision@10': 0.4910290241241455, 'precision@30': 0.2760421931743622, 'precision@50': 0.20534037053585052, 'ndcg@10': 0.5837547779083252, 'ndcg@30': 0.5082827210426331, 'ndcg@50': 0.49815797805786133, 'ndcg@all': 0.6712161898612976}
Epoch 3251 1.1937515613628409 3.630015385778327
Epoch 3252 1.194058899293866 3.6296314440275492
Epoch 3253 1.1940918913361622 3.629655725077579
Epoch 3254 1.1950421225257783 3.6274963554583097
Epoch 3255 1.1942230044749744 3.63165471428319
Epoch 325

Epoch 3379 1.1929242969953526 3.639067662389655
Epoch 3380 1.193811329484683 3.639524271613673
Epoch 3381 1.193397365466893 3.645044264040495
Epoch 3382 1.192427598942093 3.640509981858103
Epoch 3383 1.194427111692596 3.6413646120774117
Epoch 3384 1.1948213873550906 3.6378940030148157
Epoch 3385 1.192665884020733 3.6426073751951518
Epoch 3386 1.1936067987603751 3.640351596631502
Epoch 3387 1.1930978699037207 3.6367239952087402
Epoch 3388 1.1939291964497483 3.639421877108122
Epoch 3389 1.1931278851994298 3.6385661300859953
Epoch 3390 1.193293311093983 3.646208311382093
Epoch 3391 1.194236922333812 3.6438870931926526
Epoch 3392 1.1949130139852826 3.640828847885132
Epoch 3393 1.1941757829565751 3.642691875758924
Epoch 3394 1.1940612559430084 3.634943962097168
Epoch 3395 1.1945954212668346 3.645476981213218
Epoch 3396 1.194507825095751 3.6354189044550846
Epoch 3397 1.193867203784965 3.637949353770206
Epoch 3398 1.1940976196562338 3.6430099763368307
Epoch 3399 1.1943005455864801 3.639586674

Epoch 3509 1.192667915807133 3.6463129771383187
Epoch 3510 1.1949698938960918 3.6514046066685726
Epoch 3511 1.1938791358680056 3.6525714271946956
Epoch 3512 1.1938566739796197 3.6512676038240133
Epoch 3513 1.1928714827487343 3.650448886971725
Epoch 3514 1.1938720509322762 3.6458543852755896
Epoch 3515 1.1929865925632723 3.6514792693288705
Epoch 3516 1.193034647501003 3.6508861717424894
Epoch 3517 1.1942891046317696 3.6454146786739954
Epoch 3518 1.193178521959405 3.649066423114977
Epoch 3519 1.1927185644183242 3.6470228119900354
Epoch 3520 1.193150004099684 3.6533688495033667
Epoch 3521 1.1939519211562752 3.6537704844223824
Epoch 3522 1.192318211870584 3.6525193766543738
Epoch 3523 1.1949219996469063 3.650127511275442
Epoch 3524 1.1948283074194925 3.6556090053759123
Epoch 3525 1.192841561217057 3.656518032676295
Epoch 3526 1.1939287391322397 3.657063885738975
Epoch 3527 1.1945635495130082 3.6500192190471448
Epoch 3528 1.1926455748708624 3.6576797711221793
Epoch 3529 1.193793008899131 3.

Time3:  0.1643634159117937
Time1:  0.05363424867391586
Time2:  0.008083462715148926
Time3:  0.018910104408860207

train {'precision@10': 0.6296839118003845, 'precision@30': 0.3442581295967102, 'precision@50': 0.24757987260818481, 'ndcg@10': 0.789080798625946, 'ndcg@30': 0.6701392531394958, 'ndcg@50': 0.6510009765625, 'ndcg@all': 0.7906332015991211}
valid {'precision@10': 0.48833775520324707, 'precision@30': 0.27678099274635315, 'precision@50': 0.20563587546348572, 'ndcg@10': 0.5814154148101807, 'ndcg@30': 0.5071789622306824, 'ndcg@50': 0.4969978630542755, 'ndcg@all': 0.6704376339912415}
Epoch 3651 1.1923572515186511 3.662954330444336
Epoch 3652 1.192594684006875 3.6688742888601205
Epoch 3653 1.193082839076282 3.667316486960963
Epoch 3654 1.1932195871196993 3.6626805757221423
Epoch 3655 1.1929067345390543 3.6691340145311857
Epoch 3656 1.193149669128552 3.659453605350695
Epoch 3657 1.1930117662887127 3.66701424749274
Epoch 3658 1.1923197936593442 3.6624848591653922
Epoch 3659 1.192896730

Epoch 3782 1.1926586920754951 3.675624370574951
Epoch 3783 1.1915153181343747 3.6780947635048316
Epoch 3784 1.191576588223552 3.679460839221352
Epoch 3785 1.1930269222510488 3.6819898203799597
Epoch 3786 1.192685009443272 3.6784628315975794
Epoch 3787 1.192177056569105 3.6867520934657048
Epoch 3788 1.1922040051884122 3.6767037165792367
Epoch 3789 1.1913365251139592 3.680720191252859
Epoch 3790 1.1927555855254681 3.678916491960224
Epoch 3791 1.1933452274367127 3.6792400385204114
Epoch 3792 1.1926015632891516 3.6772612270555998
Epoch 3793 1.1938759611363996 3.683207348773354
Epoch 3794 1.1936338672861022 3.6801062257666337
Epoch 3795 1.1928157851709957 3.682639711781552
Epoch 3796 1.1921505060112267 3.6837403272327625
Epoch 3797 1.1926204478531552 3.6791226487410698
Epoch 3798 1.193779528838152 3.680831143730565
Epoch 3799 1.1929521954547593 3.6784214095065466
Epoch 3800 1.1922543359081648 3.6758317445453845
Time1:  0.47955929674208164
Time2:  0.049739863723516464
Time3:  0.1642654035240

Epoch 3912 1.192625225635997 3.6846469828957007
Epoch 3913 1.1926448083760446 3.68156950097335
Epoch 3914 1.191492373831788 3.687700346896523
Epoch 3915 1.1914099768588418 3.6864226115377328
Epoch 3916 1.1916998645715546 3.6855945085224353
Epoch 3917 1.191954577526851 3.6851682035546554
Epoch 3918 1.193025167225397 3.689168189701281
Epoch 3919 1.1918830233707762 3.6879144718772485
Epoch 3920 1.1924705930620607 3.6882551218333997
Epoch 3921 1.190376743238572 3.6862873780099967
Epoch 3922 1.1931496882996364 3.6895087769157007
Epoch 3923 1.1930017262174373 3.6886807240937887
Epoch 3924 1.1919596094137046 3.6833688836348686
Epoch 3925 1.1912752463106524 3.688190121399729
Epoch 3926 1.1927650285046003 3.691995143890381
Epoch 3927 1.1922272164919223 3.691216970744886
Epoch 3928 1.191815760051995 3.6913309850190816
Epoch 3929 1.1917420427004497 3.6854057688462105
Epoch 3930 1.1908962555796083 3.687919955504568
Epoch 3931 1.1939883486569276 3.696187006799798
Epoch 3932 1.1921397907692088 3.686

Epoch 4051 1.1918307239549202 3.697535891281931
Epoch 4052 1.1922951643927056 3.692358506353278
Epoch 4053 1.191576717889797 3.6974845810940393
Epoch 4054 1.191473843061436 3.697029000834415
Epoch 4055 1.1915490958425734 3.6956940952100252
Epoch 4056 1.191616234026457 3.6948712123067757
Epoch 4057 1.1909439054846067 3.7042399707593416
Epoch 4058 1.1921514872221919 3.695414718828703
Epoch 4059 1.190964865405657 3.701478029552259
Epoch 4060 1.191676868332757 3.697339835919832
Epoch 4061 1.1917549674971062 3.69633140062031
Epoch 4062 1.1911470792446917 3.6931953806626168
Epoch 4063 1.190962715107098 3.6979136216013053
Epoch 4064 1.1915043961932088 3.6939223189102974
Epoch 4065 1.1908300261051334 3.690528944918984
Epoch 4066 1.1906331300038344 3.6958864237132825
Epoch 4067 1.191953870636678 3.6927596016934046
Epoch 4068 1.192579820839285 3.6993020082774914
Epoch 4069 1.1905107639337842 3.6960222344649467
Epoch 4070 1.1932919147418954 3.7013486561022306
Epoch 4071 1.1915936599000854 3.69389

Epoch 4194 1.1916663685040167 3.7071788561971566
Epoch 4195 1.1916528033931353 3.705935264888563
Epoch 4196 1.1908640335177818 3.7061458888806795
Epoch 4197 1.1901508123553983 3.710914260462711
Epoch 4198 1.1915482599832858 3.7097510664086593
Epoch 4199 1.191786771977854 3.712395479804591
Epoch 4200 1.1914350840083339 3.71233774486341
Time1:  0.47860129550099373
Time2:  0.04967091977596283
Time3:  0.16438000090420246
Time1:  0.053690847009420395
Time2:  0.008043956011533737
Time3:  0.018899399787187576

train {'precision@10': 0.6327332258224487, 'precision@30': 0.3471432626247406, 'precision@50': 0.24959243834018707, 'ndcg@10': 0.7912788391113281, 'ndcg@30': 0.6728547811508179, 'ndcg@50': 0.6536798477172852, 'ndcg@all': 0.7923296093940735}
valid {'precision@10': 0.48860159516334534, 'precision@30': 0.2749340236186981, 'precision@50': 0.20430606603622437, 'ndcg@10': 0.5802209973335266, 'ndcg@30': 0.504845142364502, 'ndcg@50': 0.49497970938682556, 'ndcg@all': 0.6690194010734558}
Epoch 42

Epoch 4324 1.1902811924616497 3.717664969594855
Epoch 4325 1.189992007804893 3.7245326167658757
Epoch 4326 1.192484137956162 3.718633162347894
Epoch 4327 1.1916824924318414 3.7131298843183016
Epoch 4328 1.1921954639473853 3.7124556114799097
Epoch 4329 1.1918721997249893 3.7164385444239567
Epoch 4330 1.1919129968386644 3.7136423713282536
Epoch 4331 1.1913063902603953 3.718430468910619
Epoch 4332 1.1921423921111034 3.7170152538701107
Epoch 4333 1.1928337784538492 3.7141535533101937
Epoch 4334 1.190712336908307 3.715080687874242
Epoch 4335 1.1903164616802282 3.7170440146797583
Epoch 4336 1.1918157269383034 3.7129639324389005
Epoch 4337 1.1911606419156169 3.721477596383346
Epoch 4338 1.1904674242114464 3.71507423802426
Epoch 4339 1.190175391777217 3.718521682839645
Epoch 4340 1.1914261586484853 3.714700749045924
Epoch 4341 1.1924030599538347 3.7183845670599687
Epoch 4342 1.1888802818387572 3.716432860023097
Epoch 4343 1.1901181131078487 3.718793605503283
Epoch 4344 1.191279693305144 3.7137

Epoch 4454 1.1897095072339152 3.7263473209581877
Epoch 4455 1.1907009660151966 3.7259935328834937
Epoch 4456 1.1920762811487877 3.7280005279340243
Epoch 4457 1.1896402089916476 3.723363060700266
Epoch 4458 1.19157490256237 3.72905579366182
Epoch 4459 1.191259024784579 3.7277876954329643
Epoch 4460 1.1903925586862174 3.724735699201885
Epoch 4461 1.1896651917033725 3.731282999641017
Epoch 4462 1.1899849929307635 3.7213162999404106
Epoch 4463 1.1897990804666665 3.7340313384407446
Epoch 4464 1.1903305893753007 3.7259052050741097
Epoch 4465 1.1901094749657035 3.7244349529868677
Epoch 4466 1.1903130614269546 3.7312541258962533
Epoch 4467 1.1901471144274662 3.7271989521227384
Epoch 4468 1.1909509761291637 3.726009833185296
Epoch 4469 1.1910028970032407 3.7241711741999577
Epoch 4470 1.1913873372022172 3.7236947511371814
Epoch 4471 1.1903377231101544 3.724169505269904
Epoch 4472 1.19044104758759 3.727147491354691
Epoch 4473 1.1913169479509542 3.7220175140782406
Epoch 4474 1.191837723840747 3.72

Epoch 4597 1.1913482208698116 3.730882067429392
Epoch 4598 1.1897837599815682 3.727907268624557
Epoch 4599 1.1908460203667133 3.7391460820248255
Epoch 4600 1.189570987433718 3.7268323270898116
Time1:  0.48012007586658
Time2:  0.04951302148401737
Time3:  0.16425329819321632
Time1:  0.05348131246864796
Time2:  0.008081275969743729
Time3:  0.019031476229429245

train {'precision@10': 0.6353134512901306, 'precision@30': 0.3506421148777008, 'precision@50': 0.25204360485076904, 'ndcg@10': 0.7927646636962891, 'ndcg@30': 0.6751971244812012, 'ndcg@50': 0.6560013890266418, 'ndcg@all': 0.7938013076782227}
valid {'precision@10': 0.4852770268917084, 'precision@30': 0.2764819860458374, 'precision@50': 0.20608969032764435, 'ndcg@10': 0.5765466690063477, 'ndcg@30': 0.5036401152610779, 'ndcg@50': 0.49409908056259155, 'ndcg@all': 0.6680598258972168}
Epoch 4601 1.1896837827754996 3.732344075253135
Epoch 4602 1.1893633942157902 3.729081028386166
Epoch 4603 1.1901702636863754 3.7388489497335335
Epoch 4604 

Epoch 4727 1.1892588867081537 3.7468808575680383
Epoch 4728 1.1904896062019972 3.744225288692274
Epoch 4729 1.189841447866451 3.7370326268045524
Epoch 4730 1.1903995739089117 3.7434855134863603
Epoch 4731 1.1915097888450177 3.739301505841707
Epoch 4732 1.1893500866945723 3.7451886126869605
Epoch 4733 1.1910221691717182 3.7417986894908704
Epoch 4734 1.1905440953042772 3.7455725293410453
Epoch 4735 1.190063734849294 3.741906580172087
Epoch 4736 1.189807280810953 3.7416234267385384
Epoch 4737 1.1899431926465174 3.737865836996781
Epoch 4738 1.189521410660437 3.7467691270928634
Epoch 4739 1.1896859027488886 3.743713742808292
Epoch 4740 1.1915455823056182 3.7458613044337223
Epoch 4741 1.1917886078706261 3.743445772873728
Epoch 4742 1.1898383777043973 3.744497675644724
Epoch 4743 1.189900403134307 3.7456717114699516
Epoch 4744 1.1903628473393402 3.735151454022056
Epoch 4745 1.1900857819451227 3.7497260319559196
Epoch 4746 1.192620342935038 3.7460539717423287
Epoch 4747 1.1895781755447388 3.74

Epoch 4857 1.1910690675702011 3.744597309514096
Epoch 4858 1.1888161984800596 3.7496253942188464
Epoch 4859 1.1893282797601488 3.7490815237948767
Epoch 4860 1.1891282204298945 3.744137977298937
Epoch 4861 1.19089509834323 3.7531441387377287
Epoch 4862 1.1901542220199317 3.7535357224313834
Epoch 4863 1.1908886307164241 3.7562148069080554
Epoch 4864 1.1901121132555064 3.7559164950722144
Epoch 4865 1.188481344465624 3.7525775683553597
Epoch 4866 1.1907926627069887 3.7535735305986906
Epoch 4867 1.1898452217815911 3.7529687254052413
Epoch 4868 1.1888747734633105 3.7483897083684017
Epoch 4869 1.189007830550099 3.7516904002741764
Epoch 4870 1.1901632997027614 3.7468401632810893
Epoch 4871 1.191444562889679 3.745100824456466
Epoch 4872 1.188799940703208 3.75021263172752
Epoch 4873 1.189607061837849 3.74874798875106
Epoch 4874 1.1897988511107818 3.7530726759057296
Epoch 4875 1.1896258017473054 3.7551655518381217
Epoch 4876 1.190208353494343 3.751370404895983
Epoch 4877 1.1892231591263709 3.7529

In [28]:
def deterministic_neural_sort(s, tau, mask):
    """
    Deterministic neural sort.
    Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
    Minor modifications applied to the original code (masking).
    :param s: values to sort, shape [batch_size, slate_length]
    :param tau: temperature for the final softmax function
    :param mask: mask indicating padded elements
    :return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
    """
    dev = get_torch_device()

    n = s.size()[1]
    one = torch.ones((n, 1), dtype=torch.float32, device=dev)
    s = s.masked_fill(mask[:, :, None], -1e8)
    A_s = torch.abs(s - s.permute(0, 2, 1))
    A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)

    B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))

    temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
    temp = [t.type(torch.float32) for t in temp]
    temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
    scaling = torch.stack(temp).type(torch.float32).to(dev)  # type: ignore

    s = s.masked_fill(mask[:, :, None], 0.0)
    C = torch.matmul(s, scaling.unsqueeze(-2))

    P_max = (C - B).permute(0, 2, 1)
    P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
    P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
    sm = torch.nn.Softmax(-1)
    P_hat = sm(P_max / tau)
    return P_hat

def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
    """
    Sinkhorn scaling procedure.
    :param mat: a tensor of square matrices of shape N x M x M, where N is batch size
    :param mask: a tensor of masks of shape N x M
    :param tol: Sinkhorn scaling tolerance
    :param max_iter: maximum number of iterations of the Sinkhorn scaling
    :return: a tensor of (approximately) doubly stochastic matrices
    """
    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
        mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)

    for _ in range(max_iter):
        mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
        mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)

        if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
            break

    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)

    return mat

def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10):
    """
    Stochastic neural sort. Please note that memory complexity grows by factor n_samples.
    Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
    Minor modifications applied to the original code (masking).
    :param s: values to sort, shape [batch_size, slate_length]
    :param n_samples: number of samples (approximations) for each permutation matrix
    :param tau: temperature for the final softmax function
    :param mask: mask indicating padded elements
    :param beta: scale parameter for the Gumbel distribution
    :param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation
    :param eps: epsilon for the logarithm function
    :return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
    """
    dev = get_torch_device()

    batch_size = s.size()[0]
    n = s.size()[1]
    s_positive = s + torch.abs(s.min())
    samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev)
    if log_scores:
        s_positive = torch.log(s_positive + eps)

    s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1)
    mask_repeated = mask.repeat_interleave(n_samples, dim=0)

    P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated)
    P_hat = P_hat.view(n_samples, batch_size, n, n)
    return P_hat

def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE):
    """
    Discounted Cumulative Gain at k.
    Compute DCG at ranks given by ats or at the maximum rank if ats is None.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
    :param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
    :param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
    """
    y_true = y_true.clone()
    y_pred = y_pred.clone()

    actual_length = y_true.shape[1]

    if ats is None:
        ats = [actual_length]
    ats = [min(at, actual_length) for at in ats]

    true_sorted_by_preds = __apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator)

    discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
        device=true_sorted_by_preds.device)

    gains = gain_function(true_sorted_by_preds)

    discounted_gains = (gains * discounts)[:, :np.max(ats)]

    cum_dcg = torch.cumsum(discounted_gains, dim=1)

    ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)

    dcg = cum_dcg[:, ats_tensor]

    return dcg

def get_torch_device():
    """
    Getter for an available pyTorch device.
    :return: CUDA-capable GPU if available, CPU otherwise
    """
    return "cpu"
    return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
               stochastic=False, n_samples=32, beta=0.1, log_scores=True):
    """
    NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
    Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param temperature: temperature for the NeuralSort algorithm
    :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
    :param k: rank at which the loss is truncated
    :param stochastic: whether to calculate the stochastic variant
    :param n_samples: how many stochastic samples are taken, used if stochastic == True
    :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
    :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
    :return: loss value, a torch.Tensor
    """
    dev = get_torch_device()

    if k is None:
        k = y_true.shape[1]

    mask = (y_true == padded_value_indicator)
    # Choose the deterministic/stochastic variant
    if stochastic:
        P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
                                       beta=beta, log_scores=log_scores)
    else:
        P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)

    # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
    P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
                             mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
    P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])

    # Mask P_hat and apply to true labels, ie approximately sort them
    P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
    y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
    if powered_relevancies:
        y_true_masked = torch.pow(2., y_true_masked) - 1.

    ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
    discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
    discounted_gains = ground_truth * discounts

    if powered_relevancies:
        idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0)
    else:
        idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)

    discounted_gains = discounted_gains[:, :, :k]
    ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
    idcg_mask = idcg == 0.
    ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)

    assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
    if idcg_mask.all():
        return torch.tensor(0.)

    mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0])  # type: ignore
    return -1. * mean_ndcg  # -1 cause we want to maximize NDCG


def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1.,
                          powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True,
                          max_iter=50, tol=1e-6):
    """
    NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
    Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param temperature: temperature for the NeuralSort algorithm
    :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
    :param k: rank at which the loss is truncated
    :param stochastic: whether to calculate the stochastic variant
    :param n_samples: how many stochastic samples are taken, used if stochastic == True
    :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
    :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
    :param max_iter: maximum iteration count for Sinkhorn scaling
    :param tol: tolerance for Sinkhorn scaling
    :return: loss value, a torch.Tensor
    """
    dev = get_torch_device()

    if k is None:
        k = y_true.shape[1]

    mask = (y_true == padded_value_indicator)

    if stochastic:
        P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
                                       beta=beta, log_scores=log_scores)
    else:
        P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)

    # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
    P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
                                    mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
    P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
    discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)

    # This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
    discounts[k:] = 0.
    discounts = discounts[None, None, :, None]

    # Here the discounts become expected discounts
    discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1)
    if powered_relevancies:
        gains = torch.pow(2., y_true) - 1
        discounted_gains = gains.unsqueeze(0) * discounts
        idcg = dcg(y_true, y_true, ats=[k]).squeeze()
    else:
        gains = y_true
        discounted_gains = gains.unsqueeze(0) * discounts
        idcg = dcg(y_true, y_true, ats=[k]).squeeze()

    ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS)
    idcg_mask = idcg == 0.
    ndcg = ndcg.masked_fill(idcg_mask, 0.)

    assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
    if idcg_mask.all():
        return torch.tensor(0.)

    mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0])  # type: ignore
    return -1. * mean_ndcg  # -1 cause we want to maximize NDCG

NameError: name 'PADDED_Y_VALUE' is not defined

In [None]:
def lambdaLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, weighing_scheme=None, k=None, sigma=1., mu=10.,
               reduction="sum", reduction_log="binary"):
    """
    LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
    Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param weighing_scheme: a string corresponding to a name of one of the weighing schemes
    :param k: rank at which the loss is truncated
    :param sigma: score difference weight used in the sigmoid function
    :param mu: optional weight used in NDCGLoss2++ weighing scheme
    :param reduction: losses reduction method, could be either a sum or a mean
    :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
    :return: loss value, a torch.Tensor
    """
    device = y_pred.device
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    padded_mask = y_true == padded_value_indicator
    y_pred[padded_mask] = float("-inf")
    y_true[padded_mask] = float("-inf")

    # Here we sort the true and predicted relevancy scores.
    y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
    y_true_sorted, _ = y_true.sort(descending=True, dim=-1)

    # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
    true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
    true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
    print(true_sorted_by_preds.shape)
    print(true_diffs.shape)
    padded_pairs_mask = torch.isfinite(true_diffs)

    if weighing_scheme != "ndcgLoss1_scheme":
        padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)

    ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
    ndcg_at_k_mask[:k, :k] = 1

    # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
    true_sorted_by_preds.clamp_(min=0.)
    y_true_sorted.clamp_(min=0.)

    # Here we find the gains, discounts and ideal DCGs per slate.
    pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
    D = torch.log2(1. + pos_idxs.float())[None, :]
    maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
    G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]

    # Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
    if weighing_scheme is None:
        weights = 1.
    else:
        weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds)  # type: ignore

    # We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
    scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
    scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
    weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
    if reduction_log == "natural":
        losses = torch.log(weighted_probas)
    elif reduction_log == "binary":
        losses = torch.log2(weighted_probas)
    else:
        raise ValueError("Reduction logarithm base can be either natural or binary")

    if reduction == "sum":
        loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
    elif reduction == "mean":
        loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
    else:
        raise ValueError("Reduction method can be either sum or mean")

    return loss


def ndcgLoss1_scheme(G, D, *args):
    return (G / D)[:, :, None]


def ndcgLoss2_scheme(G, D, *args):
    pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device)
    delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :])
    deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.))
    deltas.diagonal().zero_()

    return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :])


def lambdaRank_scheme(G, D, *args):
    return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs(G[:, :, None] - G[:, None, :])


def ndcgLoss2PP_scheme(G, D, *args):
    return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D)


def rankNet_scheme(G, D, *args):
    return 1.


def rankNetWeightedByGTDiff_scheme(G, D, *args):
    return torch.abs(args[1][:, :, None] - args[1][:, None, :])


def rankNetWeightedByGTDiffPowed_scheme(G, D, *args):
    return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2))