In [1]:
# Required modules
from datetime import datetime
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
from PIL import Image, ImageDraw
from tqdm import trange, tqdm
import ot

In [3]:
# Check for CUDA Availability
device = ""
def check_for_CUDA():
    print(f"Is CUDA supported by this system? {torch.cuda.is_available()}")
    if torch.cuda.is_available() == False:
        device = torch.device("cpu")
        return
    else:
        print(f"CUDA version: {torch.version.cuda}")
        cuda_id = torch.cuda.current_device() # Storing ID of current CUDA device
        print(f"ID of current CUDA device: {torch.cuda.current_device()}")
        print(f"Name of current CUDA device: {torch.cuda.get_device_name(cuda_id)}")
        device = torch.device("cuda")

In [5]:
check_for_CUDA()

Is CUDA supported by this system? False


We set basic parameters to solve the Quadratic Entropy Regularization.

In [48]:
# parameters
n = 10000 # probability vectors in \R^n
epsilon = 1 # regularization parameter
max_iter = 100000

In [10]:
def save_result(coupling_matrices, filenames):
    for (matrix, filename) in zip(coupling_matrices, filenames):
        torch.save(matrix, filename+".pt")

def load_result(filenames):
    for filename in filenames:
        torch.load(filename+".pt")

We initialize two marginal probability vectors $a \in \mathbb{R}^n$ and $b \in \mathbb{R}^n$.

In [17]:
x = jnp.linspace(-5, 5, n)

mu_1 = norm.pdf(x, loc=0, scale=0.4)
mu_2 = 0.5*norm.pdf(x, loc=-2, scale=0.2) + 0.5*norm.pdf(x, loc=1, scale=0.5)
mu_3 = norm.pdf(x, loc=2, scale=0.6)
mu_4 = norm.pdf(x, loc=3, scale=0.8)
mu_5 = norm.pdf(x, loc=2.75, scale=0.7)

mu_1 = torch.from_numpy(mu_1 / mu_1.sum())
mu_2 = torch.from_numpy(mu_2 / mu_2.sum())
mu_3 = torch.from_numpy(mu_3 / mu_3.sum())
mu_4 = torch.from_numpy(mu_4 / mu_4.sum())
mu_5 = torch.from_numpy(mu_5 / mu_5.sum())

def plot_measures(x,mu_1,mu_2,mu_3,mu_4,mu_5):
    plt.figure(figsize=(8, 6))
    plt.plot(x, mu_1, label=r'$\mu_1$', color='blue')
    plt.plot(x, mu_2, label=r'$\mu_2$', color='green')
    plt.plot(x, mu_3, label=r'$\mu_3$', color='red')
    plt.plot(x, mu_4, label=r'$\mu_4$', color='orange')
    plt.plot(x, mu_5, label=r'$\mu_5$', color='purple')
    
    plt.legend()
    plt.show()

In [None]:
save_result([mu_1,mu_2,mu_3,mu_4,mu_5],["mu_1","mu_2","mu_3","mu_4","mu_5"])

We initialize cost matrix $C \in \mathbb{R}_{+}^{n\times n}$.

In [85]:
# Euclidean cost computes the Euclidean distance between two coordinates
def compute_euclidean_cost(n):
    print("Compute Euclidean Cost...")
    # Initialize cost matrix
    x = torch.arange(n, dtype=torch.float64) # vector in \R^n of the form [1,...,n]
    C = ot.dist(x.reshape((n,1)), x.reshape((n,1))) # Euclidean metric as a cost function
    return C/C.max() # normalize the cost

# Weak Coulomb cost sets relatively large real value for diagonal entries
def compute_weak_coulomb_cost(n, N=2, batch_size=5):
    print("Computing Weak Coulomb Cost...")
    indices = generate_combinations_batched(n, N, batch_size) # Generate all index combinations in batches    
    shape = (n,) * N # Initialize an N-dimensional tensor of size n in each dimension
    C = torch.zeros(shape)
    for index in indices: # Compute the Coulomb cost for each combination of indices
        total_cost = 0
        for i in range(N):
            for j in range(i + 1, N):
                diff = torch.abs(index[i] - index[j])
                if diff != 0:
                    total_cost += 1 / diff.item()  # Ensure diff is a scalar
                else:
                    total_cost += float("inf")
        C[tuple(index)] = total_cost # Assign the computed cost to the corresponding element in the matrix
    return (C + C.T + torch.diag(n*torch.ones(n)))/n

# Helper function for sliced batches
def generate_combinations_batched(n, N, batch_size=10000):
    indices = torch.tensor([], dtype=torch.long)
    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        batch_indices = torch.combinations(torch.arange(start, end), r=N)
        indices = torch.cat((indices, batch_indices), dim=0)
    return indices

# Strong Coulomb cost sets diagonal entires to be positive infinity
def compute_strong_coulomb_cost(n, N=2, batch_size=5):
    print("Computing Strong Coulomb Cost...")
    indices = generate_combinations_batched(n, N, batch_size) # Generate all index combinations in batches    
    shape = (n,) * N # Initialize an N-dimensional tensor of size n in each dimension
    C = torch.zeros(shape)
    for index in indices: # Compute the Coulomb cost for each combination of indices
        total_cost = 0
        for i in range(N):
            for j in range(i + 1, N):
                diff = torch.abs(index[i] - index[j])
                if diff != 0:
                    total_cost += 1 / diff.item()  # Ensure diff is a scalar
                else:
                    total_cost += float("inf")
        C[tuple(index)] = total_cost # Assign the computed cost to the corresponding element in the matrix
    return C + C.T + torch.diag(torch.ones(n) * float('inf'))

In [87]:
def plot_cost_matrix(matrix):
    plt.imshow(matrix, interpolation='nearest', cmap=plt.cm.inferno, extent=(0.5,np.shape(matrix)[0]+0.5,0.5,np.shape(matrix)[1]+0.5))
    plt.colorbar()
    plt.show()

In [89]:
# Generate cost matrices
C_euc = compute_euclidean_cost(n)
C_scou = compute_strong_coulomb_cost(n,batch_size=1)
C_wcou = compute_weak_coulomb_cost(n,batch_size=1)

Compute Euclidean Cost...
Computing Strong Coulomb Cost...
Computing Weak Coulomb Cost...


Let us plot each cost matrix with respect to different governing principles.

In [98]:
def plot_coupling_matrices(matrices, titles, save=False):

  if len(matrices) == 1:
      fix, axes = plt.subplots(1, 1)
  elif len(matrices) == 2:
      # Create a figure and subplots
      fig, axes = plt.subplots(1, 2, figsize=(14,14)) # Adjust figsize
  elif len(matrices) == 3:
      # Create a figure and subplots
      fig, axes = plt.subplots(1, 3, figsize=(14,14))  # Adjust figsize for better visualization
  elif len(matrices) == 4:
      # Create a figure and subplots
      fig, axes = plt.subplots(2, 2, figsize=(14,14)) # Adjust figsize
  else:
      print("Invalid input")

  # Loop through subplots and plot each matrix
  axes = axes.flatten()
  for i, matrix in enumerate(matrices):
    ax = axes[i]
    im = ax.imshow(matrix, interpolation='nearest', cmap=plt.cm.inferno, extent=(0.5, np.shape(matrix)[0] + 0.5, 0.5, np.shape(matrix)[1] + 0.5))
    ax.set_title(titles[i])  # Add title to each subplot (optional)
  # Adjust layout (optional)
  fig.colorbar(im, ax=axes, shrink=0.5, location="bottom")
  plt.show()

  if save == True:
      plt.savefig(fig, f"matrices_plot_{"_".join(titles)}.png")
      print("Plot saved as: "+f"matrices_plot_{"_".join(titles)}.png")

In [None]:
plot_coupling_matrices([C_euc, C_wcou, C_scou], ["Euclidean", "Weak Coulomb", "Strong Coulomb"])

print(C_wcou)

In [None]:
def transfer_to_GPU(args):
    print("Transferring data to CUDA GPU...")
    for var in tqdm(args):
        var.to(device)

In [13]:
def sinkhorn_quadratic_cyclic_projection(C: torch.Tensor, a: torch.Tensor, b: torch.Tensor, epsilon: float, num_iter: int = 50000,
                                         convergence_error: float = 1e-8, log=False) -> torch.Tensor:

    print("Cyclic Projection")
    
    n = a.size()[0]
    m = b.size()[0]
    f = torch.zeros_like(a)
    g = torch.zeros_like(b)

    # Use CUDA if possible
    transfer_to_GPU([C,a,b,epsilon,n,m,f,g])
    
    for it in trange(num_iter):
        f_prev = f
        g_prev = g
        rho = -(f.expand_as(C.T).T + g.expand_as(C) - C).clamp(max=0)
        f = (epsilon * a - (rho + g.expand_as(C) - C).sum(1)) / m
        g = (epsilon * b - (rho + f.expand_as(C.T).T - C).sum(0)) / n
        f_diff = (f_prev - f).abs().sum()
        g_diff = (g_prev - g).abs().sum()
        if log:
            print(f"Iteration {it}")
            print(f"f_diff {f_diff}")
            print(f"g_diff {g_diff}")
        if f_diff < convergence_error and g_diff < convergence_error:
            break

    cyclic_projection = ((f.expand_as(C.T).T + g.expand_as(C) - C).clamp(min=0) / epsilon).cpu() # Retrieve result to CPU
    
    return cyclic_projection

In [14]:
def sinkhorn_quadratic_gradient_descent(C: torch.Tensor, a: torch.Tensor, b: torch.Tensor, epsilon: float, num_iter: int = 50000,
                                        convergence_error: float = 1e-8, log=False) -> torch.Tensor:

    print("Gradient Descent")
    
    n = a.size()[0]
    m = b.size()[0]
    f = torch.zeros_like(a)
    g = torch.zeros_like(b)
    step = 1.0 / (m + n)

    # Use CUDA if possible
    transfer_to_GPU([C,a,b,epsilon,n,m,f,g])

    for it in trange(num_iter):
        f_prev = f.clone()
        g_prev = g.clone()

        P = (f.expand_as(C.T).T + g.expand_as(C) - C).clamp(min=0) / epsilon

        f -= step * epsilon * (P.sum(1) - a)
        g -= step * epsilon * (P.sum(0) - b)

        f_diff = (f_prev - f).abs().sum()
        g_diff = (g_prev - g).abs().sum()

        if log:
            print(f"Iteration {it}")
            print(f"f_diff {f_diff}")
            print(f"g_diff {g_diff}")

        if f_diff < convergence_error and g_diff < convergence_error:
            break

    gradient_descent = ((f.expand_as(C.T).T + g.expand_as(C) - C).clamp(min=0) / epsilon).cpu() # Retrieve result to CPU
    
    return gradient_descent

In [15]:
def sinkhorn_quadratic_fixed_point_iteration(C: torch.Tensor, a: torch.Tensor, b: torch.Tensor, epsilon: float, num_iter: int = 50000,
                                             convergence_error: float = 1e-8, log=False) -> torch.Tensor:

    print("Fixed Point Iteration")
    
    n = a.size()[0]
    m = b.size()[0]
    f = torch.zeros_like(a)
    g = torch.zeros_like(b)

    # Use CUDA if possible
    transfer_to_GPU([C,a,b,epsilon,n,m,f,g])

    for it in trange(num_iter):
        f_prev = f.clone()
        g_prev = g.clone()

        P = (f.expand_as(C.T).T + g.expand_as(C) - C).clamp(min=0) / epsilon
        v = - epsilon * (P.sum(1) - a)
        f += (v - v.sum() / (2 * n)) / m
        u = - epsilon * (P.sum(0) - b)
        g += (u - u.sum() / (2 * m)) / n

        f_diff = (f_prev - f).abs().sum()
        g_diff = (g_prev - g).abs().sum()

        if log:
            print(f"Iteration {it}")
            print(f"f_diff {f_diff}")
            print(f"g_diff {g_diff}")

        if f_diff < convergence_error and g_diff < convergence_error:
            break

    fixed_point_iteration = ((f.expand_as(C.T).T + g.expand_as(C) - C).clamp(min=0) / epsilon).cpu()
    
    return fixed_point_iteration

In [16]:
def sinkhorn_quadratic_nesterov_gradient_descent(C: torch.Tensor, a: torch.Tensor, b: torch.Tensor, epsilon: float, num_iter: int = 50000,
                                                 convergence_error: float = 1e-8, log=False) -> torch.Tensor:

    print("Nesterov Gradient Descent")
    
    n = a.size()[0]
    m = b.size()[0]
    f = torch.zeros_like(a)
    g = torch.zeros_like(b)
    step = 1.0 / (m + n)

    # Use CUDA if possible
    transfer_to_GPU([C,a,b,epsilon,n,m,f,g])

    f_previous = f
    g_previous = g

    for it in trange(num_iter):
        f_p = f + n * (f - f_previous) / (n + 3)
        g_p = g + n * (g - g_previous) / (n + 3)

        P = (f_p.expand_as(C.T).T
             + g_p.expand_as(C) - C).clamp(min=0) / epsilon

        f_new = f_p - step * epsilon * (P.sum(1) - a)
        g_new = g_p - step * epsilon * (P.sum(0) - b)

        f_diff = (f_previous - f_new).abs().sum()
        g_diff = (g_previous - g_new).abs().sum()

        f_previous = f
        g_previous = g

        f = f_new
        g = g_new

        if log:
            print(f"Iteration {it}")
            print(f"f_diff {f_diff}")
            print(f"g_diff {g_diff}")

        if f_diff < convergence_error and g_diff < convergence_error:
            break

    nesterov_gradient_descent = ((f.expand_as(C.T).T + g.expand_as(C) - C).clamp(min=0) / epsilon).cpu()
    
    return nesterov_gradient_descent

In [17]:
def sinkhorn_quadratic_nesterov_gradient_descent3(cost: torch.Tensor, marg1: torch.Tensor, marg2: torch.Tensor, marg3: torch.Tensor,
                                                  epsilon: float, num_iter: int = 50000, convergence_error: float = 1e-8, log=False) -> torch.Tensor:
    
    print("Nesterov Gradient Descent 3 marginals")
    
    n1 = marg1.size()[0]
    n2 = marg2.size()[0]
    n3 = marg3.size()[0]
    p1 = torch.zeros_like(marg1)
    p2 = torch.zeros_like(marg2)
    p3 = torch.zeros_like(marg3)
    step = 1.0 / (n1 + n2 + n3)
    p1_prev = p1
    p2_prev = p2
    p3_prev = p3

    # Use CUDA if possible
    transfer_to_GPU([cost,marg1,marg2,marg3,epsilon,n1,n2,n3,p1,p2,p3,p1_prev,p2_prev,p3_prev])

    for it in trange(num_iter):
        p1_p = p1 + n1 * (p1 - p1_prev) / (n1 + 3)
        p2_p = p2 + n2 * (p2 - p2_prev) / (n2 + 3)
        p3_p = p3 + n3 * (p3 - p3_prev) / (n3 + 3)

        P = (p1.expand_as(cost.T).T
             + p2.expand_as(cost.permute(0, 2, 1)).permute(0, 2, 1)
             + p3.expand_as(cost) - cost).clamp(min=0) / epsilon

        p1_new = p1_p - step * epsilon * (P.sum((1, 2)) - marg1)
        p2_new = p2_p - step * epsilon * (P.sum((0, 2)) - marg2)
        p3_new = p3_p - step * epsilon * (P.sum((0, 1)) - marg3)

        p1_diff = (p1_prev - p1_new).abs().sum()
        p2_diff = (p2_prev - p2_new).abs().sum()
        p3_diff = (p3_prev - p3_new).abs().sum()

        p1_prev = p1
        p2_prev = p2
        p3_prev = p3

        p1 = p1_new
        p2 = p2_new
        p3 = p3_new

        if log:
            print(f"Iteration {it}")
            print(f"p1_diff {p1_diff}")
            print(f"p2_diff {p2_diff}")
            print(f"p3_diff {p3_diff}")

        if p1_diff < convergence_error and p2_diff < convergence_error \
                and p2_diff < convergence_error:
            break

    nesterov_gradient_descent3 = ((p1.expand_as(cost.T).T + p2.expand_as(cost.permute(0, 2, 1)).permute(0, 2, 1) + p3.expand_as(cost) - cost).clamp(min=0) / epsilon).cpu()
    
    return nesterov_gradient_descent3

In [73]:
P_cyclic_euc = sinkhorn_quadratic_cyclic_projection(C_euc, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_grad_euc = sinkhorn_quadratic_gradient_descent(C_euc, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_fpi_euc = sinkhorn_quadratic_fixed_point_iteration(C_euc, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_nesterov_euc = sinkhorn_quadratic_nesterov_gradient_descent(C_euc, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8

Cyclic Projection


 25%|████████████████████▎                                                             | 12360/50000 [53:22<2:42:33,  3.86it/s]


KeyboardInterrupt: 

In [None]:
# Check whether the resulting matrices are doubly stochastic
list(map(torch.sum,[P_cyclic_euc,P_grad_euc,P_fpi_euc,P_nesterov_euc]))

In [None]:
plot_coupling_matrices([P_cyclic_euc, P_grad_euc, P_fpi_euc, P_nesterov_euc],
                       ["Cyclic Projection", "Gradient Descent", "Fixed Point Iteration", "Nesterov Gradient Descent"])

Let us test these algorithms for Weak Coulomb Cost Matrix

In [None]:
P_cyclic_wcou = sinkhorn_quadratic_cyclic_projection(C_wcou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_grad_wcou = sinkhorn_quadratic_gradient_descent(C_wcou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_fpi_wcou = sinkhorn_quadratic_fixed_point_iteration(C_wcou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_nesterov_wcou = sinkhorn_quadratic_nesterov_gradient_descent(C_wcou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8

In [None]:
# Check whether the resulting matrices are doubly stochastic
list(map(torch.sum,[P_cyclic_wcou, P_grad_wcou, P_fpi_wcou, P_nesterov_wcou]))

In [None]:
plot_coupling_matrices([P_cyclic_wcou, P_grad_wcou, P_fpi_wcou, P_nesterov_wcou],
                       ["Cyclic Projection", "Gradient Descent", "Fixed Point Iteration", "Nesterov Gradient Descent"])

Let us test these algorithms for Strong Coulomb Cost Matrix

In [None]:
P_cyclic_scou = sinkhorn_quadratic_cyclic_projection(C_scou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_grad_scou = sinkhorn_quadratic_gradient_descent(C_scou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_fpi_scou = sinkhorn_quadratic_fixed_point_iteration(C_scou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8
P_nesterov_scou = sinkhorn_quadratic_nesterov_gradient_descent(C_scou, mu_1, mu_2, epsilon) # max iteration = 50000, convergence error 1e-8

In [None]:
# Check whether the resulting matrices are doubly stochastic
list(map(torch.sum,[P_cyclic_scou,P_grad_scou,P_fpi_scou,P_nesterov_scou]))

In [None]:
plot_coupling_matrices([P_cyclic_scou, P_grad_scou, P_fpi_scou, P_nesterov_scou],
                       ["Cyclic Projection", "Gradient Descent", "Fixed Point Iteration", "Nesterov Gradient Descent"])