In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

Set up some datasets

In [2]:
batch_size = 16
image_size = 64

In [3]:
cifar_dataset = dset.CIFAR10(root="./data", download=True,
                       transform=transforms.Compose([
                           transforms.Scale(image_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                       ]))

Files already downloaded and verified


In [4]:
cifar_dataloader = torch.utils.data.DataLoader(cifar_dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=2)

In [5]:
ngpu = 1
nz = 100 #dimensionality of z
ngf = 64
ndf = 64
nc = 3
netG_path = ''
netD_path = ''
cuda = False
lr = 0.0002
beta1 = 0.5
niter = 5
outf = "./saved_stuff"

In [6]:
# custom weights initialization called on netG and netD ()
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1: # if conv layer then normal mean 0 std 0.02
        #note the inplace here with the underscore (makes the weights a normal distribution)
        m.weight.data.normal_(0.0, 0.02) #no bias for conv when we use batchnorm
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02) # if batch norm (mean 1.0, stdv 0.02)
        m.bias.data.fill_(0) #bias 0

In [7]:
class _netG(nn.Module): #generator class
    """deconv network going from a nz long vector to 3x64x64 images"""
    def __init__(self,ngpu):
        super(_netG,self).__init__()
        self.ngpu = ngpu
        #define this nice seuqnetial container, so we have a nice 
        #function to apply whole network to input
        self.main = nn.Sequential(
            # input is nz long vector, num channgels is ngf*8, 4x4 filter, stride 1, pad 0
            # (acts on essentially a length(z) X 1x1 image  
            # and just does a scalar product of each kernel with the corresponding element of the vector)
            # no bias because we are doing batch norm!
            nn.ConvTranspose2d(nz, ngf*8,kernel_size=4,stride=1,padding=0,bias=False),
            #pass number of channels to batch norm because it learns a scale and bias per channel
            nn.BatchNorm2d(ngf*8),
            #relu inplace
            nn.ReLU(True),
            # above: classic conv (without bias), batchnorm, then relu
            #output is ngf*8 x 4 x 4
            
            
            # now we have ngf*8 channels coming in ngf* 4 coming out, 4x4 filter, stride 2, pad 1
            nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            #should upsample from ngf*8 x 4 x 4 to 8x8
            
            #upsamples from 8x8 to 16x16
            # 2*(8-1) + 4 - 2*1
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh() #dcgan uses tanh
            # state size. (nc) x 64 x 64
        )
        
    def forward(self, input):
        # if more than one gpu then define forward pass as a data parallel option
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output
        
        

In [8]:
netG = _netG(ngpu) # make network instance

# the apply(fn) function basically goes thru every submodule in your network
# class (each layer basically) and applys fn to every module
netG.apply(weights_init) #initiallize weights
if netG_path != '':
    netG.load_state_dict(torch.load(netG_path))
print(netG)

_netG (
  (main): Sequential (
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU (inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU (inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU (inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU (inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh ()
  )
)


In [9]:
# make discriminator
class _netD(nn.Module):
    def __init__(self, ngpu):
        super(_netD, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            # no bias because batchnorm
            nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True), #ahh leakyrelu insteresting
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            #output is 1x1x1 -> a scalar to be passed to sigmoid
            nn.Sigmoid()
        )
    def forward(self, input):
            if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
                output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
            else:
                output = self.main(input)
            
            # flatten the nx1x1x1 to be nx1 then squeeze the 1th dimension so its just a (n,) tensor
            return output.view(-1, 1).squeeze(1)
        

In [10]:
netD = _netD(ngpu) #make instance of discriminator

In [11]:
netD.apply(weights_init)
if netD_path != '':
    netD.load_state_dict(torch.load(netD_path))
print(netD)

_netD (
  (main): Sequential (
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU (0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU (0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (7): LeakyReLU (0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (10): LeakyReLU (0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid ()
  )
)


In [12]:
# define the loss (in this case its binary cross entropy) cuz its 
# real vs fake

In [13]:
criterion = nn.BCELoss()

In [14]:
#makes a 4D tensor batch x 3 (rgb)x 64x64
input = torch.FloatTensor(batch_size, 3, image_size, image_size)

In [15]:
# makes a batch x nz x 1 x 1 4D tensor for the vector of noise
noise = torch.FloatTensor(batch_size, nz, 1, 1)

In [16]:
# makes vector or guassian noise
fixed_noise = torch.FloatTensor(batch_size, nz, 1,1).normal_(0,1)

In [17]:
# makes 1d label (cuz its just a binary label)
label = torch.FloatTensor(batch_size)

In [18]:
real_label = 1
fake_label = 0

In [19]:
if cuda:
    #transfer all weights, tensors to gpu
    netD.cuda()
    netG.cuda()
    criterion.cuda()
    input, label = input.cuda(), label.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

In [20]:
fixed_noise = Variable(fixed_noise)

In [21]:
# set up the optimizers
# they basically take in the parameters you want to optimize and hyperparameters
# for that optimizer
# we have a separate optimizer for generator and discriminator 

In [22]:
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))

In [23]:
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [24]:
# ok we setup criterion and optimizers, we're set to go with the 
# training loop

In [None]:
# for each epoch
for epoch in range(niter):
    # for each minibatch
    for i, data in enumerate(cifar_dataloader):
        #######
        # 1. First we update the D network:
        # maximize objective = log(D(x)) + log(1 - D(G(Z)))
        # when data comes from real (x) maximize likelihood of guessing real
        # D() predicts p_real, so log(D) is log-likelihood of real, so maximize that for x inpout
        # 1-D is p_fake, so maximize log-likihood of p_fake for input of fake (D(G(Z)))
        # log(1-D(G(Z)))
        
        #    
        #    a. first we get gradients with real 
        #       so take gradient of log(D(x)) wrt to weights of D 
        #       then we just store the gradients, but we don't update
        #       network yet (we will accumulate gradients)
        #       remember d(mean(L_real(data_real)) + mean(L_fake(data_fake)))/dw 
                # is same as d(mean(L_real(data_real))/dw + d(mean(L_fake(data_fake)))/dw)
        
        #real
        
        netD.zero_grad() # zero out gradient buffers
        real_cpu, _ = data # we don't need label, real_cpu loads data to cpu
        batch_size = real_cpu.size(0)
        
        if cuda:
            real_cpu = real_cpu.cuda() # transfer batch to gpu
        
        # now we copy real data to the input torch tensor
        input.resize_as_(real_cpu).copy_(real_cpu)
        
        # and we fill the label tensor with the real_label (1), fill it with ones
        label.resize_(batch_size).fill_(real_label)
        
        # now we make Variables for input and label because
        # we need to use autograd -> in order to be able to have
        # grad functions we need to do operations on Variables
        # so the inputs to our graph are variables and the Variable
        # constructor needs a torch tensor
        inputv = Variable(input)
        labelv = Variable(label)
        
        # output of dsicriminator
        output = netD(inputv)
        
        
        # computes average binary cross entropy across all examples
        # basically loss of discriminator for the real case 
        # (aka the first term of gan equation: log(D(x)) b/c labels always 1
        # so 1*log(D(x)) + 0 * log(1-D)
        errD_real = criterion(output, labelv)
        
        #now accumulate gradient by calling backward (dont update weights)
        # now it will calc gradients all the way back thru the input to the discrim
        # b/c forward pass went inputv -> errD
        errD_real.backward()
        
        #average prediction from D for real images
        D_x = output.data.mean()
             
        # now we train with fake!
        # resize noise vector and fill it with normal
        noise.resize_(batch_size, nz, 1, 1).normal_(0,1)
        
        # make variable so we can autograd
        noisev = Variable(noise)
        
        #get output from generator
        fake = netG(noisev)
        
        # fill label variable with 0's
        labelv = Variable(label.fill_(fake_label))
        
        # now we get output from discrim when we input fake data
        # **** key point here!!! 
        # (we call detach so fake looks like static input and not just like
        # it comes operations on noise!)
        # this is because we only want to take gradients wrt to Discrim NOT Generator
        # basically takes output of fake and wraps it in a fresh Variable
        # so it has no history
        output = netD(fake.detach())
        
        # now get BCE loss for all ground truth of 0, so
        # 0*log(output) + 1*log(1-output), which is
        # 0* log(D(G(z))) + 1*log(1-D(G(z))) = log(1-D(G(z)))
        # which is second term of discrim objective
        errD_fake = criterion(output, labelv)
        
        # accumulate gradients for take loss
        errD_fake.backward()
        
        # average output from discrim for fake images
        D_G_z1 = output.data.mean()
        
        # total loss for discrim for this epoch
        errD = errD_real + errD_fake
        
        # NOW we finally do a weight update!
        # we use optimizer D, which updates just the parameters from D
        # using the gradient of the parameters
        optimizerD.step()
        
        ###########
        # (2) ok now we update the G network: maximize log(D(G(z)))
        # aka make the Discriminator think real images came from G
        # aka make it think fake images were real
        ###########
        
        # zero out gradient buffers for generator network
        netG.zero_grad()
        
        # fill the label with 1's for real 
        # (basically fake labels are real for generator cost)
        labelv = Variable(label.fill_(real_label))
        
        # get ouput from discrim when we put in take
        # note we use the same fake b/c generator has not changed
        # but this output diff then above b/c weights have been changed
        # in netG
        output = netD(fake)
        
        # now we take the loss with ground truth as real
        # so same as 1a. when we train with real:
        # 1*log(D(x)) 0*log(1-D(x)) = log(D(x)) -> first term of objective
        # but except instead of D(x) it is D(G(z))
        # so the loss is now log(D(G(z)))
        # aka log-likelihood of guessing real for fake images
        # and we want to mimize negative log likelihood
        # so that G changes in a way that would make 
        # maximize the likelihood of discrim guessing real for fake
        errG = criterion(output, labelv)
        
        
        D_G_z2 = output.data.mean()
            
        # ok so this accumulates gradients
        # now note that this wil accumulate
        # gradients for D and G cuz it backprops all the way
        # thru (we don't detach fake, so we go from noise thru G and then thru D and then the loss)
        errG.backward()
        
        # but we use optimizerG, so we only update weights from G
        # and then we will zero out the D gradients
        # before we update D, so it wont matter
        optimizerG.step()
        
        
        # print out shit
        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, niter, i, len(cifar_dataloader),
                 errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            vutils.save_image(real_cpu,
                    '%s/real_samples.png' % outf,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03d.png' % (outf, epoch),
                    normalize=True)
    #checkpoint every epoch
    # state dict saves dict of all weights and buffers of gradients??
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch))
        
        
        

[0/5][0/3125] Loss_D: 1.9055 Loss_G: 3.5048 D(x): 0.3053 D(G(z)): 0.2774 / 0.0357
[0/5][1/3125] Loss_D: 1.9156 Loss_G: 6.4784 D(x): 0.7515 D(G(z)): 0.7580 / 0.0023
[0/5][2/3125] Loss_D: 0.9732 Loss_G: 6.9638 D(x): 0.7230 D(G(z)): 0.4143 / 0.0011
[0/5][3/3125] Loss_D: 0.9062 Loss_G: 6.0972 D(x): 0.6594 D(G(z)): 0.2803 / 0.0028
[0/5][4/3125] Loss_D: 0.6654 Loss_G: 5.0112 D(x): 0.7024 D(G(z)): 0.1618 / 0.0078
[0/5][5/3125] Loss_D: 1.2282 Loss_G: 6.9759 D(x): 0.8486 D(G(z)): 0.4697 / 0.0012
[0/5][6/3125] Loss_D: 0.6821 Loss_G: 7.3622 D(x): 0.7481 D(G(z)): 0.2871 / 0.0008
[0/5][7/3125] Loss_D: 0.6123 Loss_G: 6.8587 D(x): 0.7689 D(G(z)): 0.1970 / 0.0011
[0/5][8/3125] Loss_D: 0.9019 Loss_G: 7.7996 D(x): 0.7356 D(G(z)): 0.3559 / 0.0005
[0/5][9/3125] Loss_D: 1.0170 Loss_G: 10.0554 D(x): 0.8377 D(G(z)): 0.4637 / 0.0001
[0/5][10/3125] Loss_D: 1.0367 Loss_G: 7.4316 D(x): 0.5978 D(G(z)): 0.1544 / 0.0007
[0/5][11/3125] Loss_D: 0.6593 Loss_G: 9.6785 D(x): 0.9008 D(G(z)): 0.3906 / 0.0001
[0/5][12/3125

[0/5][99/3125] Loss_D: 0.4525 Loss_G: 18.6216 D(x): 0.7803 D(G(z)): 0.0000 / 0.0000
[0/5][100/3125] Loss_D: 0.0740 Loss_G: 15.3031 D(x): 0.9361 D(G(z)): 0.0000 / 0.0000
[0/5][101/3125] Loss_D: 0.0487 Loss_G: 9.5065 D(x): 0.9555 D(G(z)): 0.0001 / 0.0001
[0/5][102/3125] Loss_D: 0.0366 Loss_G: 4.5901 D(x): 0.9825 D(G(z)): 0.0177 / 0.0125
[0/5][103/3125] Loss_D: 1.8399 Loss_G: 25.5151 D(x): 0.9058 D(G(z)): 0.7836 / 0.0000
[0/5][104/3125] Loss_D: 0.0703 Loss_G: 27.3194 D(x): 0.9336 D(G(z)): 0.0000 / 0.0000
[0/5][105/3125] Loss_D: 1.5343 Loss_G: 27.4917 D(x): 0.5434 D(G(z)): 0.0000 / 0.0000
[0/5][106/3125] Loss_D: 0.1239 Loss_G: 26.7023 D(x): 0.9057 D(G(z)): 0.0000 / 0.0000
[0/5][107/3125] Loss_D: 0.3370 Loss_G: 26.7923 D(x): 0.7975 D(G(z)): 0.0000 / 0.0000
[0/5][108/3125] Loss_D: 0.1020 Loss_G: 23.6624 D(x): 0.9394 D(G(z)): 0.0000 / 0.0000
[0/5][109/3125] Loss_D: 0.0397 Loss_G: 20.0154 D(x): 0.9637 D(G(z)): 0.0000 / 0.0000
[0/5][110/3125] Loss_D: 0.0041 Loss_G: 13.3327 D(x): 0.9959 D(G(z)):

[0/5][197/3125] Loss_D: 1.2777 Loss_G: 7.3821 D(x): 0.8850 D(G(z)): 0.5818 / 0.0009
[0/5][198/3125] Loss_D: 0.5109 Loss_G: 7.0712 D(x): 0.6784 D(G(z)): 0.0041 / 0.0019
[0/5][199/3125] Loss_D: 0.1965 Loss_G: 3.9048 D(x): 0.8624 D(G(z)): 0.0301 / 0.0558
[0/5][200/3125] Loss_D: 0.4576 Loss_G: 3.9760 D(x): 0.8737 D(G(z)): 0.2172 / 0.0276
[0/5][201/3125] Loss_D: 0.3214 Loss_G: 4.1288 D(x): 0.8948 D(G(z)): 0.0647 / 0.0258
[0/5][202/3125] Loss_D: 0.2052 Loss_G: 4.5507 D(x): 0.9654 D(G(z)): 0.1473 / 0.0158
[0/5][203/3125] Loss_D: 0.4320 Loss_G: 4.0180 D(x): 0.8315 D(G(z)): 0.1614 / 0.0248
[0/5][204/3125] Loss_D: 0.4461 Loss_G: 5.6096 D(x): 0.8820 D(G(z)): 0.2473 / 0.0050
[0/5][205/3125] Loss_D: 0.9519 Loss_G: 1.7856 D(x): 0.6041 D(G(z)): 0.0471 / 0.2189
[0/5][206/3125] Loss_D: 0.8103 Loss_G: 5.4643 D(x): 0.9366 D(G(z)): 0.4540 / 0.0073
[0/5][207/3125] Loss_D: 0.9904 Loss_G: 3.3254 D(x): 0.5651 D(G(z)): 0.0683 / 0.0738
[0/5][208/3125] Loss_D: 0.3038 Loss_G: 5.1105 D(x): 0.9847 D(G(z)): 0.2272 /

[0/5][295/3125] Loss_D: 1.1523 Loss_G: 10.6312 D(x): 0.9912 D(G(z)): 0.4620 / 0.0001
[0/5][296/3125] Loss_D: 0.2619 Loss_G: 10.8034 D(x): 0.8403 D(G(z)): 0.0004 / 0.0001
[0/5][297/3125] Loss_D: 0.4332 Loss_G: 7.0878 D(x): 0.7789 D(G(z)): 0.0045 / 0.0037
[0/5][298/3125] Loss_D: 0.2026 Loss_G: 4.8087 D(x): 0.9475 D(G(z)): 0.1160 / 0.0137
[0/5][299/3125] Loss_D: 0.4772 Loss_G: 8.0874 D(x): 0.9889 D(G(z)): 0.3177 / 0.0004
[0/5][300/3125] Loss_D: 0.4312 Loss_G: 6.0076 D(x): 0.7570 D(G(z)): 0.0166 / 0.0036
[0/5][301/3125] Loss_D: 0.4708 Loss_G: 2.9767 D(x): 0.7155 D(G(z)): 0.0565 / 0.0579
[0/5][302/3125] Loss_D: 0.9399 Loss_G: 9.2141 D(x): 0.9606 D(G(z)): 0.5188 / 0.0004
[0/5][303/3125] Loss_D: 0.1193 Loss_G: 9.1784 D(x): 0.9012 D(G(z)): 0.0013 / 0.0003
[0/5][304/3125] Loss_D: 0.0517 Loss_G: 6.7395 D(x): 0.9604 D(G(z)): 0.0100 / 0.0046
[0/5][305/3125] Loss_D: 0.5165 Loss_G: 4.6793 D(x): 0.8148 D(G(z)): 0.1458 / 0.0163
[0/5][306/3125] Loss_D: 0.4493 Loss_G: 7.1408 D(x): 0.8728 D(G(z)): 0.2432