In [1]:
# install our package
! pip install .

Processing /Users/flora/21Spring/663-Final-Project
Building wheels for collected packages: Sinkhorn-663
  Building wheel for Sinkhorn-663 (setup.py) ... [?25ldone
[?25h  Created wheel for Sinkhorn-663: filename=Sinkhorn_663-0.1-cp38-cp38-macosx_10_13_x86_64.whl size=75374 sha256=f24eba0358fa1783f11780fded77681894c4bf5a3feaf3fd77cd2e7b4f5bd8b0
  Stored in directory: /Users/flora/Library/Caches/pip/wheels/02/4e/3e/69d82bfc22d8b08c274774ab6a96edd186d47484910d84d7ef
Successfully built Sinkhorn-663
Installing collected packages: Sinkhorn-663
  Attempting uninstall: Sinkhorn-663
    Found existing installation: Sinkhorn-663 0.1
    Uninstalling Sinkhorn-663-0.1:
      Successfully uninstalled Sinkhorn-663-0.1
Successfully installed Sinkhorn-663-0.1


In [1]:
from Sinkhorn_663 import Sinkhorn, log_domain_sinkhorn_2, Sinkhorn_numba, Sinkhorn_numba_parallel
from Sinkhorn_663 import sample_to_prob_vec, sample_to_prob_vec_nD
from Skh_cpp import Sinkhorn_cpp
import numpy as np

## 2.1 Basic implementation

First, we follow the above description of algorithm to implement it in plain Python as a baseline. We use the functions in package *numpy* to do matrix operations. The Sinkhorn distance and iteration times would be returned. As in the above description, the algorithm could calculate the matrix $P^{\lambda}$ by Sinkhorn's fixed point iteration $(u, v) \leftarrow (r./Kv, c./K'u)$. Furthermore, the iteration could been simplified to one single iteration $u\leftarrow 1./(\tilde K(c./K'u))$. We use $|u_{new} - u|<tol$ as the stopping criterion.

In [3]:
def Sinkhorn_plain(r, C, M, lamda, tol = 1e-6, maxiter = 10000):
    M = M[r > 0]
    r = r[r > 0]
    K = np.exp(-lamda * M)
    N = np.shape(C)[1]
    u = np.ones((len(r), N)) / len(r)
    K_tilde = np.diag(1/r) @ K
    d_prev = np.repeat(2., N)
    d = np.ones(N) + 0.5
    for i in range(maxiter):
        u_new = 1/(K_tilde @ (C / (K.T @ u)))
        if np.max(np.abs(u_new - u)) <= tol:
            break
        u = u_new
    v = C/(K.T @ u)
    d = np.sum(u * ((K * M) @ v), axis = 0)
    return d[0], i

### 2.1.1 Test and Check results

To verify the correctness of our implementation, we generate some simulation data and test the algorithm on it. Further optimized algorithms would also be tested on them to check results and compare performance. 

As a start, we consider the situation where the empirical measures $r$ and $c$ come from a same distribution. We generate two groups of samples from a same distribution $\text{Beta}(1, 2)$, use the function *sample_to_prob_vec* to convert samples to vectors and cost matrix as inputs, and calculate the distance by our function. Here we choose sample size $N = 3000$, $\text{maxiter} = 10000$, $\text{tol} = 1e-6$, $\lambda = 20$. The result is close to $0$ es expected. It is a little larger than $0$ because of the entropy regularization.

In [3]:
# create simulation data
N = 3000
np.random.seed(1)
u1 = np.random.beta(a = 1, b = 2, size = N)
v1 = np.random.beta(a = 1, b = 2, size = N)
M1, r1, c1 = sample_to_prob_vec(u1, v1)
c1 = c1.reshape(-1, 1)
# set parameters
maxiter = 10000
tol = 1e-6
lamda = 20

In [5]:
Sinkhorn_plain(r1, c1, M1, lamda, tol, maxiter)

(0.047851415045381665, 36)

Then we test on another simulation data from distributions with a setted known OT distance. Distributions $\text{Uniform}(0, 1)$ and $\text{Uniform}(10, 11)$ are used, which have a known OT distance $= 10$. The output of *Sinkhorn_plain* is close to 10 and a little larger than $10$ as expected.

In [4]:
np.random.seed(1)
u2 = np.random.uniform(0, 1, size = N)
v2 = np.random.uniform(10, 11, size = N)
M2, r2, c2 = sample_to_prob_vec(u2, v2, 1)
c2 = c2.reshape(-1, 1)

In [7]:
Sinkhorn_plain(r2, c2, M2, lamda, tol, maxiter)

(10.061646663537982, 230)

### 2.1.2 Profiling

Before diving deeper to optimization, we first use *profile* to profile the plain version of implementation.

In [8]:
import profile

In [9]:
profile.run("Sinkhorn_plain(r1, c1, M1, lamda, tol, maxiter)")

         346 function calls in 2.250 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 :0(abs)
        2    0.000    0.000    0.000    0.000 :0(array)
        2    0.000    0.000    0.000    0.000 :0(empty)
        1    0.000    0.000    2.250    2.250 :0(exec)
        2    0.000    0.000    0.000    0.000 :0(getattr)
       43    0.001    0.000    0.018    0.000 :0(implement_array_function)
        1    0.000    0.000    0.000    0.000 :0(isinstance)
       38    0.000    0.000    0.000    0.000 :0(items)
        3    0.000    0.000    0.000    0.000 :0(len)
       38    0.002    0.000    0.002    0.000 :0(reduce)
        1    0.000    0.000    0.000    0.000 :0(repeat)
        1    0.000    0.000    0.000    0.000 :0(setprofile)
        1    0.013    0.013    0.013    0.013 :0(zeros)
       37    0.001    0.000    0.006    0.000 <__array_function__ internals>:2(amax)
        2  

## 2.2 Optimization

(Descirbe numba, cpp, log_domain versions)

### Time

In [10]:
%timeit Sinkhorn_plain(r2, c2, M2, lamda, tol, maxiter)
%timeit Sinkhorn(r2, c2, M2, lamda, tol, maxiter)
%timeit Sinkhorn_cpp(r2, c2, M2, lamda, tol, maxiter)

3.07 s ± 436 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.79 s ± 45.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
604 ms ± 61.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Overflow

Another issue is that we need to calculate $K = e^{-\lambda M}$ in our algorithm. When $\lambda$ becomes large, it may overflow and become unable to calculate the distance. To address this problem, we implement the algorithm in log domain and compile it with numba. 


In [10]:
lamdas = list(np.arange(0, 120, 20))
for lam in lamdas:
    print("lambda = ", lam, Sinkhorn(r2, c2, M2, lam, tol, maxiter))

lambda =  0 (10.06164666353842, 0)
lambda =  20 (10.061646663537982, 230)
lambda =  40 (10.061646663537982, 287)
lambda =  60 (10.060680452304231, 2839)
lambda =  80 (nan, 9999)
lambda =  100 (nan, 9999)


In [12]:
for lam in lamdas:
    if lam != 0:
        print("lambda = ", lam, Sinkhorn(r2, c2, M2, lam, tol, maxiter, log_domain = True))

lambda =  20 [10.06164666]
lambda =  40 [10.06164666]
lambda =  60 [10.06164666]
lambda =  80 [nan]
lambda =  100 [nan]
