# Generative Adversarial Network

In [8]:
# Import de paquetes
%matplotlib inline
import matplotlib.pyplot as plt
%matplotlib inline
plt.ion()

import sys
import os

# Numpy
import numpy as np
from skimage import color, io

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# Torchvision
import torchvision.utils
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader

## Hiperparámetros

In [9]:
# Numero de canales de entrada (3= RGB)
nc = 3
# Tamaño del latent vector (z)
nz = 100
# Features en el generador
ngf = 64
# features en el discriminador
ndf = 64
# Learning rate
lr = 0.0002
# Cantidad de épocas de entrenamiento
num_epochs = 1
# Tamaño del batch
batch_size = 128
# Escala de la imagen
image_size = 64
# Cantidad de GPUs disponibles
ngpu = 1
# Directorio de descarga del dataset
dataroot = "data_faces/"

## CELEB a

In [12]:
# Download CELEB A http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
!mkdir data_faces && wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip 

--2019-10-18 16:54:05--  https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip
Resolving s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)... 52.219.116.232
Connecting to s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)|52.219.116.232|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1443490838 (1.3G) [application/zip]
Saving to: ‘celeba.zip.1’

celeba.zip.1          6%[>                     ]  91.46M   293KB/s   eta 69m 37s^C


In [11]:
# Unzip
import zipfile
with zipfile.ZipFile("celeba.zip","r") as zip_ref:
  zip_ref.extractall(dataroot)

BadZipFile: File is not a zip file

In [13]:
# Transformaciones para el dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

# Creamos el cargador de datos
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=3, drop_last=True)

# Decidimos que dispositivo vamos a usar
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Ploteamos algunas imagenes para verificar el dataset
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

FileNotFoundError: [Errno 2] No such file or directory: 'data_faces/'

In [None]:
# Generador
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input Z
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Discriminador
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input = (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

In [None]:
# Creo el generador 
netG = Generator(ngpu).to(device)
print(netG)

In [None]:
# Creo el discriminador
netD = Discriminator(ngpu).to(device)
print(netD)

In [None]:
# Initializo BCE Loss
criterion = nn.BCELoss()

# Creo un batch con ruido fijo para hacer un seguimiento del progreso de la optimización
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Convención de etiquetas para imágenes reales y falsas
real_label = 1
fake_label = 0

# Dos optimizadores
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
def D_train(x):
    #=======================Train the discriminator=======================#
    netD.zero_grad()

    # train discriminator on real
    x_real, y_real = x, torch.ones(batch_size, 1)
    x_real, y_real = x_real.to(device), y_real.to(device)
    
    D_output = netD(x_real).view(-1)
    
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    z = torch.randn(batch_size, nz, 1, 1).to(device)
    x_fake, y_fake = netG(z), torch.zeros(batch_size, 1).to(device)

    D_output = netD(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    optimizerD.step()
        
    return  D_loss

In [None]:
def G_train():
    #=======================Train the generator=======================#
    netG.zero_grad()

    z = torch.randn(batch_size, nz, 1, 1).to(device)
    y = torch.ones(batch_size, 1).to(device)

    G_output = netG(z)
    D_output = netD(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    optimizerG.step()
        
    return G_loss

In [None]:
# Loop de entrenamiento

# Listas para llevar constancia del progreso de la red
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# Por cada época
for epoch in range(num_epochs):
    # por cada batch..
    for i, (data,_)in enumerate(dataloader, 0):

        errD = D_train(data)
        errG = G_train()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\t'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item()))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

## Plot de loss

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]