In [None]:
import torch
import torch.nn as nn
import numpy as np
import scipy
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
import geotorch
import geomloss
import ot
from matplotlib import collections  as mc
from matplotlib.legend_handler import HandlerTuple
import argparse
from PIL import Image
from sklearn.decomposition import PCA
import os
import wandb
from IPython.display import clear_output

from scipy.stats import ortho_group
import itertools
from copy import deepcopy

from tqdm import tqdm
import sys
sys.path.append("../..")

from typing import Callable, Tuple, Union


import src.distributions as distributions

from src.utils import freeze, unfreeze, fig2img  
from src.models import linear_model
from src.cost import strong_cost
from src import bar_benchmark

## Config

In [None]:
DIM = 2 # or 4,8,16,64,128

INPUT_DIM = DIM
HIDDEN_DIMS = [128,128]
OUTPUT_DIM_POT = 1
OUTPUT_DIM_MAP = INPUT_DIM
LR = 5e-4

NUM_SAMPLES = 10_000
NUM_EPOCHS = 10000
PLOT_FREQ = 200
SCORE_FREQ = 1
BATCH_SIZE=1000
INNER_ITERATIONS = 10

K = 3
LAMBDAS = np.array([0.25, 0.25, 0.5])

assert K == len(LAMBDAS)   
assert torch.cuda.is_available()
DEVICE = 'cuda'
DEVICE_IDS = [i for i in range(torch.cuda.device_count())]

In [None]:
CASE = {
    'type' : 'EigWarp', 
    'sampler' : 'Gaussians',
    'params' : {'num' : K, 'alphas' : LAMBDAS, 'min_eig' : .5, 'max_eig' : 2}
}

In [None]:
SEED = 0xB00BA
np.random.seed(SEED)
torch.manual_seed(SEED)

EXP_NAME = f'EOTbary_WINnets_{DIM}_{K}_{INNER_ITERATIONS}'
OUTPUT_PATH = '../checkpoints/EOTbary_WINnets_{}_{}_{}/'.format(DIM, K, INNER_ITERATIONS)

if OUTPUT_PATH is not None:
    if not os.path.exists(OUTPUT_PATH):
        os.makedirs(OUTPUT_PATH)

In [None]:
config = dict(
    CASE=CASE['sampler'],
    BATCH_SIZE=BATCH_SIZE
)


In [None]:
wandb.init(name=EXP_NAME, project='egbarycenters', config=config)

## Initializing distributions

In [None]:
if CASE['type'] == 'EigWarp':
    if CASE['sampler'] == 'Gaussians':
        sampler = distributions.StandardNormalSampler(dim=DIM)
        
    benchmark = bar_benchmark.EigenWarpBenchmark(sampler, **CASE['params'])  

## PCA

In [None]:
pca = PCA(n_components=2)

class Identity:
    pass

if benchmark.bar_sampler is not None:
    pca.fit(benchmark.bar_sampler.sample(100000).cpu().detach().numpy())
elif benchmark.gauss_bar_sampler is not None:
    pca.fit(benchmark.gauss_bar_sampler.sample(100000).cpu().detach().numpy())
else:
    pca = Identity()
    pca.transform = lambda x: x
    
# No PCA for dim=2
if DIM == 2:
    pca = Identity()
    pca.transform = lambda x: x

## Models for the experiments

In [None]:
def make_f_pot(idx, nets, config):
    
    def f_pot(x):
        res = 0.0
        for i, (net, lmbd) in enumerate(zip(nets, LAMBDAS)):
            
            if i == idx:
                res += net(x)
            else:
                res -= lmbd * net(x) / (K - 1) / LAMBDAS[idx]
        return res
    
    return f_pot

In [None]:
D = nn.Sequential(
    nn.Linear(DIM, max(100, 2*DIM)),
    nn.ReLU(True),
    nn.Linear(max(100, 2*DIM), max(100, 2*DIM)),
    nn.ReLU(True),
    nn.Linear(max(100, 2*DIM), max(100, 2*DIM)),
    nn.ReLU(True),
    nn.Linear(max(100, 2*DIM), OUTPUT_DIM_POT)
).cuda()

T = nn.Sequential(
    nn.Linear(DIM, max(100, 2*DIM)),
    nn.ReLU(True),
    nn.Linear(max(100, 2*DIM), max(100, 2*DIM)),
    nn.ReLU(True),
    nn.Linear(max(100, 2*DIM), max(100, 2*DIM)),
    nn.ReLU(True),
    nn.Linear(max(100, 2*DIM), OUTPUT_DIM_MAP)
).cuda()

g = [deepcopy(D).to(DEVICE) for _ in range(K)]

param_nets = [net.parameters() for net in  g]
g_opt = torch.optim.Adam(itertools.chain(*param_nets), LR)

f_pots = [make_f_pot(i, g, LAMBDAS) for i in range(K)]

maps = [deepcopy(T).to(DEVICE) for _ in range(K)]

In [None]:
param_nets = [mp.parameters() for mp in  maps]
maps_opt = torch.optim.Adam(itertools.chain(maps[0].parameters(),
                                            maps[1].parameters(),
                                            maps[2].parameters()),
                               LR )

g_scheduler = torch.optim.lr_scheduler.MultiStepLR(maps_opt, milestones=[5000], gamma=0.5)
maps_scheduler = torch.optim.lr_scheduler.MultiStepLR(g_opt, milestones=[5000], gamma=0.5)

## Plotters

In [None]:
def plot_distributions(ax1, ax2):
    cols = plt.get_cmap("Dark2").colors
    Xs = []
    for i, distr in enumerate(benchmark.samplers):
        X = distr.sample(512,).detach().cpu().numpy()
        Xs.append(X)
        ax1.scatter(
            X[:, 0], X[:, 1],
            label=f"$x_{{{i + 1}}} \\sim \\mathbb{{P}}_{{{i + 1}}}$", 
            edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), linewidth=.5,
        )
    
    Xgt = benchmark.gauss_bar_sampler.sample(512).detach().cpu().numpy()
    ax2.scatter(
        Xgt[:, 0], Xgt[:, 1],
        label=r"$x \sim \mathbb{Q}_*$", 
        edgecolors='black', color=cols[K + 1], linewidth=.5,
    )
    ax1.legend(ncol=2, loc="upper left", prop={"size": 12})
    ax2.legend(ncol=2, loc="upper left", prop={"size": 12})

def alpha_color(color_rgb, alpha=0.5):
    color_rgb = np.asanyarray(color_rgb)
    alpha_color_rgb = 1. - (1. - color_rgb) * alpha
    return alpha_color_rgb

def plot_bary_i(map_, sampler, ax, i, n_samples=512, n_maps=0, n_arrows_per_map=1):
    global p1
    n_arrows = n_maps * n_arrows_per_map
    X = benchmark.samplers[i].sample(n_samples,).to(DEVICE)
    if n_maps > 0:
        Xm = benchmark.samplers[i].sample(n_maps,).to(DEVICE)
        Xm = torch.tile(Xm, (n_arrows_per_map, 1))
        X = torch.concatenate((X, Xm), dim=0)
        
        
    Y = map_(X).to(DEVICE)
    X_np = X.detach().cpu().numpy()
    Y_np = Y.detach().cpu().numpy()
    
    def darker(c): return tuple(x * 0.85 for x in c)
    
    cols = plt.get_cmap("Dark2").colors
    col_bary = plt.get_cmap("tab10").colors[K]
    p4 = ax.scatter(
        X_np[:n_samples, 0], X_np[:n_samples, 1],
        edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), zorder=0, linewidth=.5,
    )
    p1 = ax.scatter(
        Y_np[:n_samples, 0], Y_np[:n_samples, 1],
        edgecolors=(0, 0, 0), color=col_bary, zorder=0, linewidth=.5,
    )
    p3 = ax.scatter(
            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],
            linewidth=.5, edgecolors='black', color=cols[i], zorder=2,
        )
    p2 = ax.scatter(
        Y_np[-n_arrows:, 0], Y_np[-n_arrows:, 1],
        linewidth=.5, edgecolors='black', color=cols[K + 2], zorder=2,
    )
    if n_arrows > 0:
        ax.quiver(
            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],
            Y_np[-n_arrows:, 0] - X_np[-n_arrows:, 0], Y_np[-n_arrows:, 1] - X_np[-n_arrows:, 1],
            angles='xy', scale_units='xy', scale=0.95, width=.005, zorder=1, headwidth=0.0, headlength=0.0,
        )
        
    ax.legend(
        [
            (p1, p2),
            (p3, p4),
        ], [
            f"$x_{i + 1} \\sim \\pi^*_{i + 1}(\\cdot \\mid x_{i + 1})$",
            f"$x_{i + 1} \\sim \mathbb{{P}}_{i + 1}$",
        ],
        handler_map={tuple: HandlerTuple(ndivide=None)},
        loc="upper left",
        prop={"size": 12},
    )

def plot_bary(maps, benchmark, arrows=True):
    N_SAMPLES = 512
    N_MAPS = 5
    N_ARROWS_PER_MAP = 3
    
    n_maps = N_MAPS if arrows else 0
    
    fig, axs = plt.subplots(
        ncols=K + 2,
        figsize=(18.75, 3.75),
        sharex=True, sharey=True,
        dpi=200,
    )
        
    plot_distributions(axs[0], axs[1])
    axs[0].set_xlim(-7, 7)
    axs[0].set_ylim(-7, 7)
    axs[1].set_xlim(-7, 7)
    axs[1].set_ylim(-7, 7)
    
    for i, (map_, sampler, ax) in enumerate(zip(maps, benchmark.samplers, axs[2:])):
        plot_bary_i(map_, sampler, ax, i, N_SAMPLES, n_maps, N_ARROWS_PER_MAP)
    
    fig.tight_layout()
    return fig

In [None]:
plot_bary(maps, benchmark)

## Metrics

In [None]:
def score_forward_maps(benchmark, maps, lambdas, score_size=1024):
    assert (benchmark.gauss_bar_maps is not None) and (benchmark.gauss_bar_sampler is not None)
    L2_UVP_arr = []
    for k in range(benchmark.num):
        X = benchmark.samplers[k].sample(score_size)
        with torch.no_grad():
            X_push = maps[k](X)
        with torch.no_grad():
            X_push_true = benchmark.gauss_bar_maps[k](X)
            L2_UVP_arr.append(
                100 * (((X_push - X_push_true) ** 2).sum(dim=1).mean() / benchmark.gauss_bar_sampler.var).item()
            )
    weighted_L2_UVP = sum(lambda_ * L2_UVP for (lambda_, L2_UVP) in zip(lambdas, L2_UVP_arr))
    return weighted_L2_UVP

In [None]:
score_forward_maps(benchmark, maps, LAMBDAS)

## Training block

In [None]:
from PIL import Image

In [None]:
last_plot_epoch = -1
last_score_epoch = -1
best_L2_UVP = 1000

for epoch in range(NUM_EPOCHS):
    # freezing potentials 
    # unfreezing maps
    for idx in range(K):
        freeze(g[idx])
        unfreeze(maps[idx])


    #inner loop
    for it in range(INNER_ITERATIONS):
        data = [s.sample(BATCH_SIZE).to(DEVICE) for s in benchmark.samplers]

        maps_opt.zero_grad()
        loss = 0
        for k in range(K):
            mapped_x_k = maps[k](data[k]) #[B,N] 
            cost = strong_cost(data[k],mapped_x_k) #[B,1]
            cost -= f_pots[k](mapped_x_k)#[B,1]
            cost = cost.mean(dim=0)
            loss += LAMBDAS[k]*cost

        loss.backward()
        maps_opt.step()
        maps_scheduler.step()

    # unfreezing potentials 
    # freezing maps
    for idx in range(K):
        unfreeze(g[idx])
        freeze(maps[idx])

    # outer optimiztion
    g_opt.zero_grad()
    loss=0
    for k in range(K):
        mapped_x_k = maps[k](data[k]) #[B,N]
        cost = strong_cost(data[k],mapped_x_k) #[B,1]
        cost -= f_pots[k](mapped_x_k)#[B,1]
        cost = cost.mean(dim=0)
        loss += LAMBDAS[k]*cost

    loss = -1*loss
    wandb.log({f'Loss' : loss.item()}, step=epoch)
    loss.backward()
    g_opt.step()
    g_scheduler.step()

    if (epoch - last_plot_epoch >= PLOT_FREQ):
        last_plot_epoch = epoch  

        fig = plot_bary(maps, benchmark, arrows=False)
        fig.tight_layout();
        wandb.log({'Pca' : [wandb.Image(fig2img(fig))]}, step=epoch)
        plt.show()
        plt.close(fig)

    if (epoch - last_score_epoch >= SCORE_FREQ):
        last_score_epoch = epoch

        if benchmark.gauss_bar_sampler is not None:
            L2_UVP = score_forward_maps(benchmark, maps, LAMBDAS, score_size=1024)
            wandb.log({f'L2_UVP' : L2_UVP}, step=epoch)

            if L2_UVP < best_L2_UVP:
                best_L2_UVP = L2_UVP
                for k in range(benchmark.num):
                    freeze(maps[k])
                    torch.save(maps[k].state_dict(), OUTPUT_PATH + 'maps{}_best.pt'.format(k))
                    np.savez(OUTPUT_PATH + 'metrics.npz', L2_UVP=best_L2_UVP)