In [16]:
import numpy as np
import math

In [10]:
def listmle(x):
    '''args:
        x: a list of scores for feature vectors,
        ordered by their ground truth rank
    returns:
        listMLE loss
    '''
    x_exp = np.exp(x)
    x_exp = np.flip(x_exp, axis=0)
    exp_cum_sum = np.cumsum(x_exp)
    exp_cum_sum = np.flip(exp_cum_sum, axis=0)
    return np.sum(np.log(exp_cum_sum) - x)

def normalized_listmle(x):
    '''args:
        x: a list of scores for feature vectors,
        ordered by their ground truth rank
    returns:
        normalized listMLE
    '''
    x = x -np.min(x)
    if np.max(x) != 0:
        x = x / np.max(x)
    x = [sigmoid(n) for n in x]
    x_exp = np.exp(x)
    x_exp = np.flip(x_exp, axis=0)
    exp_cum_sum = np.cumsum(x_exp)
    exp_cum_sum = np.flip(exp_cum_sum, axis=0) 
    return np.sum(np.log(exp_cum_sum)  - x) / np.log(math.factorial(len(x)))

def sigmoid(x):
  return 1 / (1 + math.exp(-x))

## Sensitivity to Score Range vs Sensitivy to correct ranking
ListMLE rewards the difference between item 1 and item N heavily, where as normalized ListMLE is less sensitive

In [25]:
print(listmle([50, 40, 30, 20, 10]))
print(normalized_listmle([50, 40, 30, 20, 10]))
print(listmle([5, 4, 3, 2, 1]))
print(normalized_listmle([5, 4, 3, 2, 1]))

0.0001362008198633191
0.9404697626258447
1.6129717464613917
0.9404697626258447


In [14]:
print(listmle([5, 6, 3, 2, 1]))
print(normalized_listmle([5, 6, 3, 2, 1]))

2.1595703608692767
0.9419838634525912


In [15]:
print(listmle([5, 7, 3, 2, 1]))
print(normalized_listmle([5, 7, 3, 2, 1]))

2.8989168332898747
0.9470181651420085


## Sensitivity to range of scores
ListMLE is sensitive to the range of scores, whereas normalized listMLE retruns the same score for all correctly ranked lists of uniformly distanced scores of a given length

In [1494]:
print(listmle([5, 4, 3, 2, 1]))
print(normalized_listmle([5, 4, 3, 2, 1]))

1.6129717464613917
0.7658302279538801


In [1495]:
print(listmle([50, 40, 30, 20, 10]))
print(normalized_listmle([50, 40, 30, 20, 10]))

0.00018160178023052254
0.7658302279538801


In [1496]:
print(listmle([50, 49, 48, 47, 46]))
print(normalized_listmle([50, 49, 48, 47, 46]))

1.6129717464613904
0.7658302279538801


In [1497]:
print(listmle([60, 49, 48, 47, 46]))
print(normalized_listmle([60, 49, 48, 47, 46]))

1.1610832879586894
0.8232429606901938


## "Top Heavy" vs "Bottom Heavy"

In [1475]:
print(listmle([5, 4, 3, 2, 6]))
print(normalized_listmle([5, 4, 3, 2, 6]))

10.721130680216644
1.1465386738111343


In [1476]:
print(listmle([1, 5, 4, 3, 2]))
print(normalized_listmle([1, 5, 4, 3, 2]))

5.612971746461392
0.9747078728099092


## Sensitivity to length of list
both listMLE and normalized listMLE are sensitive to the length of the lists

In [1480]:
print(listmle([3, 2, 1]))
print(normalized_listmle([3, 2, 1]))

0.7208676519626027
0.6442531347799542


In [1481]:
print(listmle([5, 4, 3, 2, 1]))
print(normalized_listmle([5, 4, 3, 2, 1]))

1.6129717464613917
0.7658302279538801


In [1482]:
print(listmle([20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7]))
print(normalized_listmle([20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7]))

5.737123652372659
0.8707988057569969


In [1483]:
print(listmle([1, 1, 1]))
print(normalized_listmle([1, 1, 1]))

1.7917594692280547
1.0


In [1484]:
print(listmle([1, 1, 1, 1]))
print(normalized_listmle([1, 1, 1, 1]))

3.1780538303479453
1.0


In [1489]:
print(listmle(np.flip(np.linspace(1, 5, 4), axis=0)))
print(normalized_listmle(np.flip(np.linspace(1, 5, 4), axis=0)))


0.8225933226855302
0.7227566990100301


In [1490]:
print(listmle(np.flip(np.linspace(1, 5, 5), axis=0)))
print(normalized_listmle(np.flip(np.linspace(1, 5, 5), axis=0)))

1.6129717464613917
0.7658302279538801


## Sanity Check: Functions decrease as list becomes closer to correct rank

In [1447]:
print(listmle([4, 5, 3, 2, 1]))
print(normalized_listmle([4, 5, 3, 2, 1]))

2.3579645005040084
0.031343008294906556


In [1448]:
print(listmle([4.5, 5, 3, 2, 1]))
print(normalized_listmle([4.5, 5, 3, 2, 1]))

1.999359629365888
0.03056502616613068


## Sensitivity to ties

In [1206]:
print(listmle([0., 0., 0., 0.]))
print(normalized_listmle([0., 0., 0., 0.]))

-3.1780538303479453
-0.7945134575869863


In [1207]:

print(listmle([1., 1., 1., 1.]))
print(normalized_listmle([1., 1., 1., 1.]))

-3.1780538303479453
-0.7945134575869863


In [1208]:

print(listmle([2., 2., 2., 2.]))
print(normalized_listmle([2., 2., 2., 2.]))

-3.1780538303479458
-0.7945134575869863


In [1218]:

print(listmle([2., 2., 1., 1.]))
print(normalized_listmle([2., 2., 1., 1.]))

-4.561550852696365
-1.140387713174091


In [1219]:
print(listmle([4., 3., 2., 1.]))
print(normalized_listmle([4., 3., 2., 1.]))

-7.161057350523799
-1.0742399239246159


In [1220]:
print(listmle([3., 3., 1., 1.]))
print(normalized_listmle([3., 3., 1., 1.]))

-6.271846047842377
-1.140387713174091


In [1211]:
print(listmle([2., 1.]))
print(normalized_listmle([2., 1.]))

-1.3132616875182226
-0.6566308437591114
