In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

In [5]:
from copy import deepcopy

import numpy as np

import torch
import torch.nn as nn
import torchvision.transforms as t

from torchsummary import summary

from src.distributions import *
from src.loggers import TensorBoardLogger, WandbLogger
from src.plotters import ImagePlotter
from src.utils import *
from src.costs import InnerGW_conv
from src.models.resnet import resnet18_d
from src.models.unet import unet_h
from src.train import train

In [3]:
if torch.cuda.is_available():
    torch.cuda.set_device(2)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tb_logger = TensorBoardLogger()
wandb_logger = WandbLogger(project="optimal-transport",
                           entity="_devourer_",
                           mode="offline")

In [4]:
shoes = load_h5py("../data/shoes_64.hdf5",
                  transform=t.Compose([t.ToTensor(), t.Resize(32)]))
source = TensorDatasetDistribution(shoes, torch.zeros(shoes.size(0)), device=DEVICE)
handbag = load_h5py("../data/handbag_64.hdf5",
                    transform=t.Compose([t.ToTensor(), t.Resize(32)]))
target = TensorDatasetDistribution(handbag, torch.zeros(handbag.size(0)), device=DEVICE)

  0%|          | 0/50025 [00:00<?, ?it/s]

  0%|          | 0/138767 [00:00<?, ?it/s]

In [11]:
critic = resnet18_d(target.event_shape).to(DEVICE)
summary(critic, target.event_shape, batch_size=512, device=DEVICE)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [512, 64, 16, 16]           9,408
       BatchNorm2d-2          [512, 64, 16, 16]             128
         LeakyReLU-3          [512, 64, 16, 16]               0
         MaxPool2d-4            [512, 64, 8, 8]               0
            Conv2d-5            [512, 64, 8, 8]          36,864
       BatchNorm2d-6            [512, 64, 8, 8]             128
         LeakyReLU-7            [512, 64, 8, 8]               0
            Conv2d-8            [512, 64, 8, 8]          36,864
       BatchNorm2d-9            [512, 64, 8, 8]             128
     EncoderBlock-10            [512, 64, 8, 8]               0
           Conv2d-11            [512, 64, 8, 8]          36,864
      BatchNorm2d-12            [512, 64, 8, 8]             128
        LeakyReLU-13            [512, 64, 8, 8]               0
           Conv2d-14            [512, 6

In [12]:
mover = unet_h(source.event_shape, base_channels=32).to(DEVICE)
summary(mover, source.event_shape, batch_size=512, device=DEVICE)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [512, 32, 32, 32]             864
       BatchNorm2d-2          [512, 32, 32, 32]              64
              ReLU-3          [512, 32, 32, 32]               0
            Conv2d-4          [512, 32, 32, 32]           9,216
       BatchNorm2d-5          [512, 32, 32, 32]              64
              ReLU-6          [512, 32, 32, 32]               0
        DoubleConv-7          [512, 32, 32, 32]               0
         MaxPool2d-8          [512, 32, 16, 16]               0
            Conv2d-9          [512, 64, 16, 16]          18,432
      BatchNorm2d-10          [512, 64, 16, 16]             128
             ReLU-11          [512, 64, 16, 16]               0
           Conv2d-12          [512, 64, 16, 16]          36,864
      BatchNorm2d-13          [512, 64, 16, 16]             128
             ReLU-14          [512, 64,

In [17]:
def run_experiment(source, target, mover, critic, cost, n_iter, *,
                   logger=None, **kwargs):
    if logger: 
        logger.start()
        logger.log_hparams(kwargs)
    try:
        train(source, target, mover, critic, cost,
              n_iter=n_iter,
              logger=logger,
              **kwargs)
    except KeyboardInterrupt:
        pass
    finally:
        if logger: logger.finish()

In [18]:
run_experiment(
    source, target, *copy_models(mover, critic),
    n_iter=5000,
    n_samples=128,
    cost=InnerGW_conv(
        optimizer_params=dict(lr=1e-3, weight_decay=1e-10),
        n_iter=10,
        device=DEVICE
    ),
    plotter=ImagePlotter(plot_interval=100, n_images=20, n_samples=2, plot_source=False),
    logger=wandb_logger,
    n_iter_mover=5,
    optimizer_params=dict(lr=1e-4, weight_decay=1e-10),
)

Output()

  0%|          | 0/5000 [00:00<?, ?it/s]




VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
cost,▁
loss,▁

0,1
cost,550466.0
loss,550466.75
