In [2]:
import torch

In [38]:
from torch import Tensor

def _find_sorted_array_position(tensor: Tensor, values_tensor: Tensor) -> Tensor:
    dim0, dim1 = tensor.shape
    expanded_values_tensor = values_tensor.resize_((dim0, 1)).expand(dim0, dim1)
    position_of_value = torch.sum((tensor > expanded_values_tensor).long(), 1)
    return position_of_value.add(1)


def mrr(preds: Tensor, targs: Tensor) -> Tensor:
    """
    E.g.:
    preds = torch.tensor([[0.2, 0.55, 0.25], [0.005, 0.005, 0.99]])
    targs = torch.tensor([1, 2])
    -> 1.0
    """
    pred_values = preds.gather(1, targs.view(-1, 1))
    guessed_positions = _find_sorted_array_position(preds, pred_values).float()
    reciprocal = torch.reciprocal(guessed_positions)
    return torch.mean(reciprocal)


In [39]:
import unittest


class MrrTest(unittest.TestCase):
    def test_mrr_1(self):
        preds = torch.tensor([[0.2, 0.55, 0.25], [0.005, 0.005, 0.99]])
        targs = torch.tensor([1, 2])

        actual = mrr(preds, targs)
        expected = 1.0
        self.assertAlmostEqual(expected, actual.item())
        
    def test_mrr_simple(self):
        preds = torch.tensor([[0.2, 0.55, 0.25], [0.006, 0.004, 0.99]])
        targs = torch.tensor([0, 0])

        actual = mrr(preds, targs)
        expected = 0.41666668653
        self.assertAlmostEqual(expected, actual.item())
        
MrrTest().test_mrr_1()
MrrTest().test_mrr_simple()

In [1]:
from fastai.metrics import accuracy