In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import MNIST
import numpy as np
import time, os, pathlib, random, re
from datetime import datetime
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

from models import Generator, Discriminator, initialize_weights

# Treinamento de GAN, utilizando Wasserstein Loss + Gradient Penalty

In [2]:
writer = SummaryWriter('./logs/wgan-gp-mnist')

In [3]:
writer.add_text('texto_inicial', 
'Este é um treinamento de Wgan utilizando o dataset MNIST e aplicando as equações de Wasserstein Loss + Gradient Penalty')

# Baixando o dataset

In [5]:
dataset_mnist = MNIST(root='./', train=True, download=False)

In [7]:
# VARIÁVEIS
CHANNEL_NOISE = 1
NOISE_DIM = 100
IMG_CHANNEL = 1
FEATURES = 16
BATCH_SIZE = 128
IMG_SIZE = 64
LEARNING_RATE = 1e-4
MODELS_DIR = './models'
TAXA_TREINAMENTO_DISCRIMINATOR = 5  # ou seja, o discriminator treina 5 vezes mais que o generator
LAMBDA_GP = 10 # TAXA DO GRADIENT PENALTY
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print (f'{device=}')

device=device(type='cuda', index=0)


In [8]:
# Refatorando o dataset e incluindo mais uma dimensão (como se fosse um canal) para as redes neurais.
dataset = dataset_mnist.data
N, H, W = dataset.size()
dataset = dataset.view((N, 1, H, W))

dataset = dataset / 255. # normalizando os dados entre 0 e 1

transformer = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.Normalize(mean=(0.5), std=(0.5)) # depois normalizando entre -1 e 1
])
dataset = transformer(dataset)

print (f'{dataset.min()=}, {dataset.max()=}')
print (f'Novo shape: {dataset.shape=}')

dataset.min()=tensor(-1.), dataset.max()=tensor(1.)
Novo shape: dataset.shape=torch.Size([60000, 1, 64, 64])


In [9]:
# Visualiznado no tensorboard algumas imagens
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
imgs_tensor = next(iter(dataloader))
grid = make_grid(imgs_tensor, nrow=16, padding=0, normalize=True)
writer.add_image('real_imgs', grid)

# Funões úteis

In [11]:
def get_noise(b_size, device_ = torch.device('cpu')):
    return torch.randn((b_size, NOISE_DIM, 1, 1), device=device_)

def list_models():
    list_ = list(pathlib.Path(MODELS_DIR).glob('*.pt'))
    if (len(list_) > 0):
        aux = []
        for model_path in list_:
            model_path = str(model_path)
            epoch = int(re.findall(r'[0-9]{1,}', model_path)[0])
            aux.append([model_path, epoch])
        
        # Colocando a lista em ordem decrescente de épocas (da maior época em primeiro para a menor)
        aux = sorted([[mp, ep] for mp, ep in aux], key=lambda item: item[1], reverse=True)
        return aux
    else:
        return None

def carregar_treinamento():
    lista = list_models()
    
    print ('Criando os modelos...')
    generator_ = Generator(channel_noise=NOISE_DIM, channel_img=IMG_CHANNEL, features=FEATURES)
    discriminator_ = Discriminator(channels_img=IMG_CHANNEL, img_size=64, features=FEATURES)

    if (lista):
        last_checkpoint_path, last_epoch_ = lista[0]
        last_checkpoint = torch.load(last_checkpoint_path, map_location=torch.device('cpu'))
        
        print (f'Carregando o último treinamento. {last_epoch_=}')
        print ('generator: ', generator_.load_state_dict(last_checkpoint['generator_state_dict']))
        print ('discriminator: ', discriminator_.load_state_dict(last_checkpoint['discriminator_state_dict']))
        fixed_noise_ = last_checkpoint['fixed_noise']

    else:
        print ('Iniciando os pesos dos modelos.')
        generator_.apply(initialize_weights)
        discriminator_.apply(initialize_weights)
        fixed_noise_ = get_noise(64)
        last_epoch_ = 0
    
    return generator_, discriminator_, fixed_noise_, last_epoch_

def salvar_modelos(model_g, model_d, fixed_noise, epoch_):
    checkpoint = {
        'generator_state_dict': model_g.state_dict(),
        'discriminator_state_dict': model_d.state_dict(),
        'fixed_noise': fixed_noise
    }
    torch.save(checkpoint, MODELS_DIR + f'/checkpoin_{str(epoch_).zfill(4)}.pt')

def manter_somente_n_ultimos_modelos(n_ultimos):
    lista = list_models()
    for checkpoint_path, _ in lista[n_ultimos:]:
        try:
            os.remove(checkpoint_path)
        except:
            print (checkpoint_path + ' já não existia...')

def gradient_penalty(model_d, real_imgs, fake_imgs, device_):

    b_size, c, h, w = real_imgs.shape
    alpha = torch.rand((b_size, 1, 1, 1)).repeat(1, c, h, w).to(device_)

    interpolated_imgs = real_imgs * alpha + fake_imgs * (1-alpha)

    # Cálculo score
    mixed_score = model_d(interpolated_imgs)

    gradient = torch.autograd.grad(
        inputs = interpolated_imgs,
        outputs = mixed_score,
        grad_outputs = torch.ones_like(mixed_score),
        create_graph = True,
        retain_graph = True
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm-1) ** 2)
    
    return penalty

def checkpoint_image(model_g, fixed_noise_, epoch_):
    with torch.no_grad():
            fake = model_g(fixed_noise_)
            fake_grid = make_grid(fake, nrow=8, padding=0, normalize=True)
            writer.add_image(f'fake_img_{epoch}', fake_grid)

In [12]:
generator = Generator(channel_noise=NOISE_DIM, channel_img=IMG_CHANNEL, features=FEATURES)
noise = get_noise(4)
output_generator = generator(noise)
print (f'{output_generator.shape=}')
# Repara que aqui, a imagem de saída da rede generator é 64x64, que deve ser a entrada da rede discriminator

discriminator = Discriminator(channels_img=IMG_CHANNEL, img_size=64, features=FEATURES)
output_discriminator = discriminator(output_generator)
print (f'{output_discriminator.shape=}')

generator.apply(initialize_weights)
discriminator.apply(initialize_weights)

print ('ok')

output_generator.shape=torch.Size([4, 1, 64, 64])
output_discriminator.shape=torch.Size([4, 1, 1, 1])
ok


# Preparando para treinamento

In [15]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print (f'{len(dataloader)=}')
generator, discriminator, fixed_noise, last_epoch = carregar_treinamento()
generator.to(device)
discriminator.to(device)
fixed_noise = fixed_noise.to(device)

optim_generator = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0., 0.9))
optim_discriminator = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0., 0.9))

len(dataloader)=469
Criando os modelos...
Iniciando os pesos dos modelos.


# Treinamento

In [16]:
discriminator.train()
generator.train()

for epoch in range(last_epoch, last_epoch + 15 + 1, 1):
    
    for real_imgs in tqdm(dataloader):
        real_imgs = real_imgs.to(device)
        b_size = len(real_imgs)

        # Treinando o discriminator
        for _ in range(TAXA_TREINAMENTO_DISCRIMINATOR):
            noise = get_noise(b_size, device)
            fake_imgs = generator(noise)

            output_real = discriminator(real_imgs)
            output_fake = discriminator(fake_imgs)

            # Cálculo do Gradient-penalty
            gp = gradient_penalty(discriminator, real_imgs, fake_imgs, device)
            loss_discriminator = -(torch.mean(output_real.view(-1)) - torch.mean(output_fake.view(-1))) + LAMBDA_GP*gp
            discriminator.zero_grad()
            loss_discriminator.backward(retain_graph=True)
            optim_discriminator.step()
        
        # Treinando o generator
        output_fake_for_generator = discriminator(fake_imgs)
        loss_generator = -torch.mean(output_fake_for_generator.view(-1))
        generator.zero_grad()
        loss_generator.backward()
        optim_generator.step()

    # Levando as variáveis para o tensorboard
    writer.add_scalar('loss_discriminator', loss_discriminator.item(), epoch)
    writer.add_scalar('loss_generator', loss_generator.item(), epoch)

    salvar_modelos(generator, discriminator, fixed_noise, epoch)
    manter_somente_n_ultimos_modelos(3)
    checkpoint_image(generator, fixed_noise, epoch)

100%|██████████| 469/469 [02:09<00:00,  3.63it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
100%|██████████| 469/469 [02:09<00:00,  3.61it/s]
