# Librerías

In [1]:
# !pip install tensorboardX

In [2]:
import time
import argparse
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid, save_image
import torchvision.datasets as dtst

import time
from tensorboardX import SummaryWriter
from tqdm import tqdm
from copy import deepcopy
import random
import os
import sys
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Retiramos los mensajes de alerta
import warnings
warnings.filterwarnings('ignore')

In [3]:
import matplotlib.pyplot as plt

# Configuraciones para la tarjeta de video

In [4]:
# El sistema selecciona el hardware disponible
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [5]:
# Verificamos qué Hardware estamos usando
if device == "cuda:0":
    print(torch.cuda.get_device_properties(device))

_CudaDeviceProperties(name='NVIDIA GeForce RTX 3060 Ti', major=8, minor=6, total_memory=8191MB, multi_processor_count=38)


# Definimos el DataLoader

In [6]:
data_folder_complete = './Images/original/'
data_folder_pixeled = './Images/pixeled/'
# Tamaño que deseamos que tengan las imágenes
image_size = 32
# Tamaño del lote de imágenes
batch_size = 16


dsimgs_1 = dtst.ImageFolder(
    root=data_folder_complete,
    transform=transforms.Compose([
        # Se usa el resize en caso no todas las imágenes de 
        # entrada tengan el tamaño de 32px
        transforms.Resize(image_size),
        # CenterCrop busca recortar la imagen en caso sea muy grande al tamaño dado
        transforms.CenterCrop(image_size),
        # ToTensor convierte finalmente la imagen a tensor
        transforms.ToTensor(),
        # Normalize permite la normalización de la información
        # El problema encontrado es que necesitamos hallar la desviación estandar
        # media de toda la información para realizar una correcta normalización
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))

dsimgs_2 = dtst.ImageFolder(
    root=data_folder_pixeled,
    transform=transforms.Compose([
        # Se usa el resize en caso no todas las imágenes de 
        # entrada tengan el tamaño de 32px
        transforms.Resize(image_size),
        # CenterCrop busca recortar la imagen en caso sea muy grande al tamaño dado
        transforms.CenterCrop(image_size),
        # ToTensor convierte finalmente la imagen a tensor
        transforms.ToTensor(),
        # Normalize permite la normalización de la información
        # Esta desviación estandar y media es la hallada para los datos de
        # entrenamiento
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))


In [7]:
dt_loader_real = DataLoader(dsimgs_1,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=2,
                       drop_last=True)

dt_loader_gen = DataLoader(dsimgs_2,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=2,
                       drop_last=True)

# Visualizador de grupo de imágenes

In [8]:
def show_tensor_images(image_tensor, num_images=16, size=(3, 128, 128)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=4)
    # image_grid.permute.squeeze se encargará de convertir el tensor de 3 canales(RGB)
    # en una sola imagen de 1 canal
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# Puntos de guardado

In [9]:
def save_checkpoint(states, is_best, output_dir,
                    filename='checkpoint.pth'):
    torch.save(states, os.path.join(output_dir, filename))
    if is_best:
        torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))

# Ratio de aprendizaje variable

In [10]:
class LinearLrDecay(object):
    def __init__(self, optimizer, start_lr, end_lr, decay_start_step, decay_end_step):

        assert start_lr > end_lr
        self.optimizer = optimizer
        self.delta = (start_lr - end_lr) / (decay_end_step - decay_start_step)
        self.decay_start_step = decay_start_step
        self.decay_end_step = decay_end_step
        self.start_lr = start_lr
        self.end_lr = end_lr

    def step(self, current_step):
        if current_step <= self.decay_start_step:
            lr = self.start_lr
        elif current_step >= self.decay_end_step:
            lr = self.end_lr
        else:
            lr = self.start_lr - self.delta * (current_step - self.decay_start_step)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        return lr

# Inicialización de pesos, ruido y generación de ruido

In [11]:
def inits_weight(m):
        if type(m) == nn.Linear:
                nn.init.xavier_uniform_(m.weight.data, 1.)


def noise(imgs, latent_dim):
        return torch.FloatTensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))

def gener_noise(gener_batch_size, latent_dim):
        return torch.FloatTensor(np.random.normal(0, 1, (gener_batch_size, latent_dim)))

# Capa MLP

In [12]:
class MLP(nn.Module):
    def __init__(self, in_feat, hid_feat=None, out_feat=None,
                 dropout=0.):
        super().__init__()
        if not hid_feat:
            hid_feat = in_feat
        if not out_feat:
            out_feat = in_feat
        self.fc1 = nn.Linear(in_feat, hid_feat)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hid_feat, out_feat)
        self.droprateout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return self.droprateout(x)

# Bloque de atención

In [13]:
class Attention(nn.Module):
    def __init__(self, dim, heads=4, attention_dropout=0., proj_dropout=0.):
        super().__init__()
        self.heads = heads
        self.scale = 1./dim**0.5

        self.qkv = nn.Linear(dim, dim*3, bias=False)
        self.attention_dropout = nn.Dropout(attention_dropout)
        self.out = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Dropout(proj_dropout)
        )

    def forward(self, x):
        b, n, c = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.heads, c//self.heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)

        dot = (q @ k.transpose(-2, -1)) * self.scale
        attn = dot.softmax(dim=-1)
        attn = self.attention_dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
        x = self.out(x)
        return x

# Tratamiento de imágenes

In [14]:
class ImgPatches(nn.Module):
    def __init__(self, input_channel=3, dim=768, patch_size=4):
        super().__init__()
        self.patch_embed = nn.Conv2d(input_channel, dim,
                                     kernel_size=patch_size, stride=patch_size)

    def forward(self, img):
        patches = self.patch_embed(img).flatten(2).transpose(1, 2)
        return patches


# Proceso de UpSampling

In [15]:
def UpSampling(x, H, W):
        B, N, C = x.size()
        assert N == H*W
        x = x.permute(0, 2, 1)
        x = x.view(-1, C, H, W)
        x = nn.PixelShuffle(2)(x)
        B, C, H, W = x.size()
        x = x.view(-1, C, H*W)
        x = x.permute(0,2,1)
        return x, H, W

# Bloque de Codificación

In [16]:
class Encoder_Block(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, drop_rate, drop_rate)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, dim*mlp_ratio, dropout=drop_rate)

    def forward(self, x):
        x1 = self.ln1(x)
        x = x + self.attn(x1)
        x2 = self.ln2(x)
        x = x + self.mlp(x2)
        return x

# Sección Codificadora del Transformer

In [17]:
class TransformerEncoder(nn.Module):
    def __init__(self, depth, dim, heads, mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.Encoder_Blocks = nn.ModuleList([
            Encoder_Block(dim, heads, mlp_ratio, drop_rate)
            for i in range(depth)])

    def forward(self, x):
        for Encoder_Block in self.Encoder_Blocks:
            x = Encoder_Block(x)
        return x

# Differentiable Augmentation for Data-Efficient GAN Training

 https://arxiv.org/pdf/2006.10738

In [18]:
def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x


def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x


def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x


def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x


def rand_translation(x, ratio=0.2):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x


def rand_cutout(x, ratio=0.5):
    if random.random() < 0.3:
        cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
        offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
        offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
        grid_batch, grid_x, grid_y = torch.meshgrid(
            torch.arange(x.size(0), dtype=torch.long, device=x.device),
            torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
            torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
        )
        grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
        grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
        mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
        mask[grid_batch, grid_x, grid_y] = 0
        x = x * mask.unsqueeze(1)
    return x

def rand_rotate(x, ratio=0.5):
    k = random.randint(1,3)
    if random.random() < ratio:
        x = torch.rot90(x, k, [2,3])
    return x

AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
    'rotate': [rand_rotate],
}

# Modelo del generador

In [19]:
# class Generator(nn.Module):
    
#     def __init__(self, depth1=5, depth2=4, depth3=2, initial_size=8, dim=384, heads=4, mlp_ratio=4, drop_rate=0.):
# #     def __init__(self, depth1=4, depth2=3, depth3=2, initial_size=8, dim=384, heads=4, mlp_ratio=4, drop_rate=0.):
#         super(Generator, self).__init__()

#         self.initial_size = initial_size
#         self.dim = dim
#         self.depth1 = depth1
#         self.depth2 = depth2
#         self.depth3 = depth3
#         self.heads = heads
#         self.mlp_ratio = mlp_ratio
#         self.droprate_rate = drop_rate

# #         1024 == 32 x 32 (tamaño deseado de las imágenes)
# #         self.mlp = nn.Linear(1024, (self.initial_size ** 2) * self.dim)
#         self.mlp = nn.Linear(3072, (self.initial_size ** 2) * self.dim)

#         self.positional_embedding_1 = nn.Parameter(torch.zeros(1, (8**2), dim))
#         self.positional_embedding_2 = nn.Parameter(
#             torch.zeros(1, (8*2)**2, dim//4))
#         self.positional_embedding_3 = nn.Parameter(
#             torch.zeros(1, (8*4)**2, dim//16))

#         self.TransformerEncoder_encoder1 = TransformerEncoder(depth=self.depth1,
#                                                               dim=self.dim,
#                                                               heads=self.heads,
#                                                               mlp_ratio=self.mlp_ratio,
#                                                               drop_rate=self.droprate_rate)

#         self.TransformerEncoder_encoder2 = TransformerEncoder(depth=self.depth2,
#                                                               dim=self.dim//4,
#                                                               heads=self.heads,
#                                                               mlp_ratio=self.mlp_ratio,
#                                                               drop_rate=self.droprate_rate)

#         self.TransformerEncoder_encoder3 = TransformerEncoder(depth=self.depth3,
#                                                               dim=self.dim//16,
#                                                               heads=self.heads,
#                                                               mlp_ratio=self.mlp_ratio,
#                                                               drop_rate=self.droprate_rate)

#         self.linear = nn.Sequential(nn.Conv2d(self.dim//16, 3, 1, 1, 0))

#     def forward(self, noise):
# #         print('##### GENERADOR #####\n')
# #         print('>> noise.size():', noise.size())

#         x = self.mlp(noise)
# #         print('>> x after MLP size:', x.size())
#         x = x.view(-1, self.initial_size ** 2, self.dim)

# #         print('>> x after view size :', x.size())

#         x = x + self.positional_embedding_1
# #         print('>> x+embedding:', x.size())
#         H, W = self.initial_size, self.initial_size
# #         print('H, W:', H, W)
#         x = self.TransformerEncoder_encoder1(x)
# #         print('>> X after first encoder:', x.size())
#         x, H, W = UpSampling(x, H, W)

# #         print('>> X after UPSAMPLING encoder:', x.size())
#         x = x + self.positional_embedding_2

#         x = self.TransformerEncoder_encoder2(x)

# #         print('>> X after second encoder:', x.size())

#         x, H, W = UpSampling(x, H, W)
# #         print('>> X after 2ND UPSAMPLING encoder:', x.size())
#         x = x + self.positional_embedding_3

#         x = self.TransformerEncoder_encoder3(x)
# #         print('>> X after third encoder:', x.size())
#         x = self.linear(x.permute(0, 2, 1).view(-1, self.dim//16, H, W))

# #         print('>> X after linear layer:', x.size())
# #         print('\n#################\n')
#         return x

In [20]:
class Generator_2(nn.Module):
    
    def __init__(self, depth1=5, depth2=4, depth3=2, initial_size=8, dim=384, heads=4, mlp_ratio=4, drop_rate=0., patch_size = 4, input_channel=3, diff_aug  = "translation,cutout,color"):
#     def __init__(self, depth1=4, depth2=3, depth3=2, initial_size=8, dim=384, heads=4, mlp_ratio=4, drop_rate=0.):
        super(Generator_2, self).__init__()

        self.initial_size = initial_size
        self.dim = dim
        self.depth1 = depth1
        self.depth2 = depth2
        self.depth3 = depth3
        self.heads = heads
        self.mlp_ratio = mlp_ratio
        ##
        self.diff_aug = diff_aug
        self.patch_size = patch_size
        ##
        self.droprate_rate = drop_rate

#         1024 == 32 x 32 (tamaño deseado de las imágenes)
#         self.mlp = nn.Linear(1024, (self.initial_size ** 2) * self.dim)
#         self.mlp = nn.Linear(3072, (self.initial_size ** 2) * self.dim)
        # PLUS╦
        self.class_embedding = nn.Parameter(torch.zeros(1, 1, dim))
        # Image patches and embedding layer
        self.patches = ImgPatches(input_channel, dim, self.patch_size)
#         self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches+1, dim))
        # PLUS╩
        self.positional_embedding_1 = nn.Parameter(torch.zeros(1, (8**2), dim))
        self.positional_embedding_2 = nn.Parameter(
            torch.zeros(1, (8*2)**2, dim//4))
        self.positional_embedding_3 = nn.Parameter(
            torch.zeros(1, (8*4)**2, dim//16))

        self.TransformerEncoder_encoder1 = TransformerEncoder(depth=self.depth1,
                                                              dim=self.dim,
                                                              heads=self.heads,
                                                              mlp_ratio=self.mlp_ratio,
                                                              drop_rate=self.droprate_rate)

        self.TransformerEncoder_encoder2 = TransformerEncoder(depth=self.depth2,
                                                              dim=self.dim//4,
                                                              heads=self.heads,
                                                              mlp_ratio=self.mlp_ratio,
                                                              drop_rate=self.droprate_rate)

        self.TransformerEncoder_encoder3 = TransformerEncoder(depth=self.depth3,
                                                              dim=self.dim//16,
                                                              heads=self.heads,
                                                              mlp_ratio=self.mlp_ratio,
                                                              drop_rate=self.droprate_rate)

        self.linear = nn.Sequential(nn.Conv2d(self.dim//16, 3, 1, 1, 0))

    def forward(self, noise):
#         x = self.mlp(noise)
#         x = x.view(-1, self.initial_size ** 2, self.dim)
#         x = x + self.positional_embedding_1

    # PLUS╦
        x = DiffAugment(noise, self.diff_aug)
        x = self.patches(x)
        x += self.positional_embedding_1
    # PLUS╩    
#         x = x + self.positional_embedding_1
        H, W = self.initial_size, self.initial_size
        x = self.TransformerEncoder_encoder1(x)
        x, H, W = UpSampling(x, H, W)
        x = x + self.positional_embedding_2
        x = self.TransformerEncoder_encoder2(x)
        x, H, W = UpSampling(x, H, W)
        x = x + self.positional_embedding_3
        x = self.TransformerEncoder_encoder3(x)
        x = self.linear(x.permute(0, 2, 1).view(-1, self.dim//16, H, W))
        return x

# Modelo Discriminador (basado en Transformer)

In [21]:
class Discriminator(nn.Module):
    def __init__(self, diff_aug, image_size=32, patch_size=4,
                 input_channel=3, num_classes=1,
                 dim=384, depth=7, heads=4, mlp_ratio=4,
                 drop_rate=0.):
        super().__init__()
        if image_size % patch_size != 0:
            raise ValueError('Error en la dimensión de la imagen.')
        num_patches = (image_size//patch_size) ** 2
        self.diff_aug = diff_aug
        self.patch_size = patch_size
        self.depth = depth
        # Image patches and embedding layer
        self.patches = ImgPatches(input_channel, dim, self.patch_size)

        # Embedding for patch position and class
        self.positional_embedding = nn.Parameter(torch.zeros(1,
                                                             num_patches+1,
                                                             dim))
        self.class_embedding = nn.Parameter(torch.zeros(1, 1, dim))
        nn.init.trunc_normal_(self.positional_embedding, std=0.2)
        nn.init.trunc_normal_(self.class_embedding, std=0.2)

        self.droprate = nn.Dropout(p=drop_rate)
        self.TransfomerEncoder = TransformerEncoder(depth, dim, heads,
                                                    mlp_ratio, drop_rate)
        self.norm = nn.LayerNorm(dim)
        self.out = nn.Linear(dim, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x = DiffAugment(x, self.diff_aug)
        b = x.shape[0]
        cls_token = self.class_embedding.expand(b, -1, -1)
        x = self.patches(x)
        x = torch.cat((cls_token, x), dim=1)
        x += self.positional_embedding
        x = self.droprate(x)
        x = self.TransfomerEncoder(x)
        x = self.norm(x)
        x = self.out(x[:, 0])
        return x

# Modelo Discriminador (basado en Convolucion)

# Seteo de variables

In [22]:
image_size = 32 # Size of image for discriminator input
initial_size = 8 #Initial size for generator.
patch_size = 4  # Patch size for generated image.
num_classes = 1  # Number of classes for discriminator.
lr_gen = 0.0001  # Learning rate for generator.
lr_dis = 0.0001  # Learning rate for discriminator.
weight_decay = 1e-3  # Weight decay.
latent_dim = 1024  # Latent dimension.
n_critic = 5  # n_critic.
max_iter = 400000  # max_iter.
gener_batch_size = 32  # Batch size for generator.
dis_batch_size = 32  # Batch size for discriminator.
epoch = 50  # Number of epoch.
output_dir = 'checkpoint'  # Checkpoint.
dim = 384  # Embedding dimension.
#dim = 256
img_name = "img_name"  # Name of pictures file.
# loss = "available"  # Loss function
loss = None
phi = 1  # phi
beta1 = 0 #beta1
beta2 = 0.99 # beta2
lr_decay = True # lr_decay
diff_aug = "translation,cutout,color" # help='Data Augmentation
best = 1e4  # Best lr for Adam

# Visualizador de datos por paso

In [23]:
writer=SummaryWriter()
writer_dict = {'writer':writer}
writer_dict["train_global_steps"]=0
writer_dict["valid_global_steps"]=0

# Cálculo de penalidades

In [24]:
def compute_gradient_penalty(D, real_samples, fake_samples, phi):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(real_samples.get_device())
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones([real_samples.shape[0], 1], requires_grad=False).to(real_samples.get_device())
    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.contiguous().view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - phi) ** 2).mean()
    return gradient_penalty

# Función de entrenamiento

In [25]:
# def train(noise, generator, discriminator, optim_gen, optim_dis,
#           epoch, writer, schedulers, img_size=32, latent_dim=latent_dim,
#           n_critic=n_critic,
#           gener_batch_size=gener_batch_size, device="cuda:0"):

#     writer = writer_dict['writer']
#     gen_step = 0

#     generator = generator.train()
#     discriminator = discriminator.train()

#     ######
# #     input("Retransformando las imágenes - Press enter to continue")
#     #####
#     transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), transforms.RandomHorizontalFlip(
#     ), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#     ######
# #     input("Inicializas una nueva variablea para dt_loader_real - Press enter to continue")
#     #####
#     train_loader = dt_loader_real

#     for index, (img, _) in enumerate(train_loader):

#         global_steps = writer_dict['train_global_steps']

#         ######
# #         input("convertimos imágenes a tensores float - Press enter to continue")
#         #####
#         real_imgs = img.type(torch.cuda.FloatTensor)

#         ######
# #         input("Ruido 1 - Press enter to continue")
#         #####
#         noise = torch.cuda.FloatTensor(
#             np.random.normal(0, 1, (img.shape[0], latent_dim)))
#         ######
# #         input("optim_dis.zero_grad() - Press enter to continue")
#         #####
#         optim_dis.zero_grad()
#         ######
# #         input("discriminator(real_imgs) - Press enter to continue + 1.0GB")
#         #####
#         real_valid = discriminator(real_imgs)
#         ######
# #         input("generator(noise).detach() - Press enter to continue + 1.9GB")
#         #####
#         fake_imgs = generator(noise).detach()
#         ######
# #         input("discriminator(fake_imgs) - Press enter to continue + 0.1GB")
#         #####
#         fake_valid = discriminator(fake_imgs)

#         if loss == 'hinge':
#             loss_dis = torch.mean(nn.ReLU(inplace=True)(1.0 - real_valid)).to(
#                 device) + torch.mean(nn.ReLU(inplace=True)(1 + fake_valid)).to(device)
#         elif loss == 'wgangp_eps':
#             gradient_penalty = compute_gradient_penalty(
#                 discriminator, real_imgs, fake_imgs.detach(), phi)
#             loss_dis = -torch.mean(real_valid) + torch.mean(fake_valid) + \
#                 gradient_penalty * 10 / (phi ** 2)

#         loss_dis.backward()
#         optim_dis.step()

#         writer.add_scalar("loss_dis", loss_dis.item(), global_steps)

#         if global_steps % n_critic == 0:

#             optim_gen.zero_grad()
#             if schedulers:
#                 gen_scheduler, dis_scheduler = schedulers
#                 g_lr = gen_scheduler.step(global_steps)
#                 d_lr = dis_scheduler.step(global_steps)
#                 writer.add_scalar('LR/g_lr', g_lr, global_steps)
#                 writer.add_scalar('LR/d_lr', d_lr, global_steps)

#             ######
# #             input("Generating noise tensors - Press enter to continue + 0.1GB")
#             #####
#             gener_noise = torch.cuda.FloatTensor(
#                 np.random.normal(0, 1, (gener_batch_size, latent_dim)))

#             ######
# #             input("generator(gener_noise) - Press enter to continue + 0.0GB")
#             #####
#             generated_imgs = generator(gener_noise)
#             ######
# #             input("Press enter to continue")
#             #####
#             fake_valid = discriminator(generated_imgs)

#             gener_loss = -torch.mean(fake_valid).to(device)
#             gener_loss.backward()
#             optim_gen.step()
#             writer.add_scalar("gener_loss", gener_loss.item(), global_steps)

#             gen_step += 1

#         if gen_step and index % 100 == 0:
#             sample_imgs = generated_imgs[:8]
#             img_grid = make_grid(sample_imgs, nrow=4,
#                                  normalize=True, scale_each=True)
#             save_image(
#                 sample_imgs, f'./generated_imgs/generated_img_{epoch}_{index % len(train_loader)}.jpg', nrow=4, normalize=True, scale_each=True)
#             tqdm.write("[Epoch %d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
#                        (epoch+1, index % len(train_loader), len(train_loader), loss_dis.item(), gener_loss.item()))

In [26]:
def train_v2(gen_dataloader, disc_dataloader,
             generator, discriminator, optim_gen, optim_dis,
             epoch, writer, schedulers, img_size=32, latent_dim=latent_dim,
             n_critic=n_critic, gener_batch_size=gener_batch_size,
             device="cuda:0"):

    writer = writer_dict['writer']
    gen_step = 0

    generator = generator.train()
    discriminator = discriminator.train()

    gen_iterator = iter(gen_dataloader)

    for index, (img, _) in enumerate(disc_dataloader):

        try:
            (data2, _) = next(gen_iterator)
        except StopIteration:
            
            gen_iterator = iter(gen_dataloader)
            (data2, _) = next(gen_iterator)

        global_steps = writer_dict['train_global_steps']

        real_imgs = img.type(torch.cuda.FloatTensor)
        
#         fake_imgs = data2.view(16,3072).type(torch.cuda.FloatTensor)        
        fake_imgs = data2.type(torch.cuda.FloatTensor)
        
        ######
#         input("optim_dis.zero_grad() - Press enter to continue")
        #####
        optim_dis.zero_grad()

        real_valid = discriminator(real_imgs)

        # obtenemos un batch de imágenes fake generadas por el
        # modelo generador sin entrenamiento
        fake_imgs = generator(fake_imgs).detach()

#         print(fake_imgs.detach().shape)
        # Hallamos la predicción del discriminador
        fake_valid = discriminator(fake_imgs)

        if loss is not None:
            loss_dis = torch.mean(nn.ReLU(inplace=True)(1.0 - real_valid)).to(
                device) + torch.mean(nn.ReLU(inplace=True)(1 + fake_valid)).to(device)
        else:
            gradient_penalty = compute_gradient_penalty(
                discriminator, real_imgs, fake_imgs.detach(), phi)
            loss_dis = -torch.mean(real_valid) + torch.mean(fake_valid) + \
                gradient_penalty * 10 / (phi ** 2)

        loss_dis.backward()
        optim_dis.step()

        writer.add_scalar("loss_dis", loss_dis.item(), global_steps)

        if global_steps % n_critic == 0:

            optim_gen.zero_grad()
            if schedulers:
                gen_scheduler, dis_scheduler = schedulers
                g_lr = gen_scheduler.step(global_steps)
                d_lr = dis_scheduler.step(global_steps)
                writer.add_scalar('LR/g_lr', g_lr, global_steps)
                writer.add_scalar('LR/d_lr', d_lr, global_steps)

            generated_imgs = generator(data2.type(torch.cuda.FloatTensor))
            
            # Obtenemos la puntuación del discriminador para la retroalimentación
            fake_valid = discriminator(generated_imgs)

            gener_loss = -torch.mean(fake_valid).to(device)
            gener_loss.backward()
            optim_gen.step()
            writer.add_scalar("gener_loss", gener_loss.item(), global_steps)

            gen_step += 1

        if gen_step and index % 250 == 0:
            sample_imgs = generated_imgs[:16]
            img_grid = make_grid(sample_imgs, nrow=4,
                                 normalize=True, scale_each=True)
            save_image(
                sample_imgs, f'./generated_imgs_2/generated_img_{epoch}_{index % len(disc_dataloader)}.png', nrow=4, normalize=True, scale_each=True)
            tqdm.write("[Epoch %d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
                       (epoch+1, index % len(disc_dataloader), len(disc_dataloader), loss_dis.item(), gener_loss.item()))

# Obtención de puntuación FID

# Función de validación

In [27]:
def validate(generator, writer_dict, fid_stat):

    writer = writer_dict['writer']
    global_steps = writer_dict['valid_global_steps']

    generator = generator.eval()
    fid_score = get_fid(fid_stat, epoch, generator, num_img=5000,
                        val_batch_size=60*2, latent_dim=1024,
                        writer_dict=None, cls_idx=None)

    print(f"FID score: {fid_score}")

    writer.add_scalar('FID_score', fid_score, global_steps)

    writer_dict['valid_global_steps'] = global_steps + 1
    return fid_score

# Pasos previos al entrenamiento

In [28]:
# generator= Generator(depth1=5, depth2=4, depth3=2, initial_size=8, dim=dim, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
# # generator= Generator(depth1=4, depth2=3, depth3=2, initial_size=8, dim=dim, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
# generator.to(device)

In [29]:
generator= Generator_2(depth1=5, depth2=4, depth3=2, initial_size=8, dim=dim, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
# generator= Generator(depth1=4, depth2=3, depth3=2, initial_size=8, dim=dim, heads=4, mlp_ratio=4, drop_rate=0.5)#,device = device)
generator.to(device)

Generator_2(
  (patches): ImgPatches(
    (patch_embed): Conv2d(3, 384, kernel_size=(4, 4), stride=(4, 4))
  )
  (TransformerEncoder_encoder1): TransformerEncoder(
    (Encoder_Blocks): ModuleList(
      (0): Encoder_Block(
        (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=False)
          (attention_dropout): Dropout(p=0.5, inplace=False)
          (out): Sequential(
            (0): Linear(in_features=384, out_features=384, bias=True)
            (1): Dropout(p=0.5, inplace=False)
          )
        )
        (ln2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate=none)
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (droprateout): Dropout(p=0.5, inplace=False)
        )
      )
      (1): Encoder_Block(
        (

In [30]:
discriminator = Discriminator(diff_aug = "translation,cutout,color", image_size=32, patch_size=4, input_channel=3, num_classes=1,
                 dim=dim, depth=7, heads=4, mlp_ratio=4, drop_rate=0.)
discriminator.to(device)

Discriminator(
  (patches): ImgPatches(
    (patch_embed): Conv2d(3, 384, kernel_size=(4, 4), stride=(4, 4))
  )
  (droprate): Dropout(p=0.0, inplace=False)
  (TransfomerEncoder): TransformerEncoder(
    (Encoder_Blocks): ModuleList(
      (0): Encoder_Block(
        (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=False)
          (attention_dropout): Dropout(p=0.0, inplace=False)
          (out): Sequential(
            (0): Linear(in_features=384, out_features=384, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ln2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate=none)
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (droprateout): Dropout(p=0.0, inplace=False)
        )
      )

In [31]:
generator.apply(inits_weight)
discriminator.apply(inits_weight)

Discriminator(
  (patches): ImgPatches(
    (patch_embed): Conv2d(3, 384, kernel_size=(4, 4), stride=(4, 4))
  )
  (droprate): Dropout(p=0.0, inplace=False)
  (TransfomerEncoder): TransformerEncoder(
    (Encoder_Blocks): ModuleList(
      (0): Encoder_Block(
        (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=False)
          (attention_dropout): Dropout(p=0.0, inplace=False)
          (out): Sequential(
            (0): Linear(in_features=384, out_features=384, bias=True)
            (1): Dropout(p=0.0, inplace=False)
          )
        )
        (ln2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate=none)
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (droprateout): Dropout(p=0.0, inplace=False)
        )
      )

In [32]:
optim_gen = optim.Adam(filter(lambda p: p.requires_grad, 
                              generator.parameters()), 
                       lr=lr_gen, 
                       betas=(beta1, beta2))

optim_dis = optim.Adam(filter(lambda p: p.requires_grad, 
                                  discriminator.parameters()),
                           lr=lr_dis, 
                           betas=(beta1, beta2))

In [33]:
gen_scheduler = LinearLrDecay(optim_gen, 
                              lr_gen, 
                              0.0, 0, 
                              max_iter * n_critic)

dis_scheduler = LinearLrDecay(optim_dis, 
                              lr_dis, 
                              0.0, 0, 
                              max_iter * n_critic)

# Entrenamiento

In [34]:
for epoch in range(epoch):

    lr_schedulers = (gen_scheduler, dis_scheduler) if lr_decay else None

#     train(noise, generator, discriminator, optim_gen, optim_dis,
#           epoch, writer, lr_schedulers, img_size=32, latent_dim=latent_dim,
#           n_critic=n_critic, gener_batch_size=gener_batch_size)
    
    train_v2(dt_loader_gen, dt_loader_real, generator, discriminator, optim_gen, optim_dis,
          epoch, writer, lr_schedulers, img_size=32, latent_dim=latent_dim,
          n_critic=n_critic, gener_batch_size=gener_batch_size)

    checkpoint = {'epoch': epoch, 'best_fid': best}
    checkpoint['generator_state_dict'] = generator.state_dict()
    checkpoint['discriminator_state_dict'] = discriminator.state_dict()
    save_checkpoint(checkpoint, is_best=True, output_dir=output_dir)
#     score = validate(generator, writer_dict, fid_stat)

#     print(f'FID score: {score} - best ID score: {best} || @ epoch {epoch+1}.')
#     if epoch == 0 or epoch > 30:
#         if score < best:
#             
#             best = score

[Epoch 1] [Batch 0/3737] [D loss: 4.928785] [G loss: 0.423911]
[Epoch 1] [Batch 250/3737] [D loss: -5.603291] [G loss: 0.846895]
[Epoch 1] [Batch 500/3737] [D loss: -3.148864] [G loss: -1.865978]
[Epoch 1] [Batch 750/3737] [D loss: -0.248189] [G loss: 2.153396]
[Epoch 1] [Batch 1000/3737] [D loss: -0.616114] [G loss: 0.722374]
[Epoch 1] [Batch 1250/3737] [D loss: -0.083865] [G loss: 0.016641]
[Epoch 1] [Batch 1500/3737] [D loss: -0.743816] [G loss: 0.572345]
[Epoch 1] [Batch 1750/3737] [D loss: -0.555105] [G loss: 1.173501]
[Epoch 1] [Batch 2000/3737] [D loss: -0.411012] [G loss: 0.874375]
[Epoch 1] [Batch 2250/3737] [D loss: -0.270281] [G loss: 1.187196]
[Epoch 1] [Batch 2500/3737] [D loss: -0.493823] [G loss: 1.818932]
[Epoch 1] [Batch 2750/3737] [D loss: -0.219764] [G loss: -0.033861]
[Epoch 1] [Batch 3000/3737] [D loss: -0.407539] [G loss: 1.039855]
[Epoch 1] [Batch 3250/3737] [D loss: -0.039022] [G loss: -0.026963]
[Epoch 1] [Batch 3500/3737] [D loss: -0.413641] [G loss: 0.596636]

[Epoch 9] [Batch 750/3737] [D loss: -0.237303] [G loss: -0.069140]
[Epoch 9] [Batch 1000/3737] [D loss: 0.182671] [G loss: 0.062264]
[Epoch 9] [Batch 1250/3737] [D loss: -0.038857] [G loss: 0.437486]
[Epoch 9] [Batch 1500/3737] [D loss: 0.275319] [G loss: -0.636727]
[Epoch 9] [Batch 1750/3737] [D loss: -0.415891] [G loss: -0.305466]
[Epoch 9] [Batch 2000/3737] [D loss: 0.040276] [G loss: -0.177399]
[Epoch 9] [Batch 2250/3737] [D loss: -0.120701] [G loss: -0.103230]
[Epoch 9] [Batch 2500/3737] [D loss: -0.173978] [G loss: -0.694968]
[Epoch 9] [Batch 2750/3737] [D loss: -0.448101] [G loss: -0.758566]
[Epoch 9] [Batch 3000/3737] [D loss: 0.094326] [G loss: -0.505736]
[Epoch 9] [Batch 3250/3737] [D loss: -0.462984] [G loss: -0.780729]
[Epoch 9] [Batch 3500/3737] [D loss: -0.004895] [G loss: -0.516706]
[Epoch 10] [Batch 0/3737] [D loss: -0.639158] [G loss: -0.427350]
[Epoch 10] [Batch 250/3737] [D loss: -0.320375] [G loss: -0.715077]
[Epoch 10] [Batch 500/3737] [D loss: -0.297208] [G loss: 

[Epoch 17] [Batch 750/3737] [D loss: -0.328045] [G loss: -1.581429]
[Epoch 17] [Batch 1000/3737] [D loss: -0.474953] [G loss: -1.531839]
[Epoch 17] [Batch 1250/3737] [D loss: -0.156620] [G loss: -1.944632]
[Epoch 17] [Batch 1500/3737] [D loss: -0.133209] [G loss: -1.206184]
[Epoch 17] [Batch 1750/3737] [D loss: -0.209734] [G loss: -0.783514]
[Epoch 17] [Batch 2000/3737] [D loss: -0.230988] [G loss: -1.435113]
[Epoch 17] [Batch 2250/3737] [D loss: -0.303032] [G loss: -1.303110]
[Epoch 17] [Batch 2500/3737] [D loss: -0.109417] [G loss: -1.085255]
[Epoch 17] [Batch 2750/3737] [D loss: 0.200756] [G loss: -1.100820]
[Epoch 17] [Batch 3000/3737] [D loss: -0.258099] [G loss: -1.577160]
[Epoch 17] [Batch 3250/3737] [D loss: -0.181556] [G loss: -1.338347]
[Epoch 17] [Batch 3500/3737] [D loss: -0.148168] [G loss: -1.479494]
[Epoch 18] [Batch 0/3737] [D loss: -0.089086] [G loss: -1.458939]
[Epoch 18] [Batch 250/3737] [D loss: -0.498616] [G loss: -1.263058]
[Epoch 18] [Batch 500/3737] [D loss: -0.

KeyboardInterrupt: 

# TEST 

In [None]:
data_folder_pixeled = './Test_images/pixeled/'

# Tamaño que deseamos que tengan las imágenes
image_size = 32
# Tamaño del lote de imágenes
batch_size = 16



pixeled_imgs_test = dtst.ImageFolder(
    root=data_folder_pixeled,
    transform=transforms.Compose([
        # Se usa el resize en caso no todas las imágenes de 
        # entrada tengan el tamaño de 32px
        transforms.Resize(image_size),
        # CenterCrop busca recortar la imagen en caso sea muy grande al tamaño dado
        transforms.CenterCrop(image_size),
        # ToTensor convierte finalmente la imagen a tensor
        transforms.ToTensor(),
        # Normalize permite la normalización de la información
        # Esta desviación estandar y media es la hallada para los datos de
        # entrenamiento
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))



dt_loader_gen_test = DataLoader(pixeled_imgs_test,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=1,
                       drop_last=True)

In [None]:
for index, (img, _) in enumerate(dt_loader_gen_test):   
       
    fake_imgs = img.view(16,3072).type(torch.cuda.FloatTensor)
    
    generated_imgs = generator(fake_imgs)
    
    show_tensor_images(generated_imgs,16,size=(3,32,32))
    sample_imgs = generated_imgs[:16]
#     img_grid = make_grid(sample_imgs, nrow=4,
#     normalize=True, scale_each=True)
    save_image(sample_imgs, f'./generated_imgs_test/test_img_{index % len(dt_loader_gen_test)}.jpg', nrow=4, normalize=True, scale_each=True)

In [None]:
data_folder_complete = './Test_images/original/'

complete_imgs_test = dtst.ImageFolder(
    root=data_folder_complete,
    transform=transforms.Compose([
        # Se usa el resize en caso no todas las imágenes de 
        # entrada tengan el tamaño de 32px
        transforms.Resize(image_size),
        # CenterCrop busca recortar la imagen en caso sea muy grande al tamaño dado
        transforms.CenterCrop(image_size),
        # ToTensor convierte finalmente la imagen a tensor
        transforms.ToTensor(),
        # Normalize permite la normalización de la información
        # El problema encontrado es que necesitamos hallar la desviación estandar
        # media de toda la información para realizar una correcta normalización
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]))

dt_loader_real_test = DataLoader(complete_imgs_test,
                       batch_size=batch_size,
                       shuffle=False,
                       num_workers=1,
                       drop_last=True)

for index, (img, _) in enumerate(dt_loader_real_test):   
       
    real_imgs = img.type(torch.cuda.FloatTensor)
    
    show_tensor_images(real_imgs,16,size=(3,32,32))
    sample_imgs = real_imgs[:16]

    save_image(sample_imgs, f'./generated_imgs_test/real2/test_img_{index % len(dt_loader_real_test)}.png', nrow=4, normalize=True, scale_each=True)