### raw data
* word embedding: glove
* doc text: ./data/IMDB.txt

### dataset
1. IMDB
2. CNNNews
3. [PubMed](https://github.com/LIAAD/KeywordExtractor-Datasets/blob/master/datasets/PubMed.zip)

### preprocess
1. filter too frequent and less frequent words
2. stemming
3. document vector aggregation

### model
1. TopK
2. Sklearn
3. Our model

### evaluation
1. F1
2. NDCG

In [None]:
import os
from collections import defaultdict
import math
import numpy as np 
import random
import re
import torch
import torch.nn as nn
from itertools import cycle
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from tqdm.auto import tqdm

# Used to get the data
from sklearn.metrics import ndcg_score

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import nltk
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords
from nltk.collocations import BigramAssocMeasures, BigramCollocationFinder
nltk.download('stopwords')

import matplotlib.pyplot as plt 
import matplotlib
matplotlib.use('Agg')

seed = 33
import pandas as pd

## Preprocess config

In [None]:
config = {}

config["dataset"] = "CNN" # "IMDB" "CNN", "PubMed"
config["n_document"] = 100
config["normalize_word_embedding"] = False
config["min_word_freq_threshold"] = 20
config["topk_word_freq_threshold"] = 100
config["document_vector_agg_weight"] = 'IDF' # ['mean', 'IDF', 'uniform', 'gaussian', 'exponential', 'pmi']
config["document_vector_weight_normalize"] = True # weighted sum or mean, True for mean, False for sum 
config["select_topk_TFIDF"] = None # ignore
config["embedding_file"] = "../data/glove.6B.100d.txt"
config["topk"] = [10, 30, 50]


In [None]:
def in_notebook():
    try:
        from IPython import get_ipython
        if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
            return False
    except ImportError:
        return False
    return True

In [None]:
def load_word2emb(embedding_file):
    word2embedding = dict()
    word_dim = int(re.findall(r".(\d+)d", embedding_file)[0])

    with open(embedding_file, "r") as f:
        for line in tqdm(f):
            line = line.strip().split()
            word = line[0]
            embedding = list(map(float, line[1:]))
            word2embedding[word] = np.array(embedding)

    print("Number of words:%d" % len(word2embedding))

    return word2embedding

word2embedding = load_word2emb(config["embedding_file"])

In [None]:
def normalize_wordemb(word2embedding):
    # Every word emb should have norm 1
    
    word_emb = []
    word_list = []
    for word, emb in word2embedding.items():
        word_list.append(word)
        word_emb.append(emb)

    word_emb = np.array(word_emb)

    for i in range(len(word_emb)):
        norm = np.linalg.norm(word_emb[i])
        word_emb[i] = word_emb[i] / norm

    for word, emb in tqdm(zip(word_list, word_emb)):
        word2embedding[word] = emb
    return word2embedding

if config["normalize_word_embedding"]:
    normalize_wordemb(word2embedding)

In [None]:
class Vocabulary:
    def __init__(self, word2embedding, config):
        # The low frequency words will be assigned as <UNK> token
        self.itos = {0: "<UNK>"}
        self.stoi = {"<UNK>": 0}
        
        self.word2embedding = word2embedding
        self.config = config

        self.word_freq_in_corpus = defaultdict(int)
        self.IDF = {}
        self.ps = PorterStemmer()
        self.stop_words = set(stopwords.words('english'))
        
        self.word_dim = len(word2embedding['the'])
    def __len__(self):
        return len(self.itos)

    def tokenizer_eng(self, text):
        text = re.sub(r'[^A-Za-z0-9 ]+', '', text)
        text = text.strip().split()
        
        return [self.ps.stem(w) for w in text if w.lower() not in self.stop_words]
    
    def read_raw(self):        
        if self.config["dataset"] == 'IMDB':
            data_file_path = '../data/IMDB.txt'
        elif self.config["dataset"] == 'CNN':
            data_file_path = '../data/CNN.txt'
        elif self.config["dataset"] == 'PubMed':
            data_file_path = '../data/PubMed.txt'
        
        # raw documents
        self.raw_documents = []
        with open(data_file_path,'r',encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading documents"):
                self.raw_documents.append(line.strip("\n"))
                
        return self.raw_documents
    
    def build_vocabulary(self):
        sentence_list = self.raw_documents
        
        self.doc_freq = defaultdict(int) # # of document a word appear
        self.document_num = len(sentence_list)
        self.word_vectors = [[0]*self.word_dim] # unknown word emb
        
        for sentence in tqdm(sentence_list, desc="Preprocessing documents"):
            # for doc_freq
            document_words = set()
            
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.word2embedding:
                    continue
                    
                # calculate word freq
                self.word_freq_in_corpus[word] += 1
                document_words.add(word)
                
            for word in document_words:
                self.doc_freq[word] += 1
        
        # calculate IDF
        print('doc num', self.document_num)
        for word, freq in self.doc_freq.items():
            self.IDF[word] = math.log(self.document_num / (freq+1))
        
        # delete less freq words:
        delete_words = []
        for word, v in self.word_freq_in_corpus.items():
            if v < self.config["min_word_freq_threshold"]:
                delete_words.append(word)     
        for word in delete_words:
            del self.IDF[word]    
            del self.word_freq_in_corpus[word]    
        
        # delete too freq words
        print('eliminate freq words')
        IDF = [(word, freq) for word, freq in self.IDF.items()]
        IDF.sort(key=lambda x: x[1])

        for i in range(self.config["topk_word_freq_threshold"]):
            print(word)
            word = IDF[i][0]
            del self.IDF[word]
            del self.word_freq_in_corpus[word]
        
        # construct word_vectors
        idx = 1
        for word in self.word_freq_in_corpus:
            self.word_vectors.append(self.word2embedding[word])
            self.stoi[word] = idx
            self.itos[idx] = word
            idx += 1
            
    def init_word_weight(self,sentence_list, agg):
        if agg == 'mean':
            self.word_weight = {word: 1 for word in self.IDF.keys()}
        elif agg == 'IDF':
            self.word_weight = self.IDF
        elif agg == 'uniform':
            self.word_weight = {word: np.random.uniform(low=0.0, high=1.0) for word in self.IDF.keys()}
        elif agg == 'gaussian':
            mu, sigma = 10, 1 # mean and standard deviation
            self.word_weight = {word: np.random.normal(mu, sigma) for word in self.IDF.keys()}
        elif agg == 'exponential':
            self.word_weight = {word: np.random.exponential(scale=1.0) for word in self.IDF.keys()}
        elif agg == 'pmi':
            trigram_measures = BigramAssocMeasures()
            self.word_weight = defaultdict(int)
            corpus = []

            for text in tqdm(sentence_list):
                corpus.extend(text.split())

            finder = BigramCollocationFinder.from_words(corpus)
            for pmi_score in finder.score_ngrams(trigram_measures.pmi):
                pair, score = pmi_score
                self.word_weight[pair[0]] += score
                self.word_weight[pair[1]] += score
                
    def calculate_document_vector(self):
        # Return
        # document_vectors: weighted sum of word emb
        # document_answers_idx: doc to word index list
        # document_answers_wsum: word weight summation, e.g. total TFIDF score of a doc
        
        document_vectors = [] 
        document_answers = []
        document_answers_wsum = []
        
        sentence_list = self.raw_documents
        agg = self.config["document_vector_agg_weight"]
        n_document = self.config["n_document"]
        select_topk_TFIDF = self.config["select_topk_TFIDF"]
        
        self.init_word_weight(sentence_list, agg)
        for sentence in tqdm(sentence_list[:min(n_document, len(sentence_list))], desc="calculate document vectors"):
            document_vector = np.zeros(len(self.word_vectors[0]))
            select_words = []
            for word in self.tokenizer_eng(sentence):
                # pass unknown word
                if word not in self.stoi:
                    continue
                else:
                    select_words.append(word)

            # select topk TDIDF
            if select_topk_TFIDF is not None:
                doc_TFIDF = defaultdict(float)
                for word in select_words:    
                    doc_TFIDF[word] += self.IDF[word]

                doc_TFIDF_l = [(word, TFIDF) for word, TFIDF in doc_TFIDF.items()]
                doc_TFIDF_l.sort(key=lambda x:x[1], reverse=True)
                
                select_topk_words = set(list(map(lambda x:x[0], doc_TFIDF_l[:select_topk_TFIDF])))
                select_words = [word for word in select_words if word in select_topk_words]
            else:
                pass
            
            total_weight = 0
            # aggregate to doc vectors
            for word in select_words:
                document_vector += np.array(self.word2embedding[word]) * self.word_weight[word]
                total_weight += self.word_weight[word]
                
            if len(select_words) == 0:
                print('error', sentence)
                continue
            else:
                if self.config["document_vector_weight_normalize"]:
                    document_vector /= total_weight
                    total_weight = 1
            
            document_vectors.append(document_vector)
            document_answers.append(select_words)
            document_answers_wsum.append(total_weight)
        
        # get answers
        document_answers_idx = []    
        for ans in document_answers:
            ans_idx = []
            for token in ans:
                if token in self.stoi:
                    ans_idx.append(self.stoi[token])                    
            document_answers_idx.append(ans_idx)
        
        self.document_vectors = document_vectors
        self.document_answers_idx = document_answers_idx
        self.document_answers_wsum = document_answers_wsum
        
        return document_vectors, document_answers_idx, document_answers_wsum
        
    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
    
    def check_docemb(self):
        word_vectors = np.array(self.word_vectors)
        pred = np.zeros(word_vectors.shape[1])
        cnt = 0

        for word_idx in self.document_answers_idx[0]:
            pred += word_vectors[word_idx] * self.word_weight[self.itos[word_idx]]
            cnt += self.word_weight[self.itos[word_idx]]
        
        if self.config["document_vector_weight_normalize"]:
            pred /= cnt
        assert np.sum(self.document_vectors[0]) - np.sum(pred) == 0

In [None]:
def build_vocab(config, word2embedding):
    # build vocabulary
    vocab = Vocabulary(word2embedding, config)
    vocab.read_raw()
    vocab.build_vocabulary()
    vocab_size = len(vocab)
    # get doc emb
    vocab.calculate_document_vector()
    vocab.check_docemb()
    
    return vocab

vocab = build_vocab(config, word2embedding)

In [None]:
print("Finish building dataset!")
print(f"Number of documents:{len(vocab.raw_documents)}")
print(f"Number of words:{len(vocab)}")

l = list(map(len, vocab.document_answers_idx))
print("Average length of document:", np.mean(l))

In [None]:
word_vectors = np.array(vocab.word_vectors)
print("word_vectors:", word_vectors.shape)

document_vectors = np.array(vocab.document_vectors)
print("document_vectors", document_vectors.shape)

document_answers_wsum = np.array(vocab.document_answers_wsum).reshape(-1, 1)
print("document_answers_wsum", document_answers_wsum.shape)

# create weight_ans
document_answers_idx = vocab.document_answers_idx

# random shuffle
shuffle_idx = list(range(len(document_vectors)))
random.Random(seed).shuffle(shuffle_idx)

document_vectors = document_vectors[shuffle_idx]
document_answers_wsum = document_answers_wsum[shuffle_idx]
document_answers_idx = [document_answers_idx[idx] for idx in shuffle_idx]

In [None]:
# onthot_ans: word freq matrix
# weight_ans: TFIDF matrix

onehot_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
weight_ans = np.zeros((len(document_answers_idx), word_vectors.shape[0]))
print(weight_ans.shape)

for i in tqdm(range(len(document_answers_idx))):
    for word_idx in document_answers_idx[i]:
        weight_ans[i, word_idx] += vocab.word_weight[vocab.itos[word_idx]]
        onehot_ans[i, word_idx] += 1
        
    if config["document_vector_weight_normalize"]:
        weight_ans[i] /= np.sum(weight_ans[i])

In [None]:
# check
assert np.sum(document_vectors - np.dot(weight_ans, word_vectors) > 1e-10) == 0

## Results

In [None]:
final_results = []
select_columns = ['model']
for topk in config["topk"]:
    select_columns.append('percision@{}'.format(topk))
for topk in config["topk"]:
    select_columns.append('recall@{}'.format(topk))
for topk in config["topk"]:
    select_columns.append('F1@{}'.format(topk))
for topk in config["topk"]:
    select_columns.append('ndcg@{}'.format(topk))
select_columns.append('ndcg@all')
select_columns

## setting training size

In [None]:
train_size_ratio = 1
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

## Top K freq word

In [None]:
topk_results = {}

In [None]:
test_ans = document_answers_idx[:train_size]

In [None]:
word_freq = [(word, freq) for word, freq in vocab.word_freq_in_corpus.items()]
word_freq.sort(key=lambda x:x[1], reverse=True)
word_freq[:10]

In [None]:
def topk_word_evaluation(k=50):
    topk_word = [word for (word, freq) in word_freq[:k]]

    pr, re = [], []
    for ans in tqdm(test_ans):
        ans = set(ans)
        ans = [vocab.itos[a] for a in ans]

        hit = []
        for word in ans:
            if word in topk_word:
                hit.append(word)

        precision = len(hit) / k
        recall = len(hit) / len(ans)
        pr.append(precision)
        re.append(recall)

    pr = np.mean(pr)
    re = np.mean(re)
    f1 = 2 * pr * re / (pr + re) if (pr + re) != 0 else 0
    print('top {} word'.format(k))
    print('percision', np.mean(pr))
    print('recall', np.mean(re))
    print('F1', f1)
    return f1


for topk in config['topk']:
    topk_results["F1@{}".format(topk)] = topk_word_evaluation(k=topk)


In [None]:
def topk_word_evaluation_NDCG(k=50):
    freq_word =[word for (word, freq) in word_freq]
    freq_word_idx = [vocab.stoi[word] for word in freq_word if word in vocab.stoi]
    
    scores = np.zeros(len(vocab.word_vectors))
    for rank, idx in enumerate(freq_word_idx):
        scores[idx] = len(vocab.word_vectors) - rank
    
    NDCGs = []
    
    for ans in tqdm(test_ans):
        weight_ans = np.zeros(len(vocab.word_vectors))
        
        for word_idx in ans:
            if word_idx == 0:
                continue
            word = vocab.itos[word_idx]
            weight_ans[word_idx] += vocab.IDF[word]

        NDCG_score = ndcg_score(weight_ans.reshape(1,-1), scores.reshape(1,-1), k=k)
        NDCGs.append(NDCG_score)

    print('top {} NDCG:{}'.format(k, np.mean(NDCGs)))
    
    return np.mean(NDCGs)


# for topk in config['topk']:
#     topk_results["ndcg@{}".format(topk)] = topk_word_evaluation_NDCG(k=topk)
    
# topk_results["ndcg@all"] = topk_word_evaluation_NDCG(k=None)


In [None]:
topk_results["model"] = "topk"
final_results.append(pd.Series(topk_results))

## Sklearn

In [None]:
from sklearn.linear_model import LinearRegression, Lasso

In [None]:
print(document_vectors.shape)
print(weight_ans.shape)
print(word_vectors.shape)

In [None]:
def evaluate_sklearn(pred, ans):
    results = {}
        
    one_hot_ans = np.arange(ans.shape[0])[ans > 0]
    
    for topk in config["topk"]:
        one_hot_pred = np.argsort(pred)[-topk:]
        hit = np.intersect1d(one_hot_pred, one_hot_ans)
        percision = len(hit) / topk
        recall = len(hit) / len(one_hot_ans)
        f1 = 2 * percision * recall / (percision + recall) if (percision + recall) > 0 else 0
        
        results['percision@{}'.format(topk)] = percision
        results['recall@{}'.format(topk)] = recall
        results['F1@{}'.format(topk)] = f1
        
    ans = ans.reshape(1, -1)
    pred = pred.reshape(1, -1)
    for topk in config["topk"]:
        results['ndcg@{}'.format(topk)] = ndcg_score(ans, pred, k=topk)

    results['ndcg@all'] = (ndcg_score(ans, pred, k=None))
    
    return results

### linear regression

In [None]:
results = []

for doc_id, doc_emb in enumerate(tqdm(document_vectors[:train_size])):
    x = word_vectors.T
    y = doc_emb
    
    ans = weight_ans[doc_id]
    model = LinearRegression(fit_intercept=False).fit(x, y)
    r2 = model.score(x, y)

    res = evaluate_sklearn(model.coef_, ans)
    results.append(res)

In [None]:
results = pd.DataFrame(results).mean()
results['model'] = 'sk-linear-regression'
final_results.append(results)
results

### lasso

In [None]:
results = []
sk_lasso_epoch = 10000

for doc_id, doc_emb in enumerate(tqdm(document_vectors[:train_size])):
    x = word_vectors.T
    y = doc_emb
    
    ans = weight_ans[doc_id]
    model = Lasso(positive=True, fit_intercept=False, alpha=0.0001, max_iter=sk_lasso_epoch, tol=0).fit(x, y)
    r2 = model.score(x, y)

    res = evaluate_sklearn(model.coef_, ans)
    results.append(res)

In [None]:
results = pd.DataFrame(results).mean()
results['model'] = 'sk-lasso'
final_results.append(results)
results

## Our Model

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
class Custom_Lasso_Dataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 doc_w_sum,
                 weight_ans
                 ):
        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.doc_w_sum = torch.FloatTensor(doc_w_sum)
        self.weight_ans = weight_ans
        assert len(doc_vectors) == len(doc_w_sum)
        
    def __getitem__(self, idx):
                
        return self.doc_vectors[idx], self.doc_w_sum[idx], idx

    def __len__(self):
        return len(self.doc_vectors)


In [None]:
class LR(nn.Module):
    """
    Input shape: (N, 3, 64, 64)
    Output shape: (N, )
    """
    def __init__(self, num_doc, num_words):
        super(LR, self).__init__()
        weight = torch.zeros(num_doc, num_words).to(device)
        self.emb = torch.nn.Embedding.from_pretrained(weight, freeze=False)
        
    def forward(self, doc_ids, word_vectors):
        return self.emb(doc_ids) @ word_vectors

In [None]:
def evaluate_Custom_Lasso(model, train_loader):
    results = {}
    model.eval()
    
    scores = np.array(model.emb.cpu().weight.data)
    model.emb.to(device)
    true_relevance = train_loader.dataset.weight_ans

    # F1
    F1s = []
    precisions = []
    recalls = []
    for i in range(true_relevance.shape[0]):
        one_hot_ans = np.arange(true_relevance.shape[1])[true_relevance[i] > 0]
        pred = scores[i]
        
        F1_ = []
        percision_ = []
        recall_ = []
        for topk in config["topk"]:
            one_hot_pred = np.argsort(pred)[-topk:]
            
            hit = np.intersect1d(one_hot_pred, one_hot_ans)
            percision = len(hit) / topk
            recall = len(hit) / len(one_hot_ans)
            
            F1 = 2 * percision * recall / (percision + recall) if (percision + recall) > 0 else 0
            F1_.append(F1)
            percision_.append(percision)
            recall_.append(recall)
            
        F1s.append(F1_)
        precisions.append(percision_)
        recalls.append(recall_)
        
    F1s = np.mean(F1s, axis=0)
    precisions = np.mean(precisions, axis=0)
    recalls = np.mean(recalls, axis=0)
    
    for i, topk in enumerate(config["topk"]):
        results['F1@{}'.format(topk)] = F1s[i]
        results['percision@{}'.format(topk)] = precisions[i]
        results['recall@{}'.format(topk)] = recalls[i]

    # NDCG
    for topk in config["topk"]:
        results['ndcg@{}'.format(topk)] = ndcg_score(true_relevance, scores, k=topk)
    results['ndcg@all'] = ndcg_score(true_relevance, scores, k=None)
    
    return results

In [None]:
batch_size = 100
print('document num', train_size)

train_dataset = Custom_Lasso_Dataset(document_vectors[:train_size], document_answers_wsum[:train_size], weight_ans[:train_size])
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

## start training

In [None]:
# setting
lr = 0.1
momentum = 0.999
weight_decay = 0
nesterov = False # True

n_epoch = 50000

w_sum_reg = 1e-2
w_sum_reg_mul = 1
w_clip_value = 0

L1 = 1e-5

verbose = True
valid_epoch = 100

model = LR(num_doc=train_size, num_words=word_vectors.shape[0]).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)
    
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
criterion = nn.MSELoss(reduction='mean')

results = []
step = 0
for epoch in tqdm(range(n_epoch)):    
    loss_mse_his = []
    loss_w_reg_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, doc_w_sum, doc_ids = data
        
        doc_embs = doc_embs.to(device)
        doc_w_sum = doc_w_sum.to(device)
        doc_ids = doc_ids.to(device)
        
        w_reg = doc_w_sum * w_sum_reg_mul
        # w_reg = (torch.ones(doc_embs.size(0), 1) * w_sum_reg_mul).to(device)
        
        # MSE loss
        pred_doc_embs = model(doc_ids, word_vectors_tensor)     
        loss_mse = criterion(pred_doc_embs, doc_embs)

        pred_w_sum = torch.sum(model.emb(doc_ids), axis=1).view(-1, 1)
        loss_w_reg = criterion(pred_w_sum, w_reg)
        
        loss_l1 = torch.sum(torch.abs(model.emb(doc_ids)))
        loss = loss_mse + loss_w_reg * w_sum_reg + loss_l1 * L1
        
        # Model backwarding
        model.zero_grad()
        loss.backward()
        opt.step()

        loss_mse_his.append(loss_mse.item())
        loss_w_reg_his.append(loss_w_reg.item())

        for p in model.parameters():
            p.data.clamp_(w_clip_value, float('inf'))

        
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        res['loss_mse'] = np.mean(loss_mse_his)
        res['loss_w_reg'] = np.mean(loss_w_reg_his)
        
        res_ndcg = evaluate_Custom_Lasso(model, train_loader)
        res.update(res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            for k, v in res.items():
                print(k, v)

In [None]:
pd.set_option('display.max_rows', 500)
results_df = pd.DataFrame(results).set_index('epoch')
results_df

In [None]:
results_df['model'] = 'our-lasso'
final_results.append(results_df[select_columns].iloc[-1])

## Quality Check

In [None]:
# select doc_id and k
doc_id = 40
topk = 30

model

In [None]:
import colored
from colored import stylize

word_list = vocab.itos

gt = [word_list[word_idx] for word_idx in np.argsort(weight_ans[doc_id])[::-1][:topk]]
pred = [word_list[word_idx] for word_idx in np.argsort(model.emb.cpu().weight.data[doc_id].numpy())[::-1][:topk]]

print('ground truth')
for word in gt:
    if word in pred:
        print(stylize(word, colored.bg("yellow")), end=' ')
    else:
        print(word, end=' ')

print()
print('\nprediction')
for word in pred:
    if word in gt:
        print(stylize(word, colored.bg("yellow")), end=' ')
    else:
        print(word, end=' ')


In [None]:
# raw document
print()
ps = PorterStemmer()
    
for word in vocab.raw_documents[doc_id].split():
    word_stem = ps.stem(word).lower()

    if word_stem in gt:
        if word_stem in pred:
            print(stylize(word, colored.bg("yellow")), end=' ')
        else:
            print(stylize(word, colored.bg("light_gray")), end=' ')
    else:
        print(word, end=' ')
# print(dataset.documents[doc_id])

In [None]:
results = {}
   
scores = np.array(model.emb.weight.data)[doc_id].reshape(1, -1)
true_relevance = train_loader.dataset.weight_ans[doc_id].reshape(1, -1)

results['ndcg@50'] = (ndcg_score(true_relevance, scores, k=50))
results['ndcg@100'] = (ndcg_score(true_relevance, scores, k=100))
results['ndcg@200'] = (ndcg_score(true_relevance, scores, k=200))
results['ndcg@all'] = (ndcg_score(true_relevance, scores, k=None))

print('This document ndcg:')
print('ground truth length:', np.sum(weight_ans[doc_id] > 0))
print('NDCG top50', results['ndcg@50'])
print('NDCG top100', results['ndcg@100'])
print('NDCG top200', results['ndcg@200'])
print('NDCG ALL', results['ndcg@all'])


## Final results

In [None]:
is_notebook = in_notebook()

In [None]:
final_results_df = pd.DataFrame(final_results).reset_index(drop=True)

experiment_dir = './records/dataset-{}-n_document-{}-wdist-{}-filtertopk-{}'.format(
                                        config['dataset'],
                                        config['n_document'],
                                        config["document_vector_agg_weight"],
                                        config["topk_word_freq_threshold"])

print('Saving to directory', experiment_dir)
os.makedirs(experiment_dir, exist_ok=True)

In [None]:
final_results_df.to_csv(os.path.join(experiment_dir, 'result.csv'), index=False)

import json
with open(os.path.join(experiment_dir, 'config.json'), 'w') as json_file:
    json.dump(config, json_file)

In [None]:
for feat in final_results_df.set_index('model').columns:
    plt.bar(final_results_df['model'],
            final_results_df[feat], 
            width=0.5, 
            bottom=None, 
            align='center', 
            color=['lightsteelblue', 
                   'cornflowerblue', 
                   'royalblue', 
                   'navy'])
    plt.title(feat)
    plt.savefig(os.path.join(experiment_dir, '{}.png'.format(feat)))
    plt.clf()
    if is_notebook:
        plt.show()

In [None]:
print(final_results_df)
final_results_df

## MLP Decoder

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
class MLPDecoderDataset(Dataset):
    def __init__(self, 
                 doc_vectors,
                 weight_ans,
                 topk=50):
        self.doc_vectors = torch.FloatTensor(doc_vectors)
        self.weight_ans = torch.FloatTensor(weight_ans)
        self.weight_ans_s = torch.argsort(self.weight_ans, dim=1, descending=True)
        self.weight_ans_s[:, topk:] = -1
        
        assert len(doc_vectors) == len(weight_ans)
        
    def __getitem__(self, idx):
        return self.doc_vectors[idx], self.weight_ans[idx], self.weight_ans_s[idx]

    def __len__(self):
        return len(self.doc_vectors)

In [None]:
batch_size = 100

train_size_ratio = 0.9
train_size = int(len(document_answers_idx) * train_size_ratio)
train_size

print('train size', train_size)
print('valid size', len(document_vectors) - train_size)

train_dataset = MLPDecoderDataset(document_vectors[:train_size], weight_ans[:train_size], topk=50)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

valid_dataset = MLPDecoderDataset(document_vectors[train_size:], weight_ans[train_size:], topk=50)
valid_loader  = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

In [None]:
class MLPDecoder(nn.Module):
    def __init__(self, doc_emb_dim, num_words, h_dim=300):
        super().__init__()
        
        self.fc1 = nn.Linear(doc_emb_dim, h_dim) 
#         self.fc2 = nn.Linear(h_dim, h_dim)
#         self.fc3 = nn.Linear(h_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, num_words)
        self.dropout = nn.Dropout(p=0.2)
        
    def forward(self, x):
        x = F.tanh(self.fc1(x))
#         x = F.tanh(self.fc2(x))
#         x = F.relu(self.fc3(x))
        x = self.dropout(x)
        x = self.fc4(x)
        
        return x

In [None]:
from sklearn.metrics import f1_score
import timeit
from torchmetrics import F1
from torchmetrics.functional import retrieval_normalized_dcg

def evaluate_MLPDecoder(model, data_loader):
    results = {}
    model.eval()
    
    pred_all = []
    target_all = []
    
    # predict all data
    start = timeit.default_timer()
    for data in data_loader:
        doc_embs, target, _ = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
                
        pred = model(doc_embs)
        pred_all.append(pred)
        target_all.append(target)
        
    pred_all = torch.cat(pred_all, dim=0)
    target_all = torch.cat(target_all, dim=0)
    stop1 = timeit.default_timer()
    print('Time1: ', stop1 - start)  
    
    # F1
    for topk in config["topk"]:
        f1 = F1(top_k=topk)
        f1_score = f1(pred_all.cpu(), torch.sign(target_all).int().cpu())
        results['F1@{}'.format(topk)] = f1_score.item()
    stop2 = timeit.default_timer()
    print('Time2: ', stop2 - stop1)  
    
    # NDCG
    for topk in config["topk"]:
        ndcg_scores = []
        for i in range(pred_all.shape[0]):
            ndcg_scores.append(retrieval_normalized_dcg(pred_all[i], target_all[i], k=topk).item())     
        results['ndcg@{}'.format(topk)] = np.mean(ndcg_scores)
        
    ndcg_scores = []
    for i in range(pred_all.shape[0]):
        ndcg_scores.append(retrieval_normalized_dcg(pred_all[i], target_all[i]).item())     
    results['ndcg@all'] = np.mean(ndcg_scores)
    stop3 = timeit.default_timer()
    print('Time3: ', stop3 - stop2)      
    
    return results

In [None]:
DEFAULT_EPS = 1e-10
PADDED_Y_VALUE = -1

In [None]:
# setting
lr = 0.05
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 5000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        # Model backwarding
        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        
#         loss = criterion(pred, target)
        loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)
            


In [None]:
def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
    """
    ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param eps: epsilon value, used for numerical stability
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: loss value, a torch.Tensor
    """
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    mask = y_true == padded_value_indicator
    y_pred[mask] = float('-inf')
    y_true[mask] = float('-inf')

    preds_smax = F.softmax(y_pred, dim=1)
    true_smax = F.softmax(y_true, dim=1)

    preds_smax = preds_smax + eps
    preds_log = torch.log(preds_smax)

    return torch.mean(-torch.sum(true_smax * preds_log, dim=1))

In [None]:
# setting
lr = 0.05
momentum = 0.
weight_decay = 0
nesterov = False # True

n_epoch = 5000
verbose = True
valid_epoch = 50

h_dim = 3000

model = MLPDecoder(doc_emb_dim=document_vectors.shape[1], num_words=len(word_vectors), h_dim=h_dim).to(device)
model.train()

word_vectors_tensor = torch.FloatTensor(word_vectors).to(device)

opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
# opt = torch.optim.Adam(model.parameters(), lr=lr)
# criterion = nn.MSELoss(reduction='mean')
criterion = nn.MultiLabelMarginLoss(reduction='mean')
# criterion = nn.MultiLabelSoftMarginLoss(reduction='mean')

results = []
step = 0
clip_value = 1

for epoch in tqdm(range(n_epoch)):    
    train_loss_his = []
    valid_loss_his = []
    
    model.train()

    for data in train_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
#         loss = criterion(pred, target)
#         loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        loss = listNet(pred, target)
    
        # Model backwarding
        model.zero_grad()
        loss.backward()
#         torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
        
        opt.step()

        train_loss_his.append(loss.item())
        
    model.eval()
    for data in valid_loader:
        doc_embs, target, target_rank = data
        
        doc_embs = doc_embs.to(device)
        target = target.to(device)
        target_rank = target_rank.to(device)
        
        # MSE loss
        pred = model(doc_embs)     
        
#         loss = criterion(pred, target)
#         loss = criterion(pred, target_rank)
#         loss = criterion(pred, torch.sign(target))
        loss = listNet(pred, target)
    
        valid_loss_his.append(loss.item())
    
    print("Epoch", epoch, np.mean(train_loss_his), np.mean(valid_loss_his))
    
    if epoch % valid_epoch == 0:
        res = {}
        res['epoch'] = epoch
        
        train_res_ndcg = evaluate_MLPDecoder(model, train_loader)
        valid_res_ndcg = evaluate_MLPDecoder(model, valid_loader)
        
        res.update(valid_res_ndcg)
        results.append(res)
        
        if verbose:
            print()
            print('train', train_res_ndcg)
            print('valid', valid_res_ndcg)
            


In [None]:
def deterministic_neural_sort(s, tau, mask):
    """
    Deterministic neural sort.
    Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
    Minor modifications applied to the original code (masking).
    :param s: values to sort, shape [batch_size, slate_length]
    :param tau: temperature for the final softmax function
    :param mask: mask indicating padded elements
    :return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
    """
    dev = get_torch_device()

    n = s.size()[1]
    one = torch.ones((n, 1), dtype=torch.float32, device=dev)
    s = s.masked_fill(mask[:, :, None], -1e8)
    A_s = torch.abs(s - s.permute(0, 2, 1))
    A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)

    B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))

    temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
    temp = [t.type(torch.float32) for t in temp]
    temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
    scaling = torch.stack(temp).type(torch.float32).to(dev)  # type: ignore

    s = s.masked_fill(mask[:, :, None], 0.0)
    C = torch.matmul(s, scaling.unsqueeze(-2))

    P_max = (C - B).permute(0, 2, 1)
    P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
    P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
    sm = torch.nn.Softmax(-1)
    P_hat = sm(P_max / tau)
    return P_hat

def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
    """
    Sinkhorn scaling procedure.
    :param mat: a tensor of square matrices of shape N x M x M, where N is batch size
    :param mask: a tensor of masks of shape N x M
    :param tol: Sinkhorn scaling tolerance
    :param max_iter: maximum number of iterations of the Sinkhorn scaling
    :return: a tensor of (approximately) doubly stochastic matrices
    """
    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
        mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)

    for _ in range(max_iter):
        mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
        mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)

        if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
            break

    if mask is not None:
        mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)

    return mat

def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10):
    """
    Stochastic neural sort. Please note that memory complexity grows by factor n_samples.
    Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
    Minor modifications applied to the original code (masking).
    :param s: values to sort, shape [batch_size, slate_length]
    :param n_samples: number of samples (approximations) for each permutation matrix
    :param tau: temperature for the final softmax function
    :param mask: mask indicating padded elements
    :param beta: scale parameter for the Gumbel distribution
    :param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation
    :param eps: epsilon for the logarithm function
    :return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
    """
    dev = get_torch_device()

    batch_size = s.size()[0]
    n = s.size()[1]
    s_positive = s + torch.abs(s.min())
    samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev)
    if log_scores:
        s_positive = torch.log(s_positive + eps)

    s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1)
    mask_repeated = mask.repeat_interleave(n_samples, dim=0)

    P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated)
    P_hat = P_hat.view(n_samples, batch_size, n, n)
    return P_hat

def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE):
    """
    Discounted Cumulative Gain at k.
    Compute DCG at ranks given by ats or at the maximum rank if ats is None.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
    :param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
    :param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
    """
    y_true = y_true.clone()
    y_pred = y_pred.clone()

    actual_length = y_true.shape[1]

    if ats is None:
        ats = [actual_length]
    ats = [min(at, actual_length) for at in ats]

    true_sorted_by_preds = __apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator)

    discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
        device=true_sorted_by_preds.device)

    gains = gain_function(true_sorted_by_preds)

    discounted_gains = (gains * discounts)[:, :np.max(ats)]

    cum_dcg = torch.cumsum(discounted_gains, dim=1)

    ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)

    dcg = cum_dcg[:, ats_tensor]

    return dcg

def get_torch_device():
    """
    Getter for an available pyTorch device.
    :return: CUDA-capable GPU if available, CPU otherwise
    """
    return "cpu"
    return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
               stochastic=False, n_samples=32, beta=0.1, log_scores=True):
    """
    NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
    Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param temperature: temperature for the NeuralSort algorithm
    :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
    :param k: rank at which the loss is truncated
    :param stochastic: whether to calculate the stochastic variant
    :param n_samples: how many stochastic samples are taken, used if stochastic == True
    :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
    :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
    :return: loss value, a torch.Tensor
    """
    dev = get_torch_device()

    if k is None:
        k = y_true.shape[1]

    mask = (y_true == padded_value_indicator)
    # Choose the deterministic/stochastic variant
    if stochastic:
        P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
                                       beta=beta, log_scores=log_scores)
    else:
        P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)

    # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
    P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
                             mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
    P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])

    # Mask P_hat and apply to true labels, ie approximately sort them
    P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
    y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
    if powered_relevancies:
        y_true_masked = torch.pow(2., y_true_masked) - 1.

    ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
    discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
    discounted_gains = ground_truth * discounts

    if powered_relevancies:
        idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0)
    else:
        idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)

    discounted_gains = discounted_gains[:, :, :k]
    ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
    idcg_mask = idcg == 0.
    ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)

    assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
    if idcg_mask.all():
        return torch.tensor(0.)

    mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0])  # type: ignore
    return -1. * mean_ndcg  # -1 cause we want to maximize NDCG


def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1.,
                          powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True,
                          max_iter=50, tol=1e-6):
    """
    NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
    Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
    :param y_pred: predictions from the model, shape [batch_size, slate_length]
    :param y_true: ground truth labels, shape [batch_size, slate_length]
    :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
    :param temperature: temperature for the NeuralSort algorithm
    :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
    :param k: rank at which the loss is truncated
    :param stochastic: whether to calculate the stochastic variant
    :param n_samples: how many stochastic samples are taken, used if stochastic == True
    :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
    :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
    :param max_iter: maximum iteration count for Sinkhorn scaling
    :param tol: tolerance for Sinkhorn scaling
    :return: loss value, a torch.Tensor
    """
    dev = get_torch_device()

    if k is None:
        k = y_true.shape[1]

    mask = (y_true == padded_value_indicator)

    if stochastic:
        P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
                                       beta=beta, log_scores=log_scores)
    else:
        P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)

    # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
    P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
                                    mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
    P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
    discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)

    # This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
    discounts[k:] = 0.
    discounts = discounts[None, None, :, None]

    # Here the discounts become expected discounts
    discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1)
    if powered_relevancies:
        gains = torch.pow(2., y_true) - 1
        discounted_gains = gains.unsqueeze(0) * discounts
        idcg = dcg(y_true, y_true, ats=[k]).squeeze()
    else:
        gains = y_true
        discounted_gains = gains.unsqueeze(0) * discounts
        idcg = dcg(y_true, y_true, ats=[k]).squeeze()

    ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS)
    idcg_mask = idcg == 0.
    ndcg = ndcg.masked_fill(idcg_mask, 0.)

    assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
    if idcg_mask.all():
        return torch.tensor(0.)

    mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0])  # type: ignore
    return -1. * mean_ndcg  # -1 cause we want to maximize NDCG

In [None]:
# def lambdaLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, weighing_scheme=None, k=None, sigma=1., mu=10.,
#                reduction="sum", reduction_log="binary"):
#     """
#     LambdaLoss framework for LTR losses implementations, introduced in "The LambdaLoss Framework for Ranking Metric Optimization".
#     Contains implementations of different weighing schemes corresponding to e.g. LambdaRank or RankNet.
#     :param y_pred: predictions from the model, shape [batch_size, slate_length]
#     :param y_true: ground truth labels, shape [batch_size, slate_length]
#     :param eps: epsilon value, used for numerical stability
#     :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
#     :param weighing_scheme: a string corresponding to a name of one of the weighing schemes
#     :param k: rank at which the loss is truncated
#     :param sigma: score difference weight used in the sigmoid function
#     :param mu: optional weight used in NDCGLoss2++ weighing scheme
#     :param reduction: losses reduction method, could be either a sum or a mean
#     :param reduction_log: logarithm variant used prior to masking and loss reduction, either binary or natural
#     :return: loss value, a torch.Tensor
#     """
#     device = y_pred.device
#     y_pred = y_pred.clone()
#     y_true = y_true.clone()

#     padded_mask = y_true == padded_value_indicator
#     y_pred[padded_mask] = float("-inf")
#     y_true[padded_mask] = float("-inf")

#     # Here we sort the true and predicted relevancy scores.
#     y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
#     y_true_sorted, _ = y_true.sort(descending=True, dim=-1)

#     # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
#     true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
#     true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
#     print(true_sorted_by_preds.shape)
#     print(true_diffs.shape)
#     padded_pairs_mask = torch.isfinite(true_diffs)

#     if weighing_scheme != "ndcgLoss1_scheme":
#         padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)

#     ndcg_at_k_mask = torch.zeros((y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device)
#     ndcg_at_k_mask[:k, :k] = 1

#     # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
#     true_sorted_by_preds.clamp_(min=0.)
#     y_true_sorted.clamp_(min=0.)

#     # Here we find the gains, discounts and ideal DCGs per slate.
#     pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
#     D = torch.log2(1. + pos_idxs.float())[None, :]
#     maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], dim=-1).clamp(min=eps)
#     G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]

#     # Here we apply appropriate weighing scheme - ndcgLoss1, ndcgLoss2, ndcgLoss2++ or no weights (=1.0)
#     if weighing_scheme is None:
#         weights = 1.
#     else:
#         weights = globals()[weighing_scheme](G, D, mu, true_sorted_by_preds)  # type: ignore

#     # We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
#     scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]).clamp(min=-1e8, max=1e8)
#     scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
#     weighted_probas = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) ** weights).clamp(min=eps)
#     if reduction_log == "natural":
#         losses = torch.log(weighted_probas)
#     elif reduction_log == "binary":
#         losses = torch.log2(weighted_probas)
#     else:
#         raise ValueError("Reduction logarithm base can be either natural or binary")

#     if reduction == "sum":
#         loss = -torch.sum(losses[padded_pairs_mask & ndcg_at_k_mask])
#     elif reduction == "mean":
#         loss = -torch.mean(losses[padded_pairs_mask & ndcg_at_k_mask])
#     else:
#         raise ValueError("Reduction method can be either sum or mean")

#     return loss


# def ndcgLoss1_scheme(G, D, *args):
#     return (G / D)[:, :, None]


# def ndcgLoss2_scheme(G, D, *args):
#     pos_idxs = torch.arange(1, G.shape[1] + 1, device=G.device)
#     delta_idxs = torch.abs(pos_idxs[:, None] - pos_idxs[None, :])
#     deltas = torch.abs(torch.pow(torch.abs(D[0, delta_idxs - 1]), -1.) - torch.pow(torch.abs(D[0, delta_idxs]), -1.))
#     deltas.diagonal().zero_()

#     return deltas[None, :, :] * torch.abs(G[:, :, None] - G[:, None, :])


# def lambdaRank_scheme(G, D, *args):
#     return torch.abs(torch.pow(D[:, :, None], -1.) - torch.pow(D[:, None, :], -1.)) * torch.abs(G[:, :, None] - G[:, None, :])


# def ndcgLoss2PP_scheme(G, D, *args):
#     return args[0] * ndcgLoss2_scheme(G, D) + lambdaRank_scheme(G, D)


# def rankNet_scheme(G, D, *args):
#     return 1.


# def rankNetWeightedByGTDiff_scheme(G, D, *args):
#     return torch.abs(args[1][:, :, None] - args[1][:, None, :])


# def rankNetWeightedByGTDiffPowed_scheme(G, D, *args):
#     return torch.abs(torch.pow(args[1][:, :, None], 2) - torch.pow(args[1][:, None, :], 2))

In [None]:
# from torchmetrics.functional import retrieval_normalized_dcg
# preds = torch.tensor([[.1, .2, .3, 4, 70], [.1, .2, .3, 4, 2]])
# target = torch.tensor([[10, 0, 8, 1, 5], [10, 0, 0, 1, 5]])
# # preds = torch.tensor([[.1, .2, .3, 4, 70]])
# # target = torch.tensor([[10, 0, 0, 1, 6]])

# print(retrieval_normalized_dcg(preds[0], target[0], k=3))
# print(ndcg_score(target[0].reshape(1,-1), preds[0].reshape(1,-1), k=3))
# print(retrieval_normalized_dcg(preds[1], target[1], k=3))
# print(ndcg_score(target[1].reshape(1,-1), preds[1].reshape(1,-1), k=3))
# print(retrieval_normalized_dcg(preds, target, k=3))
# print(ndcg_score(target, preds, k=3))

# ndcg = RetrievalNormalizedDCG(k=3)
# ndcg(preds, target, indexes=torch.vstack([torch.arange(preds.shape[0]) for _ in range(preds.shape[1])]).T)