In [11]:
""" Conditional DCGAN for MNIST images generations.
    Author: Moustafa Alzantot (malzantot@ucla.edu)
    All rights reserved.
"""
#ref: https://github.com/pytorch/examples/blob/master/dcgan/main.py

import os
import argparse
import numpy as np
import torch
from torch import nn, optim

import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader

import torchvision
import torchvision.utils as vutils
from torchvision.utils import save_image
from torchvision import datasets, transforms

LATENT_DIM = 100
GEN_FILTERS = 64
DISCRIM_FILTERS = 64
CHANNELS = 3
IMAGE_SIZE = 64

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(LATENT_DIM, GEN_FILTERS * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(GEN_FILTERS * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(GEN_FILTERS * 8, GEN_FILTERS * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GEN_FILTERS * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(GEN_FILTERS * 4, GEN_FILTERS * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GEN_FILTERS * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(GEN_FILTERS * 2, GEN_FILTERS, 4, 2, 1, bias=False),
            nn.BatchNorm2d(GEN_FILTERS),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(GEN_FILTERS, CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

            # state size. (nc) x 32 x 32
            
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 32 x 32
            nn.Conv2d(in_channels=CHANNELS, out_channels=DISCRIM_FILTERS, 
                      kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(DISCRIM_FILTERS, DISCRIM_FILTERS * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISCRIM_FILTERS * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(DISCRIM_FILTERS * 2, DISCRIM_FILTERS * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISCRIM_FILTERS * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(DISCRIM_FILTERS * 4, DISCRIM_FILTERS * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(DISCRIM_FILTERS * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(DISCRIM_FILTERS * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)

In [6]:
BATCH_SIZE = 128

# train_dataset = datasets.CIFAR10(root='../_datasets/cifar-10',
#     train=True,
#     download=True,
#     transform=transforms.ToTensor())

train_dataset = datasets.CIFAR10(root='../_datasets/cifar-10', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.Resize(IMAGE_SIZE),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

train_loader = DataLoader(train_dataset, shuffle=True,
    batch_size=BATCH_SIZE)

Files already downloaded and verified


In [15]:
# parser = argparse.ArgumentParser('Conditional DCGAN')
BATCH_SIZE = 128
LR = .01
EPOCHS = 10                   
NOISE = 100 # number of dimensions for input noise
CUDA = False
SAVE_EVERY = 5 # after how many epochs to save the model
PRINT_EVERY = 1 # After how many epochs to print loss and save output samples.
SAVE_DIR = 'models' # Path to save the trained models.
SAMPLES_DIR = 'samples' # Path to save the output samples

NGPU = 0
device = torch.device("cuda:0" if CUDA else "cpu")
NET_G = ''
NET_D = ''
BETA1 = .5

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        
netG = Generator(NGPU).to(device)
netG.apply(weights_init)
if NET_G != '':
    netG.load_state_dict(torch.load(NET_G))
print(netG)

netD = Discriminator(NGPU).to(device)
netD.apply(weights_init)

if NET_D != '':
    netD.load_state_dict(torch.load(NET_D))
print(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(size=(BATCH_SIZE, LATENT_DIM, 1, 1), device=device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)


In [16]:
for epoch in range(EPOCHS):
    for i, data in enumerate(train_loader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
#         print(f'type data[0]: {type(data[0])}')
#         print(f'shape data[0]: {data[0].shape}')
#         print(f'len data: {len(data)}')
#         print(f'content data[1]: {data[1]}')
        real_input = data[0].to(device)
        batch_size = real_input.size(0)
        label = torch.full((batch_size,), real_label, device=device)
        
        output = netD(real_input)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(size=(batch_size, LATENT_DIM, 1, 1), device=device)
        # generate fake input
        fake = netG(noise)
        label.fill_(fake_label)

        output = netD(fake.detach()) #detach the generated image from the graph
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item() # mean of the descriminator output
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item() # ???
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, EPOCHS, i, len(train_loader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            vutils.save_image(real_input,
                    f'{SAMPLES_DIR}/real_samples.png', normalize=False)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    f'{SAMPLES_DIR}/fake_samples_epoch_%03d.png' % (epoch),
                    normalize=False)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (SAVE_DIR, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (SAVE_DIR, epoch))

[0/10][0/391] Loss_D: 1.9414 Loss_G: 62.7425 D(x): 0.6502 D(G(z)): 0.7014 / 0.0000
[0/10][1/391] Loss_D: 37.7405 Loss_G: 41.9640 D(x): 0.0000 D(G(z)): 0.0000 / 0.0000
[0/10][2/391] Loss_D: 21.4176 Loss_G: 32.7500 D(x): 0.3877 D(G(z)): 0.5682 / 0.0000
[0/10][3/391] Loss_D: 9.6515 Loss_G: 9.6678 D(x): 0.5968 D(G(z)): 0.7731 / 0.1239
[0/10][4/391] Loss_D: 10.5163 Loss_G: 1.9072 D(x): 0.5044 D(G(z)): 0.2898 / 0.3753
[0/10][5/391] Loss_D: 7.2621 Loss_G: 0.1580 D(x): 0.7685 D(G(z)): 0.4327 / 0.9234
[0/10][6/391] Loss_D: 9.7982 Loss_G: 2.1033 D(x): 0.9907 D(G(z)): 0.9814 / 0.3039
[0/10][7/391] Loss_D: 5.0940 Loss_G: 7.9459 D(x): 0.6752 D(G(z)): 0.3495 / 0.4209
[0/10][8/391] Loss_D: 6.1636 Loss_G: 4.6553 D(x): 0.5059 D(G(z)): 0.2504 / 0.1916
[0/10][9/391] Loss_D: 7.9230 Loss_G: 27.5097 D(x): 0.8725 D(G(z)): 0.8866 / 0.0135
[0/10][10/391] Loss_D: 9.0681 Loss_G: 19.2424 D(x): 0.2919 D(G(z)): 0.0918 / 0.0804
[0/10][11/391] Loss_D: 6.4379 Loss_G: 12.4586 D(x): 0.5643 D(G(z)): 0.2548 / 0.0557
[0/10

[0/10][195/391] Loss_D: 0.6686 Loss_G: 2.7718 D(x): 0.8233 D(G(z)): 0.3066 / 0.0870
[0/10][196/391] Loss_D: 0.8563 Loss_G: 2.2083 D(x): 0.7218 D(G(z)): 0.3330 / 0.1638
[0/10][197/391] Loss_D: 0.4757 Loss_G: 2.6853 D(x): 0.7462 D(G(z)): 0.1038 / 0.0948
[0/10][198/391] Loss_D: 0.2727 Loss_G: 2.6466 D(x): 0.9238 D(G(z)): 0.1579 / 0.0955
[0/10][199/391] Loss_D: 0.6150 Loss_G: 4.3788 D(x): 0.9553 D(G(z)): 0.3951 / 0.0276
[0/10][200/391] Loss_D: 0.7142 Loss_G: 2.0679 D(x): 0.6359 D(G(z)): 0.1102 / 0.1899
[0/10][201/391] Loss_D: 0.3221 Loss_G: 3.2781 D(x): 0.8979 D(G(z)): 0.1725 / 0.0502
[0/10][202/391] Loss_D: 1.2701 Loss_G: 6.1909 D(x): 0.8169 D(G(z)): 0.5947 / 0.0038
[0/10][203/391] Loss_D: 2.8594 Loss_G: 1.4244 D(x): 0.1109 D(G(z)): 0.0058 / 0.3578
[0/10][204/391] Loss_D: 1.0573 Loss_G: 0.6493 D(x): 0.5453 D(G(z)): 0.0865 / 0.5726
[0/10][205/391] Loss_D: 1.5341 Loss_G: 3.9994 D(x): 0.9920 D(G(z)): 0.6857 / 0.0435
[0/10][206/391] Loss_D: 0.3292 Loss_G: 4.1742 D(x): 0.8391 D(G(z)): 0.0992 /

[0/10][389/391] Loss_D: 1.2190 Loss_G: 1.6406 D(x): 0.6756 D(G(z)): 0.4927 / 0.2433
[0/10][390/391] Loss_D: 1.2873 Loss_G: 1.1439 D(x): 0.5186 D(G(z)): 0.3761 / 0.3628
[1/10][0/391] Loss_D: 1.1133 Loss_G: 1.6914 D(x): 0.6833 D(G(z)): 0.4632 / 0.2303
[1/10][1/391] Loss_D: 1.1344 Loss_G: 1.6966 D(x): 0.5962 D(G(z)): 0.3838 / 0.2134
[1/10][2/391] Loss_D: 1.0969 Loss_G: 1.1692 D(x): 0.5554 D(G(z)): 0.3332 / 0.3433
[1/10][3/391] Loss_D: 1.3003 Loss_G: 2.6287 D(x): 0.7028 D(G(z)): 0.5547 / 0.1346
[1/10][4/391] Loss_D: 1.7554 Loss_G: 0.7518 D(x): 0.3549 D(G(z)): 0.2497 / 0.5167
[1/10][5/391] Loss_D: 1.9027 Loss_G: 2.3631 D(x): 0.6991 D(G(z)): 0.6706 / 0.1326
[1/10][6/391] Loss_D: 1.6717 Loss_G: 0.8491 D(x): 0.3385 D(G(z)): 0.2866 / 0.4647
[1/10][7/391] Loss_D: 1.3788 Loss_G: 1.5572 D(x): 0.6719 D(G(z)): 0.5546 / 0.2384
[1/10][8/391] Loss_D: 1.1732 Loss_G: 1.2859 D(x): 0.5027 D(G(z)): 0.3342 / 0.2965
[1/10][9/391] Loss_D: 0.9980 Loss_G: 1.3827 D(x): 0.6181 D(G(z)): 0.3688 / 0.2730
[1/10][10/39

[1/10][193/391] Loss_D: 1.0502 Loss_G: 1.4160 D(x): 0.5719 D(G(z)): 0.2938 / 0.2766
[1/10][194/391] Loss_D: 1.0954 Loss_G: 2.2249 D(x): 0.7799 D(G(z)): 0.5318 / 0.1552
[1/10][195/391] Loss_D: 0.9798 Loss_G: 1.5046 D(x): 0.5940 D(G(z)): 0.2758 / 0.2550
[1/10][196/391] Loss_D: 0.7457 Loss_G: 2.3912 D(x): 0.8256 D(G(z)): 0.3888 / 0.1356
[1/10][197/391] Loss_D: 0.6587 Loss_G: 2.1733 D(x): 0.7556 D(G(z)): 0.2576 / 0.1328
[1/10][198/391] Loss_D: 0.9343 Loss_G: 2.6485 D(x): 0.7672 D(G(z)): 0.4457 / 0.0931
[1/10][199/391] Loss_D: 0.7090 Loss_G: 2.2988 D(x): 0.6605 D(G(z)): 0.1952 / 0.1247
[1/10][200/391] Loss_D: 0.5078 Loss_G: 2.4673 D(x): 0.8087 D(G(z)): 0.2178 / 0.1112
[1/10][201/391] Loss_D: 0.6167 Loss_G: 4.5864 D(x): 0.8929 D(G(z)): 0.3617 / 0.0161
[1/10][202/391] Loss_D: 0.8269 Loss_G: 1.9370 D(x): 0.5344 D(G(z)): 0.0306 / 0.2018
[1/10][203/391] Loss_D: 1.2313 Loss_G: 4.8563 D(x): 0.9343 D(G(z)): 0.6353 / 0.0112
[1/10][204/391] Loss_D: 1.8945 Loss_G: 0.6802 D(x): 0.2498 D(G(z)): 0.0501 /

[1/10][387/391] Loss_D: 1.7104 Loss_G: 2.3562 D(x): 0.8574 D(G(z)): 0.7387 / 0.1259
[1/10][388/391] Loss_D: 1.2050 Loss_G: 1.3261 D(x): 0.4397 D(G(z)): 0.2247 / 0.2898
[1/10][389/391] Loss_D: 1.0210 Loss_G: 1.2646 D(x): 0.6562 D(G(z)): 0.4110 / 0.3114
[1/10][390/391] Loss_D: 1.1534 Loss_G: 1.6747 D(x): 0.6411 D(G(z)): 0.4716 / 0.2134
[2/10][0/391] Loss_D: 1.2784 Loss_G: 1.4039 D(x): 0.5165 D(G(z)): 0.3959 / 0.2872
[2/10][1/391] Loss_D: 1.3249 Loss_G: 1.6908 D(x): 0.5857 D(G(z)): 0.4912 / 0.2134
[2/10][2/391] Loss_D: 1.1595 Loss_G: 0.9943 D(x): 0.4848 D(G(z)): 0.2870 / 0.3995
[2/10][3/391] Loss_D: 1.2350 Loss_G: 2.6287 D(x): 0.7980 D(G(z)): 0.6035 / 0.0909
[2/10][4/391] Loss_D: 1.1460 Loss_G: 1.0289 D(x): 0.4436 D(G(z)): 0.1773 / 0.3934
[2/10][5/391] Loss_D: 1.1316 Loss_G: 2.6361 D(x): 0.7604 D(G(z)): 0.5322 / 0.0906
[2/10][6/391] Loss_D: 0.9832 Loss_G: 1.4130 D(x): 0.5313 D(G(z)): 0.2342 / 0.2844
[2/10][7/391] Loss_D: 0.9347 Loss_G: 1.9597 D(x): 0.7220 D(G(z)): 0.4073 / 0.1661
[2/10][8

[2/10][191/391] Loss_D: 1.0910 Loss_G: 1.4158 D(x): 0.6087 D(G(z)): 0.4008 / 0.2658
[2/10][192/391] Loss_D: 1.2676 Loss_G: 1.1214 D(x): 0.4846 D(G(z)): 0.3767 / 0.3473
[2/10][193/391] Loss_D: 1.3218 Loss_G: 1.2491 D(x): 0.5516 D(G(z)): 0.4640 / 0.3136
[2/10][194/391] Loss_D: 1.1591 Loss_G: 1.3023 D(x): 0.5797 D(G(z)): 0.4181 / 0.2893
[2/10][195/391] Loss_D: 1.0566 Loss_G: 1.5696 D(x): 0.6214 D(G(z)): 0.4049 / 0.2277
[2/10][196/391] Loss_D: 0.9919 Loss_G: 1.5327 D(x): 0.6040 D(G(z)): 0.3488 / 0.2421
[2/10][197/391] Loss_D: 1.0406 Loss_G: 1.6369 D(x): 0.6235 D(G(z)): 0.3891 / 0.2184
[2/10][198/391] Loss_D: 1.1660 Loss_G: 2.0131 D(x): 0.6289 D(G(z)): 0.4537 / 0.1620
[2/10][199/391] Loss_D: 1.1111 Loss_G: 1.7649 D(x): 0.5349 D(G(z)): 0.3140 / 0.2065
[2/10][200/391] Loss_D: 1.1677 Loss_G: 1.7070 D(x): 0.5917 D(G(z)): 0.4043 / 0.2080
[2/10][201/391] Loss_D: 1.1076 Loss_G: 2.0160 D(x): 0.6322 D(G(z)): 0.4231 / 0.1608
[2/10][202/391] Loss_D: 0.9608 Loss_G: 1.8107 D(x): 0.6425 D(G(z)): 0.3478 /

[2/10][385/391] Loss_D: 1.1066 Loss_G: 1.5898 D(x): 0.5462 D(G(z)): 0.3458 / 0.2323
[2/10][386/391] Loss_D: 1.0942 Loss_G: 1.6525 D(x): 0.6188 D(G(z)): 0.4089 / 0.2363
[2/10][387/391] Loss_D: 1.4152 Loss_G: 2.0112 D(x): 0.6086 D(G(z)): 0.5061 / 0.2224
[2/10][388/391] Loss_D: 1.6508 Loss_G: 1.5344 D(x): 0.4163 D(G(z)): 0.3563 / 0.3125
[2/10][389/391] Loss_D: 1.3832 Loss_G: 1.1224 D(x): 0.5825 D(G(z)): 0.4555 / 0.3598
[2/10][390/391] Loss_D: 1.1840 Loss_G: 1.8182 D(x): 0.6489 D(G(z)): 0.4756 / 0.1916
[3/10][0/391] Loss_D: 1.1565 Loss_G: 1.2616 D(x): 0.5396 D(G(z)): 0.3585 / 0.3256
[3/10][1/391] Loss_D: 1.3494 Loss_G: 2.6705 D(x): 0.6448 D(G(z)): 0.5292 / 0.1098
[3/10][2/391] Loss_D: 1.2479 Loss_G: 1.2419 D(x): 0.4528 D(G(z)): 0.2463 / 0.3705
[3/10][3/391] Loss_D: 1.1322 Loss_G: 2.1204 D(x): 0.7359 D(G(z)): 0.4724 / 0.1398
[3/10][4/391] Loss_D: 1.0883 Loss_G: 0.8180 D(x): 0.5011 D(G(z)): 0.2489 / 0.4660
[3/10][5/391] Loss_D: 1.8200 Loss_G: 3.2792 D(x): 0.8092 D(G(z)): 0.7639 / 0.0563
[3/1

KeyboardInterrupt: 