In [17]:
from cosmikyu import visualization as covis
from cosmikyu import gan, config
import numpy as np
import os
import torchvision.transforms as transforms
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import torch
from tensorboardX import SummaryWriter

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
data_dir = config.default_data_dir
mnist_dir = os.path.join(data_dir, 'mnist')
cuda = False
shape = (1,28,28)
latent_dim = 100
sample_interval = 1000
save_interval = 50000
batch_size = 64

writer = SummaryWriter()

In [19]:
# Configure data loader
os.makedirs(data_dir, exist_ok=True)
os.makedirs(mnist_dir, exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        mnist_dir,
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=batch_size,
    shuffle=True,
)
VP = None # covis.VisdomPlotter()

In [20]:
WGAN = gan.WGAN("mnistv2", shape, latent_dim, cuda=True, ngpu=4)

In [21]:
WGAN.train(dataloader, visdom_plotter=VP, verbose=True)

loading saved states
[Epoch 0/200] [Batch 0/938] [D loss: 0.022548] [G loss: 0.009579]
saving states
[Epoch 0/200] [Batch 5/938] [D loss: 0.529250] [G loss: 0.016102]
[Epoch 0/200] [Batch 10/938] [D loss: 1.272625] [G loss: 0.040125]
[Epoch 0/200] [Batch 15/938] [D loss: 2.039369] [G loss: 0.079226]
[Epoch 0/200] [Batch 20/938] [D loss: 2.881820] [G loss: 0.137924]
[Epoch 0/200] [Batch 25/938] [D loss: 3.649480] [G loss: 0.216242]
[Epoch 0/200] [Batch 30/938] [D loss: 4.377010] [G loss: 0.312017]
[Epoch 0/200] [Batch 35/938] [D loss: 5.139267] [G loss: 0.441116]
[Epoch 0/200] [Batch 40/938] [D loss: 5.788734] [G loss: 0.577204]
[Epoch 0/200] [Batch 45/938] [D loss: 6.637061] [G loss: 0.714371]
[Epoch 0/200] [Batch 50/938] [D loss: 7.104387] [G loss: 0.910814]
[Epoch 0/200] [Batch 55/938] [D loss: 7.791006] [G loss: 1.134170]
[Epoch 0/200] [Batch 60/938] [D loss: 8.318607] [G loss: 1.303359]
[Epoch 0/200] [Batch 65/938] [D loss: 8.741735] [G loss: 1.618415]
[Epoch 0/200] [Batch 70/938] 

[Epoch 0/200] [Batch 600/938] [D loss: 2.005661] [G loss: 16.794849]
[Epoch 0/200] [Batch 605/938] [D loss: 1.424290] [G loss: 16.813429]
[Epoch 0/200] [Batch 610/938] [D loss: 1.711176] [G loss: 16.794567]
[Epoch 0/200] [Batch 615/938] [D loss: 1.892174] [G loss: 16.482193]
[Epoch 0/200] [Batch 620/938] [D loss: 1.616594] [G loss: 16.689457]
[Epoch 0/200] [Batch 625/938] [D loss: 1.696751] [G loss: 16.614529]
[Epoch 0/200] [Batch 630/938] [D loss: 1.707897] [G loss: 16.563805]
[Epoch 0/200] [Batch 635/938] [D loss: 1.563280] [G loss: 16.447937]
[Epoch 0/200] [Batch 640/938] [D loss: 1.550278] [G loss: 16.469917]
[Epoch 0/200] [Batch 645/938] [D loss: 1.439404] [G loss: 16.583298]
[Epoch 0/200] [Batch 650/938] [D loss: 1.559813] [G loss: 16.467180]
[Epoch 0/200] [Batch 655/938] [D loss: 1.863800] [G loss: 16.289307]
[Epoch 0/200] [Batch 660/938] [D loss: 1.277176] [G loss: 16.379204]
[Epoch 0/200] [Batch 665/938] [D loss: 1.632572] [G loss: 16.097008]
[Epoch 0/200] [Batch 670/938] [D l

[Epoch 1/200] [Batch 265/938] [D loss: 0.325077] [G loss: 13.498137]
[Epoch 1/200] [Batch 270/938] [D loss: 0.168543] [G loss: 13.522017]
[Epoch 1/200] [Batch 275/938] [D loss: 0.377774] [G loss: 13.482162]
[Epoch 1/200] [Batch 280/938] [D loss: 0.283166] [G loss: 13.452322]
[Epoch 1/200] [Batch 285/938] [D loss: 0.646639] [G loss: 13.294157]
[Epoch 1/200] [Batch 290/938] [D loss: 0.439073] [G loss: 13.395706]
[Epoch 1/200] [Batch 295/938] [D loss: 0.310261] [G loss: 13.367146]
[Epoch 1/200] [Batch 300/938] [D loss: 0.474146] [G loss: 13.289360]
[Epoch 1/200] [Batch 305/938] [D loss: 0.291006] [G loss: 13.304443]
[Epoch 1/200] [Batch 310/938] [D loss: 0.367512] [G loss: 13.274711]
[Epoch 1/200] [Batch 315/938] [D loss: 0.218348] [G loss: 13.291154]
[Epoch 1/200] [Batch 320/938] [D loss: 0.368296] [G loss: 13.245365]
[Epoch 1/200] [Batch 325/938] [D loss: 0.305426] [G loss: 13.219190]
[Epoch 1/200] [Batch 330/938] [D loss: 0.420334] [G loss: 13.151430]
[Epoch 1/200] [Batch 335/938] [D l

[Epoch 1/200] [Batch 865/938] [D loss: 0.230275] [G loss: 12.540371]
[Epoch 1/200] [Batch 870/938] [D loss: 0.236594] [G loss: 12.528528]
[Epoch 1/200] [Batch 875/938] [D loss: 0.186902] [G loss: 12.472450]
[Epoch 1/200] [Batch 880/938] [D loss: 0.250137] [G loss: 12.547741]
[Epoch 1/200] [Batch 885/938] [D loss: 0.200026] [G loss: 12.562500]
[Epoch 1/200] [Batch 890/938] [D loss: 0.159545] [G loss: 12.559885]
[Epoch 1/200] [Batch 895/938] [D loss: 0.231613] [G loss: 12.627895]
[Epoch 1/200] [Batch 900/938] [D loss: 0.146267] [G loss: 12.601877]
[Epoch 1/200] [Batch 905/938] [D loss: 0.126176] [G loss: 12.658211]
[Epoch 1/200] [Batch 910/938] [D loss: 0.145391] [G loss: 12.638365]
[Epoch 1/200] [Batch 915/938] [D loss: 0.146534] [G loss: 12.671419]
[Epoch 1/200] [Batch 920/938] [D loss: 0.213681] [G loss: 12.618069]
[Epoch 1/200] [Batch 925/938] [D loss: 0.164080] [G loss: 12.741311]
[Epoch 1/200] [Batch 930/938] [D loss: 0.135478] [G loss: 12.718567]
[Epoch 1/200] [Batch 935/938] [D l

KeyboardInterrupt: 