<a href="https://colab.research.google.com/github/nesmachnow/Curso-GANs/blob/main/W_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ejemplo de Wasserstein GANs

In [1]:
# !pip install torch torchvision

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

In [4]:
import numpy as np
import datetime
import os, sys

In [5]:
from matplotlib.pyplot import imshow, imsave
%matplotlib inline

In [6]:
!nvidia-smi

Wed Oct 27 22:58:01 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [7]:
MODEL_NAME = 'W-GAN'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
def to_onehot(x, num_classes=10):
    assert isinstance(x, int) or isinstance(x, (torch.LongTensor, torch.cuda.LongTensor))
    if isinstance(x, int):
        c = torch.zeros(1, num_classes).long()
        c[0][x] = 1
    else:
        x = x.cpu()
        c = torch.LongTensor(x.size(0), num_classes)
        c.zero_()
        c.scatter_(1, x, 1) # dim, index, src value
    return c

In [9]:
def get_sample_image(G, n_noise=100):
    """
        100 imágenes de ejemplo
    """
    img = np.zeros([280, 280])
    for j in range(10):
        c = torch.zeros([10, 10]).to(DEVICE)
        c[:, j] = 1
        z = torch.randn(10, n_noise).to(DEVICE)
        y_hat = G(z,c).view(10, 28, 28)
        result = y_hat.cpu().data.numpy()
        img[j*28:(j+1)*28] = np.concatenate([x for x in result], axis=-1)
    return img

In [10]:
class Critic(nn.Module):
    """
        Crítico para MNIST con una ANN convolucional
    """
    def __init__(self, in_channel=1, input_size=784, condition_size=10, num_classes=1):
        super(Critic, self).__init__()
        self.transform = nn.Sequential(
            nn.Linear(input_size+condition_size, 784),
            nn.LeakyReLU(0.2),
        )
        self.conv = nn.Sequential(
            # 28 -> 14
            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 14 -> 7
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 7 -> 4
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4),
        )
        self.fc = nn.Sequential(
            # reshape input, 128 -> 1
            nn.Linear(128, 1),
        )
    
    def forward(self, x, c=None):
        # x: (N, 1, 28, 28), c: (N, 10)
        x, c = x.view(x.size(0), -1), c.float() # may not need
        v = torch.cat((x, c), 1) # v: (N, 794)
        y_ = self.transform(v) # (N, 784)
        y_ = y_.view(y_.shape[0], 1, 28, 28) # (N, 1, 28, 28)
        y_ = self.conv(y_)
        y_ = y_.view(y_.size(0), -1)
        y_ = self.fc(y_)
        return y_

In [11]:
class Generator(nn.Module):
    """
        Generador para MNIST con una ANN convolucional
    """
    def __init__(self, input_size=100, condition_size=10):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size+condition_size, 4*4*512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            # input: 4 by 4, output: 7 by 7
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # input: 7 by 7, output: 14 by 14
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # input: 14 by 14, output: 28 by 28
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )
        
    def forward(self, x, c):
        # x: (N, 100), c: (N, 10)
        x, c = x.view(x.size(0), -1), c.float() # may not need
        v = torch.cat((x, c), 1) # v: (N, 110)
        y_ = self.fc(v)
        y_ = y_.view(y_.size(0), 512, 4, 4)
        y_ = self.conv(y_) # (N, 28, 28)
        return y_

In [12]:
C = Critic().to(DEVICE)
G = Generator().to(DEVICE)
# C.load_state_dict('C_dc.pkl')
# G.load_state_dict('G_dc.pkl')

In [13]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5],
                                std=[0.5])]
)

In [14]:
mnist = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [15]:
batch_size = 64

In [16]:
data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)

In [18]:
C_opt = torch.optim.RMSprop(C.parameters(), lr=0.0005)
G_opt = torch.optim.RMSprop(G.parameters(), lr=0.0005)

In [26]:
max_epoch = 30 # 
step = 0
g_step = 0
n_noise = 100

In [20]:
def n_critic(step, nc=2):
    if step < 25 or step % 500 == 0:
        return 100
    return nc

In [21]:
C_labels = torch.ones([batch_size, 1]).to(DEVICE) # Etiqueta del crítico 'real'
C_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Etiqueta del crítico 'fake'

In [22]:
if not os.path.exists('samples'):
    os.makedirs('samples')

In [27]:
for epoch in range(max_epoch):
    for idx, (images, labels) in enumerate(data_loader):
                   
        # Entrenamiento del crítico
        x = images.to(DEVICE)
        y = labels.view(batch_size, 1)
        y = to_onehot(y).to(DEVICE)
        x_outputs = C(x, y)

        z = torch.randn(batch_size, n_noise).to(DEVICE)
        z_outputs = C(G(z, y), y)
        C_x_loss = torch.mean(x_outputs)
        C_z_loss = torch.mean(z_outputs)
        C_loss = C_z_loss - C_x_loss
        
        C.zero_grad()
        C_loss.backward()
        C_opt.step()
        # Poda de pesos para la restricción K-Lipshitziana
        for p in C.parameters():
            p.data.clamp_(-0.01, 0.01)
                    
        # if step % n_critic(step) == 0:
        if step % 3 == 0:
            g_step += 1
            # Entrenamiento del generador
            z = torch.randn(batch_size, n_noise).to(DEVICE)
            z_outputs = C(G(z, y), y)
            G_loss = -torch.mean(z_outputs)

            C.zero_grad()
            G.zero_grad()
            G_loss.backward()
            G_opt.step()
            
        if step % 500 == 0:
            print('Epoch: {}/{}, Step: {}, C Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, C_loss.item(), G_loss.item()))
        
        if step % 1000 == 0:
            G.eval()
            img = get_sample_image(G, n_noise)
            imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')
            G.train()
        step += 1

Epoch: 0/30, Step: 0, C Loss: 0.0, G Loss: -0.010217981413006783
Epoch: 0/30, Step: 500, C Loss: 0.0, G Loss: -0.01016167551279068
Epoch: 1/30, Step: 1000, C Loss: 9.313225746154785e-10, G Loss: -0.010206181555986404
Epoch: 1/30, Step: 1500, C Loss: 0.0, G Loss: -0.010232598520815372
Epoch: 2/30, Step: 2000, C Loss: 0.0, G Loss: -0.010239941067993641
Epoch: 2/30, Step: 2500, C Loss: 0.0, G Loss: -0.010267484933137894
Epoch: 3/30, Step: 3000, C Loss: -9.313225746154785e-10, G Loss: -0.010258047841489315
Epoch: 3/30, Step: 3500, C Loss: -9.313225746154785e-10, G Loss: -0.010262617841362953
Epoch: 4/30, Step: 4000, C Loss: 0.0, G Loss: -0.010267555713653564
Epoch: 4/30, Step: 4500, C Loss: 0.0, G Loss: -0.010271361097693443
Epoch: 5/30, Step: 5000, C Loss: 0.0, G Loss: -0.010281631723046303
Epoch: 5/30, Step: 5500, C Loss: 0.0, G Loss: -0.010264331474900246
Epoch: 6/30, Step: 6000, C Loss: 9.313225746154785e-10, G Loss: -0.010209852829575539
Epoch: 6/30, Step: 6500, C Loss: 0.0, G Loss: -

KeyboardInterrupt: ignored

## Sample

In [None]:
# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')

In [None]:
def save_checkpoint(state, file_name='checkpoint.pth.tar'):
    torch.save(state, file_name)

In [None]:
# Salvar parámetros
# torch.save(D.state_dict(), 'C_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':C.state_dict(), 'optimizer' : C_opt.state_dict()}, 'D_w.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_w.pth.tar')