In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ot
import time
import scipy.stats as st

from solvers import solvers_L2_UOT as sl2
from solvers import solver_kl_UOT as skl
from solvers import solver_semirelax_L2_UOT as ssrl2

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

### Data generation
Both of the source and the target samples are 10-dimensional vectors drawn from a Gaussian distribution ($n = m = 1000$), and the cost matrix $C$ is computed by l2-norm.

In [2]:
n = 500
dim = 10
nb_run = 5       # 5 runs 
sigma = np.eye(dim)*2
m = np.arange(0,dim)
sigma2 = np.eye(dim)
m2 = np.arange(5,dim+5)

a_list = []
b_list = []
C_list = []
coef_list = []
balanced_distance = []

for i in range(nb_run):
    np.random.seed(i)
    xs = np.random.randn(n, dim).dot(sigma) + m
    np.random.seed(i+100)
    xt = np.random.randn(n, dim).dot(sigma2) + m2

    np.random.seed(i)
    a = np.random.normal(100,10,n)
    a = a / np.sum(a)
    a_list.append(a)
    np.random.seed(i + 100)
    b = np.random.normal(100, 10, n)
    b = b / np.sum(b)
    b_list.append(b)
    C = ot.dist(xs, xt)
    C = C / C.max()
    C_list.append(C)

    # solution of balanced OT 
    G0 = ot.emd(a, b, C)
    balanced_distance.append(np.sum(G0 * C))

    # compute the tolerance coefficient in solver Celer
    y = np.concatenate((a, b))
    coef = np.linalg.norm(y) ** 2 / len(y)
    coef_list.append(coef)

### L2-penalized UOT with different solvers and different $\lambda$

In [3]:
%%capture
maxiter = 200000
tol = 1e-7
tol_dual = 1e-4
reg_list = np.geomspace(10, 20000, num=15)

ot_cost_lasso_all = []
ot_cost_lasso_cd_all = []
ot_cost_bfgs_all = []
ot_cost_regpath_all = []
ot_cost_mu_all = []

timing_lasso_all = []
timing_lasso_cd_all = []
timing_bfgs_all = []
timing_regpath_all = []
timing_mu_all = []

sol_lasso_all = []
sol_lasso_cd_all = []
sol_bfgs_all = []
sol_regpath_all = []
sol_mu_all = []

for i in range(nb_run):
    ot_cost_lasso = []
    ot_cost_lasso_cd = []
    ot_cost_bfgs = []
    ot_cost_regpath = []
    ot_cost_mu = []

    timing_lasso = []
    timing_lasso_cd = []
    timing_bfgs = []
    timing_regpath = []
    timing_mu = []
    
    sol_lasso = []
    sol_lasso_cd = []
    sol_bfgs = []
    sol_regpath = []
    sol_mu = []

    a = a_list[i]
    b = b_list[i]
    C = C_list[i]
    for reg in reg_list:
        start = time.time()
        Gl = sl2.ot_ul2_solve_lasso_celer(C, a, b, reg,  maxiter, tol_dual*coef_list[i])
        timing_lasso.append(time.time()-start)
        sol_lasso.append(Gl)
        ot_cost_lasso.append(np.sum(Gl*C))

        start = time.time()
        Gl2 = sl2.ot_ul2_solve_lasso_cd(C, a, b, reg, maxiter, tol_dual)
        timing_lasso_cd.append(time.time()-start)
        sol_lasso_cd.append(Gl2)
        ot_cost_lasso_cd.append(np.sum(Gl2*C))

        # BFGS tol = 1e-12
        start = time.time()
        Gss = sl2.ot_ul2_solve_BFGS(C,a,b,reg, maxiter, 1e-12)
        timing_bfgs.append(time.time()-start)
        sol_bfgs.append(Gss)
        ot_cost_bfgs.append(np.sum(Gss*C))

        start = time.time()
        Grp, _, _, _ = sl2.ot_ul2_reg_path(a, b, C, reg, savePi=False)
        timing_regpath.append(time.time()-start)
        sol_regpath.append(Grp)
        ot_cost_regpath.append(np.sum(Grp*C))

        start = time.time()
        Gmu = sl2.ot_ul2_solve_mu(C, a, b, reg, maxiter, tol)
        timing_mu.append(time.time()-start)
        sol_mu.append(Gmu)
        ot_cost_mu.append(np.sum(Gmu*C))
        
    ot_cost_lasso_all.append(ot_cost_lasso)
    ot_cost_lasso_cd_all.append(ot_cost_lasso_cd)
    ot_cost_bfgs_all.append(ot_cost_bfgs)
    ot_cost_regpath_all.append(ot_cost_regpath)
    ot_cost_mu_all.append(ot_cost_mu)

    timing_lasso_all.append(timing_lasso)
    timing_lasso_cd_all.append(timing_lasso_cd)
    timing_bfgs_all.append(timing_bfgs)
    timing_regpath_all.append(timing_regpath)
    timing_mu_all.append(timing_mu)
    
    sol_lasso_all.append(sol_lasso)
    sol_lasso_cd_all.append(sol_lasso_cd)
    sol_bfgs_all.append(sol_bfgs)
    sol_regpath_all.append(sol_regpath)
    sol_mu_all.append(sol_mu)


### KL-penalized UOT with different solvers and different $\lambda$

In [4]:
%%capture
ot_cost_bfgs_kl_all = []
ot_cost_mu_kl_all = []
timing_bfgs_kl_all = []
timing_mu_kl_all = []
sol_bfgs_kl_all = []
sol_mu_kl_all = []

maxiter = 200000
tol = 1e-7

reg_list = np.geomspace(0.01, 150, num=15)
for i in range(nb_run):
    ot_cost_bfgs_kl = []
    ot_cost_mu_kl = []
    timing_bfgs_kl = []
    timing_mu_kl = []
    sol_bfgs_kl = []
    sol_mu_kl = []
    
    a = a_list[i]
    b = b_list[i]
    C = C_list[i]
    for reg in reg_list:
        start = time.time()
        G1 = skl.ot_ukl_solve_BFGS(C, a, b, reg, maxiter, tol)
        timing_bfgs_kl.append(time.time()-start)
        sol_bfgs_kl.append(G1)
        ot_cost_bfgs_kl.append(np.sum(G1*C))

        start = time.time()
        Gmu = skl.ot_ukl_solve_mu(C, a, b, reg, maxiter, tol)
        timing_mu_kl.append(time.time()-start)
        sol_mu_kl.append(Gmu)
        ot_cost_mu_kl.append(np.sum(Gmu*C))
    
    ot_cost_bfgs_kl_all.append(ot_cost_bfgs_kl)
    ot_cost_mu_kl_all.append(ot_cost_mu_kl)
    timing_bfgs_kl_all.append(timing_bfgs_kl)
    timing_mu_kl_all.append(timing_mu_kl)
    sol_bfgs_kl_all.append(sol_bfgs_kl)
    sol_mu_kl_all.append(sol_mu_kl)
    

### Regularization path with different size of samples

In [None]:
%%capture
niter_list_all = []
timings_list_all = []

# generate 10-dimsensional source and target samples 
dim = 10
sigma = np.eye(dim)*2
m = np.arange(0,dim)
sigma2 = np.eye(dim)
m2 = np.arange(5,dim+5)
n_list = np.linspace(100, 1000, 10, endpoint=True, dtype=int)
nb_run = 5
# compute the regularization path for different size of source and target samples
for i in range(nb_run):
    niter_list = []
    timings_list = []
    for n in n_list:    
        np.random.seed(n+i)
        xs = np.random.randn(n, dim).dot(sigma) + m
        np.random.seed(2*n+i)
        xt = np.random.randn(n, dim).dot(sigma2) + m2
        np.random.seed(n+i)
        a = np.random.normal(100,10,n)
        a /= np.sum(a)
        np.random.seed(2*n+i)
        b = np.random.normal(100,10,n)
        b /= np.sum(b)
        C = ot.dist(xs, xt)
        C = C/C.max()

        start = time.time()
        _, _, _, niter = sl2.ot_ul2_reg_path(a, b, C, savePi=False)
        t =  time.time()-start
        niter_list.append(niter)
        timings_list.append(t)
        
    niter_list_all.append(niter_list)
    timings_list_all.append(timings_list)
    

### Computation of confidence interval

In [None]:
data_bfgs = np.array(timing_bfgs_all)
data_mu = np.array(timing_mu_all)
data_regpath = np.array(timing_regpath_all)
data_celer = np.array(timing_lasso_all)
data_cd = np.array(timing_lasso_cd_all)

# calculate confidence interval of l2 penalized UOT
low_CI_bound_bfgs, high_CI_bound_bfgs = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_bfgs, 0), scale=st.sem(data_bfgs))
low_CI_bound_mu, high_CI_bound_mu = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_mu, 0), scale=st.sem(data_mu))
low_CI_bound_regpath, high_CI_bound_regpath = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_regpath, 0), scale=st.sem(data_regpath))
low_CI_bound_celer, high_CI_bound_celer = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_celer, 0), scale=st.sem(data_celer))
low_CI_bound_cd, high_CI_bound_cd = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_cd, 0), scale=st.sem(data_cd))

data_bfgs_kl = np.array(timing_bfgs_kl_all)
data_mu_kl = np.array(timing_mu_kl_all)
# calculate confidence interval of kl penalized UOT
low_CI_bound_bfgs_kl, high_CI_bound_bfgs_kl = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_bfgs_kl, 0), scale=st.sem(data_bfgs_kl))
low_CI_bound_mu_kl, high_CI_bound_mu_kl = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_mu_kl, 0), scale=st.sem(data_mu_kl))

data_time_regpath = np.array(timings_list_all)
# calculate confidence interval of timings of regularization path
low_CI_bound, high_CI_bound = st.t.interval(0.95, nb_run - 1, loc=np.mean(data_time_regpath, 0), scale=st.sem(data_time_regpath))


### Figure 3

In [None]:
fig = plt.figure(figsize=(16,4), constrained_layout=True)
gs = GridSpec(1, 3, figure=fig, width_ratios=[1,1,1])
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
ax3 = fig.add_subplot(gs[2])

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica"],
    "pdf.fonttype": 42,
    "ps.fonttype":42, 
    "font.size": 16})

# Left panel 
ax1.loglog(n_list, np.mean(data_time_regpath, 0), linestyle = '-', marker = '^', markersize=4)
ax1.fill_between(n_list, low_CI_bound, high_CI_bound, alpha=0.25)

ax1.set_xlabel("number of points (n=m)", fontsize=14)
ax1.set_ylabel("Time(sec)", fontsize=14)
ax1.grid(which='both')
ax1.set_title("Regularization path", fontsize=18)


# center panel
reg_list = np.geomspace(10, 20000, num=15)
ax2.loglog(reg_list, np.mean(data_bfgs, 0), label = 'BFGS', linestyle = '-.', marker = '^', markersize=4, c='green')
ax2.loglog(reg_list, np.mean(data_mu, 0), label = 'Multiplicative update', linestyle = '--', marker = 'v', markersize=4, c='orange')
ax2.loglog(reg_list, np.mean(data_celer, 0), label = 'Lasso (celer)', linestyle = ':', marker = 'd', markersize=4, c='brown')
ax2.loglog(reg_list, np.mean(data_cd, 0), label = 'Lasso (CD)', linestyle = '--', marker = 'D', markersize=4, c='c')
ax2.loglog(reg_list, np.mean(data_regpath, 0), label = 'Regularization path', linestyle = '-', marker = 's', markersize=4, c='m')

# plot confidence interval
ax2.fill_between(reg_list, low_CI_bound_bfgs, high_CI_bound_bfgs, alpha=0.25, color = 'green')
ax2.fill_between(reg_list, low_CI_bound_mu, high_CI_bound_mu, alpha=0.25, color = 'orange')
ax2.fill_between(reg_list, low_CI_bound_celer, high_CI_bound_celer, alpha=0.25, color = 'brown')
ax2.fill_between(reg_list, low_CI_bound_cd, high_CI_bound_cd, alpha=0.25, color = 'c')
ax2.fill_between(reg_list, low_CI_bound_regpath, high_CI_bound_regpath, alpha=0.25, color = 'm')

ax2.set_xlabel("$\lambda$", fontsize=14)
ax2.set_ylabel("Time (sec)", fontsize=14)
ax2.set_title('$\ell_2$-penalized UOT', fontsize=18)
ax2.legend(fontsize=12, loc = 'lower right')
ax2.grid()

# Right panel
reg_list = np.geomspace(0.01, 150, num=15)
ax3.loglog(reg_list, np.mean(data_bfgs_kl, 0), label = 'BFGS', linestyle = '--', marker = '^', markersize=4, c='green')
ax3.loglog(reg_list, np.mean(data_mu_kl, 0), label = 'Multiplicative update', linestyle = '-.', marker = 'v', markersize=4, c='orange')

# plot confidence interval
ax3.fill_between(reg_list, low_CI_bound_bfgs_kl, high_CI_bound_bfgs_kl, alpha=0.25, color = 'green')
ax3.fill_between(reg_list, low_CI_bound_mu_kl, high_CI_bound_mu_kl, alpha=0.25, color = 'orange')

ax3.set_xlabel("$\lambda$",fontsize=14)
ax3.set_ylabel("Time (sec)", fontsize=14)
ax3.set_title('KL-penalized UOT', fontsize=18)
ax3.legend(fontsize=12, loc = 'lower right')
ax3.grid()

# plt.savefig('simu.pdf', bbox_inches='tight', pad_inches=0) 
plt.savefig('simu.jpg', bbox_inches='tight', pad_inches=0) 
plt.show()


### empirical complexity of regularization path

In [None]:
z1 = np.polyfit(np.log(n_list), np.log(np.mean(data_time_regpath, 0)), 1)
z2 = np.polyfit(np.log(n_list)[3:], np.log(np.mean(data_time_regpath, 0)[3:]), 1)
print('empirical complexity with all points:', z1[0]) 
print('empirical complexity with last 7 points:', z2[0]) 
