# Strong barycenter estimation on toy examples

In [None]:
import torch
from torch import nn
import torch.distributions as TD

import numpy as np
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerTuple
%matplotlib inline

from typing import Dict, Any, Literal, List, Tuple, Union, Optional
from tqdm import tqdm
import itertools
from copy import deepcopy

from IPython.display import clear_output

import sys
sys.path.append("../..")
from src.utils import Config, make_f_pot, freeze, unfreeze
from src.models import linear_model

In [None]:
def seed_everything(
    seed: int,
    *,
    avoid_benchmark_noise: bool = False,
    only_deterministic_algorithms: bool = False
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = not avoid_benchmark_noise
    torch.use_deterministic_algorithms(only_deterministic_algorithms, warn_only=True)

## Data

In [None]:
def sample_gauss(mu, cov, n):
    """
    mu - torch.Size([2])
    cov - torch.Size([2,2])
    n - int (amount of samples)
    """
    dist =  TD.MultivariateNormal(mu, cov)
    return dist.sample(torch.Size([n]))
        

In [None]:
def plot_initial_data(mus,covs,n):
    """
    mus - list of torch.Size([2])
    covs - list of torch.Size([2,2])
    n - int (amount of samples)
    """
    
    for idx,mu,cov in zip(range(len(mus)), mus,covs):
        d = sample_gauss(mu, cov, n)
        plt.scatter(d[:,0],d[:,1],edgecolor='black',label=f'distribution {idx+1}')
        plt.grid()
        plt.legend()

## Twister experiment

In [None]:
CONFIG = Config()

CONFIG.GPU_DEVICE = 0
assert torch.cuda.is_available()
CONFIG.DEVICE = f'cuda:{CONFIG.GPU_DEVICE}'

CONFIG.K = 3  # amount of distributions
CONFIG.LAMBDAS = [0.3333,0.3333,0.3333]
CONFIG.DIM = 2
CONFIG.INPUT_DIM = CONFIG.DIM
CONFIG.HIDDEN_DIMS = [128,128]
CONFIG.OUTPUT_DIM_POT = 1
CONFIG.OUTPUT_DIM_MAP = CONFIG.DIM
CONFIG.LR = 1e-3
CONFIG.NUM_SAMPLES = 10_000
CONFIG.NUM_EPOCHS = 1200
CONFIG.BATCH_SIZE= 1024
CONFIG.INNER_ITERATIONS = 3

CONFIG.PRIOR_MEAN = torch.tensor([5., 5.], device=CONFIG.DEVICE)
CONFIG.PRIOR_COV = 2 * torch.eye(2, device=CONFIG.DEVICE)
CONFIG.CONDITIONAL_COV = .1 * torch.eye(2, device=CONFIG.DEVICE)
CONFIG.KL_REG_STRENGTH = 0.0
CONFIG.ED_SAMPLE_REG_STRENGTH = 0.01

In [None]:
class MLP(nn.Module):
    def __init__(self, *hidden_dims: int):
        """Sequential linear layers with the ReLU activation.
        
        ReLU is applied between all layers. A number of layers equals
        `len(hidden_dims) - 1`. The first and the last hidden dims are treated as the 
        input and the output dimensions of the backbone.
        """
        assert len(hidden_dims) >= 2
        super().__init__()
        
        inp, *hidden_dims = hidden_dims
        self._layers = nn.Sequential(nn.Linear(inp, hidden_dims[0]))
        for inp, out in zip(hidden_dims[:-1], hidden_dims[1:]):
            self._layers.append(nn.ReLU(inplace=True))
            self._layers.append(nn.Linear(inp, out))
        
    def forward(self, x): return self._layers(x)

In [None]:
class OTMap(nn.Module):
    def __init__(
        self,
        inp_dim: int = None,
        hidden_dims: List[int] = None,
        out_dim: int = None,
        *args, **kwargs,
    ):
        """Initialize OT map class.
        
        Args:
            inp_dim: a dimensionality of the source space.
            out_dim: a dimensionality of the target space.
            hidden_dims: hidden dimensions.
        """
        super().__init__()
        
    def forward(
        self, 
        x: torch.FloatTensor,
        reg: bool = False,
    ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]:
        """Compute OT Map.
        
        If the map is weak, return one sample per input item.
        
        Args:
            x: tensor of shape (bs, inp_dim)
            reg: wether to return the regularization term
        
        Returns:
            tensor of shape (bs, out_dim) [and regularization term]
        """
        
        raise NotImplementedError

In [None]:
class DeterministicMap(OTMap):
    def __init__(self, inp_dim: int, hidden_dims: List[int], out_dim: int):
        super().__init__()
        self._bb = MLP(inp_dim, *hidden_dims, out_dim)
        
    def forward(self, x, reg: bool = False):
        out = self._bb(x)
        if reg:
            return out, torch.tensor(0.0, device=x.device)
        return out

In [None]:
class GaussianMap(OTMap):
    def __init__(
        self,
        inp_dim: int,
        hidden_dims: List[int],
        out_dim: int,
        prior_mean: torch.FloatTensor,
        prior_cov: torch.FloatTensor,
    ):
        super().__init__()
        self._out_dim = out_dim
        out_dim_combined = (
            out_dim                         # mean
            + out_dim * (out_dim + 1) // 2  # covariance matrix
        )
        self._bb = MLP(inp_dim, *hidden_dims, out_dim_combined)
        self._m_pr = prior_mean
        self._cov_pr = prior_cov
        
    def forward(self, x, reg: bool = False):
        bs = x.shape[0]
        dev = x.device
        
        mean_cov = self._bb(x)
        
        mean = mean_cov[:, :self._out_dim]
        cov_l = torch.zeros(
            (bs, self._out_dim, self._out_dim),
            device=dev,
        )
        tril_idx = torch.tril_indices(self._out_dim, self._out_dim, device=dev)
        cov_l[:, tril_idx[0], tril_idx[1]] = mean_cov[:, self._out_dim:]
        cov = cov_l @ cov_l.mT
        
        noise = torch.randn(bs, self._out_dim, device=dev)
        out = mean + torch.einsum("...ij,...j->...i", cov_l, noise)
        kl = self.kl_reg(mean, cov, self._m_pr, self._cov_pr)
        
        if reg:
            return out, kl
        return out
        
    
    # tested with torch.distributions.kl_divergence
    @staticmethod
    def kl_reg(
        m_post: torch.FloatTensor,
        cov_post: torch.FloatTensor,
        m_pr: torch.FloatTensor,
        cov_pr: torch.FloatTensor,
    ):
        """
        means have shape (bs, k), covariance matrices have shape (k, k)
        """

        _, pr_log_det = torch.linalg.slogdet(cov_pr)
        _, post_log_det = torch.linalg.slogdet(cov_post)
        mean_diff = m_pr - m_post  # (bs, k)
        mT_sigma_pr_inv = torch.linalg.solve(cov_pr, mean_diff, left=False)
        assert mT_sigma_pr_inv.shape == mean_diff.shape  # (bs, k)

        return 0.5 * (
            torch.einsum("...ii", torch.linalg.solve(cov_pr, cov_post))
            + (mT_sigma_pr_inv * mean_diff).sum(1)
            - m_pr.shape[-1]
            + pr_log_det - post_log_det
        )

In [None]:
class NoiseInputMap(OTMap):
    def __init__(
        self,
        inp_dim: int,
        hidden_dims: List[int],
        out_dim: int,
        prior: torch.distributions.Distribution,
        noise_dim: Optional[int] = None,
    ):
        super().__init__()
        self._noise_dim = noise_dim or inp_dim
        self._prior = prior
        self._bb = MLP(inp_dim + self._noise_dim, *hidden_dims, out_dim)
        
    def forward(self, x, reg: bool = False):
        bs = x.shape[0]
        dev = x.device
        
        noise = torch.randn(bs, self._noise_dim, device=dev)
        x = torch.cat((x, noise), dim=-1)
        out = self._bb(x)
        ed = self.energy_dist_reg_sample(out)
        
        if reg:
            return out, ed
        return out
        
    def energy_dist_reg_sample(
        self,
        sample: torch.FloatTensor,
    ):
        """Compute energy distance (only sample-dependent terms) using sample estimate.

        Args:
            sample: has shape (bs, d)
            prior: torch distribution of item shape (d,)

        Returns:
            tensor of shape (bs,)
        """
        pr_sample_1, pr_sample_2 = self._prior.sample((2, *sample.shape[:-1]))
        l12 = (sample - pr_sample_1).norm(dim=1)
        l11 = (pr_sample_1 - pr_sample_2).norm(dim=1)
        return 2 * l12 - l11

In [None]:
H_SLOPE = 0.4

def norm2theta(norms):
    return H_SLOPE * norms

def rotate_batch(Rs, Xs):
    if len(Xs.shape) == 1:
        Xs = Xs[None]
    assert len(Xs.shape) == 2
    assert Xs.size(1) == 2
    assert Xs.size(0) == Rs.size(0)
    assert len(Rs.shape) == 3
    assert Rs.size(1) == Rs.size(2) == 2
    return torch.matmul(
        Xs.unsqueeze(1), 
        Rs.transpose(1, 2)).squeeze(1)

def cossin2R(cos, sin):
    assert cos.shape == sin.shape
    assert len(cos.shape) == 1
    return torch.stack([cos, -sin, sin, cos]).T.view(-1, 2, 2)
    
def lin_space_rotator(Xs, pos=True):
    if len(Xs.shape) == 1:
        Xs = Xs[None]
    assert len(Xs.shape) == 2
    assert Xs.size(1) == 2
    X_norms = torch.norm(Xs, dim=-1)
    thetas = norm2theta(X_norms)
    if not pos:
        thetas = - thetas
    cos = torch.cos(thetas)
    sin = torch.sin(thetas)
    Rs = cossin2R(cos, sin)
    return rotate_batch(Rs, Xs)

def h(Xs):
    return lin_space_rotator(Xs, pos=True)

def h_inv(Zs):
    return lin_space_rotator(Zs, pos=False)

In [None]:
Z1distrib = TD.Normal(
        torch.tensor([0., 4]).to(CONFIG.DEVICE),
        torch.tensor([1., 1.]).to(CONFIG.DEVICE))

Z2distrib = TD.Normal(
        torch.tensor([3.46, -2.]).to(CONFIG.DEVICE),
        torch.tensor([1., 1.]).to(CONFIG.DEVICE))

Z3distrib = TD.Normal(
        torch.tensor([-3.46, -2]).to(CONFIG.DEVICE),
        torch.tensor([1., 1.]).to(CONFIG.DEVICE))

Zgtdistrib = TD.Normal(
        torch.tensor([0.0, 0.0]).to(CONFIG.DEVICE),
        torch.tensor([1., 1.]).to(CONFIG.DEVICE))

twister_data = [Z1distrib, Z2distrib, Z3distrib]

In [None]:
def plot_initial_data(n):
    """
    mus - list of torch.Size([2])
    covs - list of torch.Size([2,2])
    n - int (amount of samples)
    """
    
    for idx,k in enumerate(range(CONFIG.K)):
        d = twister_data[k].sample([n])
        d = h_inv(d)
        plt.scatter(d[:,0].cpu(),d[:,1].cpu(),edgecolor='black',label=f'distribution {idx+1}')
        plt.axis("equal")
        plt.grid()
        plt.legend()

In [None]:
plot_initial_data(2_000)

In [None]:
seed_everything(0, avoid_benchmark_noise=True)

In [None]:
class Pots(nn.Module):
    
    # TODO: optimize when 2 potentials
    def __init__(self, bary_weights, *dims):
        assert len(bary_weights) > 1
        super().__init__()
        self._lambdas = bary_weights
        self._nets = nn.ModuleList([MLP(*dims) for _ in range(len(bary_weights))])
        
    def __getitem__(self, idx):
        assert 0 <= idx < len(self._lambdas)
        
        def f_pot(x):
            res = 0.0
            for i, (net, lmbd) in enumerate(zip(self._nets, self._lambdas)):

                if i == idx:
                    res += net(x)
                else:
                    res -= lmbd * net(x) / (len(self._lambdas) - 1) / self._lambdas[idx]
            return res
        
        return f_pot

In [None]:
def get_opt_sched(model, total_steps):
    opt = torch.optim.Adam(model.parameters(), CONFIG.LR)
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt,
        CONFIG.LR,
        total_steps=total_steps,
    )
    
    return opt, sched

In [None]:
def train(
    maps: OTMap, maps_opt, maps_sched, 
    pots: Pots, pots_opt, pots_sched,
    reg_coeff: float = 0.0,
):
    
    losses = []
    for epoch in tqdm(range(CONFIG.NUM_EPOCHS)):
        
        freeze(pots)
        unfreeze(maps)
        
        #inner loop
        for it in range(CONFIG.INNER_ITERATIONS):
            
            # data sampling
            data = [
                h_inv(twister_data[k].sample([CONFIG.BATCH_SIZE])).to(CONFIG.DEVICE)
                for k in range(CONFIG.K)
            ]
            
            maps_opt.zero_grad()
            loss = 0
            for k in range(CONFIG.K):
                mapped_x_k, reg = maps[k](data[k], reg=True)  # [B, N]
                cost = strong_cost(h(data[k]), h(mapped_x_k))  # [B, 1]
                cost -= pots[k](mapped_x_k)  # [B, 1]
                cost += reg_coeff * torch.unsqueeze(reg, -1)
                cost = cost.mean(dim=0)
                loss += CONFIG.LAMBDAS[k] * cost

            loss.backward()
            maps_opt.step()
            maps_sched.step()
        
        # unfreezing potentials 
        # freezing maps
        unfreeze(pots)
        freeze(maps)
        
        # outer optimiztion
        pots_opt.zero_grad()
        loss=0
        cost = None
        for k in range(CONFIG.K):
            mapped_x_k = maps[k](data[k])  # [B, N]
            cost = torch.zeros((CONFIG.BATCH_SIZE, 1), device=CONFIG.DEVICE)
            cost -= pots[k](mapped_x_k)  # [B, 1]
            cost = cost.mean(dim=0)
            loss += CONFIG.LAMBDAS[k] * cost
        
        loss = -1*loss
        losses.append(loss.item())
        loss.backward()
        pots_opt.step()
        pots_sched.step()
        
        
        # plotting part
        if epoch % 10 ==0 :
            data = [h_inv(twister_data[k].sample([1_000])).to(CONFIG.DEVICE)
                        for k in range(CONFIG.K)]
            
            clear_output(wait=True)
            fig, (ax, ax_l) = plt.subplots(1, 2, figsize=(12.8, 4.8))
            for k in range(CONFIG.K):
                d = maps[k](data[k]).detach().cpu()
                ax.scatter(data[k][:,0].cpu(),data[k][:,1].cpu(),edgecolor='black',label=f'data {k+1}')
                ax.scatter(d[:,0],d[:,1],edgecolor='black',label=f'barycenter {k+1}')
                ax.grid()
                ax.legend()
                ax.set_xlim(-8, 8)
                ax.set_ylim(-8, 8)
                
            ax_l.plot(losses)
            plt.show()

In [None]:
seed_everything(0, avoid_benchmark_noise=True)

maps_ur = nn.ModuleList([
    DeterministicMap(CONFIG.INPUT_DIM, CONFIG.HIDDEN_DIMS, CONFIG.OUTPUT_DIM_MAP)
    for _ in range(CONFIG.K)
]).to(CONFIG.DEVICE)
maps_opt, maps_sched = get_opt_sched(maps_ur, CONFIG.NUM_EPOCHS * CONFIG.INNER_ITERATIONS)

pots_ur = Pots(
    CONFIG.LAMBDAS,
    CONFIG.INPUT_DIM,
    *CONFIG.HIDDEN_DIMS,
    CONFIG.OUTPUT_DIM_POT
).to(CONFIG.DEVICE)
pots_opt, pots_sched = get_opt_sched(pots_ur, CONFIG.NUM_EPOCHS)

train(maps_ur, maps_opt, maps_sched, pots_ur, pots_opt, pots_sched)

In [None]:
seed_everything(0, avoid_benchmark_noise=True)

maps_kl_1 = nn.ModuleList([
    GaussianMap(
        CONFIG.INPUT_DIM,
        CONFIG.HIDDEN_DIMS,
        CONFIG.OUTPUT_DIM_MAP,
        CONFIG.PRIOR_MEAN,
        CONFIG.PRIOR_COV,
    )
    for _ in range(CONFIG.K)
]).to(CONFIG.DEVICE)
maps_opt, maps_sched = get_opt_sched(maps_kl_1, CONFIG.NUM_EPOCHS * CONFIG.INNER_ITERATIONS)

pots_kl_1 = Pots(
    CONFIG.LAMBDAS,
    CONFIG.INPUT_DIM,
    *CONFIG.HIDDEN_DIMS,
    CONFIG.OUTPUT_DIM_POT
).to(CONFIG.DEVICE)
pots_opt, pots_sched = get_opt_sched(pots_kl_1, CONFIG.NUM_EPOCHS)

train(maps_kl_1, maps_opt, maps_sched, pots_kl_1, pots_opt, pots_sched, 1.0)

In [None]:
seed_everything(0, avoid_benchmark_noise=True)

maps_ed_1 = nn.ModuleList([
    GaussianMap(
        CONFIG.INPUT_DIM,
        CONFIG.HIDDEN_DIMS,
        CONFIG.OUTPUT_DIM_MAP,
        CONFIG.PRIOR_MEAN,
        CONFIG.PRIOR_COV,
    )
    for _ in range(CONFIG.K)
]).to(CONFIG.DEVICE)
maps_opt, maps_sched = get_opt_sched(maps_ed_1, CONFIG.NUM_EPOCHS * CONFIG.INNER_ITERATIONS)

pots_ed_1 = Pots(
    CONFIG.LAMBDAS,
    CONFIG.INPUT_DIM,
    *CONFIG.HIDDEN_DIMS,
    CONFIG.OUTPUT_DIM_POT
).to(CONFIG.DEVICE)
pots_opt, pots_sched = get_opt_sched(pots_ed_1, CONFIG.NUM_EPOCHS)

train(maps_ed_1, maps_opt, maps_sched, pots_ed_1, pots_opt, pots_sched, 1.0)

In [None]:
def plot_gaussian_pdf(mean, covar, n_points, ax):
#     span = 10 * torch.sqrt(torch.diag(covar).max())
#     x = mean[0] + torch.linspace(-1, 1, n_points) * span
#     y = mean[1] + torch.linspace(-1, 1, n_points) * span
    x = torch.linspace(*ax.get_xlim(), n_points)
    y = torch.linspace(*ax.get_ylim(), n_points)
    X, Y = torch.meshgrid(x, y)
    distr = TD.MultivariateNormal(mean, covariance_matrix=covar)
    Z = torch.sqrt(-distr.log_prob(torch.stack((X, Y), axis=-1)))
    
    ax.contour(X, Y, Z, levels=10, alpha=0.3, linewidths=1)

In [None]:
def plot_bary_i(map_nets, samplers, ax, i, n_samples=512, n_maps=0, n_arrows_per_map=1, seed=0):
    seed_everything(seed, avoid_benchmark_noise=True)
    
    n_arrows = n_maps * n_arrows_per_map
    X = h_inv(samplers[i].sample((n_samples,))).to(CONFIG.DEVICE)
    if n_maps > 0:
        Xm = h_inv(samplers[i].sample((n_maps,))).to(CONFIG.DEVICE)
        Xm = torch.tile(Xm, (n_arrows_per_map, 1))
        X = torch.cat((X, Xm), dim=0)
        
    Y = map_nets[i](X)
    X_np = X.detach().cpu().numpy()
    Y_np = Y.detach().cpu().numpy()
        
    def alpha_color(color_rgb, alpha=0.5):
        color_rgb = np.asanyarray(color_rgb)
        alpha_color_rgb = 1. - (1. - color_rgb) * alpha
        return alpha_color_rgb
    
    def darker(c): return tuple(x * 0.85 for x in c)
    
    cols = mpl.colormaps["tab10"].colors
    col_bary = alpha_color(mpl.colormaps["tab10"].colors[CONFIG.K])
    
    p4 = ax.scatter(
        X_np[:n_samples, 0], 
        X_np[:n_samples, 1],
        edgecolors=alpha_color((0, 0, 0)),
        color=alpha_color(cols[i]),
        zorder=0,
        linewidth=.5,
    )
    p1 = ax.scatter(
        Y_np[:n_samples, 0], 
        Y_np[:n_samples, 1],
        edgecolors=(0, 0, 0), 
        color=col_bary,
        zorder=0,
        linewidth=.5,
    )

    if n_arrows > 0:
        p3 = ax.scatter(
            X_np[-n_arrows:, 0],
            X_np[-n_arrows:, 1],
            linewidth=.5,
            edgecolors='black', 
            color=cols[i], 
            zorder=2,
        )
        p2 = ax.scatter(
            Y_np[-n_arrows:, 0], 
            Y_np[-n_arrows:, 1],
            linewidth=.5, 
            edgecolors='black', 
            color=cols[CONFIG.K + 2],
            zorder=2,
        )
        ax.quiver(
            X_np[-n_arrows:, 0], 
            X_np[-n_arrows:, 1],
            Y_np[-n_arrows:, 0] - X_np[-n_arrows:, 0],
            Y_np[-n_arrows:, 1] - X_np[-n_arrows:, 1],
            angles='xy', 
            scale_units='xy', 
            scale=0.95, 
            width=.005, 
            zorder=1, 
            headwidth=0.0,
            headlength=0.0,
        )
        
#     fig.legend(
#         [
#             (p1, p2),
#             (p3, p4),
#         ],
#         [
#             f"$y \\sim T_{{{i + 1},\\phi}}(x_{i + 1},\\cdot)\\#\\mathbb{{S}}$",
#             f"$x_{i + 1} \\sim \mathbb{{P}}_{i + 1}$",
#         ],
#         handler_map={tuple: HandlerTuple(ndivide=None)},
#         loc="upper left",
#         prop={"size": 13.5},
#     )

In [None]:
BARY_IDX = 0
N_SAMPLES = 128
N_MAPS = 10
N_ARROWS_PER_MAP = 3

fig, (ax_gt, ax_ur, ax_kl_1, ax_ed_1) = plt.subplots(
    ncols=4,
    figsize=(13, 3.2), 
    dpi=300,
    sharey=True,
    sharex=True,
)

for i, s in enumerate(twister_data):
    ax_gt.scatter(
        *h_inv(s.sample((512,))).cpu().T,
        linewidth=.5, 
        edgecolors='black',
        label=f"$x_{i + 1} \\sim \mathbb{{P}}_{i + 1}$",
    )
ax_gt.scatter(
    *h_inv(Zgtdistrib.sample((512,))).cpu().T,
    linewidth=.5, 
    edgecolors='black',
    label="$y \\sim \\mathbb{Q}^*$",
)
ax_gt.set_xlim((-8, 8))
ax_gt.set_ylim((-8, 8))

plot_bary_i(maps_ur, twister_data, ax_ur, BARY_IDX, N_SAMPLES, N_MAPS, N_ARROWS_PER_MAP)
plot_gaussian_pdf(CONFIG.PRIOR_MEAN.cpu(), CONFIG.PRIOR_COV.cpu(), 50, ax_ur)

plot_bary_i(maps_kl_1, twister_data, ax_kl_1, BARY_IDX, N_SAMPLES, N_MAPS, N_ARROWS_PER_MAP)
plot_gaussian_pdf(CONFIG.PRIOR_MEAN.cpu(), CONFIG.PRIOR_COV.cpu(), 50, ax_kl_1)

plot_bary_i(maps_ed_1, twister_data, ax_ed_1, BARY_IDX, N_SAMPLES, N_MAPS, N_ARROWS_PER_MAP)
plot_gaussian_pdf(CONFIG.PRIOR_MEAN.cpu(), CONFIG.PRIOR_COV.cpu(), 50, ax_ed_1)

fig.legend(loc="upper center")

plt.tight_layout(pad=1)