In [1]:
import itertools
import tensorflow as tf

def ot_sinkhorn(C, a, b, eps, itr):
    K = tf.exp(-C / eps)
    u = tf.ones_like(a)
    for _ in range(itr):
        v = b / tf.linalg.matvec(tf.transpose(K),u)
        u = a / tf.linalg.matvec(K,v)
    P = tf.reshape(u, (-1,1))*(K * tf.reshape(v, (1,-1)))
    return u, v, P

def ot_sort(x, eps=0.1, itr=10):
    l = x.shape[0]
    y = (
        tf.math.reduce_min(x)
        + (tf.range(l, dtype=tf.float32) * tf.math.reduce_max(x) / l)
    )
    # stop_gradient(y)
    C = (
        tf.repeat(y[None, :], l, axis=0)
        - tf.transpose(tf.repeat(x[None, :], l, axis=0))
    ) ** 2
    
    a = tf.ones_like(x) / l
    b = tf.ones_like(y) / l
    
    _, _, P = ot_sinkhorn(C, a, b, eps, itr)
    b_hat = tf.cumsum(b, axis=0)
    r = l**2 * tf.linalg.matvec(P, b_hat)
    s = l * tf.linalg.matvec(tf.transpose(P), x)
    return r, s    

In [2]:
x = tf.constant([6.5, 0.4, 1.5, 3.8,])
print("x =", x.numpy())

for eps, itr in itertools.product([1.0, 0.5, 0.1, 0.02, 0.01], [10, 3, 1]):
    r, s = ot_sort(x, eps, itr)
    print(
        "(eps, iter) =", (format(eps, '.2f'), itr),
        ", rank(x) =", r.numpy(),
        ", sort(x) =", s.numpy(),
    )

x = [6.5 0.4 1.5 3.8]
(eps, iter) = ('1.00', 10) , rank(x) = [3.9968197 1.1420645 1.8819754 3.0264602] , sort(x) = [0.5554942 1.3809325 3.6119862 6.6515865]
(eps, iter) = ('1.00', 3) , rank(x) = [3.9989045 1.1250924 1.852144  3.0912457] , sort(x) = [0.5925583 1.3958677 3.2627218 6.9488516]
(eps, iter) = ('1.00', 1) , rank(x) = [3.9995468 1.0959185 1.8018402 3.209512 ] , sort(x) = [0.67236686 1.3496025  2.7493653  7.428665  ]
(eps, iter) = ('0.50', 10) , rank(x) = [3.9999995 1.0188777 1.9609721 3.0395496] , sort(x) = [0.45122853 1.4556235  3.6357753  6.6573734 ]
(eps, iter) = ('0.50', 3) , rank(x) = [3.9999998 1.0119616 1.9392414 3.0985885] , sort(x) = [0.48653013 1.4220234  3.408082   6.883364  ]
(eps, iter) = ('0.50', 1) , rank(x) = [3.9999995 1.009408  1.923769  3.167692 ] , sort(x) = [0.51072514 1.3990191  3.1431413  7.1471143 ]
(eps, iter) = ('0.10', 10) , rank(x) = [3.9999998 1.        1.9999944 3.0011442] , sort(x) = [0.40000835 1.4999917  3.7956526  6.504347  ]
(eps, iter) = ('0

In [3]:
# batch, desc order対応版

def _t(x):
    return tf.transpose(x, perm=[0, 2, 1])

def _e(x):
    return tf.expand_dims(x,axis=1)

def ot_sinkhorn_batch(C, a, b, eps, itr):
    K = tf.exp(-C / eps)
    u = tf.ones_like(a)
    for _ in range(itr):
        v = b / tf.linalg.matvec(_t(K),u)
        u = a / tf.linalg.matvec(K,v)
    P = _e(u) * (K * _e(v))
    return u, v, P

def ot_sort_batch(x, eps=0.1, itr=10, desc_order=False):
    l = x.shape[1]
    m = x.shape[0]
    i = tf.range(l, dtype=tf.float32)
    
    y = (
        tf.math.reduce_min(x, axis=1)[:, None]
        + tf.repeat(i[None, :], m, axis=0)
        * tf.math.reduce_max(x, axis=1)[:, None] / l
    )
    if desc_order:
        y = tf.reverse(y, [-1])
    # stop_gradient(y)
    C = (
        tf.repeat(_e(y), l, axis=1)
        - _t(tf.repeat(_e(x), l, axis=1))
    ) ** 2

    a = tf.ones_like(x) / l
    b = tf.ones_like(x) / l
    _, _, P = ot_sinkhorn_batch(C, a, b, eps, itr)
    b_hat = tf.cumsum(b, axis=1)
    r = l**2 * tf.linalg.matvec(P, b_hat)
    s = l * tf.linalg.matvec(_t(P), x)
    return r, s

In [4]:
x = tf.constant([[6.5, 0.4, 1.5, 3.8,], [9,1,5,3,]])
r,s = ot_sort_batch(x)
print("x =", x)
print("rank(x) =", r)
print("sort(x) =", s)

x = tf.Tensor(
[[6.5 0.4 1.5 3.8]
 [9.  1.  5.  3. ]], shape=(2, 4), dtype=float32)
rank(x) = tf.Tensor(
[[3.9084735 1.011591  2.0002186 3.0358374]
 [4.        1.        3.        2.       ]], shape=(2, 4), dtype=float32)
sort(x) = tf.Tensor(
[[0.40464482 1.5001597  3.8395972  6.3556166 ]
 [1.         3.         5.         9.        ]], shape=(2, 4), dtype=float32)


In [5]:
r,s = ot_sort_batch(x, desc_order=True)
print("x =", x)
print("rank(x) =", r)
print("sort(x) =", s)

x = tf.Tensor(
[[6.5 0.4 1.5 3.8]
 [9.  1.  5.  3. ]], shape=(2, 4), dtype=float32)
rank(x) = tf.Tensor(
[[0.99999994 3.9537764  3.0000052  2.022238  ]
 [1.         4.         2.         3.        ]], shape=(2, 4), dtype=float32)
sort(x) = tf.Tensor(
[[6.504449  3.8400276 1.4999917 0.3953859]
 [9.        5.        3.        1.       ]], shape=(2, 4), dtype=float32)
