<a href="https://colab.research.google.com/github/fsilvao/ia/blob/main/GAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm


#Configuración
epochs = 100
batch_size = 64
sample_size = 100 # Número de valores aleatorios para muestrear
g_lr= 1.0e-4
d_lr= 1.0e-4

#Dataloader para MNIST

transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train = True, download = True, transform = transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

#Red generadora
class Generator(nn.Sequential):
  def __init__(self, sample_size:int):
    super().__init__(
        nn.Linear(sample_size, 128),
        nn.LeakyReLU(0.01),
        nn.Linear(128, 784),
        nn.Sigmoid()
    )
    #Tamaño del vector de valor aleatorio
    self.sample_size = sample_size

  def forward(self, batch_size: int):
    #Generamos valores aleatorios
    z = torch.randn(batch_size, self.sample_size)

    #Salida del generador
    output = super().forward(z)

    #Convertimos el output a una escala de grises (1x28x28)
    generated_images = output.reshape(batch_size, 1,28,28)
    return generated_images

#Red Discriminadora
class Discriminator(nn.Sequential):
  def __init__(self):
    super().__init__(
        nn.Linear(784, 128),
        nn.LeakyReLU(0.01),
        nn.Linear(128,1)
    )

  def forward(self, images: torch.Tensor, targets: torch.Tensor):
    prediction = super().forward(images.reshape(-1,784))
    loss = F.binary_cross_entropy_with_logits(prediction,targets)
    return loss

#Función auxiliar para guardar imágenes en forma de grilla:

def save_image_grid(epoch:int, images: torch.Tensor, ncol:int):
  image_grid = make_grid(images, ncol) #Imágenes en la grilla
  image_grid = image_grid.permute(1, 2, 0) #Move channel last
  image_grid = image_grid.cpu().numpy() #A numpy

  plt.imshow(image_grid)
  plt.xticks([])
  plt.yticks([])
  plt.savefig(f'generated_{epoch:03d}.jpg')
  plt.close()


#Etiquetas reales y falsas/generadas
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)

#Instanciamos las redes generadora y discriminadora
generator = Generator(sample_size)
discriminator = Discriminator()

#Optimizadores
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)

#Loop de entrenamiento
for epoch in range(epochs):
  d_losses = []
  g_losses = []

  for images, labels in tqdm(dataloader):
    #Entrenamos el discriminador
    #Loss con las imágenes MNIST como inputs y etiquetas real_targets
    discriminator.train()
    d_loss = discriminator(images, real_targets)

    #Generar imágenes en modo eval
    generator.eval()
    with torch.no_grad():
      generated_images = generator(batch_size)

    # Loss con las imágenes generads como inputs y etiquetas fake_targets
    d_loss += discriminator(generated_images, fake_targets)

    #Optimizamos los parámetros del discriminador
    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()


    #Entrenamos el Generador
    generator.train()
    generated_images = generator(batch_size)

    #Loss con imágenes generadas como input y etiquetas real_targets
    discriminator.eval()
    g_loss = discriminator(generated_images, real_targets)

    #Optimizamos los parámetros del generador:
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()

    #Guardamos los losses para el log
    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())

  #Imprimir losses en promedio
  print(epoch, np.mean(d_losses), np.mean(g_losses))

  #Guardamos las imágenes
  save_image_grid(epoch, generator(batch_size), ncol=8)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 80256877.93it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 61490199.91it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 16721086.15it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17721422.11it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw




100%|██████████| 937/937 [00:26<00:00, 35.71it/s]


0 0.3957257770645962 2.9097461256426196


100%|██████████| 937/937 [00:17<00:00, 52.85it/s]


1 0.2990843303652685 2.293939580530596


100%|██████████| 937/937 [00:18<00:00, 51.49it/s]


2 0.48986765141166516 1.5066959983892858


100%|██████████| 937/937 [00:17<00:00, 53.56it/s]


3 0.50743497702966 1.2746323718587007


100%|██████████| 937/937 [00:17<00:00, 53.05it/s]


4 0.36052650482003185 1.8540173331345158


100%|██████████| 937/937 [00:19<00:00, 48.72it/s]


5 0.5187241329772749 1.9480203298519871


100%|██████████| 937/937 [00:18<00:00, 49.98it/s]


6 0.4991274860587166 2.1030600114209674


100%|██████████| 937/937 [00:17<00:00, 53.10it/s]


7 0.43615316120130404 2.274750127481931


100%|██████████| 937/937 [00:17<00:00, 53.47it/s]


8 0.53335220000151 2.086005357566231


100%|██████████| 937/937 [00:18<00:00, 51.20it/s]


9 0.5428089505834793 1.9331011563507698


100%|██████████| 937/937 [00:17<00:00, 53.41it/s]


10 0.5699112327846926 1.9846828687407292


100%|██████████| 937/937 [00:18<00:00, 51.13it/s]


11 0.5164475905666229 1.940176420685067


100%|██████████| 937/937 [00:17<00:00, 53.99it/s]


12 0.3635348344943312 2.252374063052134


100%|██████████| 937/937 [00:18<00:00, 50.83it/s]


13 0.36613288024001084 2.260573486762724


100%|██████████| 937/937 [00:17<00:00, 52.23it/s]


14 0.35778748095353613 2.3760571011611464


100%|██████████| 937/937 [00:17<00:00, 53.46it/s]


15 0.39131393780576 2.394910442918189


100%|██████████| 937/937 [00:18<00:00, 51.21it/s]


16 0.4492986014424928 2.385027566994267


100%|██████████| 937/937 [00:17<00:00, 53.42it/s]


17 0.5398628793760196 2.1963787405824204


100%|██████████| 937/937 [00:18<00:00, 50.86it/s]


18 0.3795446370397332 2.4830589099780216


100%|██████████| 937/937 [00:17<00:00, 53.43it/s]


19 0.46529788219368956 2.185509528141775


100%|██████████| 937/937 [00:18<00:00, 51.35it/s]


20 0.36193850103729946 2.4762862091257833


100%|██████████| 937/937 [00:17<00:00, 53.65it/s]


21 0.39148300725533336 2.5074032391148164


100%|██████████| 937/937 [00:19<00:00, 49.10it/s]


22 0.40155789928795055 2.4215353981789556


100%|██████████| 937/937 [00:17<00:00, 53.23it/s]


23 0.38587597875260493 2.61573077571303


100%|██████████| 937/937 [00:17<00:00, 52.30it/s]


24 0.45697890443125044 2.3401652976568377


100%|██████████| 937/937 [00:18<00:00, 51.90it/s]


25 0.40863803669953425 2.5756076810963


100%|██████████| 937/937 [00:17<00:00, 53.32it/s]


26 0.3124223267669484 3.0977882639193868


100%|██████████| 937/937 [00:18<00:00, 51.21it/s]


27 0.6124096014233765 2.217941043598293


100%|██████████| 937/937 [00:17<00:00, 53.10it/s]


28 0.357112608619408 2.8324131342237506


100%|██████████| 937/937 [00:18<00:00, 51.06it/s]


29 0.45817088614851076 2.5038877336294445


100%|██████████| 937/937 [00:17<00:00, 52.87it/s]


30 0.46518272589275905 2.4839633201967564


100%|██████████| 937/937 [00:19<00:00, 48.53it/s]


31 0.6318365850246321 2.149648007740969


100%|██████████| 937/937 [00:17<00:00, 53.69it/s]


32 0.4752297785168905 2.4333787721401854


100%|██████████| 937/937 [00:18<00:00, 51.25it/s]


33 0.48414532719388714 2.428315446623616


100%|██████████| 937/937 [00:17<00:00, 53.44it/s]


34 0.7138263973285193 2.13988409295535


100%|██████████| 937/937 [00:18<00:00, 51.63it/s]


35 0.4932630555336478 2.662304824037257


100%|██████████| 937/937 [00:17<00:00, 52.70it/s]


36 0.5066494669451149 2.3286921822783024


100%|██████████| 937/937 [00:17<00:00, 53.53it/s]


37 0.5178914827114744 2.2537030616463056


100%|██████████| 937/937 [00:18<00:00, 51.09it/s]


38 0.5140640905470832 2.2849473538495433


100%|██████████| 937/937 [00:17<00:00, 53.74it/s]


39 0.5257932196113827 2.416916157290928


100%|██████████| 937/937 [00:19<00:00, 49.01it/s]


40 0.5658271398908173 2.331281066576979


100%|██████████| 937/937 [00:17<00:00, 53.72it/s]


41 0.4892772166903271 2.3805150724907693


100%|██████████| 937/937 [00:18<00:00, 51.10it/s]


42 0.4828308164564497 2.3628844352262126


100%|██████████| 937/937 [00:17<00:00, 53.67it/s]


43 0.8241029608847619 1.9531752144005408


100%|██████████| 937/937 [00:18<00:00, 51.68it/s]


44 0.37863805252370863 2.9529450881188555


100%|██████████| 937/937 [00:17<00:00, 52.47it/s]


45 0.5682223247107727 2.414064347171478


100%|██████████| 937/937 [00:17<00:00, 53.53it/s]


46 0.5094783504273873 2.371898370338733


100%|██████████| 937/937 [00:18<00:00, 51.77it/s]


47 0.542555907238255 2.355025525408595


100%|██████████| 937/937 [00:18<00:00, 51.19it/s]


48 0.7914632369894355 2.0442675591278485


100%|██████████| 937/937 [00:18<00:00, 51.12it/s]


49 0.4192344630286177 2.7676669721796774


100%|██████████| 937/937 [00:17<00:00, 54.13it/s]


50 0.5774200405356978 2.3747984254881818


100%|██████████| 937/937 [00:18<00:00, 51.81it/s]


51 0.5596231837600883 2.2909377327341915


100%|██████████| 937/937 [00:17<00:00, 54.55it/s]


52 0.5726261259779223 2.228636838837647


100%|██████████| 937/937 [00:17<00:00, 53.33it/s]


53 0.5867194254726585 2.1973599494457754


100%|██████████| 937/937 [00:17<00:00, 53.04it/s]


54 0.5884527456544125 2.215234513471068


100%|██████████| 937/937 [00:17<00:00, 53.93it/s]


55 0.5902420494192565 2.254245050052696


100%|██████████| 937/937 [00:18<00:00, 51.77it/s]


56 0.5798951102740737 2.2605193579845775


100%|██████████| 937/937 [00:18<00:00, 51.42it/s]


57 0.5868273516729268 2.229788203885741


100%|██████████| 937/937 [00:18<00:00, 51.65it/s]


58 0.5765785516707849 2.2396875025750225


100%|██████████| 937/937 [00:17<00:00, 53.94it/s]


59 0.5742358992486524 2.263701405825457


100%|██████████| 937/937 [00:18<00:00, 50.02it/s]


60 0.5728564831083333 2.2668595174079007


100%|██████████| 937/937 [00:18<00:00, 51.95it/s]


61 0.5752711472605437 2.2635092615699564


100%|██████████| 937/937 [00:18<00:00, 49.36it/s]


62 0.5715880740986562 2.2761785232142935


100%|██████████| 937/937 [00:18<00:00, 51.29it/s]


63 0.5664302408695221 2.300790564990985


100%|██████████| 937/937 [00:18<00:00, 50.06it/s]


64 0.5469622536937099 2.2695419363431006


100%|██████████| 937/937 [00:17<00:00, 52.18it/s]


65 0.5457390757037074 2.327235242982048


100%|██████████| 937/937 [00:19<00:00, 47.90it/s]


66 0.5769514152372659 2.320167957527788


100%|██████████| 937/937 [00:17<00:00, 52.19it/s]


67 0.6015476043730688 2.2538606733751654


100%|██████████| 937/937 [00:18<00:00, 49.89it/s]


68 0.6065007597048993 2.222309759103374


100%|██████████| 937/937 [00:18<00:00, 51.93it/s]


69 0.6039201290274124 2.2031258890443066


100%|██████████| 937/937 [00:18<00:00, 50.52it/s]


70 0.6118841223108603 2.211110358049928


100%|██████████| 937/937 [00:18<00:00, 51.43it/s]


71 1.369609744342694 1.9734240074263312


100%|██████████| 937/937 [00:18<00:00, 51.99it/s]


72 3.105749548054047 0.7049518345736897


100%|██████████| 937/937 [00:18<00:00, 51.35it/s]


73 1.2716415573464133 1.4755264583621235


100%|██████████| 937/937 [00:18<00:00, 50.03it/s]


74 0.7973339061154373 2.0529301987131987


100%|██████████| 937/937 [00:18<00:00, 50.10it/s]


75 0.5249987158411468 2.56156407158301


100%|██████████| 937/937 [00:18<00:00, 51.37it/s]


76 0.5758713694684915 2.532390032532121


100%|██████████| 937/937 [00:18<00:00, 50.00it/s]


77 0.6923839993766939 2.1381350820575986


100%|██████████| 937/937 [00:17<00:00, 52.10it/s]


78 0.6609878696207049 2.142697972574509


100%|██████████| 937/937 [00:18<00:00, 49.92it/s]


79 0.6474760208465755 2.1569356438698866


100%|██████████| 937/937 [00:18<00:00, 52.02it/s]


80 0.6519673952043629 2.203921011825


100%|██████████| 937/937 [00:18<00:00, 49.79it/s]


81 0.6640466677150197 2.1937473359713557


100%|██████████| 937/937 [00:18<00:00, 49.95it/s]


82 0.6524546313056824 2.1354983800757785


100%|██████████| 937/937 [00:18<00:00, 50.25it/s]


83 0.6489339548788304 2.1152512944495285


 83%|████████▎ | 781/937 [00:14<00:02, 55.67it/s]