In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

In [2]:
from copy import deepcopy

import numpy as np

import torch
import torch.nn as nn
import torchvision.transforms as t

from torchsummary import summary

from src.distributions import *
from src.loggers import TensorBoardLogger, WandbLogger
from src.plotters import ImagePlotter
from src.utils import *
from src.costs import InnerGW_conv
from src.models.resnet2 import ResNet_D, weights_init_D
from src.models.unet import unet_h
from src.train import train

In [11]:
if torch.cuda.is_available():
    torch.cuda.set_device(3)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

tb_logger = TensorBoardLogger()
wandb_logger = WandbLogger(project="optimal-transport",
                           group="style-transfer",
                           entity="_devourer_")

In [4]:
shoes = load_h5py("../data/shoes_64.hdf5",
                  transform=t.Compose([t.ToTensor(), t.Resize(32)]))
source = TensorDatasetDistribution(shoes, torch.zeros(shoes.size(0)), device=DEVICE)
handbag = load_h5py("../data/handbag_64.hdf5",
                    transform=t.Compose([t.ToTensor(), t.Resize(32)]))
target = TensorDatasetDistribution(handbag, torch.zeros(handbag.size(0)), device=DEVICE)

  0%|          | 0/50025 [00:00<?, ?it/s]

  0%|          | 0/138767 [00:00<?, ?it/s]

In [5]:
# critic = resnet18_d(target.event_shape).to(DEVICE)
# summary(critic, target.event_shape, batch_size=512, device=DEVICE)

critic = ResNet_D(size=32, nc=3).to(DEVICE).apply(weights_init_D)
summary(critic, target.event_shape, batch_size=256)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [256, 64, 32, 32]           1,792
         LeakyReLU-2          [256, 64, 32, 32]               0
            Conv2d-3          [256, 64, 32, 32]          36,928
         LeakyReLU-4          [256, 64, 32, 32]               0
            Conv2d-5          [256, 64, 32, 32]          36,928
         LeakyReLU-6          [256, 64, 32, 32]               0
       ResNetBlock-7          [256, 64, 32, 32]               0
            Conv2d-8         [256, 128, 32, 32]           8,192
            Conv2d-9          [256, 64, 32, 32]          36,928
        LeakyReLU-10          [256, 64, 32, 32]               0
           Conv2d-11         [256, 128, 32, 32]          73,856
        LeakyReLU-12         [256, 128, 32, 32]               0
      ResNetBlock-13         [256, 128, 32, 32]               0
        AvgPool2d-14         [256, 128,

In [8]:
mover = nn.Sequential(unet_h(source.event_shape, base_channels=32), nn.Sigmoid()).to(DEVICE)
summary(mover, source.event_shape, batch_size=512, device=DEVICE)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [512, 32, 32, 32]             864
       BatchNorm2d-2          [512, 32, 32, 32]              64
              ReLU-3          [512, 32, 32, 32]               0
            Conv2d-4          [512, 32, 32, 32]           9,216
       BatchNorm2d-5          [512, 32, 32, 32]              64
              ReLU-6          [512, 32, 32, 32]               0
        DoubleConv-7          [512, 32, 32, 32]               0
         MaxPool2d-8          [512, 32, 16, 16]               0
            Conv2d-9          [512, 64, 16, 16]          18,432
      BatchNorm2d-10          [512, 64, 16, 16]             128
             ReLU-11          [512, 64, 16, 16]               0
           Conv2d-12          [512, 64, 16, 16]          36,864
      BatchNorm2d-13          [512, 64, 16, 16]             128
             ReLU-14          [512, 64,

In [14]:
def run_experiment(source, target, mover, critic, cost, n_iter, *,
                   logger=None, **kwargs):
    if logger: 
        logger.start()
        logger.log_hparams(kwargs)
    try:
        train(source, target, mover, critic, cost,
              n_iter=n_iter,
              logger=logger,
              **kwargs)
    except KeyboardInterrupt:
        pass
    finally:
        if logger: logger.finish()

In [None]:
run_experiment(
    source, target, *copy_models(mover, critic),
    n_iter=10000,
    n_samples=64,
    cost=InnerGW_conv(
        optimizer_params=dict(lr=2e-4, weight_decay=1e-10),
        device=DEVICE
    ),
    plotter=ImagePlotter(plot_interval=100, n_images=20, n_samples=2, plot_source=False),
    logger=wandb_logger,
    optimizer_params=dict(lr=2e-4, weight_decay=1e-10),
)

Output()

  0%|          | 0/10000 [00:00<?, ?it/s]