In [None]:
# torch
import torch
import torch.distributions as TD
from torch import nn
import torch.nn.functional as F
# from torch.autograd import grad
from functorch import grad, vmap
# import geotorch

# base
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
from cycler import cycler

from tqdm import tqdm
import logging
import numpy as np
import os
import sys
import itertools
import functools
import operator as ops

from sklearn.decomposition import PCA

import warnings

warnings.filterwarnings("ignore")
sys.path.append("..")

In [None]:
# dataset utils
sys.path.append("../..")
from src.utils import Distrib2Sampler
from src.utils import Config

 

# langevin sampling
from src.eot import sample_langevin_batch
from src.eot_utils import computePotGrad, evaluating
# from src.plotters import plot_training_phase
from src.tools import *
from src import distributions
from src import benchmarks

# training utils
from src.dgm_utils.statsmanager import StatsManager, StatsManagerDrawScheduler

# typing
from typing import Callable, Tuple, Union
import ot

# models
from src.models2D import FullyConnectedMLP
DEVICE = 'cuda:0'
# DEVICE = "cpu"

In [None]:
CONFIG = Config()

CONFIG.CLIP_GRADS_NORM = False
CONFIG.HREG = 0.01
CONFIG.USE_L2 = False
CONFIG.DISCRETE_OT_FOR_GT = CONFIG.USE_L2

CONFIG.LANGEVIN_THRESH = None
CONFIG.LANGEVIN_SAMPLING_NOISE = 0.03
CONFIG.ENERGY_SAMPLING_ITERATIONS = 300
CONFIG.LANGEVIN_DECAY = 1.0
CONFIG.LANGEVIN_SCORE_COEFFICIENT = 1.0
CONFIG.LANGEVIN_COST_COEFFICIENT = 1.0

# learning parameters
CONFIG.MAX_STEPS = 200
CONFIG.BATCH_SIZE = 256
CONFIG.BASIC_NOISE_VAR = 1.0

CONFIG.DIM = 2
CONFIG.NUM = 3
CONFIG.ALPHAS = 1 / np.ones(CONFIG.NUM)
# CONFIG.ALPHAS = np.array([0.3, 0.7])

CONFIG.OUTPUT_SEED = 0xAB0BA

assert CONFIG.NUM == len(CONFIG.ALPHAS)

In [None]:
class TransformedL2Generic:

    def h(self, X):
        raise NotImplementedError()

    def h_inv(self, Z):
        raise NotImplementedError()

    def __init__(self):
        pass

    def dist_squared(self, X, Y):
        return self.dist(X, Y, squared=True)

    def dist(self, X, Y, squared=False):
        z_X = self.h(X)
        z_Y = self.h(Y)
        dist_squared = torch.sum((z_X - z_Y).pow(2), dim=-1)
        if squared:
            return dist_squared
        return torch.sqrt(dist_squared)

    def cdist(self, X, Y, squared=False):
        '''
        X: (bs, n, D)
        Y: (bs, m, D)
        '''
        z_X = self.h(X.flatten(start_dim=0, end_dim=1)).view(X.shape)
        z_Y = self.h(Y.flatten(start_dim=0, end_dim=1)).view(Y.shape)
        dists = torch.cdist(z_X, z_Y)
        if squared:
            return dists.pow(2)
        return dists

    def bary(self, Xs, alps):
        assert isinstance(Xs, list)
        assert len(Xs) == len(alps)
        alps = np.asarray(alps)
        alps /= np.sum(alps)
        baryZ = 0.
        for i in range(len(alps)):
            baryZ += self.h(Xs[i]) * alps[i]
        baryX = self.h_inv(baryZ)
        return baryX

class TransformedL2TwoMaps(TransformedL2Generic):

    def __init__(self, h, h_inv):
        '''
        h: X -> Z
        h_inv: Z -> X
        '''
        super().__init__()
        self.h = h
        self.h_inv = h_inv

In [None]:
H_SLOPE = 0.4

def norm2theta(norms):
    return H_SLOPE * norms

def rotate_batch(Rs, Xs):
    if len(Xs.shape) == 1:
        Xs = Xs[None]
    assert len(Xs.shape) == 2
    assert Xs.size(1) == 2
    assert Xs.size(0) == Rs.size(0)
    assert len(Rs.shape) == 3
    assert Rs.size(1) == Rs.size(2) == 2
    return torch.matmul(
        Xs.unsqueeze(1), 
        Rs.transpose(1, 2)).squeeze(1)

def cossin2R(cos, sin):
    assert cos.shape == sin.shape
    assert len(cos.shape) == 1
    return torch.stack([cos, -sin, sin, cos]).T.view(-1, 2, 2)
    
def lin_space_rotator(Xs, pos=True):
    if len(Xs.shape) == 1:
        Xs = Xs[None]
    assert len(Xs.shape) == 2
    assert Xs.size(1) == 2
    X_norms = torch.norm(Xs, dim=-1)
    thetas = norm2theta(X_norms)
    if not pos:
        thetas = - thetas
    cos = torch.cos(thetas)
    sin = torch.sin(thetas)
    Rs = cossin2R(cos, sin)
    return rotate_batch(Rs, Xs)

def h(Xs):
    return lin_space_rotator(Xs, pos=True)

def h_inv(Zs):
    return lin_space_rotator(Zs, pos=False)

In [None]:
tf = TransformedL2TwoMaps(h, h_inv)

In [None]:
assert torch.cuda.is_available()

if DEVICE != "cpu":
    torch.cuda.set_device(DEVICE)    

def seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.deterministic = True
    
seed(CONFIG.OUTPUT_SEED)

## Initializing distributionsCONFIG

In [None]:
Z1distrib = TD.Normal(
        torch.tensor([0., 4]).to(DEVICE),
        torch.tensor([1., 1.]).to(DEVICE))

Z2distrib = TD.Normal(
        torch.tensor([3.46, -2.]).to(DEVICE),
        torch.tensor([1., 1.]).to(DEVICE))

Z3distrib = TD.Normal(
        torch.tensor([-3.46, -2]).to(DEVICE),
        torch.tensor([1., 1.]).to(DEVICE))

Zgtdistrib = TD.Normal(
        torch.tensor([0.0, 0.0]).to(DEVICE),
        torch.tensor([1., 1.]).to(DEVICE))

# Nets

In [None]:
nets = [FullyConnectedMLP(CONFIG.DIM, [32, 32], 1).to(DEVICE) for _ in range(CONFIG.NUM)]
param_gens = [net.parameters() for net in nets]
opt = torch.optim.Adam(
    itertools.chain(*param_gens),
    lr=1e-2,
)

def make_f_pot(idx):
    def f_pot(x):
        res = 0.0
        for i, (net, alpha) in enumerate(zip(nets, CONFIG.ALPHAS)):
            if i == idx:
                res += net(x)
            else:
                res -= alpha * net(x) / (CONFIG.NUM - 1) / CONFIG.ALPHAS[idx]
        return res
    return f_pot

f_pots = [make_f_pot(i) for i in range(CONFIG.NUM)]

In [None]:
def cost_grad_y(y: torch.Tensor, x: torch.Tensor):
    def f(y_in, x_in):
#         return 0.5 * ((x_in - y_in) ** 2).sum(-1)
        return 0.5 * tf.dist_squared(x_in, y_in).squeeze()
    g = vmap(grad(f))
    return g(y, x)

def l2_grad_y(y, x):
    '''
    returns \nabla_y c(x, y)
    '''
    return y - x

grad_fn = l2_grad_y if CONFIG.USE_L2 else cost_grad_y

def cond_score(
        f : Callable[[torch.Tensor], torch.Tensor], 
        cost_grad_y_fn : Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 
        y : torch.Tensor,
        x : torch.Tensor,
        config: Config,
        ret_stats=False
    ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
    with torch.enable_grad():
        y.requires_grad_(True)
        proto_s = f(y)
        s = computePotGrad(y, proto_s)
        assert s.shape == y.shape
    cost_coeff = config.LANGEVIN_COST_COEFFICIENT * (config.LANGEVIN_SAMPLING_NOISE ** 2 / config.HREG)
    cost_part = cost_grad_y_fn(y, x) * cost_coeff
    score_part = s * config.LANGEVIN_SCORE_COEFFICIENT
    if not ret_stats:
        return score_part - cost_part
    return score_part - cost_part, cost_part, score_part

def sample_langevin_mu_f(
        f: Callable[[torch.Tensor], torch.Tensor], 
        x: torch.Tensor, 
        y_init: torch.Tensor, 
        config: Config
    ) -> torch.Tensor:
    
    def score(y, ret_stats=False):
        return cond_score(f, grad_fn, y, x, config, ret_stats=ret_stats)
    
    y, r_t, cost_r_t, score_r_t, noise_norm = sample_langevin_batch(
        score, 
        y_init,
        n_steps=config.ENERGY_SAMPLING_ITERATIONS, 
        decay=config.LANGEVIN_DECAY, 
        thresh=config.LANGEVIN_THRESH, 
        noise=config.LANGEVIN_SAMPLING_NOISE, 
        data_projector=lambda x: x, 
        compute_stats=True)
    
    return y

In [None]:
init_noise_sampler = Distrib2Sampler(TD.Normal(
    torch.zeros(CONFIG.DIM).to(DEVICE), 
    torch.ones(CONFIG.DIM).to(DEVICE) * CONFIG.BASIC_NOISE_VAR))

In [None]:
samplers = [Z1distrib, Z2distrib, Z3distrib]

In [None]:
def alpha_color(color_rgb, alpha=0.5):
    color_rgb = np.asanyarray(color_rgb)
    alpha_color_rgb = 1. - (1. - color_rgb) * alpha
    return alpha_color_rgb

In [None]:
def plot_bary_i(f_pot, sampler, ax, i, n_samples=512, n_maps=0, n_arrows_per_map=1):
    n_arrows = n_maps * n_arrows_per_map
    X = sampler.sample((n_samples,)).to(DEVICE)
    if n_maps > 0:
        Xm = sampler.sample((n_maps,)).to(DEVICE)
        Xm = torch.tile(Xm, (n_arrows_per_map, 1))
        X = torch.concatenate((X, Xm), dim=0)
        
    X = tf.h_inv(X)
    Y_init = init_noise_sampler.sample(n_samples + n_arrows).to(DEVICE)
    Y = sample_langevin_mu_f(f_pot, X, Y_init, CONFIG).to(DEVICE)
    X_np = X.detach().cpu().numpy()
    Y_np = Y.detach().cpu().numpy()
    
    def darker(c): return tuple(x * 0.85 for x in c)
    
    cols = mpl.colormaps["Dark2"].colors
    col_bary = mpl.colormaps["tab10"].colors[CONFIG.NUM]
    p4 = ax.scatter(
        X_np[:n_samples, 0], X_np[:n_samples, 1],
        edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), zorder=0, linewidth=.5,
    )
    p1 = ax.scatter(
        Y_np[:n_samples, 0], Y_np[:n_samples, 1],
        edgecolors=(0, 0, 0), color=col_bary, zorder=0, linewidth=.5,
    )

    if n_arrows > 0:
        p3 = ax.scatter(
            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],
            linewidth=.5, edgecolors='black', color=cols[i], zorder=2,
        )
        p2 = ax.scatter(
            Y_np[-n_arrows:, 0], Y_np[-n_arrows:, 1],
            linewidth=.5, edgecolors='black', color=cols[CONFIG.NUM + 2], zorder=2,
        )
        ax.quiver(
            X_np[-n_arrows:, 0], X_np[-n_arrows:, 1],
            Y_np[-n_arrows:, 0] - X_np[-n_arrows:, 0], Y_np[-n_arrows:, 1] - X_np[-n_arrows:, 1],
            angles='xy', scale_units='xy', scale=0.95, width=.005, zorder=1, headwidth=0.0, headlength=0.0,
        )
        
    ax.legend(
        [
            (p1, p2),
            (p3, p4),
        ], [
            f"$y \\sim \\pi_{i + 1}^{{f_{{\\theta^*,{i + 1}}}}}(\\cdot \\mid x_{i + 1})$",
            f"$x_{i + 1} \\sim \mathbb{{P}}_{i + 1}$",
        ],
        handler_map={tuple: HandlerTuple(ndivide=None)},
        loc="upper left",
        prop={"size": 13.5},
    )

In [None]:
def plot_distributions(ax, samplers):
    cols = mpl.colormaps["Dark2"].colors
    Xs = []
    for i, distr in enumerate(samplers):
        X = tf.h_inv(distr.sample((512,))).detach().cpu().numpy()
        Xs.append(X)
        ax.scatter(
            X[:, 0], X[:, 1],
            label=f"$x_{{{i + 1}}} \\sim \\mathbb{{P}}_{{{i + 1}}}$", 
            edgecolors=alpha_color((0, 0, 0)), color=alpha_color(cols[i]), linewidth=.5,
        )
    
    if CONFIG.DISCRETE_OT_FOR_GT:
        mw = [ot.unif(x.shape[0]) for x in Xs]
        Y_init = init_noise_sampler.sample((Xs[0].shape[0],)).detach().cpu().numpy()
        Xgt = ot.lp.free_support_barycenter(Xs, mw, Y_init)
    else:
        Xgt = tf.h_inv(Zgtdistrib.sample((512,))).detach().cpu().numpy()
    ax.scatter(
        Xgt[:, 0], Xgt[:, 1],
        label=r"$y \sim \mathbb{Q}_*$", 
        edgecolors='black', color=cols[CONFIG.NUM + 1], linewidth=.5,
    )
    ax.legend(ncol=2, loc="upper center", prop={"size": 13.5})

In [None]:
def plot_bary(potential_fns, samplers, arrows=True):
    N_SAMPLES = 512
    N_MAPS = 5
    N_ARROWS_PER_MAP = 3
    
    n_maps = N_MAPS if arrows else 0
    
    fig, axs = plt.subplots(
        ncols=CONFIG.NUM + 1,
        figsize=(15, 3.75),
        sharex=True, sharey=True,
        dpi=200,
    )
        
    plot_distributions(axs[0], samplers)
    axs[0].set_xlim(-7, 7)
    axs[0].set_ylim(-7, 7)
    
    for i, (f_pot, sampler, ax) in enumerate(zip(potential_fns, samplers, axs[1:])):
        plot_bary_i(f_pot, sampler, ax, i, N_SAMPLES, n_maps, N_ARROWS_PER_MAP)
    
    fig.tight_layout()
    return fig

In [None]:
SMDS = StatsManagerDrawScheduler(StatsManager('loss'), 1, 1, (5, 4), epoch_freq=10)
last_plot_it = -1
last_score_it = -1

for it in tqdm(range(CONFIG.MAX_STEPS)):
    Xs = [tf.h_inv(s.sample((CONFIG.BATCH_SIZE,))).to(DEVICE) for s in samplers]
    Ys_init = [init_noise_sampler.sample(CONFIG.BATCH_SIZE).to(DEVICE) for _ in range(CONFIG.NUM)]

    for net in nets: net.eval()
    with torch.no_grad():
        Ys = [sample_langevin_mu_f(f, X.to(DEVICE), Y_init, CONFIG) for f, X, Y_init in zip(f_pots, Xs, Ys_init)]

    for net in nets: net.train()
    loss = sum(alpha * f(Y).mean() for alpha, f, Y in zip(CONFIG.ALPHAS, f_pots, Ys))
    opt.zero_grad()
    loss.backward()
    opt.step()
    SMDS.SM.upd('loss', loss.item())
    SMDS.epoch()

In [None]:
seed(4)
plot_bary(f_pots, samplers)
plt.show()