In [76]:
%%writefile solution_2.py
import math
from typing import List, Optional, Union

import torch


def num_swapped_pairs(ys_true: torch.Tensor, ys_pred: torch.Tensor) -> int:
    order = ys_true.argsort(descending=True)
    pairs = torch.combinations(order, 2)
    mask_true_equal = ys_true[pairs[:, 0]] != ys_true[pairs[:, 1]]
    mask_pred = ys_pred[pairs[:, 0]] < ys_pred[pairs[:, 1]]
    return (mask_true_equal & mask_pred).sum().item()


def compute_gain(y_value: float, gain_scheme: str) -> float:
    if gain_scheme == 'exp2':
        return 2**y_value - 1.0
    return y_value


def dcg(ys_true: torch.Tensor, ys_pred: torch.Tensor, gain_scheme: str) -> float:    
    order = ys_pred.argsort(descending=True)
    index = torch.arange(len(order), dtype=torch.float64) + 1
    return (compute_gain(ys_true[order], gain_scheme) / torch.log2(index + 1)).sum().item()
    

def ndcg(ys_true: torch.Tensor, ys_pred: torch.Tensor, gain_scheme: str = 'const') -> float:
    dcg_val = dcg(ys_true, ys_pred, gain_scheme)
    dcg_best_val = dcg(ys_true, ys_true, gain_scheme)
    return dcg_val / dcg_best_val


def precission_at_k(ys_true: torch.Tensor, ys_pred: torch.Tensor, k: int) -> float:
    total_relevant = ys_true.sum().item()
    if total_relevant == 0:
        return -1
    
    order = ys_pred.argsort(descending=True)[:k]
    n_retrieved = len(order)
    n_relevant = (ys_true[order] == 1).sum().item()
    if n_retrieved > total_relevant:
        return n_relevant / total_relevant
    else:
        return n_relevant / n_retrieved


def reciprocal_rank(ys_true: torch.Tensor, ys_pred: torch.Tensor) -> float:
    order = ys_pred.argsort(descending=True)
    return 1 / (ys_true[order].argsort(descending=True)[0] + 1)


def p_found(ys_true: torch.Tensor, ys_pred: torch.Tensor, p_break: float = 0.15 ) -> float:
    order = ys_pred.argsort(descending=True)
    
    p_rels = ys_true[order]
    p_look_ = 1
    p_rel_ = p_rels[0].item()
    p_found = p_look_ * p_rel_
    for i in range(1, len(ys_true)):
        p_rel = p_rels[i].item()
        p_look = p_look_ * (1 - p_rel_) * (1 - p_break)
        
        p_found += p_look * p_rel
        
        p_rel_ = p_rel
        p_look_ = p_look
    
    return p_found
    


def average_precision(ys_true: torch.Tensor, ys_pred: torch.Tensor) -> float:
    if ys_true.sum() == 0:
        return -1
    
    order = ys_pred.argsort(descending=True)
    
    n = ys_true.sum().item()
    recall_ = 0.0
    ap = 0.0
    for k in range(1, len(ys_true) + 1):        
        n_relevant = (ys_true[order][:k] == 1).sum().item()
        
        precision = n_relevant / k
        recall = n_relevant / n
        
        ap += (recall - recall_) * precision
        recall_ = recall
    
    return ap


Overwriting solution_2.py


In [77]:
import solution_2
import torch
import numpy as np
import importlib
importlib.reload(solution_2)
y_true = torch.rand(3)
y_pred = torch.rand(3)

print(solution_2.num_swapped_pairs(y_true, y_pred))
print(solution_2.compute_gain(y_true, 'const'))
print(solution_2.dcg(y_true, y_pred, 'const'))
print(solution_2.ndcg(y_true, y_pred, 'const'))

y_bin_true = torch.Tensor(np.random.choice([0, 1], 3))
k = 100
print(solution_2.precission_at_k(y_bin_true, y_pred, k))
print(solution_2.reciprocal_rank(torch.Tensor([0, 0, 1]), y_pred))
print(solution_2.average_precision(y_bin_true, y_pred))
print(solution_2.p_found(y_bin_true, y_pred))

2
tensor([0.0258, 0.7164, 0.1161])
0.535789370221151
0.6676405111984647
1.0
tensor(0.3333)
1.0
1.0


In [78]:
y_bin_true, y_pred

(tensor([1., 1., 0.]), tensor([0.9349, 0.7841, 0.4637]))

In [62]:
order = y_pred.argsort(descending=True)
order

tensor([0, 2, 1])