# GAN

In [1]:
import os
import time

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchnet.meter import AverageValueMeter

In [4]:
import mnist

In [5]:
intermediate_path = os.path.join("..", "intermediate", "gan")
if not os.path.exists(intermediate_path):
    os.makedirs(intermediate_path)

In [6]:
mb_size = 64
epochs = 100
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
lr = 1e-3

In [7]:
def xavier_init(size):
    in_dim = size[0]
    out_dim = size[1]
    xavier_stddev = 1. / np.sqrt((in_dim+out_dim) / 2.)
    return Variable(torch.randn(*size) * xavier_stddev,
                    requires_grad=True)

In [8]:
Wzh = xavier_init(size=[Z_dim, h_dim])
Whx = xavier_init(size=[h_dim, X_dim])

In [9]:
bzh = Variable(torch.zeros(h_dim), requires_grad=True)
bhx = Variable(torch.zeros(X_dim), requires_grad=True)

In [10]:
def G(z):
    h = F.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = F.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

In [11]:
Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X):
    h = F.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = F.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y

In [12]:
G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params

In [13]:
def reset_grad():
    for p in params:
        p.grad.data.zero_()

In [14]:
G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

In [15]:
train_loader = DataLoader(
    mnist.MNIST('../data', train=True, download=True,
                transform=transforms.ToTensor()),
                batch_size=mb_size, shuffle=True)

In [16]:
def train(epoch):
    D_losses = AverageValueMeter()
    G_losses = AverageValueMeter()
    start = time.time()
    
    for i, (X, _) in enumerate(train_loader):
        ones_label = Variable(torch.ones(X.size(0)))
        zeros_label = Variable(torch.zeros(X.size(0)))
        X = X.view(-1, 784)
        X = Variable(X)
        # Sample data
        z = Variable(torch.randn(X.size(0), Z_dim))
        
        # Dicriminator forward-loss-backward-update
        G_sample = G(z)
        D_real = D(X)
        D_fake = D(G_sample)
        
        D_loss_real = F.binary_cross_entropy(D_real, ones_label)
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
        
        D_loss = D_loss_real + D_loss_fake
        
        D_loss.backward()
        D_solver.step()
        reset_grad()
        
        # Generator forward-loss-backward-update
        z = Variable(torch.randn(X.size(0), Z_dim))
        G_sample = G(z)
        D_fake = D(G_sample)
        
        G_loss = F.binary_cross_entropy(D_fake, ones_label)
        
        G_loss.backward()
        G_solver.step()
        reset_grad()
        
        D_losses.add(D_loss.data.cpu()[0] * X.size(0), X.size(0))
        G_losses.add(G_loss.data.cpu()[0] * X.size(0), X.size(0))
        
    print("   * EPOCH {} | Time: {}s | D_loss: {:.4f} | G_loss: {:.4f}"
          .format(epoch, round(time.time()-start),
                  D_losses.value()[0],
                  G_losses.value()[0]))

In [None]:
def plot(samples, epoch):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis("off")
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect("equal")
        plt.imshow(sample.reshape(28, 28), cmap="Greys_r")

    out_path = os.path.join(intermediate_path, "out")
    if not os.path.exists(out_path):
        os.makedirs(out_path)

    out_filepath = os.path.join(out_path, 
                                "{}.png".format(str(epoch).zfill(3)))
    plt.savefig(out_filepath, bbox_inches='tight')
    plt.close(fig)

In [None]:
for epoch in range(1, epochs+1):
    train(epoch)
    z = torch.randn(mb_size, Z_dim)
    samples = G(Variable(z)).data.numpy()[:16]
    plot(samples, epoch)

   * EPOCH 1 | Time: 25s | D_loss: 0.0885 | G_loss: 5.3372
   * EPOCH 2 | Time: 26s | D_loss: 0.0639 | G_loss: 5.1707
   * EPOCH 3 | Time: 25s | D_loss: 0.1065 | G_loss: 4.5663
   * EPOCH 4 | Time: 25s | D_loss: 0.2214 | G_loss: 4.6510
   * EPOCH 5 | Time: 26s | D_loss: 0.3088 | G_loss: 4.1979
   * EPOCH 6 | Time: 25s | D_loss: 0.4808 | G_loss: 3.7872
   * EPOCH 7 | Time: 25s | D_loss: 0.5365 | G_loss: 3.1798
   * EPOCH 8 | Time: 26s | D_loss: 0.5744 | G_loss: 3.1267
   * EPOCH 9 | Time: 25s | D_loss: 0.6274 | G_loss: 2.9110
   * EPOCH 10 | Time: 25s | D_loss: 0.6818 | G_loss: 2.7613
   * EPOCH 11 | Time: 25s | D_loss: 0.6946 | G_loss: 2.6511
   * EPOCH 12 | Time: 26s | D_loss: 0.7201 | G_loss: 2.7746
   * EPOCH 13 | Time: 25s | D_loss: 0.7558 | G_loss: 2.6204
   * EPOCH 14 | Time: 25s | D_loss: 0.7741 | G_loss: 2.3711
   * EPOCH 15 | Time: 25s | D_loss: 0.7586 | G_loss: 2.2646
   * EPOCH 16 | Time: 26s | D_loss: 0.7686 | G_loss: 2.1683
   * EPOCH 17 | Time: 26s | D_loss: 0.7824 | G_lo

# DCGAN

In [1]:
import argparse
import os
import time

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchnet.meter import AverageValueMeter

In [3]:
parser = {
    "data_path": os.path.join("..", "data"),
    "workers": 5,
    "batch_size": 64,
    "image_size": 64,
    "z_dim": 100,
    "G_features": 64,
    "D_features": 64,
    "image_channels": 3,
    "epochs": 25,
    "lr": 0.0002,
    "beta1": 0.5,
    "cuda": True,
    "intermediate_path": os.path.join("..", "intermediate", "dcgan"),
    "seed": 7
}
args = argparse.Namespace(**parser)

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
args.cuda = args.cuda and torch.cuda.is_available()

if not os.path.isdir(args.intermediate_path):
    os.makedirs(args.intermediate_path)

In [4]:
dataset = CIFAR10(root=args.data_path, download=True,
                  transform=transforms.Compose([
                      transforms.Scale(args.image_size),
                      transforms.ToTensor(),
                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                  ]))

Files already downloaded and verified


In [5]:
data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.workers)

In [6]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [7]:
class _netG(nn.Module):
    def __init__(self):
        super(_netG, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(args.z_dim, args.G_features * 8,
                               4, 1, 0, bias=False),
            nn.BatchNorm2d(args.G_features * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(args.G_features * 8, args.G_features * 4,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(args.G_features * 4),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(args.G_features * 4, args.G_features * 2,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(args.G_features * 2),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(args.G_features * 2, args.G_features,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(args.G_features),
            nn.ReLU(True),
            # state size. (ngf) x 16 x 16
            nn.ConvTranspose2d(args.G_features, args.image_channels,
                               4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 32 x 32
        )

    def forward(self, input):
        return self.main(input)

In [8]:
netG = _netG()
netG.apply(weights_init)

_netG (
  (main): Sequential (
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU (inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU (inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU (inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU (inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh ()
  )
)

In [9]:
class _netD(nn.Module):
    def __init__(self):
        super(_netD, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 32 x 32
            nn.Conv2d(args.image_channels, args.D_features,
                      4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 16 x 16
            nn.Conv2d(args.D_features, args.D_features * 2,
                      4, 2, 1, bias=False),
            nn.BatchNorm2d(args.D_features * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8
            nn.Conv2d(args.D_features * 2, args.D_features * 4,
                      4, 2, 1, bias=False),
            nn.BatchNorm2d(args.D_features * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(args.D_features * 4, args.D_features * 8,
                      4, 2, 1, bias=False),
            nn.BatchNorm2d(args.D_features * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(args.D_features * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1)

In [10]:
netD = _netD()
netD.apply(weights_init)

_netD (
  (main): Sequential (
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU (0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU (0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (7): LeakyReLU (0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (10): LeakyReLU (0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid ()
  )
)

In [11]:
criterion = nn.BCELoss()

In [12]:
if args.cuda:
    netD.cuda()
    netG.cuda()
    criterion.cuda()

In [13]:
D_optimizer = optim.Adam(netD.parameters(), lr=args.lr,
                         betas=(args.beta1, 0.999))
G_optimizer = optim.Adam(netG.parameters(), lr=args.lr,
                         betas=(args.beta1, 0.999))

In [14]:
def train(epoch):
    D_losses = AverageValueMeter()
    G_losses = AverageValueMeter()
    D_real_accuracies = AverageValueMeter()
    D_fake_accuracies = AverageValueMeter()
    G_real_accuracies = AverageValueMeter()
    
    start = time.time()
    for i, (real, _) in enumerate(data_loader):
        batch_size = real.size(0)
        real_label = Variable(torch.ones(batch_size))
        fake_label = Variable(torch.zeros(batch_size))
        real = Variable(real)
        z = Variable(torch.randn(batch_size, args.z_dim, 1, 1))
        if args.cuda:
            real_label = real_label.cuda()
            fake_label = fake_label.cuda()
            real = real.cuda()
            z = z.cuda()
        
        real_output = netD(real)
        D_real_loss = criterion(real_output, real_label)
        D_real_loss.backward()
        D_real_accuracy = real_output.data.mean()
        
        fake = netG(z)
        fake_output = netD(fake.detach())
        D_fake_loss = criterion(fake_output, fake_label)
        D_fake_loss.backward()
        D_fake_accuracy = fake_output.data.mean()
        
        D_loss = (D_real_loss + D_fake_loss)/2
        D_optimizer.step()
        netD.zero_grad()
        
        output = netD(fake)
        G_loss = criterion(output, real_label)
        G_loss.backward()
        G_real_accuracy = output.data.mean()
        G_optimizer.step()
        netG.zero_grad()
        
        D_losses.add(D_loss.data.cpu()[0] * batch_size, batch_size)
        G_losses.add(G_loss.data.cpu()[0] * batch_size, batch_size)
        D_real_accuracies.add(D_real_accuracy * batch_size, batch_size)
        D_fake_accuracies.add(D_fake_accuracy * batch_size, batch_size)
        G_real_accuracies.add(G_real_accuracy * batch_size, batch_size)
        
    print("=> EPOCH {:2d} | Time: {}s | D_loss: {:.3f} | G_loss: {:.3f} "
          "| D_real_acc: {:.3f} | D_fake_acc: {:.3f} | G_real_acc: {:.3f}"
          .format(epoch, int(time.time()-start), D_losses.value()[0],
                  G_losses.value()[0], D_real_accuracies.value()[0],
                  D_fake_accuracies.value()[0], G_real_accuracies.value()[0]))
    
    z = Variable(torch.randn(args.batch_size, args.z_dim, 1, 1))
    if args.cuda:
        z = z.cuda()
    fake = netG(z)
    vutils.save_image(fake.data.cpu(), os.path.join(
        args.intermediate_path,
        "fake_samples_epoch_{:02d}.png".format(epoch)))

In [None]:
for epoch in range(1, args.epochs+1):
    train(epoch)

=> EPOCH  1 | Time: 175s | D_loss: 0.586 | G_loss: 0.749 | D_real_acc: 0.877 | D_fake_acc: 0.618 | G_real_acc: 0.503
=> EPOCH  2 | Time: 175s | D_loss: 0.642 | G_loss: 0.577 | D_real_acc: 0.791 | D_fake_acc: 0.638 | G_real_acc: 0.569
