In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

In [2]:
from pathlib import Path

import numpy as np
import torch
import torch.optim as o
import torch.optim.lr_scheduler as lr
from torchvision import transforms as t, datasets as d
from sklearn.model_selection import train_test_split

from src.costs import *
from src.distributions import *
from src.loggers import WandbLogger
from src.models.unet import unet_h
from src.plotters import ImagePlotter
from src.train import run_experiment
from src.utils import *
from src.models.resnet2 import ResNet_D, weights_init_D

In [3]:
np.random.seed(0);
torch.manual_seed(0);

In [7]:
LOGGER = WandbLogger(
    project="optimal-transport",
    entity="_devourer_",
    group="style-transfer",
    mode="offline",
)

PLOTTER = ImagePlotter(n_images=10,
                       n_samples=10,
                       plot_source=True)

CONFIG = dict(
    num_epochs=200,
    num_samples=64,
    num_steps_train=250,
    num_steps_eval=250,
    # optimizer_params=dict(
    #     lr=2e-5
    # ),
    # scheduler_params=dict(
    #     type=lr.CyclicLR,
    #     params=dict(
    #         base_lr=1e-4,
    #         max_lr=1e-2,
    #         mode="triangular"
    #     )
    # ),
)

# TRANSFORM = t.Compose([t.ToTensor(), t.Resize(32)])
TRANSFORM = t.ToTensor()

CHECKPOINT_DIR = Path("../checkpoints/")
if not CHECKPOINT_DIR.exists():
    CHECKPOINT_DIR.mkdir()

DEVICE = torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(2)
    DEVICE = torch.device("cuda")
print(DEVICE)

cuda


In [6]:
features, classes = load_dataset(d.FashionMNIST, root="../data/",
                                 transform=TRANSFORM)
source = TensorDatasetDistribution(features, classes, device=DEVICE)

features, classes = load_dataset(d.MNIST, root="../data/",
                                 transform=TRANSFORM)
target = TensorDatasetDistribution(features, classes, device=DEVICE)

p, q = source.event_shape, target.event_shape

  0%|          | 0/60000 [00:00<?, ?it/s]

  0%|          | 0/60000 [00:00<?, ?it/s]

In [9]:
critic = ResNet_D(size=32, nc=1).to(DEVICE).apply(weights_init_D)
mover = nn.Sequential(unet_h(source.event_shape, base_channels=48), nn.Tanh()).to(DEVICE)

In [18]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW(p, q, device=DEVICE),
    use_fid=False,
    name="test_run",
    plotter=PLOTTER,
    # logger=LOGGER,
    checkpoint_dir=CHECKPOINT_DIR,
    **CONFIG
)

Epoch:   0%|          | 0/6 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/25 [00:00<?, ?it/s]

self.name='test_run'
logger.name='test_run'
self.name='test_run'
logger.name='test_run'


Output()

Saving checkpoints to: ../checkpoints/test_run_epoch=0


In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW_conv(device=DEVICE),
    num_steps_cost=5,
    plotter=PLOTTER,
    logger=LOGGER,
    use_fid=False,
    **CONFIG
)

In [7]:
handbags = load_h5py("../data/handbag_64.hdf5", transform=TRANSFORM)
shoes = load_h5py("../data/shoes_64.hdf5", transform=TRANSFORM)

handbags_train, handbags_eval = train_test_split(handbags, test_size=.1)
shoes_train, shoes_eval = train_test_split(shoes, test_size=.1)

source = TensorDatasetDistribution(handbags_train,
                                   torch.zeros(handbags_train.size(0)),
                                   device=DEVICE,
                                   store_on_device=False)
target = TensorDatasetDistribution(shoes_train,
                                   torch.zeros(shoes_train.size(0)),
                                   device=DEVICE,
                                   store_on_device=False)

source_eval = TensorDatasetDistribution(handbags_eval,
                                        torch.zeros(handbags_eval.size(0)),
                                        device=DEVICE,
                                        store_on_device=False)
target_eval = TensorDatasetDistribution(shoes_eval,
                                        torch.zeros(shoes_eval.size(0)),
                                        device=DEVICE,
                                        store_on_device=False)
fid_mu, fid_sigma = get_inception_statistics(shoes_eval,
                                             128, verbose=True)

p, q = source.event_shape, target.event_shape

  0%|          | 0/138767 [00:00<?, ?it/s]

  0%|          | 0/50025 [00:00<?, ?it/s]



  0%|          | 0/40 [00:00<?, ?it/s]

In [8]:
critic = ResNet_D(size=64, nc=3).to(DEVICE).apply(weights_init_D)
mover = nn.Sequential(unet_h(source.event_shape, base_channels=48), nn.Tanh()).to(DEVICE)

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW(p, q, device=DEVICE),
    plotter=PLOTTER,
    logger=LOGGER,
    fid_mu=fid_mu,
    fid_sigma=fid_sigma,
    source_eval=source_eval,
    target_eval=target_eval,
    **CONFIG
)

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    cost=InnerGW_conv(device=DEVICE),
    name="innerGW_conv/bags->shoes/64",
    num_steps_cost=5,
    plotter=PLOTTER,
    logger=LOGGER,
    source_eval=source_eval,
    target_eval=target_eval,
    fid_mu=fid_mu,
    fid_sigma=fid_sigma,
    **CONFIG
)

Epoch:   0%|          | 0/600 [00:00<?, ?it/s]

Training:   0%|          | 0/250 [00:00<?, ?it/s]

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33m_devourer_[0m. Use [1m`wandb login --relogin`[0m to force relogin


Output()

