In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

In [2]:
import numpy as np
import torch
from torchvision import transforms as t, datasets as d

from src.costs import *
from src.distributions import *
from src.loggers import WandbLogger
from src.models.unet import unet_h
from src.plotters import ImagePlotter
from src.train import run_experiment
from src.utils import *
from src.models.resnet2 import ResNet_D, weights_init_D
import torch.optim as o
import torch.optim.lr_scheduler as lr

In [3]:
np.random.seed(0);
torch.manual_seed(0);

In [4]:
LOGGER = WandbLogger(
    project="optimal-transport",
    entity="_devourer_",
    group="style-transfer",
    # mode="offline",
)

PLOTTER = ImagePlotter(n_images=5,
                       n_samples=5,
                       plot_source=True)

CONFIG = dict(
    num_epochs=1000,
    num_samples=64,
    optimizer_params=dict(
        lr=2e-5
    ),
    # scheduler_params=dict(
    #     type=lr.CyclicLR,
    #     params=dict(
    #         base_lr=1e-4,
    #         max_lr=1e-2,
    #         mode="triangular"
    #     )
    # ),
)

TRANSFORM = t.Compose([t.ToTensor(), t.Resize(32)])

DEVICE = torch.device("cpu")

if torch.cuda.is_available():
    torch.cuda.set_device(2)
    DEVICE = torch.device("cuda")
print(DEVICE)

cuda


In [6]:
features, classes = load_dataset(d.FashionMNIST, root="../data/",
                                 transform=TRANSFORM)
source = TensorDatasetDistribution(features, classes, device=DEVICE)

features, classes = load_dataset(d.MNIST, root="../data/",
                                 transform=TRANSFORM)
target = TensorDatasetDistribution(features, classes, device=DEVICE)

p, q = source.event_shape, target.event_shape

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

In [7]:
critic = ResNet_D(size=32, nc=1).to(DEVICE).apply(weights_init_D)
mover = nn.Sequential(unet_h(source.event_shape, base_channels=48), nn.Tanh()).to(DEVICE)

In [15]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW(p, q, device=DEVICE),
    plotter=PLOTTER,
    logger=LOGGER,
    use_fid=False,
    **CONFIG
)

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW_conv(device=DEVICE),
    num_steps_cost=5,
    plotter=PLOTTER,
    logger=LOGGER,
    use_fid=False,
    **CONFIG
)

In [5]:
handbags = load_h5py("../data/handbag_64.hdf5", transform=TRANSFORM)
shoes = load_h5py("../data/shoes_64.hdf5", transform=TRANSFORM)

source = TensorDatasetDistribution(handbags, torch.zeros(handbags.size(0)),
                                   device=DEVICE)
target = TensorDatasetDistribution(shoes, torch.zeros(shoes.size(0)),
                                   device=DEVICE)
fid_mu, fid_sigma = get_inception_statistics(target.features, 128, verbose=True)

p, q = source.event_shape, target.event_shape

  0%|          | 0/138767 [00:00<?, ?it/s]

  0%|          | 0/50025 [00:00<?, ?it/s]



  0%|          | 0/391 [00:00<?, ?it/s]

In [6]:
critic = ResNet_D(size=32, nc=3).to(DEVICE).apply(weights_init_D)
mover = nn.Sequential(unet_h(source.event_shape, base_channels=48), nn.Tanh()).to(DEVICE)

In [12]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW(p, q, device=DEVICE),
    plotter=PLOTTER,
    logger=LOGGER,
    fid_mu=fid_mu,
    fid_sigma=fid_sigma,
    **CONFIG
)

Epoch:   0%|          | 0/1000 [00:00<?, ?it/s]

Training:   0%|          | 0/50 [00:00<?, ?it/s]

Validating:   0%|          | 0/50 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33m_devourer_[0m. Use [1m`wandb login --relogin`[0m to force relogin


Output()



VBox(children=(Label(value='0.481 MB of 0.481 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
FID,▁▅▄█▇▆▅▅
GW,█▂▁▂▃▃▂▃
epoch,▁▁▂▂▃▃▄▄▅▅▅▅▆▆▇▇█
train/cost,█▅▅▄▄▄▄▄▄▃▃▂▂▂▂▂▃▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▂▂▁▁▁
train/critic(h_x),▁▁▁▁▁▁▁▁▂▂▃▃▄▅███▆▅▆▄▂▅▂▅▂▄▂▂▂▁▂▂▁▁▁▁▁▁▂
train/critic(y),▁▁▁▁▁▁▁▁▂▂▃▃▄▆███▆▅▆▄▂▅▁▅▁▄▂▂▂▁▂▂▁▁▁▁▁▁▂
train/loss,▄▄▄▄▄▄▄▅▅▆▆▆██▆▆▃▅▄▄▄▄▂▂▂▂▁▃▃▃▄▄▃▃▄▄▄▄▄▄
train/step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
FID,36.80211
GW,310276.75188
epoch,8.0
train/cost,0.00624
train/critic(h_x),12.36443
train/critic(y),12.56699
train/loss,0.20879
train/step,446.0


In [11]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW_conv(device=DEVICE),
    num_steps_cost=5,
    plotter=PLOTTER,
    logger=LOGGER,
    fid_mu=fid_mu,
    fid_sigma=fid_sigma,
    **CONFIG
)

Epoch:   0%|          | 0/1000 [00:00<?, ?it/s]

Training:   0%|          | 0/50 [00:00<?, ?it/s]

Validating:   0%|          | 0/50 [00:00<?, ?it/s]

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Output()



VBox(children=(Label(value='0.300 MB of 0.300 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
FID,▁▄▅█▇
GW,█▂▁▂▂
train/cost,▇▅█▆▆▄▃▂▂▃▃▄▅▄▄▃▃▄▃▃▃▂▂▂▃▂▁▃▁▂▂▁▂▂▃▃▂▂▃▁
train/critic(h_x),▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▃▄▆▆████▇▆▆▅▅▄▄▄▃▃▃▂▃
train/critic(y),▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▃▄▄▆▆████▇▆▆▅▅▄▄▄▃▃▃▂▃
train/loss,▄▄▄▄▄▄▄▄▄▄▄▄▅▅▅▅▆▆███▇▆▄▁▃▅▅▄▄▄▃▃▄▄▄▄▄▄▃
train/step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███

0,1
FID,42.17827
GW,192904.62063
train/cost,0.00578
train/critic(h_x),45.73952
train/critic(y),45.42585
train/loss,-0.3079
train/step,263.0
