<a href="https://colab.research.google.com/github/crodier1/data_science/blob/main/PyTorch_GAN_letters.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

In [None]:
train_dataset = torchvision.datasets.MNIST(
    root='.',
    train=True,
    transform=transform,
    download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 107713003.18it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 28125306.20it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 82515736.18it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5493232.06it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [None]:
len(train_dataset)

60000

In [None]:
batch_size = 128

dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True)

In [None]:
D = nn.Sequential(
    nn.Linear(784, 512),
    nn.LeakyReLU(0.2),
    nn.Linear(512, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    # nn.Sigmoid()
)

In [None]:
latent_dim = 100

G = nn.Sequential(
    nn.Linear(latent_dim, 256),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(256, momentum=0.7),
    nn.Linear(256, 512),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(512, momentum=0.7),
    nn.Linear(512, 1024),
    nn.LeakyReLU(0.2),
    nn.BatchNorm1d(1024, momentum=0.7),
    nn.Linear(1024, 784),
    nn.Tanh()
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

D = D.to(device)
G = G.to(device)


In [None]:
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
def scale_image(img):
  out = (img + 1) / 2
  return out

In [None]:
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

In [None]:
# for inputs, labels in dataloader:
#   print(inputs.shape)
#   break

torch.Size([128, 1, 28, 28])


In [None]:
ones_ = torch.ones(batch_size, 1).to(device)
zeros_ = torch.zeros(batch_size, 1).to(device)

d_losses = []
g_losses = []

for epoch in range(200):

  for inputs,_ in dataloader:
    n = inputs.size(0)

    inputs = inputs.reshape(n, 784).to(device)

    # set ones & zeros to correct size
    ones = ones_[:n]
    zeros = zeros_[:n]

    # Train Descriminator

    real_outputs = D(inputs)
    d_loss_real = criterion(real_outputs, ones)

    noise = torch.randn(n, latent_dim).to(device)

    fake_images = G(noise)
    fake_outputs = D(fake_images)

    d_loss_fake = criterion(fake_outputs, zeros)

    d_loss = 0.5 * (d_loss_real + d_loss_fake)

    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    # Train Generator

    for _ in range(2):
      noise = torch.randn(n, latent_dim).to(device)
      fake_images = G(noise)
      fake_outputs = D(fake_images)

      g_loss = criterion(fake_outputs, ones)

      d_optimizer.zero_grad()
      g_optimizer.zero_grad()
      g_loss.backward()
      g_optimizer.step()

    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())

  print(f'Epoch {epoch+1} D Loss: {d_loss.item()} G Loss: {g_loss.item()}')

  fake_images = fake_images.reshape(-1, 1, 28, 28)
  save_image(scale_image(fake_images), f'gan_images/{epoch+1}.png')

128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 784])
128
noise.shape torch.Size([128, 100])
128
torch.Size([128, 1, 28, 28])
torch.Size([128, 

KeyboardInterrupt: 

In [None]:
plt.plot(d_losses)
plt.plot(g_losses)
plt.legend()
plt.show()