In [1]:
from cosmikyu import visualization as covis
from cosmikyu import gan, config
import numpy as np
import os
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import torch
import mlflow
import torchsummary

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
data_dir = config.default_data_dir
mnist_dir = os.path.join(data_dir, 'mnist')
cuda = True
shape = (1,32,32)
latent_dim = 64
sample_interval = 1000
save_interval = 50000
batch_size = 64
nepochs=2

In [3]:
# Configure data loader
os.makedirs(data_dir, exist_ok=True)
os.makedirs(mnist_dir, exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_dir,
        train=True,
        download=True,
        transform=transforms.Compose([transforms.Resize(shape[-1]), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=batch_size,
    shuffle=True,
)


In [5]:
COSMOGAN = gan.COSMOGAN("mnist_cosmogan", shape, latent_dim, cuda=cuda, ngpu=4)
mlflow.set_experiment(COSMOGAN.identifier)          
with mlflow.start_run(experiment_id=COSMOGAN.experiment.experiment_id) as mlflow_run:
    torch.cuda.empty_cache()
    COSMOGAN.train(
        dataloader,
        nepochs=nepochs,
        ncritics=1,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=2e-04,
        betas=(0.5, 0.999)
    )

loading saved states
failed to load saved states


KeyboardInterrupt: 

In [5]:
DCGAN = gan.DCGAN("mnist_cosmogan", shape, latent_dim, cuda=cuda, ngpu=4, nconv_layer_gen=2, nconv_layer_disc=2, nconv_fcgen=32, nconv_fcdis=32)
mlflow.set_experiment(DCGAN.identifier)          
with mlflow.start_run(experiment_id=DCGAN.experiment.experiment_id) as mlflow_run:
    torch.cuda.empty_cache()
    DCGAN.train(
        dataloader,
        nepochs=nepochs,
        ncritics=1,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=False,
        mlflow_run=mlflow_run,
        lr=2e-04,
        betas=(0.5, 0.999)
    )

loading saved states
failed to load saved states
saving states
saving states


In [6]:
WGAN_GP = gan.WGAN_GP("mnist_wgan_gp", shape, latent_dim, cuda=cuda, ngpu=4)
mlflow.set_experiment(WGAN_GP.identifier)
with mlflow.start_run(experiment_id=WGAN_GP.experiment.experiment_id) as mlflow_run:
    torch.cuda.empty_cache()
    WGAN_GP.train(
        dataloader,
        nepochs=nepochs,
        ncritics=5,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=2e-04,
        betas=(0.5, 0.999),
        lambda_gp=10,
    )

loading saved states
failed to load saved states
[Epoch 0/2] [Batch 0/938] [D loss: 8.338694] [G loss: -0.034286]
saving states
[Epoch 0/2] [Batch 5/938] [D loss: 4.228947] [G loss: -0.042754]
[Epoch 0/2] [Batch 10/938] [D loss: -5.872320] [G loss: -0.080434]
[Epoch 0/2] [Batch 15/938] [D loss: -20.855030] [G loss: -0.205500]
[Epoch 0/2] [Batch 20/938] [D loss: -33.967487] [G loss: -0.467806]
[Epoch 0/2] [Batch 25/938] [D loss: -40.171349] [G loss: -0.804005]
[Epoch 0/2] [Batch 30/938] [D loss: -40.587364] [G loss: -1.150115]
[Epoch 0/2] [Batch 35/938] [D loss: -40.527969] [G loss: -1.461487]
[Epoch 0/2] [Batch 40/938] [D loss: -40.401794] [G loss: -1.855785]
[Epoch 0/2] [Batch 45/938] [D loss: -40.378300] [G loss: -2.148134]
[Epoch 0/2] [Batch 50/938] [D loss: -40.284805] [G loss: -2.520683]
[Epoch 0/2] [Batch 55/938] [D loss: -40.028637] [G loss: -3.047812]
[Epoch 0/2] [Batch 60/938] [D loss: -38.714119] [G loss: -3.503300]
[Epoch 0/2] [Batch 65/938] [D loss: -37.910328] [G loss: -4.

[Epoch 0/2] [Batch 595/938] [D loss: -2.071856] [G loss: -8.411453]
[Epoch 0/2] [Batch 600/938] [D loss: -2.789547] [G loss: -8.496975]
[Epoch 0/2] [Batch 605/938] [D loss: -2.851047] [G loss: -7.051051]
[Epoch 0/2] [Batch 610/938] [D loss: -3.361418] [G loss: -5.348888]
[Epoch 0/2] [Batch 615/938] [D loss: -3.393899] [G loss: -4.635007]
[Epoch 0/2] [Batch 620/938] [D loss: -4.211488] [G loss: -3.774219]
[Epoch 0/2] [Batch 625/938] [D loss: -3.995700] [G loss: -3.650753]
[Epoch 0/2] [Batch 630/938] [D loss: -4.427218] [G loss: -3.384097]
[Epoch 0/2] [Batch 635/938] [D loss: -4.509789] [G loss: -3.663091]
[Epoch 0/2] [Batch 640/938] [D loss: -4.757484] [G loss: -4.708852]
[Epoch 0/2] [Batch 645/938] [D loss: -4.678022] [G loss: -3.275721]
[Epoch 0/2] [Batch 650/938] [D loss: -5.438917] [G loss: -1.587640]
[Epoch 0/2] [Batch 655/938] [D loss: -5.786448] [G loss: -0.101482]
[Epoch 0/2] [Batch 660/938] [D loss: -5.072834] [G loss: -1.108147]
[Epoch 0/2] [Batch 665/938] [D loss: -6.825358] 

[Epoch 1/2] [Batch 260/938] [D loss: -7.769081] [G loss: 2.431551]
[Epoch 1/2] [Batch 265/938] [D loss: -7.737829] [G loss: 1.941904]
[Epoch 1/2] [Batch 270/938] [D loss: -8.286416] [G loss: 0.872480]
[Epoch 1/2] [Batch 275/938] [D loss: -8.277871] [G loss: 2.604989]
[Epoch 1/2] [Batch 280/938] [D loss: -7.735622] [G loss: 1.397768]
[Epoch 1/2] [Batch 285/938] [D loss: -7.826842] [G loss: 2.137001]
[Epoch 1/2] [Batch 290/938] [D loss: -8.062252] [G loss: 2.631724]
[Epoch 1/2] [Batch 295/938] [D loss: -8.121384] [G loss: 3.448383]
[Epoch 1/2] [Batch 300/938] [D loss: -7.908410] [G loss: 2.361703]
[Epoch 1/2] [Batch 305/938] [D loss: -7.580414] [G loss: 2.177536]
[Epoch 1/2] [Batch 310/938] [D loss: -7.574720] [G loss: 2.219851]
[Epoch 1/2] [Batch 315/938] [D loss: -8.639702] [G loss: 2.847414]
[Epoch 1/2] [Batch 320/938] [D loss: -7.442295] [G loss: 1.425089]
[Epoch 1/2] [Batch 325/938] [D loss: -7.513917] [G loss: 2.230259]
[Epoch 1/2] [Batch 330/938] [D loss: -7.097714] [G loss: 3.058

[Epoch 1/2] [Batch 875/938] [D loss: -7.223490] [G loss: 1.878277]
[Epoch 1/2] [Batch 880/938] [D loss: -6.724697] [G loss: 0.206166]
[Epoch 1/2] [Batch 885/938] [D loss: -7.005715] [G loss: -0.040606]
[Epoch 1/2] [Batch 890/938] [D loss: -6.943243] [G loss: 0.243583]
[Epoch 1/2] [Batch 895/938] [D loss: -6.715596] [G loss: 0.986187]
[Epoch 1/2] [Batch 900/938] [D loss: -6.676334] [G loss: 0.705724]
[Epoch 1/2] [Batch 905/938] [D loss: -6.932789] [G loss: -0.405977]
[Epoch 1/2] [Batch 910/938] [D loss: -7.614096] [G loss: 0.143837]
[Epoch 1/2] [Batch 915/938] [D loss: -7.184932] [G loss: -0.453341]
[Epoch 1/2] [Batch 920/938] [D loss: -6.908195] [G loss: -0.900278]
[Epoch 1/2] [Batch 925/938] [D loss: -7.018968] [G loss: -0.819533]
[Epoch 1/2] [Batch 930/938] [D loss: -6.879057] [G loss: -1.081206]
[Epoch 1/2] [Batch 935/938] [D loss: -6.603237] [G loss: -0.498918]
saving states


In [26]:

WGAN = gan.WGAN("mnist_wgan", shape, latent_dim, cuda=cuda, ngpu=4)
mlflow.set_experiment(WGAN.identifier)
with mlflow.start_run(experiment_id=WGAN.experiment.experiment_id) as mlflow_run:
    torch.cuda.empty_cache()
    WGAN.train(
        dataloader,
        nepochs=nepochs,
        ncritics=5,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=2e-04,
        clip_tresh=0.01,
    )

loading saved states
failed to load saved states
torch.Size([64, 64])
torch.Size([64, 64])
[Epoch 0/2] [Batch 0/938] [D loss: -0.153472] [G loss: 0.008971]
saving states
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
[Epoch 0/2] [Batch 5/938] [D loss: -3.337174] [G loss: -0.109996]
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
[Epoch 0/2] [Batch 10/938] [D loss: -11.381778] [G loss: -1.197862]
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
[Epoch 0/2] [Batch 15/938] [D loss: -17.959133] [G loss: -3.988242]
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64, 64])
[Epoch 0/2] [Batch 20/938] [D loss: -20.941551] [G loss: -9.015030]
torch.Size([64, 64])
torch.Size([64, 64])
torch.Size([64

KeyboardInterrupt: 

In [61]:
shape = (1, 32, 32)
DCGAN = gan.DCGAN("mnist_dcgan", shape, latent_dim, cuda=cuda, ngpu=4, nconv_layer_gen=4, nconv_layer_disc=4, nconv_fcgen=64,
                  nconv_fcdis=64, kernal_size=5, stride=2, padding=2, output_padding=1)
torchsummary.summary(DCGAN.generator, (1,1,latent_dim,))
torchsummary.summary(DCGAN.discriminator, shape)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1           [-1, 1, 1, 2048]         133,120
            Linear-2           [-1, 1, 1, 2048]         133,120
           Reshape-3            [-1, 512, 2, 2]               0
           Reshape-4            [-1, 512, 2, 2]               0
       BatchNorm2d-5            [-1, 512, 2, 2]           1,024
       BatchNorm2d-6            [-1, 512, 2, 2]           1,024
         LeakyReLU-7            [-1, 512, 2, 2]               0
         LeakyReLU-8            [-1, 512, 2, 2]               0
   ConvTranspose2d-9            [-1, 256, 4, 4]       3,277,056
  ConvTranspose2d-10            [-1, 256, 4, 4]       3,277,056
      BatchNorm2d-11            [-1, 256, 4, 4]             512
      BatchNorm2d-12            [-1, 256, 4, 4]             512
        LeakyReLU-13            [-1, 256, 4, 4]               0
        LeakyReLU-14            [-1, 25