In [None]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt 
from torch.optim import Adam
from torch.func import vmap, jacrev
from torch.nn.utils import clip_grad_norm_
from tqdm import trange
from scipy.stats import wishart

from utils.models import mGradNet_C, mGradNet_M, WSoftmax, WTanh
from utils.pdfs import MultivariateNormal

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
results = {}

In [None]:
for dim in [2, 4, 8, 16, 32, 64, 128]:
    print(f"Running for dimension: {dim}")

    np.random.seed(1234)
    torch.manual_seed(1234)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1234)

    source_mu = torch.randn(dim,) 
    source_cov = torch.from_numpy(wishart.rvs(dim+1, np.identity(dim), 1)).float()
    source = MultivariateNormal(source_mu, source_cov)
    target = MultivariateNormal(torch.zeros(dim), torch.eye(dim))

    source_samples = source.sample(1000)
    target_samples = target.sample(1000)

    def inv_sqrt(A):
        eigvals, eigvecs = torch.linalg.eigh(A)
        return eigvecs @ torch.diag(eigvals.pow(-0.5)) @ eigvecs.T

    A = inv_sqrt(source_cov)
    def optimal_ot_map(x):
        z = x - source_mu
        return z @ A

    optimal_target = optimal_ot_map(source_samples).cpu().detach()

    # Define the OT map model
    for otmap_name in ['mGradNet_C', 'mGradNet_M']:

        if otmap_name == 'mGradNet_C':
            otmap = mGradNet_C(in_dim=dim, embed_dim=32, num_layers=4, activation=lambda: nn.Tanh())
            
        elif otmap_name == 'mGradNet_M':
            otmap = mGradNet_M(num_modules=4, in_dim=dim, embed_dim=32, activation=lambda: WSoftmax(32))

        else:
            raise ValueError("Unknown OT map type specified.")


        # Train the OT map
        otmap = otmap.to(device)
        opt = Adam(otmap.parameters(), lr=1e-2)
        jacobian_fn = jacrev(otmap)
        batched_jacobian_fn = vmap(jacobian_fn)

        pbar = trange(1000, dynamic_ncols=True) # train for 1000 iterations
        for i in pbar:
            opt.zero_grad()
            samples = source.sample(1000).to(device)  # sample points from the source distribution
            out = otmap(samples)
        
            J = batched_jacobian_fn(samples)
            J = J.squeeze(1)
            
            _, logabsdet = torch.linalg.slogdet(J)

            log_p = source.log_pdf(samples)       # source log-density
            log_q = target.log_pdf(out)           # target log-density

            loss = F.l1_loss(logabsdet, log_p - log_q)
            loss.backward()
            clip_grad_norm_(otmap.parameters(), max_norm=2.0) # clip gradient norms for stability
            opt.step()

            pbar.set_description(f"Loss: {loss.item():.4f}, Test loss: {F.mse_loss(otmap(source_samples), optimal_target).item():.4f}")

        results[(otmap_name, dim)] = F.mse_loss(otmap(source_samples), optimal_target).item()

In [None]:
fig, ax = plt.subplots(dpi=300)

dims = [2, 4, 8, 16, 32, 64, 128]
dim_strs = [str(dim) for dim in dims]
otmap_name = 'mGradNet_C'
mses = [results[(otmap_name, dim)] for dim in dims]
ax.bar(np.arange(len(dims))-0.2, mses, width=0.4, label='mGradNet-C')
otmap_name = 'mGradNet_M'
mses = [results[(otmap_name, dim)] for dim in dims]
ax.bar(np.arange(len(dims))+0.2, mses, width=0.4, label='mGradNet-M')
ax.set_yscale('log')
ax.set_xlabel('Dimension')
ax.set_xticks(np.arange(len(dims)))
ax.set_xticklabels(dim_strs)
ax.set_ylabel('MSE')
ax.legend()
ax.set_title('MSE between Learned and Optimal OT Maps')
plt.show()
plt.savefig('high_dim_gaussians_results.pdf', bbox_inches='tight')