# Wasserstein GAN

Este modelo implementa la *Generative Adversarial Network* de Arjovsky et al. (2017), una alternativa a la DC-GAN tradicional con la función de loss "Wasserstein-1".

In [1]:
import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
IMG_SIZE = 256
img_shape = (3, IMG_SIZE, IMG_SIZE)
BATCH_SIZE = 64
data_dir = "../data/PlantVillage/"

os.makedirs("images", exist_ok=True)

cuda = True if torch.cuda.is_available() else False

In [8]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [9]:
ROOT_I = 'Images'
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomSizedCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
    'val': transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
}

image_datasets = {x: datasets.ImageFolder(data_dir+x,data_transforms[x]) for x in ['train']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],\
                                            batch_size=BATCH_SIZE,\
                                            shuffle=shuf,\
                                            num_workers=4)\
              for x,shuf in [('train', True)]}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train']}
class_names = [c for c in image_datasets['train'].classes if c != ROOT_I]

In [10]:
len(dataloaders['train'])*64

1984

In [11]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

# Optimizers
# L_R = 0.00005
L_R = 0.001
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=L_R)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=L_R)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## La gente del paper de tomatos entrena la WGAN con 200.000 épocas ;-;

In [12]:
batches_done = 0
dataloader = dataloaders['train']
EPOCHS = 200
for epoch in range(EPOCHS):

    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = imgs.type(Tensor)

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Tensor(np.random.normal(0, 1, (imgs.shape[0], 100)))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-0.01, 0.01)

        # Train the generator every n_critic iterations
        if i % 20 == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch+1, EPOCHS, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )

        if batches_done % 400 == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

[Epoch 1/200] [Batch 0/31] [D loss: 0.042334] [G loss: -1.306823]
[Epoch 1/200] [Batch 20/31] [D loss: -5488.943848] [G loss: 4020.186768]
[Epoch 2/200] [Batch 0/31] [D loss: -1843.124268] [G loss: 15.426870]
[Epoch 2/200] [Batch 20/31] [D loss: -1207.727905] [G loss: -397.375793]
[Epoch 3/200] [Batch 0/31] [D loss: -243.301697] [G loss: -377.413818]
[Epoch 3/200] [Batch 20/31] [D loss: -261.730469] [G loss: -37.577255]
[Epoch 4/200] [Batch 0/31] [D loss: -301.850800] [G loss: -19.913982]
[Epoch 4/200] [Batch 20/31] [D loss: -189.921082] [G loss: -20.499317]
[Epoch 5/200] [Batch 0/31] [D loss: -287.409821] [G loss: -5.769900]
[Epoch 5/200] [Batch 20/31] [D loss: -408.996582] [G loss: -27.887245]
[Epoch 6/200] [Batch 0/31] [D loss: -291.598328] [G loss: -32.632210]
[Epoch 6/200] [Batch 20/31] [D loss: -935.410522] [G loss: 545.153809]
[Epoch 7/200] [Batch 0/31] [D loss: -485.065491] [G loss: -5.979647]
[Epoch 7/200] [Batch 20/31] [D loss: -473.851654] [G loss: -106.596191]
[Epoch 8/200]

[Epoch 58/200] [Batch 20/31] [D loss: -229.516129] [G loss: -149.181442]
[Epoch 59/200] [Batch 0/31] [D loss: -186.472595] [G loss: 209.947876]
[Epoch 59/200] [Batch 20/31] [D loss: -298.932617] [G loss: -42.057163]
[Epoch 60/200] [Batch 0/31] [D loss: -211.534332] [G loss: -86.537605]
[Epoch 60/200] [Batch 20/31] [D loss: -307.360016] [G loss: 221.319046]
[Epoch 61/200] [Batch 0/31] [D loss: -259.135773] [G loss: -85.678680]
[Epoch 61/200] [Batch 20/31] [D loss: -293.285187] [G loss: -32.541977]
[Epoch 62/200] [Batch 0/31] [D loss: -261.860352] [G loss: 272.726288]
[Epoch 62/200] [Batch 20/31] [D loss: -247.019653] [G loss: -205.000153]
[Epoch 63/200] [Batch 0/31] [D loss: -300.527405] [G loss: 262.610474]
[Epoch 63/200] [Batch 20/31] [D loss: -279.607422] [G loss: -74.857162]
[Epoch 64/200] [Batch 0/31] [D loss: -200.567123] [G loss: -1.149960]
[Epoch 64/200] [Batch 20/31] [D loss: -306.023499] [G loss: 100.231300]
[Epoch 65/200] [Batch 0/31] [D loss: -264.123932] [G loss: -81.190308

[Epoch 116/200] [Batch 0/31] [D loss: -259.978455] [G loss: -99.804092]
[Epoch 116/200] [Batch 20/31] [D loss: -234.933212] [G loss: 173.553589]
[Epoch 117/200] [Batch 0/31] [D loss: -194.081726] [G loss: 105.665070]
[Epoch 117/200] [Batch 20/31] [D loss: -238.716629] [G loss: -42.241142]
[Epoch 118/200] [Batch 0/31] [D loss: -174.560089] [G loss: -95.945045]
[Epoch 118/200] [Batch 20/31] [D loss: -239.171417] [G loss: 278.692566]
[Epoch 119/200] [Batch 0/31] [D loss: -196.025116] [G loss: -186.148834]
[Epoch 119/200] [Batch 20/31] [D loss: -224.450714] [G loss: 244.388123]
[Epoch 120/200] [Batch 0/31] [D loss: -197.063202] [G loss: -193.758148]
[Epoch 120/200] [Batch 20/31] [D loss: -262.878418] [G loss: 141.733170]
[Epoch 121/200] [Batch 0/31] [D loss: -197.891113] [G loss: -44.165512]
[Epoch 121/200] [Batch 20/31] [D loss: -244.635422] [G loss: 157.953384]
[Epoch 122/200] [Batch 0/31] [D loss: -204.541534] [G loss: -73.650826]
[Epoch 122/200] [Batch 20/31] [D loss: -249.296661] [G l

[Epoch 172/200] [Batch 20/31] [D loss: -178.072357] [G loss: -54.578377]
[Epoch 173/200] [Batch 0/31] [D loss: -200.020493] [G loss: 196.822052]
[Epoch 173/200] [Batch 20/31] [D loss: -187.568268] [G loss: -50.883896]
[Epoch 174/200] [Batch 0/31] [D loss: -170.963120] [G loss: 69.176964]
[Epoch 174/200] [Batch 20/31] [D loss: -187.657089] [G loss: -104.249130]
[Epoch 175/200] [Batch 0/31] [D loss: -155.216766] [G loss: 79.616165]
[Epoch 175/200] [Batch 20/31] [D loss: -220.508362] [G loss: 221.050385]
[Epoch 176/200] [Batch 0/31] [D loss: -187.882797] [G loss: -137.942902]
[Epoch 176/200] [Batch 20/31] [D loss: -170.729904] [G loss: 142.608841]
[Epoch 177/200] [Batch 0/31] [D loss: -185.430176] [G loss: 106.680374]
[Epoch 177/200] [Batch 20/31] [D loss: -256.692169] [G loss: -238.214584]
[Epoch 178/200] [Batch 0/31] [D loss: -188.168076] [G loss: 295.269989]
[Epoch 178/200] [Batch 20/31] [D loss: -205.611267] [G loss: -271.163635]
[Epoch 179/200] [Batch 0/31] [D loss: -199.145264] [G l

In [None]:
z = Tensor(np.random.normal(0, 1, (imgs.shape[0], 100)))
gen_imgs = generator(z)
gen_imgs = gen_imgs.cpu()

import matplotlib.pyplot as plt
num_to_gen = gen_imgs.shape[0]
_, axes = plt.subplots(figsize=(16, 4), ncols=num_to_gen)
for ii in range(num_to_gen):
    ax = axes[ii]
    img = gen_imgs[ii]
    npimg = img.detach().numpy()
    ax.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')