<a href="https://colab.research.google.com/github/haruka-inb/pytorch_practice/blob/main/generative_adversarial_network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generative Adversarial Network

In [None]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper parameters
latent_size = 64
hidden_size = 256
image_size = 784
# num_epochs = 200
num_epochs = 20
batch_size = 100
sample_dir = 'samples'

# Create a directory to store the generated images if not exits
if not os.path.exists(sample_dir):
  os.mkdir(sample_dir)

# Process image: normalize with mean=0.5, std=0.5
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=0.5, std=0.5)])

# Load MNIST dataset
mnist = torchvision.datasets.MNIST(root='/../../data', train=True,
                                      transform=transform, download=True)

# Create a data loader
data_loader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,
                                           shuffle=True, num_workers=2)


In [None]:
# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()
)

# Generator
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

# transfer models to GPUs
D = D.to(device)
G = G.to(device)

# Define binary cross entropy loss and Adam optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.002)

# Define the method denorm()
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

# Define the method reset_grad()
def reset_grad():
  d_optimizer.zero_grad()
  g_optimizer.zero_grad()

In [None]:
# Start training
for e in range(num_epochs):
  for i, (images, _) in enumerate(data_loader):
    images = images.reshape(batch_size, -1).to(device)

    # Create the labels which are later used as input for the BCE loss
    # label real images as 1 and fake images as 0
    real_labels = torch.ones(batch_size, 1).to(device)
    fake_labels = torch.zeros(batch_size, 1).to(device)

    # ================================================================== #
    #                      Train the discriminator                       #
    # ================================================================== #
    # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
    # Second term of the loss is always zero since real_labels == 1
    outputs = D(images)
    d_loss_real = criterion(outputs, real_labels)
    real_score = outputs

    # Compute BCELoss using fake images
    # First term of the loss is always zero since fake_labels == 0
    z = torch.randn(batch_size, latent_size).to(device)
    fake_images = G(z)
    outputs = D(fake_images)
    d_loss_fake =  criterion(outputs, fake_labels)
    fake_score = outputs

    # Backprop and optimize
    d_loss = d_loss_real + d_loss_fake
    reset_grad()
    d_loss.backward()
    d_optimizer.step()

    # ================================================================== #
    #                        Train the generator                         #
    # ================================================================== #
    # Compute loss with fake images
    z = torch.randn(batch_size, latent_size).to(device)
    fake_images = G(z)
    outputs = D(fake_images)

    # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
    g_loss = criterion(outputs, real_labels)

    # Backprop and optimize
    reset_grad()
    g_loss.backward()
    g_optimizer.step()

    if (i+1) % 200 == 0:
      print(f"Epoch {e}/{num_epochs}, Steps {i}/{len(data_loader)}, d_loss: {d_loss.item()}, g_loss: {g_loss.item()}, D(x): {real_score.mean().item()}, G(x): {fake_score.mean().item()}")

  # Save real images
  if (e+1) == 1:
    images = images.reshape(images.size(0), 1, 28, 28)
    save_image(denorm(images), os.path.join(sample_dir, 'realimages.png'))

  # Save sampled images
  fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
  save_image(denorm(images), os.path.join(sample_dir, 'fake_images-{}.png'.format(e+1)))

# Save the model checkpoints
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

Epoch 0/20, Steps 199/600, d_loss: 0.016027158126235008, g_loss: 6.685244083404541, D(x): 0.9961581230163574, G(x): 0.012034215964376926
Epoch 0/20, Steps 399/600, d_loss: 1.0172320799028967e-05, g_loss: 98.96879577636719, D(x): 0.9999897480010986, G(x): 7.286766027473692e-40
Epoch 0/20, Steps 599/600, d_loss: 1.2976300240552519e-05, g_loss: 99.07817840576172, D(x): 0.999987006187439, G(x): 1.2509890253280942e-37
Epoch 1/20, Steps 199/600, d_loss: 4.2743417907331605e-06, g_loss: 97.8670425415039, D(x): 0.9999957084655762, G(x): 1.382111044972203e-38
Epoch 1/20, Steps 399/600, d_loss: 2.6667642032407457e-06, g_loss: 98.30533599853516, D(x): 0.9999973177909851, G(x): 4.130481366428474e-39
Epoch 1/20, Steps 599/600, d_loss: 1.9930444977944717e-05, g_loss: 98.89678192138672, D(x): 0.9999801516532898, G(x): 9.027356685070084e-39
Epoch 2/20, Steps 199/600, d_loss: 7.808275199749914e-07, g_loss: 98.41651916503906, D(x): 0.9999992251396179, G(x): 2.9346398816444973e-38
Epoch 2/20, Steps 399/60