# Wasserstein GAN

Este modelo implementa la *Generative Adversarial Network* de Arjovsky et al. (2017), una alternativa a la DC-GAN tradicional con la función de loss "Wasserstein-1".

In [1]:
import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [4]:
IMG_SIZE = 256
img_shape = (3, IMG_SIZE, IMG_SIZE)
BATCH_SIZE = 64
data_dir = "../data/PlantVillage/"

os.makedirs("images", exist_ok=True)

cuda = True if torch.cuda.is_available() else False
cuda

True

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [10]:
ROOT_I = 'Images'
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomSizedCrop(IMG_SIZE),
        transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
    'val': transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
}

image_datasets = {x: datasets.ImageFolder(data_dir+x,data_transforms[x]) for x in ['train']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],\
                                            batch_size=BATCH_SIZE,\
                                            shuffle=shuf,\
                                            num_workers=4)\
              for x,shuf in [('train', True)]}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train']}
class_names = [c for c in image_datasets['train'].classes if c != ROOT_I]

In [11]:
len(dataloaders['train'])*64

1984

In [12]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

# Optimizers
# L_R = 0.00005
L_R = 0.001
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=L_R)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=L_R)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## La gente del paper de tomatos entrena la WGAN con 200.000 épocas ;-;

In [13]:
batches_done = 0
dataloader = dataloaders['train']
EPOCHS = 2000
for epoch in range(EPOCHS):

    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = imgs.type(Tensor)

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Tensor(np.random.normal(0, 1, (imgs.shape[0], 100)))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-0.01, 0.01)

        # Train the generator every n_critic iterations
        if i % 50 == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch+1, EPOCHS, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )

        if batches_done % 400 == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

[Epoch 1/2000] [Batch 0/31] [D loss: 0.074933] [G loss: 0.248629]
[Epoch 2/2000] [Batch 0/31] [D loss: -2320.885742] [G loss: -1339.189209]
[Epoch 3/2000] [Batch 0/31] [D loss: -1727.542969] [G loss: 3699.496094]
[Epoch 4/2000] [Batch 0/31] [D loss: -1607.864990] [G loss: 4095.459961]
[Epoch 5/2000] [Batch 0/31] [D loss: -4637.826172] [G loss: 3771.036865]
[Epoch 6/2000] [Batch 0/31] [D loss: -2131.261230] [G loss: 1330.270508]
[Epoch 7/2000] [Batch 0/31] [D loss: -304.896545] [G loss: -637.565247]
[Epoch 8/2000] [Batch 0/31] [D loss: 27.346497] [G loss: -118.724716]
[Epoch 9/2000] [Batch 0/31] [D loss: 18.168625] [G loss: -67.767029]
[Epoch 10/2000] [Batch 0/31] [D loss: 3.133989] [G loss: -14.307465]
[Epoch 11/2000] [Batch 0/31] [D loss: -9.370640] [G loss: 32.715054]
[Epoch 12/2000] [Batch 0/31] [D loss: -18.962067] [G loss: 74.523026]
[Epoch 13/2000] [Batch 0/31] [D loss: -881.536377] [G loss: 4483.184570]
[Epoch 14/2000] [Batch 0/31] [D loss: -1642.019775] [G loss: 1229.350830]
[E

[Epoch 116/2000] [Batch 0/31] [D loss: -23.532959] [G loss: 693.121765]
[Epoch 117/2000] [Batch 0/31] [D loss: 14.784485] [G loss: 556.431091]
[Epoch 118/2000] [Batch 0/31] [D loss: -3.058552] [G loss: 52.394882]
[Epoch 119/2000] [Batch 0/31] [D loss: -7.302887] [G loss: 151.477112]
[Epoch 120/2000] [Batch 0/31] [D loss: 0.366376] [G loss: 8.651319]
[Epoch 121/2000] [Batch 0/31] [D loss: 6.780273] [G loss: -355.916016]
[Epoch 122/2000] [Batch 0/31] [D loss: -19.246918] [G loss: -416.783875]
[Epoch 123/2000] [Batch 0/31] [D loss: -8.465126] [G loss: 105.313187]
[Epoch 124/2000] [Batch 0/31] [D loss: -0.641434] [G loss: 119.825287]
[Epoch 125/2000] [Batch 0/31] [D loss: 6.258240] [G loss: 314.755798]
[Epoch 126/2000] [Batch 0/31] [D loss: 1.212494] [G loss: -220.129318]
[Epoch 127/2000] [Batch 0/31] [D loss: -2.488209] [G loss: 66.516685]
[Epoch 128/2000] [Batch 0/31] [D loss: -3.999115] [G loss: -172.273895]
[Epoch 129/2000] [Batch 0/31] [D loss: 0.076965] [G loss: 15.824241]
[Epoch 130

[Epoch 231/2000] [Batch 0/31] [D loss: -12.766197] [G loss: 75.780968]
[Epoch 232/2000] [Batch 0/31] [D loss: -165.648926] [G loss: 1238.026123]
[Epoch 233/2000] [Batch 0/31] [D loss: -197.257690] [G loss: 1435.858765]
[Epoch 234/2000] [Batch 0/31] [D loss: -112.060730] [G loss: 354.862793]
[Epoch 235/2000] [Batch 0/31] [D loss: -15.443779] [G loss: 113.583862]
[Epoch 236/2000] [Batch 0/31] [D loss: -14.300812] [G loss: 92.098663]
[Epoch 237/2000] [Batch 0/31] [D loss: -3.259132] [G loss: 109.751526]
[Epoch 238/2000] [Batch 0/31] [D loss: -8.336876] [G loss: 121.541687]
[Epoch 239/2000] [Batch 0/31] [D loss: 101.752319] [G loss: -699.691895]
[Epoch 240/2000] [Batch 0/31] [D loss: -67.537109] [G loss: 721.591187]
[Epoch 241/2000] [Batch 0/31] [D loss: -14.532837] [G loss: 503.736633]
[Epoch 242/2000] [Batch 0/31] [D loss: -66.751221] [G loss: -1159.906738]
[Epoch 243/2000] [Batch 0/31] [D loss: -6.295387] [G loss: 107.132553]
[Epoch 244/2000] [Batch 0/31] [D loss: -36.171692] [G loss: 5

[Epoch 345/2000] [Batch 0/31] [D loss: -7.126163] [G loss: 138.016022]
[Epoch 346/2000] [Batch 0/31] [D loss: -852.292969] [G loss: 10763.914062]
[Epoch 347/2000] [Batch 0/31] [D loss: -2.242373] [G loss: -10.082554]
[Epoch 348/2000] [Batch 0/31] [D loss: -2236.942383] [G loss: -7608.968750]
[Epoch 349/2000] [Batch 0/31] [D loss: -258.094727] [G loss: -8493.357422]
[Epoch 350/2000] [Batch 0/31] [D loss: -1.981479] [G loss: -4.810239]
[Epoch 351/2000] [Batch 0/31] [D loss: -2.479862] [G loss: -0.476772]
[Epoch 352/2000] [Batch 0/31] [D loss: -2.263179] [G loss: -10.701637]
[Epoch 353/2000] [Batch 0/31] [D loss: 7.043892] [G loss: -39.521111]
[Epoch 354/2000] [Batch 0/31] [D loss: -5.197200] [G loss: -19.923820]
[Epoch 355/2000] [Batch 0/31] [D loss: -1.703617] [G loss: 1.527264]
[Epoch 356/2000] [Batch 0/31] [D loss: -0.823080] [G loss: 1.806269]
[Epoch 357/2000] [Batch 0/31] [D loss: -2.103752] [G loss: -1.241973]
[Epoch 358/2000] [Batch 0/31] [D loss: -0.324440] [G loss: -3.824974]
[E

[Epoch 459/2000] [Batch 0/31] [D loss: -2.882935] [G loss: -79.653030]
[Epoch 460/2000] [Batch 0/31] [D loss: -3.349792] [G loss: -119.212181]
[Epoch 461/2000] [Batch 0/31] [D loss: -12.008606] [G loss: -188.893555]
[Epoch 462/2000] [Batch 0/31] [D loss: -15.261292] [G loss: -159.090164]
[Epoch 463/2000] [Batch 0/31] [D loss: -6.184998] [G loss: -148.947418]
[Epoch 464/2000] [Batch 0/31] [D loss: -7.408432] [G loss: -46.312611]
[Epoch 465/2000] [Batch 0/31] [D loss: -6.769787] [G loss: -49.952919]
[Epoch 466/2000] [Batch 0/31] [D loss: -15.105228] [G loss: -28.978580]
[Epoch 467/2000] [Batch 0/31] [D loss: 44.644470] [G loss: 372.359009]
[Epoch 468/2000] [Batch 0/31] [D loss: -330.824707] [G loss: -2413.261719]
[Epoch 469/2000] [Batch 0/31] [D loss: -904.212402] [G loss: 4455.382812]
[Epoch 470/2000] [Batch 0/31] [D loss: 58.970459] [G loss: 442.453613]
[Epoch 471/2000] [Batch 0/31] [D loss: -21.060089] [G loss: -187.124969]
[Epoch 472/2000] [Batch 0/31] [D loss: -22.804474] [G loss: -

[Epoch 575/2000] [Batch 0/31] [D loss: -7.264511] [G loss: 308.932495]
[Epoch 576/2000] [Batch 0/31] [D loss: -3.080170] [G loss: 325.857361]
[Epoch 577/2000] [Batch 0/31] [D loss: 0.027663] [G loss: 0.037240]
[Epoch 578/2000] [Batch 0/31] [D loss: 5.040062] [G loss: -61.811512]
[Epoch 579/2000] [Batch 0/31] [D loss: -158.096191] [G loss: -3411.720703]
[Epoch 580/2000] [Batch 0/31] [D loss: -207.908447] [G loss: 4112.685547]
[Epoch 581/2000] [Batch 0/31] [D loss: 6.148956] [G loss: 51.113625]
[Epoch 582/2000] [Batch 0/31] [D loss: -0.759621] [G loss: -10.533595]
[Epoch 583/2000] [Batch 0/31] [D loss: 16.744507] [G loss: -993.636475]
[Epoch 584/2000] [Batch 0/31] [D loss: 10.987885] [G loss: -254.868073]
[Epoch 585/2000] [Batch 0/31] [D loss: -27.936157] [G loss: 494.574005]
[Epoch 586/2000] [Batch 0/31] [D loss: -8.957947] [G loss: 476.714478]
[Epoch 587/2000] [Batch 0/31] [D loss: 0.084037] [G loss: -2.444636]
[Epoch 588/2000] [Batch 0/31] [D loss: 0.294414] [G loss: -6.484469]
[Epoch

[Epoch 691/2000] [Batch 0/31] [D loss: 3.117607] [G loss: 47.645767]
[Epoch 692/2000] [Batch 0/31] [D loss: -4.976588] [G loss: -26.866631]
[Epoch 693/2000] [Batch 0/31] [D loss: -112.854126] [G loss: -1310.148926]
[Epoch 694/2000] [Batch 0/31] [D loss: -14.287079] [G loss: 362.310425]
[Epoch 695/2000] [Batch 0/31] [D loss: -96.918335] [G loss: 790.160828]
[Epoch 696/2000] [Batch 0/31] [D loss: -36.164062] [G loss: 576.948853]
[Epoch 697/2000] [Batch 0/31] [D loss: -60.498474] [G loss: 563.440308]
[Epoch 698/2000] [Batch 0/31] [D loss: -0.443199] [G loss: 60.146523]
[Epoch 699/2000] [Batch 0/31] [D loss: 6.756958] [G loss: -402.769806]
[Epoch 700/2000] [Batch 0/31] [D loss: -0.623646] [G loss: -46.491005]
[Epoch 701/2000] [Batch 0/31] [D loss: -0.916034] [G loss: 12.648894]
[Epoch 702/2000] [Batch 0/31] [D loss: 1.175880] [G loss: -60.190048]
[Epoch 703/2000] [Batch 0/31] [D loss: -0.067201] [G loss: 2.672742]
[Epoch 704/2000] [Batch 0/31] [D loss: -0.269843] [G loss: 3.423733]
[Epoch 

[Epoch 805/2000] [Batch 0/31] [D loss: -0.000285] [G loss: 0.015188]
[Epoch 806/2000] [Batch 0/31] [D loss: -0.007583] [G loss: -0.104253]
[Epoch 807/2000] [Batch 0/31] [D loss: -0.017162] [G loss: -0.184792]
[Epoch 808/2000] [Batch 0/31] [D loss: -0.075432] [G loss: -2.716555]
[Epoch 809/2000] [Batch 0/31] [D loss: -0.204218] [G loss: -2.929053]
[Epoch 810/2000] [Batch 0/31] [D loss: -0.034637] [G loss: -0.660320]
[Epoch 811/2000] [Batch 0/31] [D loss: 0.008704] [G loss: -0.679437]
[Epoch 812/2000] [Batch 0/31] [D loss: 0.045034] [G loss: -1.202370]
[Epoch 813/2000] [Batch 0/31] [D loss: -0.084117] [G loss: -4.155879]
[Epoch 814/2000] [Batch 0/31] [D loss: -1.268906] [G loss: 56.833378]
[Epoch 815/2000] [Batch 0/31] [D loss: -1.789085] [G loss: 68.278831]
[Epoch 816/2000] [Batch 0/31] [D loss: 0.408385] [G loss: 1.762614]
[Epoch 817/2000] [Batch 0/31] [D loss: -0.118906] [G loss: 5.221177]
[Epoch 818/2000] [Batch 0/31] [D loss: -1.632812] [G loss: 171.467728]
[Epoch 819/2000] [Batch 0

[Epoch 923/2000] [Batch 0/31] [D loss: -0.085577] [G loss: 4.001918]
[Epoch 924/2000] [Batch 0/31] [D loss: -0.457336] [G loss: 37.652519]
[Epoch 925/2000] [Batch 0/31] [D loss: -2.401413] [G loss: -56.672234]
[Epoch 926/2000] [Batch 0/31] [D loss: -2.998978] [G loss: 141.173340]
[Epoch 927/2000] [Batch 0/31] [D loss: 1.598724] [G loss: -34.523357]
[Epoch 928/2000] [Batch 0/31] [D loss: -0.018836] [G loss: 0.814230]
[Epoch 929/2000] [Batch 0/31] [D loss: -0.042048] [G loss: 2.543615]
[Epoch 930/2000] [Batch 0/31] [D loss: 0.039669] [G loss: -4.413682]
[Epoch 931/2000] [Batch 0/31] [D loss: 0.017738] [G loss: -0.191917]
[Epoch 932/2000] [Batch 0/31] [D loss: -1.182106] [G loss: 73.867645]
[Epoch 933/2000] [Batch 0/31] [D loss: -0.048084] [G loss: 5.982499]
[Epoch 934/2000] [Batch 0/31] [D loss: -0.075743] [G loss: 1.423195]
[Epoch 935/2000] [Batch 0/31] [D loss: -0.009172] [G loss: -3.422318]
[Epoch 936/2000] [Batch 0/31] [D loss: -0.018989] [G loss: 2.043915]
[Epoch 937/2000] [Batch 0/

[Epoch 1039/2000] [Batch 0/31] [D loss: -436.166382] [G loss: 1119.843384]
[Epoch 1040/2000] [Batch 0/31] [D loss: -287.804810] [G loss: 997.334412]
[Epoch 1041/2000] [Batch 0/31] [D loss: -216.444366] [G loss: 674.210815]
[Epoch 1042/2000] [Batch 0/31] [D loss: -125.037231] [G loss: 452.127289]
[Epoch 1043/2000] [Batch 0/31] [D loss: -79.408783] [G loss: 260.248138]
[Epoch 1044/2000] [Batch 0/31] [D loss: -22.503063] [G loss: 51.354912]
[Epoch 1045/2000] [Batch 0/31] [D loss: -16.863487] [G loss: 79.231316]
[Epoch 1046/2000] [Batch 0/31] [D loss: -11.279232] [G loss: 64.958282]
[Epoch 1047/2000] [Batch 0/31] [D loss: -15.385082] [G loss: 41.623207]
[Epoch 1048/2000] [Batch 0/31] [D loss: -11.509064] [G loss: 59.009209]
[Epoch 1049/2000] [Batch 0/31] [D loss: -12.551220] [G loss: 50.494087]
[Epoch 1050/2000] [Batch 0/31] [D loss: -13.750015] [G loss: 46.537529]
[Epoch 1051/2000] [Batch 0/31] [D loss: -16.311436] [G loss: 62.471092]
[Epoch 1052/2000] [Batch 0/31] [D loss: 234.526611] [G

[Epoch 1153/2000] [Batch 0/31] [D loss: 0.000381] [G loss: -0.061929]
[Epoch 1154/2000] [Batch 0/31] [D loss: 0.002897] [G loss: -0.543552]
[Epoch 1155/2000] [Batch 0/31] [D loss: -0.030220] [G loss: -1.116771]
[Epoch 1156/2000] [Batch 0/31] [D loss: -6.250305] [G loss: -208.613312]
[Epoch 1157/2000] [Batch 0/31] [D loss: 0.562826] [G loss: -20.587332]
[Epoch 1158/2000] [Batch 0/31] [D loss: -0.283336] [G loss: -5.057531]
[Epoch 1159/2000] [Batch 0/31] [D loss: -2.146038] [G loss: 99.944885]
[Epoch 1160/2000] [Batch 0/31] [D loss: -0.000298] [G loss: 0.012318]
[Epoch 1161/2000] [Batch 0/31] [D loss: 0.035346] [G loss: -1.536203]
[Epoch 1162/2000] [Batch 0/31] [D loss: -0.000943] [G loss: -0.051049]
[Epoch 1163/2000] [Batch 0/31] [D loss: -0.220012] [G loss: -3.104123]
[Epoch 1164/2000] [Batch 0/31] [D loss: 0.032462] [G loss: -0.589282]
[Epoch 1165/2000] [Batch 0/31] [D loss: 0.115507] [G loss: -0.698901]
[Epoch 1166/2000] [Batch 0/31] [D loss: -0.012832] [G loss: 0.953814]
[Epoch 1167

[Epoch 1267/2000] [Batch 0/31] [D loss: 0.093138] [G loss: -1.325635]
[Epoch 1268/2000] [Batch 0/31] [D loss: 5.535789] [G loss: -45.446671]
[Epoch 1269/2000] [Batch 0/31] [D loss: -0.064071] [G loss: -8.576156]
[Epoch 1270/2000] [Batch 0/31] [D loss: 2.735367] [G loss: 57.727570]
[Epoch 1271/2000] [Batch 0/31] [D loss: -0.175055] [G loss: -3.038288]
[Epoch 1272/2000] [Batch 0/31] [D loss: 0.951152] [G loss: -6.532000]
[Epoch 1273/2000] [Batch 0/31] [D loss: -0.073410] [G loss: -3.428518]
[Epoch 1274/2000] [Batch 0/31] [D loss: -0.050035] [G loss: -1.768487]
[Epoch 1275/2000] [Batch 0/31] [D loss: -0.461592] [G loss: -10.570176]
[Epoch 1276/2000] [Batch 0/31] [D loss: -0.595882] [G loss: 50.789124]
[Epoch 1277/2000] [Batch 0/31] [D loss: 0.204601] [G loss: -12.889385]
[Epoch 1278/2000] [Batch 0/31] [D loss: 0.014861] [G loss: -2.873060]
[Epoch 1279/2000] [Batch 0/31] [D loss: -3.236557] [G loss: -123.810036]
[Epoch 1280/2000] [Batch 0/31] [D loss: 0.024063] [G loss: -1.447053]
[Epoch 1

[Epoch 1382/2000] [Batch 0/31] [D loss: -72.930267] [G loss: -133.446808]
[Epoch 1383/2000] [Batch 0/31] [D loss: -46.650146] [G loss: -472.645386]
[Epoch 1384/2000] [Batch 0/31] [D loss: -17.185333] [G loss: -462.396332]
[Epoch 1385/2000] [Batch 0/31] [D loss: 15.272217] [G loss: -461.155212]
[Epoch 1386/2000] [Batch 0/31] [D loss: -5.169933] [G loss: -168.048157]
[Epoch 1387/2000] [Batch 0/31] [D loss: -31.985779] [G loss: 598.294434]
[Epoch 1388/2000] [Batch 0/31] [D loss: -17.744873] [G loss: 435.497925]
[Epoch 1389/2000] [Batch 0/31] [D loss: -2.364969] [G loss: -16.387228]
[Epoch 1390/2000] [Batch 0/31] [D loss: -9.924236] [G loss: -64.851364]
[Epoch 1391/2000] [Batch 0/31] [D loss: -10.542729] [G loss: 819.924438]
[Epoch 1392/2000] [Batch 0/31] [D loss: -4.416382] [G loss: -150.511765]
[Epoch 1393/2000] [Batch 0/31] [D loss: -75.284180] [G loss: -585.335571]
[Epoch 1394/2000] [Batch 0/31] [D loss: -34.867279] [G loss: -547.795654]
[Epoch 1395/2000] [Batch 0/31] [D loss: 3.453316

[Epoch 1496/2000] [Batch 0/31] [D loss: -94.284515] [G loss: -537.617432]
[Epoch 1497/2000] [Batch 0/31] [D loss: -53.812500] [G loss: -587.239197]
[Epoch 1498/2000] [Batch 0/31] [D loss: -3.684814] [G loss: -570.932373]
[Epoch 1499/2000] [Batch 0/31] [D loss: -3.397064] [G loss: -120.993164]
[Epoch 1500/2000] [Batch 0/31] [D loss: -1.975121] [G loss: -125.332489]
[Epoch 1501/2000] [Batch 0/31] [D loss: -0.190615] [G loss: -25.309582]
[Epoch 1502/2000] [Batch 0/31] [D loss: 1.671134] [G loss: -8.271063]
[Epoch 1503/2000] [Batch 0/31] [D loss: 0.201097] [G loss: -5.743889]
[Epoch 1504/2000] [Batch 0/31] [D loss: -0.039125] [G loss: -4.205455]
[Epoch 1505/2000] [Batch 0/31] [D loss: -22.240326] [G loss: 733.838806]
[Epoch 1506/2000] [Batch 0/31] [D loss: 0.671646] [G loss: -5.237297]
[Epoch 1507/2000] [Batch 0/31] [D loss: -0.222904] [G loss: -48.076473]
[Epoch 1508/2000] [Batch 0/31] [D loss: -0.147278] [G loss: -73.101944]
[Epoch 1509/2000] [Batch 0/31] [D loss: 9.909210] [G loss: -85.

[Epoch 1609/2000] [Batch 0/31] [D loss: -11.883087] [G loss: 200.504944]
[Epoch 1610/2000] [Batch 0/31] [D loss: 0.643250] [G loss: 239.437378]
[Epoch 1611/2000] [Batch 0/31] [D loss: -0.703453] [G loss: 93.591530]
[Epoch 1612/2000] [Batch 0/31] [D loss: -4.078503] [G loss: 41.985580]
[Epoch 1613/2000] [Batch 0/31] [D loss: -3.721714] [G loss: 59.473091]
[Epoch 1614/2000] [Batch 0/31] [D loss: -10.854523] [G loss: 86.761963]
[Epoch 1615/2000] [Batch 0/31] [D loss: -6.569431] [G loss: 63.497330]
[Epoch 1616/2000] [Batch 0/31] [D loss: -9.377571] [G loss: 90.665970]
[Epoch 1617/2000] [Batch 0/31] [D loss: -8.048553] [G loss: 74.758881]
[Epoch 1618/2000] [Batch 0/31] [D loss: -2.924011] [G loss: 46.103073]
[Epoch 1619/2000] [Batch 0/31] [D loss: -4.652115] [G loss: 37.387657]
[Epoch 1620/2000] [Batch 0/31] [D loss: -1.440835] [G loss: 11.018593]
[Epoch 1621/2000] [Batch 0/31] [D loss: 0.100541] [G loss: 6.219747]
[Epoch 1622/2000] [Batch 0/31] [D loss: 6.343010] [G loss: 1.016436]
[Epoch 

[Epoch 1724/2000] [Batch 0/31] [D loss: -0.004741] [G loss: 0.006159]
[Epoch 1725/2000] [Batch 0/31] [D loss: -10.334480] [G loss: -227.624252]
[Epoch 1726/2000] [Batch 0/31] [D loss: 3.454262] [G loss: -31.724628]
[Epoch 1727/2000] [Batch 0/31] [D loss: 1.803055] [G loss: -87.105667]
[Epoch 1728/2000] [Batch 0/31] [D loss: -2.858192] [G loss: -74.865814]
[Epoch 1729/2000] [Batch 0/31] [D loss: -20.374084] [G loss: -1125.006592]
[Epoch 1730/2000] [Batch 0/31] [D loss: 0.089375] [G loss: 8.265436]
[Epoch 1731/2000] [Batch 0/31] [D loss: -40.283417] [G loss: 540.392273]
[Epoch 1732/2000] [Batch 0/31] [D loss: -31.451935] [G loss: 468.822876]
[Epoch 1733/2000] [Batch 0/31] [D loss: -10.869385] [G loss: 350.423706]
[Epoch 1734/2000] [Batch 0/31] [D loss: -0.022720] [G loss: -22.969458]
[Epoch 1735/2000] [Batch 0/31] [D loss: -7.396545] [G loss: -201.211426]
[Epoch 1736/2000] [Batch 0/31] [D loss: -0.167974] [G loss: -18.508656]
[Epoch 1737/2000] [Batch 0/31] [D loss: -0.215464] [G loss: 5.

[Epoch 1838/2000] [Batch 0/31] [D loss: -5.496971] [G loss: 43.016655]
[Epoch 1839/2000] [Batch 0/31] [D loss: -3.908195] [G loss: 70.221291]
[Epoch 1840/2000] [Batch 0/31] [D loss: 11.241501] [G loss: 67.099716]
[Epoch 1841/2000] [Batch 0/31] [D loss: 267.117188] [G loss: -260.015381]
[Epoch 1842/2000] [Batch 0/31] [D loss: -13.802834] [G loss: -174.604858]
[Epoch 1843/2000] [Batch 0/31] [D loss: -0.858687] [G loss: 7.688243]
[Epoch 1844/2000] [Batch 0/31] [D loss: -100.061523] [G loss: -153.141571]
[Epoch 1845/2000] [Batch 0/31] [D loss: 0.556412] [G loss: -43.063896]
[Epoch 1846/2000] [Batch 0/31] [D loss: 0.268477] [G loss: -8.069374]
[Epoch 1847/2000] [Batch 0/31] [D loss: -2.344921] [G loss: 40.700493]
[Epoch 1848/2000] [Batch 0/31] [D loss: -101.414062] [G loss: 1955.087158]
[Epoch 1849/2000] [Batch 0/31] [D loss: -7.761475] [G loss: -178.715500]
[Epoch 1850/2000] [Batch 0/31] [D loss: -208.765625] [G loss: -1652.045898]
[Epoch 1851/2000] [Batch 0/31] [D loss: -0.813545] [G loss

[Epoch 1951/2000] [Batch 0/31] [D loss: -23.945435] [G loss: -276.156982]
[Epoch 1952/2000] [Batch 0/31] [D loss: 2.191010] [G loss: -210.165161]
[Epoch 1953/2000] [Batch 0/31] [D loss: -328.142273] [G loss: 4820.408203]
[Epoch 1954/2000] [Batch 0/31] [D loss: -0.955678] [G loss: -14.343742]
[Epoch 1955/2000] [Batch 0/31] [D loss: -4.209810] [G loss: -29.296080]
[Epoch 1956/2000] [Batch 0/31] [D loss: -15.322495] [G loss: -148.081329]
[Epoch 1957/2000] [Batch 0/31] [D loss: -36.495422] [G loss: -324.895142]
[Epoch 1958/2000] [Batch 0/31] [D loss: -42.420349] [G loss: -332.005676]
[Epoch 1959/2000] [Batch 0/31] [D loss: -21.237823] [G loss: -308.188110]
[Epoch 1960/2000] [Batch 0/31] [D loss: -33.986298] [G loss: -250.744843]
[Epoch 1961/2000] [Batch 0/31] [D loss: -18.661865] [G loss: -212.837280]
[Epoch 1962/2000] [Batch 0/31] [D loss: -13.591522] [G loss: -198.068771]
[Epoch 1963/2000] [Batch 0/31] [D loss: -21.094925] [G loss: -161.878540]
[Epoch 1964/2000] [Batch 0/31] [D loss: -18

In [None]:
z = Tensor(np.random.normal(0, 1, (imgs.shape[0], 100)))
gen_imgs = generator(z)
gen_imgs = gen_imgs.cpu()

import matplotlib.pyplot as plt
num_to_gen = gen_imgs.shape[0]
_, axes = plt.subplots(figsize=(16, 4), ncols=num_to_gen)
for ii in range(num_to_gen):
    ax = axes[ii]
    img = gen_imgs[ii]
    npimg = img.detach().numpy()
    ax.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')