In [22]:
from cosmikyu import visualization as covis
from cosmikyu import gan, config
import numpy as np
import os
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import torch
import mlflow
import torchsummary

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
data_dir = config.default_data_dir
mnist_dir = os.path.join(data_dir, 'mnist')
cuda = True
shape = (1,32,32)
latent_dim = 100
sample_interval = 1000
save_interval = 50000
batch_size = 64
nepochs=2

In [24]:
# Configure data loader
os.makedirs(data_dir, exist_ok=True)
os.makedirs(mnist_dir, exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_dir,
        train=True,
        download=True,
        transform=transforms.Compose([transforms.Resize(shape[-1]), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=batch_size,
    shuffle=True,
)


In [25]:
DCGAN = gan.DCGAN("mnist_dcgan_v2", shape, latent_dim, cuda=True, ngpu=1)#, nconv_layer_gen=2, nconv_layer_disc=2, nconv_fcgen=32, nconv_fcdis=32)
torchsummary.summary(DCGAN.generator, (100,))
torchsummary.summary(DCGAN.discriminator, shape)

hello
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 8192]         827,392
       BatchNorm2d-2            [-1, 128, 8, 8]             256
          Upsample-3          [-1, 128, 16, 16]               0
            Conv2d-4          [-1, 128, 16, 16]         147,584
       BatchNorm2d-5          [-1, 128, 16, 16]             256
         LeakyReLU-6          [-1, 128, 16, 16]               0
          Upsample-7          [-1, 128, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          73,792
       BatchNorm2d-9           [-1, 64, 32, 32]             128
        LeakyReLU-10           [-1, 64, 32, 32]               0
           Conv2d-11            [-1, 1, 32, 32]             577
             Tanh-12            [-1, 1, 32, 32]               0
Total params: 1,049,985
Trainable params: 1,049,985
Non-trainable params: 0
---------------------

In [5]:
DCGAN = gan.DCGAN("mnist_dcgan", shape, latent_dim, cuda=cuda, ngpu=4)
mlflow.set_experiment(DCGAN.identifier)
with mlflow.start_run(experiment_id=DCGAN.experiment.experiment_id) as mlflow_run:
    torch.cuda.empty_cache()
    DCGAN.train(
        dataloader,
        nepochs=nepochs,
        ncritics=1,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=2e-04,
        betas=(0.5, 0.999)
    )

loading saved states
failed to load saved states
[Epoch 0/2] [Batch 0/938] [D loss: 0.693299] [G loss: 0.707946]
saving states
[Epoch 0/2] [Batch 1/938] [D loss: 0.693183] [G loss: 0.707466]
[Epoch 0/2] [Batch 2/938] [D loss: 0.693235] [G loss: 0.706792]
[Epoch 0/2] [Batch 3/938] [D loss: 0.693113] [G loss: 0.706293]
[Epoch 0/2] [Batch 4/938] [D loss: 0.693210] [G loss: 0.705453]
[Epoch 0/2] [Batch 5/938] [D loss: 0.693136] [G loss: 0.704937]
[Epoch 0/2] [Batch 6/938] [D loss: 0.693059] [G loss: 0.704243]
[Epoch 0/2] [Batch 7/938] [D loss: 0.692941] [G loss: 0.703679]
[Epoch 0/2] [Batch 8/938] [D loss: 0.693007] [G loss: 0.703259]
[Epoch 0/2] [Batch 9/938] [D loss: 0.692832] [G loss: 0.702785]
[Epoch 0/2] [Batch 10/938] [D loss: 0.692876] [G loss: 0.702268]
[Epoch 0/2] [Batch 11/938] [D loss: 0.692638] [G loss: 0.701726]
[Epoch 0/2] [Batch 12/938] [D loss: 0.692531] [G loss: 0.701056]
[Epoch 0/2] [Batch 13/938] [D loss: 0.692413] [G loss: 0.700538]
[Epoch 0/2] [Batch 14/938] [D loss: 0

[Epoch 0/2] [Batch 126/938] [D loss: 0.689781] [G loss: 0.683614]
[Epoch 0/2] [Batch 127/938] [D loss: 0.689230] [G loss: 0.678326]
[Epoch 0/2] [Batch 128/938] [D loss: 0.691927] [G loss: 0.678007]
[Epoch 0/2] [Batch 129/938] [D loss: 0.694402] [G loss: 0.671082]
[Epoch 0/2] [Batch 130/938] [D loss: 0.697763] [G loss: 0.670242]
[Epoch 0/2] [Batch 131/938] [D loss: 0.696146] [G loss: 0.672262]
[Epoch 0/2] [Batch 132/938] [D loss: 0.699234] [G loss: 0.681113]
[Epoch 0/2] [Batch 133/938] [D loss: 0.695328] [G loss: 0.682067]
[Epoch 0/2] [Batch 134/938] [D loss: 0.697226] [G loss: 0.688764]
[Epoch 0/2] [Batch 135/938] [D loss: 0.696835] [G loss: 0.690955]
[Epoch 0/2] [Batch 136/938] [D loss: 0.696762] [G loss: 0.700489]
[Epoch 0/2] [Batch 137/938] [D loss: 0.693243] [G loss: 0.700975]
[Epoch 0/2] [Batch 138/938] [D loss: 0.693985] [G loss: 0.702345]
[Epoch 0/2] [Batch 139/938] [D loss: 0.692527] [G loss: 0.708584]
[Epoch 0/2] [Batch 140/938] [D loss: 0.693681] [G loss: 0.707838]
[Epoch 0/2

[Epoch 0/2] [Batch 251/938] [D loss: 0.692446] [G loss: 0.687075]
[Epoch 0/2] [Batch 252/938] [D loss: 0.693590] [G loss: 0.687934]
[Epoch 0/2] [Batch 253/938] [D loss: 0.694786] [G loss: 0.689274]
[Epoch 0/2] [Batch 254/938] [D loss: 0.692332] [G loss: 0.692683]
[Epoch 0/2] [Batch 255/938] [D loss: 0.694227] [G loss: 0.690323]
[Epoch 0/2] [Batch 256/938] [D loss: 0.694957] [G loss: 0.693381]
[Epoch 0/2] [Batch 257/938] [D loss: 0.693283] [G loss: 0.691676]
[Epoch 0/2] [Batch 258/938] [D loss: 0.693404] [G loss: 0.693742]
[Epoch 0/2] [Batch 259/938] [D loss: 0.692563] [G loss: 0.695447]
[Epoch 0/2] [Batch 260/938] [D loss: 0.693734] [G loss: 0.696627]
[Epoch 0/2] [Batch 261/938] [D loss: 0.693077] [G loss: 0.700139]
[Epoch 0/2] [Batch 262/938] [D loss: 0.693623] [G loss: 0.698536]
[Epoch 0/2] [Batch 263/938] [D loss: 0.693573] [G loss: 0.700273]
[Epoch 0/2] [Batch 264/938] [D loss: 0.692131] [G loss: 0.696553]
[Epoch 0/2] [Batch 265/938] [D loss: 0.691593] [G loss: 0.697125]
[Epoch 0/2

[Epoch 0/2] [Batch 376/938] [D loss: 0.690653] [G loss: 0.695544]
[Epoch 0/2] [Batch 377/938] [D loss: 0.688184] [G loss: 0.691220]
[Epoch 0/2] [Batch 378/938] [D loss: 0.689566] [G loss: 0.690136]
[Epoch 0/2] [Batch 379/938] [D loss: 0.691901] [G loss: 0.691527]
[Epoch 0/2] [Batch 380/938] [D loss: 0.690875] [G loss: 0.686844]
[Epoch 0/2] [Batch 381/938] [D loss: 0.689702] [G loss: 0.687048]
[Epoch 0/2] [Batch 382/938] [D loss: 0.693104] [G loss: 0.682397]
[Epoch 0/2] [Batch 383/938] [D loss: 0.698146] [G loss: 0.680188]
[Epoch 0/2] [Batch 384/938] [D loss: 0.693049] [G loss: 0.683059]
[Epoch 0/2] [Batch 385/938] [D loss: 0.694093] [G loss: 0.685088]
[Epoch 0/2] [Batch 386/938] [D loss: 0.692533] [G loss: 0.685312]
[Epoch 0/2] [Batch 387/938] [D loss: 0.693547] [G loss: 0.684974]
[Epoch 0/2] [Batch 388/938] [D loss: 0.697138] [G loss: 0.684343]
[Epoch 0/2] [Batch 389/938] [D loss: 0.697032] [G loss: 0.686678]
[Epoch 0/2] [Batch 390/938] [D loss: 0.694481] [G loss: 0.693864]
[Epoch 0/2

[Epoch 0/2] [Batch 502/938] [D loss: 0.698007] [G loss: 0.685271]
[Epoch 0/2] [Batch 503/938] [D loss: 0.699083] [G loss: 0.689567]
[Epoch 0/2] [Batch 504/938] [D loss: 0.693978] [G loss: 0.694699]
[Epoch 0/2] [Batch 505/938] [D loss: 0.694890] [G loss: 0.698440]
[Epoch 0/2] [Batch 506/938] [D loss: 0.694817] [G loss: 0.699158]
[Epoch 0/2] [Batch 507/938] [D loss: 0.692992] [G loss: 0.695589]
[Epoch 0/2] [Batch 508/938] [D loss: 0.694380] [G loss: 0.701679]
[Epoch 0/2] [Batch 509/938] [D loss: 0.692162] [G loss: 0.702514]
[Epoch 0/2] [Batch 510/938] [D loss: 0.692036] [G loss: 0.698655]
[Epoch 0/2] [Batch 511/938] [D loss: 0.695146] [G loss: 0.699887]
[Epoch 0/2] [Batch 512/938] [D loss: 0.693674] [G loss: 0.699516]
[Epoch 0/2] [Batch 513/938] [D loss: 0.693215] [G loss: 0.702630]
[Epoch 0/2] [Batch 514/938] [D loss: 0.695057] [G loss: 0.698485]
[Epoch 0/2] [Batch 515/938] [D loss: 0.693624] [G loss: 0.703174]
[Epoch 0/2] [Batch 516/938] [D loss: 0.692204] [G loss: 0.698709]
[Epoch 0/2

[Epoch 0/2] [Batch 628/938] [D loss: 0.693233] [G loss: 0.719420]
[Epoch 0/2] [Batch 629/938] [D loss: 0.690741] [G loss: 0.712023]
[Epoch 0/2] [Batch 630/938] [D loss: 0.694547] [G loss: 0.711143]
[Epoch 0/2] [Batch 631/938] [D loss: 0.693631] [G loss: 0.710328]
[Epoch 0/2] [Batch 632/938] [D loss: 0.691746] [G loss: 0.708108]
[Epoch 0/2] [Batch 633/938] [D loss: 0.692590] [G loss: 0.700595]
[Epoch 0/2] [Batch 634/938] [D loss: 0.693246] [G loss: 0.701110]
[Epoch 0/2] [Batch 635/938] [D loss: 0.689983] [G loss: 0.695859]
[Epoch 0/2] [Batch 636/938] [D loss: 0.691275] [G loss: 0.695521]
[Epoch 0/2] [Batch 637/938] [D loss: 0.689748] [G loss: 0.692162]
[Epoch 0/2] [Batch 638/938] [D loss: 0.692174] [G loss: 0.691856]
[Epoch 0/2] [Batch 639/938] [D loss: 0.686653] [G loss: 0.685993]
[Epoch 0/2] [Batch 640/938] [D loss: 0.689393] [G loss: 0.679632]
[Epoch 0/2] [Batch 641/938] [D loss: 0.691758] [G loss: 0.675328]
[Epoch 0/2] [Batch 642/938] [D loss: 0.694442] [G loss: 0.674480]
[Epoch 0/2

[Epoch 0/2] [Batch 754/938] [D loss: 0.692811] [G loss: 0.705539]
[Epoch 0/2] [Batch 755/938] [D loss: 0.691009] [G loss: 0.707563]
[Epoch 0/2] [Batch 756/938] [D loss: 0.687585] [G loss: 0.711134]
[Epoch 0/2] [Batch 757/938] [D loss: 0.691977] [G loss: 0.711064]
[Epoch 0/2] [Batch 758/938] [D loss: 0.691649] [G loss: 0.712881]
[Epoch 0/2] [Batch 759/938] [D loss: 0.691981] [G loss: 0.707366]
[Epoch 0/2] [Batch 760/938] [D loss: 0.695382] [G loss: 0.711961]
[Epoch 0/2] [Batch 761/938] [D loss: 0.690951] [G loss: 0.707172]
[Epoch 0/2] [Batch 762/938] [D loss: 0.688673] [G loss: 0.706280]
[Epoch 0/2] [Batch 763/938] [D loss: 0.694103] [G loss: 0.703980]
[Epoch 0/2] [Batch 764/938] [D loss: 0.686460] [G loss: 0.694805]
[Epoch 0/2] [Batch 765/938] [D loss: 0.684785] [G loss: 0.696292]
[Epoch 0/2] [Batch 766/938] [D loss: 0.692623] [G loss: 0.688782]
[Epoch 0/2] [Batch 767/938] [D loss: 0.684621] [G loss: 0.684708]
[Epoch 0/2] [Batch 768/938] [D loss: 0.695506] [G loss: 0.680030]
[Epoch 0/2

[Epoch 0/2] [Batch 880/938] [D loss: 0.689834] [G loss: 0.690530]
[Epoch 0/2] [Batch 881/938] [D loss: 0.691459] [G loss: 0.688550]
[Epoch 0/2] [Batch 882/938] [D loss: 0.684491] [G loss: 0.705864]
[Epoch 0/2] [Batch 883/938] [D loss: 0.689260] [G loss: 0.682401]
[Epoch 0/2] [Batch 884/938] [D loss: 0.695075] [G loss: 0.691271]
[Epoch 0/2] [Batch 885/938] [D loss: 0.697188] [G loss: 0.700683]
[Epoch 0/2] [Batch 886/938] [D loss: 0.693294] [G loss: 0.722101]
[Epoch 0/2] [Batch 887/938] [D loss: 0.688094] [G loss: 0.720080]
[Epoch 0/2] [Batch 888/938] [D loss: 0.693996] [G loss: 0.725306]
[Epoch 0/2] [Batch 889/938] [D loss: 0.696645] [G loss: 0.714883]
[Epoch 0/2] [Batch 890/938] [D loss: 0.687518] [G loss: 0.714030]
[Epoch 0/2] [Batch 891/938] [D loss: 0.695568] [G loss: 0.711142]
[Epoch 0/2] [Batch 892/938] [D loss: 0.685470] [G loss: 0.706918]
[Epoch 0/2] [Batch 893/938] [D loss: 0.687463] [G loss: 0.707679]
[Epoch 0/2] [Batch 894/938] [D loss: 0.686889] [G loss: 0.704934]
[Epoch 0/2

[Epoch 1/2] [Batch 69/938] [D loss: 0.688631] [G loss: 0.766298]
[Epoch 1/2] [Batch 70/938] [D loss: 0.685882] [G loss: 0.760917]
[Epoch 1/2] [Batch 71/938] [D loss: 0.683797] [G loss: 0.733277]
[Epoch 1/2] [Batch 72/938] [D loss: 0.683167] [G loss: 0.690044]
[Epoch 1/2] [Batch 73/938] [D loss: 0.684157] [G loss: 0.699670]
[Epoch 1/2] [Batch 74/938] [D loss: 0.689336] [G loss: 0.692005]
[Epoch 1/2] [Batch 75/938] [D loss: 0.679542] [G loss: 0.670355]
[Epoch 1/2] [Batch 76/938] [D loss: 0.679186] [G loss: 0.670285]
[Epoch 1/2] [Batch 77/938] [D loss: 0.684321] [G loss: 0.674666]
[Epoch 1/2] [Batch 78/938] [D loss: 0.680101] [G loss: 0.700547]
[Epoch 1/2] [Batch 79/938] [D loss: 0.674767] [G loss: 0.682305]
[Epoch 1/2] [Batch 80/938] [D loss: 0.691028] [G loss: 0.709214]
[Epoch 1/2] [Batch 81/938] [D loss: 0.680607] [G loss: 0.716516]
[Epoch 1/2] [Batch 82/938] [D loss: 0.669906] [G loss: 0.702670]
[Epoch 1/2] [Batch 83/938] [D loss: 0.680099] [G loss: 0.702405]
[Epoch 1/2] [Batch 84/938

[Epoch 1/2] [Batch 195/938] [D loss: 0.678310] [G loss: 0.762215]
[Epoch 1/2] [Batch 196/938] [D loss: 0.693340] [G loss: 0.779644]
[Epoch 1/2] [Batch 197/938] [D loss: 0.677229] [G loss: 0.734683]
[Epoch 1/2] [Batch 198/938] [D loss: 0.694666] [G loss: 0.699364]
[Epoch 1/2] [Batch 199/938] [D loss: 0.682164] [G loss: 0.690937]
[Epoch 1/2] [Batch 200/938] [D loss: 0.688665] [G loss: 0.679728]
[Epoch 1/2] [Batch 201/938] [D loss: 0.687494] [G loss: 0.670419]
[Epoch 1/2] [Batch 202/938] [D loss: 0.680021] [G loss: 0.703126]
[Epoch 1/2] [Batch 203/938] [D loss: 0.699775] [G loss: 0.715038]
[Epoch 1/2] [Batch 204/938] [D loss: 0.690823] [G loss: 0.753320]
[Epoch 1/2] [Batch 205/938] [D loss: 0.683634] [G loss: 0.782744]
[Epoch 1/2] [Batch 206/938] [D loss: 0.671962] [G loss: 0.753371]
[Epoch 1/2] [Batch 207/938] [D loss: 0.692705] [G loss: 0.743770]
[Epoch 1/2] [Batch 208/938] [D loss: 0.690801] [G loss: 0.710403]
[Epoch 1/2] [Batch 209/938] [D loss: 0.677230] [G loss: 0.676007]
[Epoch 1/2

[Epoch 1/2] [Batch 320/938] [D loss: 0.662433] [G loss: 0.694481]
[Epoch 1/2] [Batch 321/938] [D loss: 0.695829] [G loss: 0.732513]
[Epoch 1/2] [Batch 322/938] [D loss: 0.675635] [G loss: 0.804897]
[Epoch 1/2] [Batch 323/938] [D loss: 0.697875] [G loss: 0.815303]
[Epoch 1/2] [Batch 324/938] [D loss: 0.686176] [G loss: 0.788557]
[Epoch 1/2] [Batch 325/938] [D loss: 0.699160] [G loss: 0.766095]
[Epoch 1/2] [Batch 326/938] [D loss: 0.686403] [G loss: 0.741139]
[Epoch 1/2] [Batch 327/938] [D loss: 0.667078] [G loss: 0.735924]
[Epoch 1/2] [Batch 328/938] [D loss: 0.672711] [G loss: 0.714969]
[Epoch 1/2] [Batch 329/938] [D loss: 0.673281] [G loss: 0.694528]
[Epoch 1/2] [Batch 330/938] [D loss: 0.691254] [G loss: 0.733865]
[Epoch 1/2] [Batch 331/938] [D loss: 0.681039] [G loss: 0.774907]
[Epoch 1/2] [Batch 332/938] [D loss: 0.657610] [G loss: 0.779995]
[Epoch 1/2] [Batch 333/938] [D loss: 0.693910] [G loss: 0.743898]
[Epoch 1/2] [Batch 334/938] [D loss: 0.669584] [G loss: 0.687661]
[Epoch 1/2

[Epoch 1/2] [Batch 446/938] [D loss: 0.685490] [G loss: 0.696137]
[Epoch 1/2] [Batch 447/938] [D loss: 0.685740] [G loss: 0.738987]
[Epoch 1/2] [Batch 448/938] [D loss: 0.664044] [G loss: 0.729250]
[Epoch 1/2] [Batch 449/938] [D loss: 0.650421] [G loss: 0.688653]
[Epoch 1/2] [Batch 450/938] [D loss: 0.686913] [G loss: 0.705622]
[Epoch 1/2] [Batch 451/938] [D loss: 0.670815] [G loss: 0.722786]
[Epoch 1/2] [Batch 452/938] [D loss: 0.662627] [G loss: 0.687852]
[Epoch 1/2] [Batch 453/938] [D loss: 0.664294] [G loss: 0.735281]
[Epoch 1/2] [Batch 454/938] [D loss: 0.652617] [G loss: 0.700377]
[Epoch 1/2] [Batch 455/938] [D loss: 0.693633] [G loss: 0.699205]
[Epoch 1/2] [Batch 456/938] [D loss: 0.701081] [G loss: 0.775598]
[Epoch 1/2] [Batch 457/938] [D loss: 0.674072] [G loss: 0.732875]
[Epoch 1/2] [Batch 458/938] [D loss: 0.686486] [G loss: 0.744827]
[Epoch 1/2] [Batch 459/938] [D loss: 0.694326] [G loss: 0.753452]
[Epoch 1/2] [Batch 460/938] [D loss: 0.684064] [G loss: 0.715555]
[Epoch 1/2

[Epoch 1/2] [Batch 571/938] [D loss: 0.683706] [G loss: 0.734731]
[Epoch 1/2] [Batch 572/938] [D loss: 0.678017] [G loss: 0.719574]
[Epoch 1/2] [Batch 573/938] [D loss: 0.650988] [G loss: 0.752552]
[Epoch 1/2] [Batch 574/938] [D loss: 0.656187] [G loss: 0.693042]
[Epoch 1/2] [Batch 575/938] [D loss: 0.685629] [G loss: 0.717495]
[Epoch 1/2] [Batch 576/938] [D loss: 0.666466] [G loss: 0.844780]
[Epoch 1/2] [Batch 577/938] [D loss: 0.696071] [G loss: 0.889040]
[Epoch 1/2] [Batch 578/938] [D loss: 0.645061] [G loss: 0.747153]
[Epoch 1/2] [Batch 579/938] [D loss: 0.632939] [G loss: 0.685845]
[Epoch 1/2] [Batch 580/938] [D loss: 0.643815] [G loss: 0.724227]
[Epoch 1/2] [Batch 581/938] [D loss: 0.709464] [G loss: 0.763284]
[Epoch 1/2] [Batch 582/938] [D loss: 0.669768] [G loss: 0.814489]
[Epoch 1/2] [Batch 583/938] [D loss: 0.656184] [G loss: 0.794233]
[Epoch 1/2] [Batch 584/938] [D loss: 0.664672] [G loss: 0.742741]
[Epoch 1/2] [Batch 585/938] [D loss: 0.654225] [G loss: 0.702714]
[Epoch 1/2

[Epoch 1/2] [Batch 697/938] [D loss: 0.680287] [G loss: 0.749564]
[Epoch 1/2] [Batch 698/938] [D loss: 0.655953] [G loss: 0.740725]
[Epoch 1/2] [Batch 699/938] [D loss: 0.691797] [G loss: 0.823180]
[Epoch 1/2] [Batch 700/938] [D loss: 0.632593] [G loss: 0.821316]
[Epoch 1/2] [Batch 701/938] [D loss: 0.637007] [G loss: 0.807748]
[Epoch 1/2] [Batch 702/938] [D loss: 0.676359] [G loss: 0.852268]
[Epoch 1/2] [Batch 703/938] [D loss: 0.689167] [G loss: 0.778329]
[Epoch 1/2] [Batch 704/938] [D loss: 0.660801] [G loss: 0.667401]
[Epoch 1/2] [Batch 705/938] [D loss: 0.675981] [G loss: 0.725937]
[Epoch 1/2] [Batch 706/938] [D loss: 0.703388] [G loss: 0.765276]
[Epoch 1/2] [Batch 707/938] [D loss: 0.677912] [G loss: 0.830258]
[Epoch 1/2] [Batch 708/938] [D loss: 0.640332] [G loss: 0.791510]
[Epoch 1/2] [Batch 709/938] [D loss: 0.634379] [G loss: 0.679893]
[Epoch 1/2] [Batch 710/938] [D loss: 0.684543] [G loss: 0.703891]
[Epoch 1/2] [Batch 711/938] [D loss: 0.656535] [G loss: 0.734000]
[Epoch 1/2

[Epoch 1/2] [Batch 822/938] [D loss: 0.673556] [G loss: 0.799171]
[Epoch 1/2] [Batch 823/938] [D loss: 0.650048] [G loss: 0.817710]
[Epoch 1/2] [Batch 824/938] [D loss: 0.686721] [G loss: 0.848450]
[Epoch 1/2] [Batch 825/938] [D loss: 0.713508] [G loss: 0.704641]
[Epoch 1/2] [Batch 826/938] [D loss: 0.667968] [G loss: 0.696329]
[Epoch 1/2] [Batch 827/938] [D loss: 0.661082] [G loss: 0.727381]
[Epoch 1/2] [Batch 828/938] [D loss: 0.655563] [G loss: 0.779548]
[Epoch 1/2] [Batch 829/938] [D loss: 0.655811] [G loss: 0.808161]
[Epoch 1/2] [Batch 830/938] [D loss: 0.652994] [G loss: 0.762159]
[Epoch 1/2] [Batch 831/938] [D loss: 0.671419] [G loss: 0.805354]
[Epoch 1/2] [Batch 832/938] [D loss: 0.677128] [G loss: 0.777444]
[Epoch 1/2] [Batch 833/938] [D loss: 0.661790] [G loss: 0.728080]
[Epoch 1/2] [Batch 834/938] [D loss: 0.675518] [G loss: 0.714644]
[Epoch 1/2] [Batch 835/938] [D loss: 0.664765] [G loss: 0.801351]
[Epoch 1/2] [Batch 836/938] [D loss: 0.673862] [G loss: 0.754109]
[Epoch 1/2

In [6]:
WGAN_GP = gan.WGAN_GP("mnist_wgan_gp", shape, latent_dim, cuda=cuda, ngpu=4)
mlflow.set_experiment(WGAN_GP.identifier)
with mlflow.start_run(experiment_id=WGAN_GP.experiment.experiment_id) as mlflow_run:
    torch.cuda.empty_cache()
    WGAN_GP.train(
        dataloader,
        nepochs=nepochs,
        ncritics=5,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=2e-04,
        betas=(0.5, 0.999),
        lambda_gp=10,
    )

loading saved states
failed to load saved states
[Epoch 0/2] [Batch 0/938] [D loss: 8.338694] [G loss: -0.034286]
saving states
[Epoch 0/2] [Batch 5/938] [D loss: 4.228947] [G loss: -0.042754]
[Epoch 0/2] [Batch 10/938] [D loss: -5.872320] [G loss: -0.080434]
[Epoch 0/2] [Batch 15/938] [D loss: -20.855030] [G loss: -0.205500]
[Epoch 0/2] [Batch 20/938] [D loss: -33.967487] [G loss: -0.467806]
[Epoch 0/2] [Batch 25/938] [D loss: -40.171349] [G loss: -0.804005]
[Epoch 0/2] [Batch 30/938] [D loss: -40.587364] [G loss: -1.150115]
[Epoch 0/2] [Batch 35/938] [D loss: -40.527969] [G loss: -1.461487]
[Epoch 0/2] [Batch 40/938] [D loss: -40.401794] [G loss: -1.855785]
[Epoch 0/2] [Batch 45/938] [D loss: -40.378300] [G loss: -2.148134]
[Epoch 0/2] [Batch 50/938] [D loss: -40.284805] [G loss: -2.520683]
[Epoch 0/2] [Batch 55/938] [D loss: -40.028637] [G loss: -3.047812]
[Epoch 0/2] [Batch 60/938] [D loss: -38.714119] [G loss: -3.503300]
[Epoch 0/2] [Batch 65/938] [D loss: -37.910328] [G loss: -4.

[Epoch 0/2] [Batch 595/938] [D loss: -2.071856] [G loss: -8.411453]
[Epoch 0/2] [Batch 600/938] [D loss: -2.789547] [G loss: -8.496975]
[Epoch 0/2] [Batch 605/938] [D loss: -2.851047] [G loss: -7.051051]
[Epoch 0/2] [Batch 610/938] [D loss: -3.361418] [G loss: -5.348888]
[Epoch 0/2] [Batch 615/938] [D loss: -3.393899] [G loss: -4.635007]
[Epoch 0/2] [Batch 620/938] [D loss: -4.211488] [G loss: -3.774219]
[Epoch 0/2] [Batch 625/938] [D loss: -3.995700] [G loss: -3.650753]
[Epoch 0/2] [Batch 630/938] [D loss: -4.427218] [G loss: -3.384097]
[Epoch 0/2] [Batch 635/938] [D loss: -4.509789] [G loss: -3.663091]
[Epoch 0/2] [Batch 640/938] [D loss: -4.757484] [G loss: -4.708852]
[Epoch 0/2] [Batch 645/938] [D loss: -4.678022] [G loss: -3.275721]
[Epoch 0/2] [Batch 650/938] [D loss: -5.438917] [G loss: -1.587640]
[Epoch 0/2] [Batch 655/938] [D loss: -5.786448] [G loss: -0.101482]
[Epoch 0/2] [Batch 660/938] [D loss: -5.072834] [G loss: -1.108147]
[Epoch 0/2] [Batch 665/938] [D loss: -6.825358] 

[Epoch 1/2] [Batch 260/938] [D loss: -7.769081] [G loss: 2.431551]
[Epoch 1/2] [Batch 265/938] [D loss: -7.737829] [G loss: 1.941904]
[Epoch 1/2] [Batch 270/938] [D loss: -8.286416] [G loss: 0.872480]
[Epoch 1/2] [Batch 275/938] [D loss: -8.277871] [G loss: 2.604989]
[Epoch 1/2] [Batch 280/938] [D loss: -7.735622] [G loss: 1.397768]
[Epoch 1/2] [Batch 285/938] [D loss: -7.826842] [G loss: 2.137001]
[Epoch 1/2] [Batch 290/938] [D loss: -8.062252] [G loss: 2.631724]
[Epoch 1/2] [Batch 295/938] [D loss: -8.121384] [G loss: 3.448383]
[Epoch 1/2] [Batch 300/938] [D loss: -7.908410] [G loss: 2.361703]
[Epoch 1/2] [Batch 305/938] [D loss: -7.580414] [G loss: 2.177536]
[Epoch 1/2] [Batch 310/938] [D loss: -7.574720] [G loss: 2.219851]
[Epoch 1/2] [Batch 315/938] [D loss: -8.639702] [G loss: 2.847414]
[Epoch 1/2] [Batch 320/938] [D loss: -7.442295] [G loss: 1.425089]
[Epoch 1/2] [Batch 325/938] [D loss: -7.513917] [G loss: 2.230259]
[Epoch 1/2] [Batch 330/938] [D loss: -7.097714] [G loss: 3.058

[Epoch 1/2] [Batch 875/938] [D loss: -7.223490] [G loss: 1.878277]
[Epoch 1/2] [Batch 880/938] [D loss: -6.724697] [G loss: 0.206166]
[Epoch 1/2] [Batch 885/938] [D loss: -7.005715] [G loss: -0.040606]
[Epoch 1/2] [Batch 890/938] [D loss: -6.943243] [G loss: 0.243583]
[Epoch 1/2] [Batch 895/938] [D loss: -6.715596] [G loss: 0.986187]
[Epoch 1/2] [Batch 900/938] [D loss: -6.676334] [G loss: 0.705724]
[Epoch 1/2] [Batch 905/938] [D loss: -6.932789] [G loss: -0.405977]
[Epoch 1/2] [Batch 910/938] [D loss: -7.614096] [G loss: 0.143837]
[Epoch 1/2] [Batch 915/938] [D loss: -7.184932] [G loss: -0.453341]
[Epoch 1/2] [Batch 920/938] [D loss: -6.908195] [G loss: -0.900278]
[Epoch 1/2] [Batch 925/938] [D loss: -7.018968] [G loss: -0.819533]
[Epoch 1/2] [Batch 930/938] [D loss: -6.879057] [G loss: -1.081206]
[Epoch 1/2] [Batch 935/938] [D loss: -6.603237] [G loss: -0.498918]
saving states


In [7]:

WGAN = gan.WGAN("mnist_wgan", shape, latent_dim, cuda=cuda, ngpu=4)
mlflow.set_experiment(WGAN.identifier)
with mlflow.start_run(experiment_id=WGAN.experiment.experiment_id) as mlflow_run:
    torch.cuda.empty_cache()
    WGAN.train(
        dataloader,
        nepochs=nepochs,
        ncritics=5,
        sample_interval=1000,
        save_interval=10000,
        load_states=True,
        save_states=True,
        verbose=True,
        mlflow_run=mlflow_run,
        lr=2e-04,
        clip_tresh=0.01,
    )

loading saved states
failed to load saved states
[Epoch 0/2] [Batch 0/938] [D loss: -0.170010] [G loss: 0.005773]
saving states
[Epoch 0/2] [Batch 5/938] [D loss: -3.180050] [G loss: -0.102186]
[Epoch 0/2] [Batch 10/938] [D loss: -10.901387] [G loss: -1.057055]
[Epoch 0/2] [Batch 15/938] [D loss: -16.889118] [G loss: -3.784881]
[Epoch 0/2] [Batch 20/938] [D loss: -19.562603] [G loss: -8.415730]
[Epoch 0/2] [Batch 25/938] [D loss: -16.856379] [G loss: -14.702381]
[Epoch 0/2] [Batch 30/938] [D loss: -11.755901] [G loss: -20.045073]
[Epoch 0/2] [Batch 35/938] [D loss: -7.847595] [G loss: -23.969528]
[Epoch 0/2] [Batch 40/938] [D loss: -5.508705] [G loss: -25.085793]
[Epoch 0/2] [Batch 45/938] [D loss: -3.750599] [G loss: -25.782438]
[Epoch 0/2] [Batch 50/938] [D loss: -4.067463] [G loss: -24.500439]
[Epoch 0/2] [Batch 55/938] [D loss: -3.251064] [G loss: -24.668972]
[Epoch 0/2] [Batch 60/938] [D loss: -2.647486] [G loss: -24.228962]
[Epoch 0/2] [Batch 65/938] [D loss: -2.913563] [G loss: 

[Epoch 0/2] [Batch 600/938] [D loss: -0.139653] [G loss: -14.920118]
[Epoch 0/2] [Batch 605/938] [D loss: -0.030760] [G loss: -14.545588]
[Epoch 0/2] [Batch 610/938] [D loss: -0.117523] [G loss: -13.779972]
[Epoch 0/2] [Batch 615/938] [D loss: -0.063792] [G loss: -13.245800]
[Epoch 0/2] [Batch 620/938] [D loss: 0.025464] [G loss: -12.332292]
[Epoch 0/2] [Batch 625/938] [D loss: -0.025538] [G loss: -12.055208]
[Epoch 0/2] [Batch 630/938] [D loss: 0.031797] [G loss: -11.779657]
[Epoch 0/2] [Batch 635/938] [D loss: -0.077046] [G loss: -11.655844]
[Epoch 0/2] [Batch 640/938] [D loss: -0.194336] [G loss: -11.942909]
[Epoch 0/2] [Batch 645/938] [D loss: -0.145226] [G loss: -11.549927]
[Epoch 0/2] [Batch 650/938] [D loss: -0.126217] [G loss: -10.825991]
[Epoch 0/2] [Batch 655/938] [D loss: -0.135314] [G loss: -11.134477]
[Epoch 0/2] [Batch 660/938] [D loss: -0.056293] [G loss: -10.426508]
[Epoch 0/2] [Batch 665/938] [D loss: -0.144836] [G loss: -10.337225]
[Epoch 0/2] [Batch 670/938] [D loss:

[Epoch 1/2] [Batch 270/938] [D loss: -1.015103] [G loss: -1.878932]
[Epoch 1/2] [Batch 275/938] [D loss: -0.989455] [G loss: -2.327407]
[Epoch 1/2] [Batch 280/938] [D loss: -0.725729] [G loss: -2.478575]
[Epoch 1/2] [Batch 285/938] [D loss: -0.822465] [G loss: -2.667237]
[Epoch 1/2] [Batch 290/938] [D loss: -0.891923] [G loss: -2.736108]
[Epoch 1/2] [Batch 295/938] [D loss: -0.874644] [G loss: -2.233170]
[Epoch 1/2] [Batch 300/938] [D loss: -0.815446] [G loss: -2.528566]
[Epoch 1/2] [Batch 305/938] [D loss: -0.995366] [G loss: -2.252422]
[Epoch 1/2] [Batch 310/938] [D loss: -0.692020] [G loss: -2.220856]
[Epoch 1/2] [Batch 315/938] [D loss: -0.609085] [G loss: -2.614325]
[Epoch 1/2] [Batch 320/938] [D loss: -0.677980] [G loss: -2.696496]
[Epoch 1/2] [Batch 325/938] [D loss: -0.632648] [G loss: -2.793819]
[Epoch 1/2] [Batch 330/938] [D loss: -0.376219] [G loss: -3.824297]
[Epoch 1/2] [Batch 335/938] [D loss: -0.501630] [G loss: -3.943409]
[Epoch 1/2] [Batch 340/938] [D loss: -0.469038] 

[Epoch 1/2] [Batch 875/938] [D loss: -1.336512] [G loss: -0.826768]
[Epoch 1/2] [Batch 880/938] [D loss: -0.987317] [G loss: -1.030190]
[Epoch 1/2] [Batch 885/938] [D loss: -1.024814] [G loss: -1.044658]
[Epoch 1/2] [Batch 890/938] [D loss: -1.062631] [G loss: -1.164799]
[Epoch 1/2] [Batch 895/938] [D loss: -1.221211] [G loss: -1.045762]
[Epoch 1/2] [Batch 900/938] [D loss: -1.184878] [G loss: -1.081911]
[Epoch 1/2] [Batch 905/938] [D loss: -0.895218] [G loss: -1.340942]
[Epoch 1/2] [Batch 910/938] [D loss: -0.725032] [G loss: -1.759452]
[Epoch 1/2] [Batch 915/938] [D loss: -0.885188] [G loss: -1.368738]
[Epoch 1/2] [Batch 920/938] [D loss: -0.768236] [G loss: -1.297593]
[Epoch 1/2] [Batch 925/938] [D loss: -1.152323] [G loss: -1.094676]
[Epoch 1/2] [Batch 930/938] [D loss: -1.039139] [G loss: -1.573450]
[Epoch 1/2] [Batch 935/938] [D loss: -0.769973] [G loss: -1.741257]
saving states
