In [1]:
import itertools
import tensorflow as tf

In [2]:
def ot_sinkhorn(C, a, b, eps, iter):
    K = tf.exp(-C / eps)
    u = tf.ones_like(a)
    for _ in range(iter):
        v = b / tf.linalg.matvec(tf.transpose(K),u)
        u = a / tf.linalg.matvec(K,v)
    return u, v, tf.reshape(u, (-1,1))*(K * tf.reshape(v, (1,-1)))

def ot_sort(x, eps=0.1, iter=10):
    l = x.shape[0]
    y = (
        tf.math.reduce_min(x)
        + (tf.range(l, dtype=tf.float32) * tf.math.reduce_max(x) / l)
    )
    C = (
        tf.repeat(y[None, :], l, axis=0)
        - tf.transpose(tf.repeat(x[None, :], l, axis=0))
    ) ** 2
    
    a = tf.ones_like(x) / l
    b = tf.ones_like(y) / l
    
    _, _, P = ot_sinkhorn(C, a, b, eps, iter)
    b_hat = tf.cumsum(b, axis=0)
    r = l**2 * tf.linalg.matvec(P, b_hat)
    s = l * tf.linalg.matvec(tf.transpose(P), x)
    return r, s    

In [3]:
x = tf.constant([6.5, 0.4, 1.5, 3.8,])

In [4]:
for eps, iter in itertools.product(tf.range(0.01, 0.5, 0.1).numpy(), [1, 3, 10]):
    r, s = ot_sort(x, eps, iter)
    print(
        "(eps, iter) =", (format(eps, '.2f'), iter),
        ", rank(x) =", r.numpy(),
        ", sort(x) =", s.numpy(),
    )

(eps, iter) = ('0.01', 1) , rank(x) = [nan nan nan nan] , sort(x) = [nan nan nan nan]
(eps, iter) = ('0.01', 3) , rank(x) = [nan nan nan nan] , sort(x) = [nan nan nan nan]
(eps, iter) = ('0.01', 10) , rank(x) = [nan nan nan nan] , sort(x) = [nan nan nan nan]
(eps, iter) = ('0.11', 1) , rank(x) = [4.        1.        1.9999833 3.0021536] , sort(x) = [0.40002507 1.499975   3.791817   6.508183  ]
(eps, iter) = ('0.11', 3) , rank(x) = [3.9999998 1.        1.9999833 3.002135 ] , sort(x) = [0.40002507 1.499975   3.7918868  6.508113  ]
(eps, iter) = ('0.11', 10) , rank(x) = [4.        1.        1.9999833 3.002073 ] , sort(x) = [0.40002504 1.499975   3.7921221  6.507878  ]
(eps, iter) = ('0.21', 1) , rank(x) = [4.        1.0000129 1.9968747 3.0371933] , sort(x) = [0.4046831 1.4953212 3.658656  6.64134  ]
(eps, iter) = ('0.21', 3) , rank(x) = [3.9999998 1.0000131 1.9969128 3.0323765] , sort(x) = [0.40462542 1.4953785  3.6769617  6.623034  ]
(eps, iter) = ('0.21', 10) , rank(x) = [3.9999998 1.00