In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import importlib
import argparse
import time

In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.autograd as autograd
from torch.autograd import Variable
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision import transforms
import torchvision.utils as vutils
from torchnet.meter import AverageValueMeter

# Set parameters

In [3]:
parser = {
    'dataset': 'cifar10',
    'dataroot': './data',
    'workers': 2,
    'batch_size': 64,
    'image_size': 64,
    'image_channels': 3,
    'z_dim': 100,
    'n_G': 64,
    'n_D': 64,
    'epochs': 25, 
    'lr': 1e-3,
    'beta1': 0.5,
    'netG': '',
    'netD': '',
    'outf': './output',
    'ngpu': 0,
    'manualSeed': 7,
    'no_cuda': True,
}
args = argparse.Namespace(**parser)

In [4]:
print(args)

Namespace(batch_size=64, beta1=0.5, dataroot='./data', dataset='cifar10', epochs=25, image_channels=3, image_size=64, lr=0.001, manualSeed=7, n_D=64, n_G=64, netD='', netG='', ngpu=0, no_cuda=True, outf='./output', workers=2, z_dim=100)


# Load data

In [5]:
dataset = dset.CIFAR10(root=args.dataroot, download=True, transform=transforms.Compose([
                                                        transforms.Scale(args.image_size),
                                                        transforms.ToTensor(),
                                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

assert dataset

Files already downloaded and verified


In [6]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers))

# Define class

In [7]:
# custom weights initialization called on netG and netD
# m: layer of model
def weights_init(m):
    classname = m.__class__.__name__  #returns the name of class of m
    if classname.find('Conv') != -1:  #name contains Conv
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1: #name contains BatchNorm, this can be seen like activation function after batchnorm
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

```class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)```

```python
def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True
```

```class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True)```

In [8]:
#Define class Discriminator:
class _netG(nn.Module):
    def __init__(self, ngpu):
        super(_netG, self).__init__()  #no need to list __init__ of nn.Module
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(args.z_dim, args.n_G * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(args.n_G * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(args.n_G * 8, args.n_G * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(args.n_G * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(args.n_G * 4, args.n_G * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(args.n_G * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(args.n_G * 2, args.n_G, 4, 2, 1, bias=False),
            nn.BatchNorm2d(args.n_G),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(args.n_G, args.image_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64 --> image size = 64
        )

    #def forward(self, input):
    #    gpu_ids = None
    #    if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
    #       gpu_ids = range(self.ngpu)
    #    return nn.parallel.data_parallel(self.main, input, gpu_ids)
    
    def forward(self, input):
        return self.main(input)

In [9]:
netG = _netG(args.ngpu)
netG.apply(weights_init)
if args.netG != '':
    netG.load_state_dict(torch.load(args.netG))
print(netG)

_netG (
  (main): Sequential (
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU (inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU (inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU (inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU (inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh ()
  )
)


```class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)```

In [10]:
class _netD(nn.Module):
    def __init__(self, ngpu):
        super(_netD, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(args.image_channels, args.n_D, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(args.n_D, args.n_D * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(args.n_D * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(args.n_D * 2, args.n_D * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(args.n_D * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(args.n_D * 4, args.n_D * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(args.n_D * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(args.n_D * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    #def forward(self, input):
    #    gpu_ids = None
    #    if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
    #        gpu_ids = range(self.ngpu)
    #    output = nn.parallel.data_parallel(self.main, input, gpu_ids)
    #   return output.view(-1, 1)
    
    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1)

In [11]:
netD = _netD(args.ngpu)
netD.apply(weights_init)
if args.netD != '':
    netD.load_state_dict(torch.load(args.netD))
print(netD)

_netD (
  (main): Sequential (
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU (0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU (0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (7): LeakyReLU (0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (10): LeakyReLU (0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid ()
  )
)


# Train

In [12]:
criterion = nn.BCELoss()

```class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)```

In [13]:
optimizerD = optim.Adam(netD.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, 0.999))

**NOTE**: 
if update Pytorch, the steps when training will be changed:

buffer --> backward --> update _**CHANGED TO**_ backward --> buffer --> update

In [14]:
def train(args, data_loader, netG, netD, G_optimizer, D_optimizer, criterion, epoch):
    
    D_losses = AverageValueMeter()
    G_losses = AverageValueMeter()
    D_reals = AverageValueMeter()        
    D_fakes = AverageValueMeter()
    G_reals = AverageValueMeter()
        
    start = time.time()
    
    for i, (real, _) in enumerate(dataloader):
        batch_size = real.size(0)
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
       
        # Create ones_label and zeros_label
        real_label = Variable(torch.ones(batch_size))
        fake_label = Variable(torch.zeros(batch_size))
        z = Variable(torch.randn(batch_size, args.z_dim, 1,1))
        
        #train with real
        real = Variable(real)
        real_output = netD(real)
        
        D_real_loss = criterion(real_output, real_label)
        D_real = real_output.data.mean()

        
        #train with fake
        fake = netG(z)
        fake_output = netD(fake.detach())   #Use detach so we just need to creat z once
        
        D_fake_loss = criterion(fake_output, fake_label)
        D_fake = fake_output.data.mean()
        
        D_loss = D_real_loss + D_fake_loss
        netD.zero_grad()
        D_loss.backward()
        D_optimizer.step()
        
        
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        output = netD(fake)
        G_loss = criterion(output, real_label)
        G_real = output.data.mean()
        
        netG.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        
        D_losses.add(D_loss.data.cpu()[0] * batch_size, batch_size)
        G_losses.add(G_loss.data.cpu()[0] * batch_size, batch_size)
        D_reals.add(D_real * batch_size, batch_size)
        D_fakes.add(D_fake * batch_size, batch_size)
        G_reals.add(G_real * batch_size, batch_size)
        
        if i % 100 == 0:
            plot(real, epoch)
        
    print("=> EPOCH {} | Time: {}s | D_loss: {:.4f} | G_loss: {:.4f}"
          " | D_real: {:.4f} | D_fake: {:.4f} | G_real: {:.4f}"
          .format(epoch, round(time.time()-start), D_losses.value()[0],
                  G_losses.value()[0], D_reals.value()[0],
                  D_fakes.value()[0], G_reals.value()[0]))

In [None]:
def plot(X, epoch):
    z = Variable(torch.randn(X.size(0), args.z_dim, 1,1))
    vutils.save_image(X.data, '%s/real_samples.png' % args.outf)
    fake = netG(z)
    vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03d.png' % (args.outf, epoch))

In [None]:
for epoch in range(1, args.epochs+1):
    train(args, dataloader, netG, netD, optimizerG, optimizerD, criterion, epoch)