In [1]:
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
from distributions import mallows_kendall as mk
from distributions import permutil as pu
from distributions import sampling
import scipy as sp
import pandas as pd

In [None]:
sigma1 = np.array([0, 1, 2, 3])
sigma2 = np.array([1, 2, 0, 3])
sigma = sigma1[np.argsort(sigma2)]
print(np.argsort(sigma))
mk.distance(sigma)

In [None]:
from typing import Optional, Union

def position_weighted_distance(sigma:np.ndarray, delta:Optional[Union[str, np.ndarray]]='dcg', k:Optional[int]=3, verbose:bool=True):
    n = len(sigma)

    if delta is None: mk.distance(sigma) 
    if isinstance(delta, np.ndarray):
        _delta = delta.copy()
    elif isinstance(delta, str):
        _delta = np.ones(n)
        if delta == 'dcg':
            i = np.arange(1, n)
            _delta[1:] = 1/np.log2(i + 1) - 1/np.log2(i + 2)
            
        elif delta == 'topk':
            _delta[k:] = 0
    if verbose: 
        print('Delta', _delta)
    p = np.ones(n)
    
    for i in range(1, n):
        p[i] = p[i-1] + _delta[i]

    denominator = np.arange(n) - sigma
    msk = np.where(denominator == 0)[0]
    denominator[msk] = 1
    p_bar = (p - p[sigma])/denominator
    p_bar[msk] = 1
    
    V = 0
    for j, sigma_j in enumerate(sigma):
        for i in range(0, j):
            if sigma_j < sigma[i]:
                V += 1 * p_bar[i] * p_bar[j]
        
    return V

In [None]:
position_weighted_distance(sigma, delta='dcg')

In [None]:
sigma1 = np.array([0, 1, 2, 3])
sigma2 = np.array([1, 2, 3, 0])
sigma = sigma1[np.argsort(sigma2)]
print(np.argsort(sigma))
mk.distance(sigma)

In [None]:
position_weighted_distance(sigma, delta='dcg')