In [2]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
from copy import deepcopy

import numpy as np

import torch
import torch.nn as nn
import torchvision.transforms as t

from torchsummary import summary

from src.distributions import *
from src.loggers import TensorBoardLogger, WandbLogger
from src.plotters import ImagePlotter
from src.utils import *
from src.costs import InnerGW_opt, InnerGW_const
from src.models.resnet import resnet18_d, resnet18_g
from src.train import train


tb_logger = TensorBoardLogger()
wandb_logger = WandbLogger(project="optimal-transport",
                           entity="_devourer_",
                           mode="offline")

In [14]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(2)
print(DEVICE)

cuda


## Define source and target distributions

In [8]:
n_components = 10
locs = 2 * uniform_circle(n_components)
scales = .3 * torch.ones_like(locs)
source = GaussianMixture(locs, scales, device=DEVICE)

features, classes = load_mnist("../data/",
                               transform=t.Compose([t.Pad(2), t.ToTensor()]))
target = TensorDatasetDistribution(features, classes, device=DEVICE)
p, q = source.event_shape.numel(), target.event_shape.numel()

  0%|          | 0/60000 [00:00<?, ?it/s]

In [18]:
n_neurons = 128

critic = resnet18_d(target.event_shape).to(DEVICE)
summary(critic, target.event_shape, batch_size=512)

mover = resnet18_g(target.event_shape, p).to(DEVICE)
summary(mover, (p,), batch_size=512)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [512, 64, 16, 16]           3,136
       BatchNorm2d-2          [512, 64, 16, 16]             128
         LeakyReLU-3          [512, 64, 16, 16]               0
         MaxPool2d-4            [512, 64, 8, 8]               0
            Conv2d-5            [512, 64, 8, 8]          36,864
       BatchNorm2d-6            [512, 64, 8, 8]             128
         LeakyReLU-7            [512, 64, 8, 8]               0
            Conv2d-8            [512, 64, 8, 8]          36,864
       BatchNorm2d-9            [512, 64, 8, 8]             128
     EncoderBlock-10            [512, 64, 8, 8]               0
           Conv2d-11            [512, 64, 8, 8]          36,864
      BatchNorm2d-12            [512, 64, 8, 8]             128
        LeakyReLU-13            [512, 64, 8, 8]               0
           Conv2d-14            [512, 6

In [21]:
def run_experiment(source, target, mover, critic, cost, n_iter, *,
                   logger=None, **kwargs):
    if logger: logger.start()
    try:
        train(source, target, mover, critic, cost,
              n_iter=n_iter,
              logger=logger,
              **kwargs)
    except KeyboardInterrupt:
        pass
    finally:
        if logger: logger.finish()

In [28]:
run_experiment(
    source, target, *copy_models(mover, critic),
    n_iter=4000,
    n_samples=512,
    cost=InnerGW_opt(p, q,
        optimizer_params=dict(lr=2e-5, weight_decay=1e-10),
        n_iter=2,
        device=DEVICE
    ),
    # cost=InnerGW_const(V.to(DEVICE)),
    plotter=ImagePlotter(plot_interval=100, n_images=10),
    logger=wandb_logger,
    n_iter_mover=2,
    optimizer_params=dict(lr=2e-5, weight_decay=1e-10),
)

Output()

  0%|          | 0/4000 [00:00<?, ?it/s]




VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

UnboundLocalError: local variable 'interrupted' referenced before assignment