In [None]:
import numpy as np
import pylab as pl
import scipy as sp
import torch

In [None]:
from scipy.spatial.distance import cdist

n = 500
dim = 10
nb_run = 5 # 5 runs 
sigma = np.eye(dim) * 2
m = np.arange(0, dim)
sigma2 = np.eye(dim)
m2 = np.arange(5, dim + 5)

a_cpu_list = []
b_cpu_list = []
C_cpu_list = []

a_gpu_list = []
b_gpu_list = []
C_gpu_list = []

for i in range(nb_run):
    np.random.seed(i)
    xs = np.random.randn(n, dim).dot(sigma) + m
    np.random.seed(i+100)
    xt = np.random.randn(n, dim).dot(sigma2) + m2

    np.random.seed(i)
    a = np.random.normal(100,10,n)
    a = a / np.sum(a)
    a_cpu = torch.tensor(a,dtype=torch.float32)
    a_cpu_list.append(a_cpu)
    a_gpu = torch.tensor(a,dtype=torch.float32, device='cuda')
    a_gpu_list.append(a_gpu)

    np.random.seed(i + 100)
    b = np.random.normal(100, 10, n)
    b = b / np.sum(b)
    b_cpu = torch.tensor(b,dtype=torch.float32)
    b_gpu = torch.tensor(b,dtype=torch.float32, device='cuda')
    b_gpu_list.append(b_gpu)
    b_cpu_list.append(b_cpu)

    C = cdist(xs, xt, 'sqeuclidean')
    C = C / C.max()
    C_cpu = torch.tensor(C,dtype=torch.float32)
    C_gpu = torch.tensor(C,dtype=torch.float32, device='cuda')
    C_gpu_list.append(C_gpu)
    C_cpu_list.append(C_cpu)

In [None]:
import time

def ot_ukl_torch(C,a,b,reg,nitermax=10000,P0=None, tol=1e-15,verbose=False):
    # min_\pi <\pi,C> +regKL(pi 1-a) + reg KL(pi^T 1-b)
    K=torch.exp(-C/reg/2)
    if P0 is None:
        P=a[:,None]*b[None,:]
    else:
        P=P0
    
    for i in range(nitermax):
        Pold=P
        u=torch.sqrt(a/(P.sum(1)+1e-16))
        v=torch.sqrt(b/(P.sum(0)+1e-16))
        P= P*K*u[:,None]*v[None,:]
        if verbose:
            print(torch.sqrt(torch.sum((P-Pold)**2)))
        if torch.sqrt(torch.sum((P-Pold)**2))<tol:
            break

    return P

def ot_ul2_torch(C,a,b,reg,nitermax=10000,P0=None, tol=1e-15,verbose=False):
    # min_\pi <\pi,C> +regKL(pi 1-a) + reg L2(pi^T 1-b)
    if P0 is None:
        P=a[:,None]*b[None,:]
    else:
        P=P0
    
    abt = torch.maximum(a[:, None] + b[None, :] - C/reg/2, torch.zeros(len(a), len(b), dtype=torch.float32, device=a.device))
    for i in range(nitermax):
        Pold=P
        Pd = P.sum(0, keepdims=True) + P.sum(1, keepdims=True) + 1e-16
        P = P*abt/Pd
        if verbose:
            print(torch.sqrt(torch.sum((P-Pold)**2)))
        if torch.sqrt(torch.sum((P-Pold)**2))<tol:
            break

    return P

In [None]:
timing_mu_kl_gpu_all = []
timing_mu_l2_gpu_all = []
timing_mu_kl_cpu_all = []
timing_mu_l2_cpu_all = []

maxiter = 200000
tol = 1e-7

reg_list = np.geomspace(0.01, 150, num=15)
nb_run = 5
for i in range(nb_run):

    timing_mu_kl_gpu = []
    timing_mu_kl_cpu = []

    for reg in reg_list:
        start = time.time()
        ot_ukl_torch(C_cpu_list[i], a_cpu_list[i], b_cpu_list[i], reg, nitermax=maxiter, tol=tol)
        timing_mu_kl_cpu.append(time.time()-start)

        start = time.time()
        ot_ukl_torch(C_gpu_list[i],a_gpu_list[i],b_gpu_list[i],reg,nitermax=maxiter,tol=tol,verbose=False)
        timing_mu_kl_gpu.append(time.time()-start)
 
    timing_mu_kl_gpu_all.append(timing_mu_kl_gpu)
    timing_mu_kl_cpu_all.append(timing_mu_kl_cpu)


timing_mu_l2_gpu_all = []
timing_mu_l2_cpu_all = []

reg_list = np.geomspace(10, 20000, num=15)
for i in range(nb_run):

    timing_mu_l2_gpu = []
    timing_mu_l2_cpu = []
    
    for reg in reg_list:
        start = time.time()
        ot_ul2_torch(C_cpu_list[i], a_cpu_list[i], b_cpu_list[i], reg, nitermax=maxiter, tol=tol)
        timing_mu_l2_cpu.append(time.time()-start)

        start = time.time()
        ot_ul2_torch(C_gpu_list[i],a_gpu_list[i],b_gpu_list[i],reg,nitermax=maxiter,tol=tol,verbose=False)
        timing_mu_l2_gpu.append(time.time()-start)
 
    timing_mu_l2_gpu_all.append(timing_mu_l2_gpu)
    timing_mu_l2_cpu_all.append(timing_mu_l2_cpu)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

fig = plt.figure(figsize=(11,4), constrained_layout=True)
gs = GridSpec(1, 2, figure=fig, width_ratios=[1,1])
ax2 = fig.add_subplot(gs[0])
ax3 = fig.add_subplot(gs[1])

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica"],
    "pdf.fonttype": 42,
    "ps.fonttype":42, 
    "font.size": 16})

# left panel
reg_list = np.geomspace(10, 20000, num=15)
ax2.loglog(reg_list, np.mean(timing_mu_l2_gpu_all, 0), label = 'GPU', linestyle = '-.', marker = '^', markersize=4, c='green')
ax2.loglog(reg_list, np.mean(timing_mu_l2_cpu_all, 0), label = 'CPU', linestyle = '--', marker = 'v', markersize=4, c='orange')

ax2.set_xlabel("lambda", fontsize=14)
ax2.set_ylabel("Time (sec)", fontsize=14)
ax2.set_title('L2-penalized UOT', fontsize=18)
ax2.legend(fontsize=12, loc = 'lower right')
ax2.grid()

# right panel
reg_list = np.geomspace(0.01, 150, num=15)
ax3.loglog(reg_list, np.mean(timing_mu_kl_gpu_all, 0), label = 'GPU', linestyle = '--', marker = '^', markersize=4, c='green')
ax3.loglog(reg_list, np.mean(timing_mu_kl_cpu_all, 0), label = 'CPU', linestyle = '-.', marker = 'v', markersize=4, c='orange')

ax3.set_xlabel("$\lambda$",fontsize=14)
ax3.set_ylabel("Time (sec)", fontsize=14)
ax3.set_title('KL-penalized UOT', fontsize=18)
ax3.legend(fontsize=12, loc = 'lower right')
ax3.grid()

plt.savefig('simu_cpu_gpu.pdf', bbox_inches='tight', pad_inches=0) 
plt.savefig('simu_cpu_gpu.jpg', bbox_inches='tight', pad_inches=0) 
plt.show()