In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

In [None]:
import torch

from src.costs import *
from src.distributions import *
from src.loggers import WandbLogger
from src.models.simple import mlp
from src.plotters import Plotter
from src.train import run_experiment, Experiment
from src.utils import *


In [None]:
np.random.seed(0);
torch.manual_seed(0);

In [None]:
logger = WandbLogger(
    project="optimal-transport",
    entity="_devourer_",
    group="test",
    mode="offline",
)

plotter = Plotter(
    plot_target=True,
    plot_critic=False,
    plot_arrows=False,
    n_samples=512
)

config = dict(
    num_epochs=100,
    num_samples=512,
    num_steps_mover=10,
    num_steps_critic=1,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(2)
print(DEVICE)

### Case \#1: Gaussian to Gaussian

In [None]:
source = to_composite(Normal(torch.tensor([3., 0.]), torch.tensor([1., 2.]), device=DEVICE))
target = Normal(torch.tensor([-3., 0.]), torch.tensor([1., 2.]), device=DEVICE)
p, q = source.event_shape.numel(), target.event_shape.numel()

critic = mlp(q, hidden_size=64).to(DEVICE)
mover = mlp(p, q, hidden_size=64).to(DEVICE)

#### Fixed $P$ cost

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW(p, q, device=DEVICE),
    plotter=plotter,
    # logger=logger,
    **config
)

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW(p, q, device=DEVICE),
    plotter=plotter,
    logger=logger,
    **config
)

### Kernel cost

In [None]:
run_experiment(
    source, target, mover, critic,
    cost=innerGW_kernel(kernel_1, source, mover, n_samples_mc=512),
    plotter=plotter,
    # logger=logger,
    num_steps_mover=5,
    **config
)

#### Trainable $P$ cost

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW_opt(p, q, device=DEVICE),
    plotter=plotter,
    # logger=logger,
    num_steps_cost=10,
    **config
)

### Case \#3: 3D-GMM to 2D-GMM same components

In [None]:
n_components = 10

locs_3d = 2 * fibonacci_sphere(n_components)
scales_3d = .1 * torch.ones_like(locs_3d)
source = GaussianMixture(locs_3d, scales_3d, device=DEVICE)

locs_2d = uniform_circle(n_components)
scales_2d = .1 * torch.ones_like(locs_2d)
target = GaussianMixture(locs_2d, scales_2d, device=DEVICE)

p, q = source.event_shape.numel(), target.event_shape.numel()

critic = mlp(q, hidden_size=64).to(DEVICE)
mover = mlp(p, q, hidden_size=64).to(DEVICE)

#### Fixed $P$ cost

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW(p, q, device=DEVICE),
    plotter=plotter,
    # logger=logger,
    **config
)

#### Trainable $P$ cost

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW_opt(p, q, device=DEVICE),
    plotter=plotter,
    # logger=logger,
    num_steps_cost=10,
    **config
)

### Kernel cost

In [None]:
run_experiment(
    source, target, mover, critic,
    cost=innerGW_kernel(kernel_1, source, mover, n_samples_mc=512),
    plotter=plotter,
    # logger=logger
    num_steps_mover=5,
    **config
)