In [4]:
from cosmikyu import visualization as covis
from cosmikyu import gan, config
import numpy as np
import os
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import torch
import mlflow

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
data_dir = config.default_data_dir
mnist_dir = os.path.join(data_dir, 'mnist')
cuda = False
shape = (1,28,28)
latent_dim = 100
sample_interval = 1000
save_interval = 50000
batch_size = 64

In [6]:
# Configure data loader
os.makedirs(data_dir, exist_ok=True)
os.makedirs(mnist_dir, exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_dir,
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=batch_size,
    shuffle=True,
)


In [19]:
WGAN_GP = gan.WGAN_GP("mnist_wgan_gp", shape, latent_dim, cuda=True, ngpu=4)
mlflow.set_experiment(WGAN_GP.identifier)
with mlflow.start_run(experiment_id=WGAN_GP.experiment.experiment_id) as mlflow_run:
    WGAN_GP.train(
        dataloader,
        nepochs=200,
        ncritics=5,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=2e-04,
        betas=(0.5, 0.999),
        lambda_gp=10,
    )

loading saved states
failed to load saved states
[Epoch 0/1] [Batch 0/938] [D loss: -8.017841] [G loss: 0.038266]
saving states
[Epoch 0/1] [Batch 5/938] [D loss: -3.938023] [G loss: 0.052839]
[Epoch 0/1] [Batch 10/938] [D loss: 5.341599] [G loss: 0.099901]
[Epoch 0/1] [Batch 15/938] [D loss: 18.523251] [G loss: 0.235510]
[Epoch 0/1] [Batch 20/938] [D loss: 29.054474] [G loss: 0.487936]
[Epoch 0/1] [Batch 25/938] [D loss: 32.816338] [G loss: 0.767252]
[Epoch 0/1] [Batch 30/938] [D loss: 34.258266] [G loss: 1.010601]
[Epoch 0/1] [Batch 35/938] [D loss: 33.607262] [G loss: 1.282418]
[Epoch 0/1] [Batch 40/938] [D loss: 33.999588] [G loss: 1.532189]
[Epoch 0/1] [Batch 45/938] [D loss: 34.089283] [G loss: 1.851242]
[Epoch 0/1] [Batch 50/938] [D loss: 32.675175] [G loss: 2.212576]
[Epoch 0/1] [Batch 55/938] [D loss: 33.637203] [G loss: 2.539676]
[Epoch 0/1] [Batch 60/938] [D loss: 32.648293] [G loss: 2.939202]
[Epoch 0/1] [Batch 65/938] [D loss: 31.925140] [G loss: 3.363510]
[Epoch 0/1] [Bat

[Epoch 0/1] [Batch 615/938] [D loss: 4.429830] [G loss: 5.395204]
[Epoch 0/1] [Batch 620/938] [D loss: 4.344399] [G loss: 4.624345]
[Epoch 0/1] [Batch 625/938] [D loss: 4.474028] [G loss: 3.741048]
[Epoch 0/1] [Batch 630/938] [D loss: 4.463370] [G loss: 4.557129]
[Epoch 0/1] [Batch 635/938] [D loss: 4.331319] [G loss: 4.280912]
[Epoch 0/1] [Batch 640/938] [D loss: 4.225828] [G loss: 4.916999]
[Epoch 0/1] [Batch 645/938] [D loss: 4.767808] [G loss: 4.052751]
[Epoch 0/1] [Batch 650/938] [D loss: 4.462874] [G loss: 5.046508]
[Epoch 0/1] [Batch 655/938] [D loss: 4.748121] [G loss: 5.697442]
[Epoch 0/1] [Batch 660/938] [D loss: 5.078349] [G loss: 3.043800]
[Epoch 0/1] [Batch 665/938] [D loss: 4.550309] [G loss: 4.878983]
[Epoch 0/1] [Batch 670/938] [D loss: 4.955491] [G loss: 3.595835]
[Epoch 0/1] [Batch 675/938] [D loss: 4.905604] [G loss: 3.438066]
[Epoch 0/1] [Batch 680/938] [D loss: 4.262588] [G loss: 4.721101]
[Epoch 0/1] [Batch 685/938] [D loss: 5.101101] [G loss: 3.003175]
[Epoch 0/1

In [46]:

WGAN = gan.WGAN("mnist_wgan", shape, latent_dim, cuda=True, ngpu=4)
mlflow.set_experiment(WGAN.identifier)
with mlflow.start_run(experiment_id=WGAN.experiment.experiment_id) as mlflow_run:
    WGAN.train(
        dataloader,
        nepochs=1,
        ncritics=5,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=5e-05,
        clip_tresh=0.01,
    )

loading saved states
failed to load saved states
[Epoch 0/1] [Batch 0/938] [D loss: 0.149059] [G loss: -0.009395]
saving states
[Epoch 0/1] [Batch 5/938] [D loss: 0.115641] [G loss: -0.008499]
[Epoch 0/1] [Batch 10/938] [D loss: 0.367596] [G loss: -0.005231]
[Epoch 0/1] [Batch 15/938] [D loss: 0.805645] [G loss: 0.006488]
[Epoch 0/1] [Batch 20/938] [D loss: 1.372553] [G loss: 0.032683]
[Epoch 0/1] [Batch 25/938] [D loss: 1.984034] [G loss: 0.072270]
[Epoch 0/1] [Batch 30/938] [D loss: 2.680871] [G loss: 0.136101]
[Epoch 0/1] [Batch 35/938] [D loss: 3.365093] [G loss: 0.216362]
[Epoch 0/1] [Batch 40/938] [D loss: 4.089509] [G loss: 0.321988]
[Epoch 0/1] [Batch 45/938] [D loss: 4.789481] [G loss: 0.429416]
[Epoch 0/1] [Batch 50/938] [D loss: 5.334355] [G loss: 0.576705]
[Epoch 0/1] [Batch 55/938] [D loss: 6.014461] [G loss: 0.745393]
[Epoch 0/1] [Batch 60/938] [D loss: 6.722319] [G loss: 0.946569]
[Epoch 0/1] [Batch 65/938] [D loss: 7.354253] [G loss: 1.161500]
[Epoch 0/1] [Batch 70/938]

[Epoch 0/1] [Batch 610/938] [D loss: 1.620930] [G loss: 19.343304]
[Epoch 0/1] [Batch 615/938] [D loss: 2.029818] [G loss: 19.050636]
[Epoch 0/1] [Batch 620/938] [D loss: 1.368628] [G loss: 19.084351]
[Epoch 0/1] [Batch 625/938] [D loss: 1.486616] [G loss: 19.091068]
[Epoch 0/1] [Batch 630/938] [D loss: 1.521614] [G loss: 19.150620]
[Epoch 0/1] [Batch 635/938] [D loss: 1.952475] [G loss: 18.691364]
[Epoch 0/1] [Batch 640/938] [D loss: 1.375393] [G loss: 19.063080]
[Epoch 0/1] [Batch 645/938] [D loss: 1.558311] [G loss: 18.623796]
[Epoch 0/1] [Batch 650/938] [D loss: 1.367371] [G loss: 18.871212]
[Epoch 0/1] [Batch 655/938] [D loss: 1.621298] [G loss: 18.662251]
[Epoch 0/1] [Batch 660/938] [D loss: 1.390577] [G loss: 18.606510]
[Epoch 0/1] [Batch 665/938] [D loss: 1.539234] [G loss: 18.722187]
[Epoch 0/1] [Batch 670/938] [D loss: 1.320089] [G loss: 18.728756]
[Epoch 0/1] [Batch 675/938] [D loss: 1.086983] [G loss: 18.650928]
[Epoch 0/1] [Batch 680/938] [D loss: 1.483360] [G loss: 18.551