In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [7]:
num_epoch = 1000
batch_size = 100
lr = 0.0001
img_size = 28*28
num_channel = 1
dir_name = 'GAN_result'

## generator parameter
noise_size = 100
hidden_size1 = 256
hidden_size2 = 512

##
if not os.path.exists(dir_name):
  os.makedirs(dir_name)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

In [5]:
MNIST_dataset = datasets.MNIST(root='./', train=True, transform=transform, download=True)

data_loader = DataLoader(dataset=MNIST_dataset, batch_size=batch_size, shuffle=True)

100%|██████████| 9.91M/9.91M [00:11<00:00, 901kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 65.6kB/s]
100%|██████████| 1.65M/1.65M [00:06<00:00, 245kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.89MB/s]


# Discriminator

In [12]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(img_size, hidden_size2)
    self.linear2 = nn.Linear(hidden_size2, hidden_size1)
    self.linear3 = nn.Linear(hidden_size1, 1)
    self.leaky_relu = nn.LeakyReLU(0.2)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.leaky_relu(self.linear1(x))
    x = self.leaky_relu(self.linear2(x))
    x = self.linear3(x)
    x = self.sigmoid(x)
    return x

# Generator

In [9]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(noise_size, hidden_size1)
    self.linear2 = nn.Linear(hidden_size1, hidden_size2)
    self.linear3 = nn.Linear(hidden_size2, img_size)    ## MNIST
    self.relu = nn.ReLU()
    self.tanh = nn.Tanh()

  def forward(self, x):
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    x = self.tanh(x)
    return x

In [13]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

# Vanilla GAN

In [14]:
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = optim.Adam(generator.parameters(), lr=lr)

In [15]:
for epoch in range(num_epoch):
  for i, (images, labels) in enumerate(data_loader):
    real_label = torch.full((batch_size, 1), 1, dtype=torch.float32).to(device) # 1
    fake_label = torch.full((batch_size, 1), 0, dtype=torch.float32).to(device) # 1

    real_images = images.view(batch_size, -1).to(device)

    ### generator training
    g_optimizer.zero_grad()
    z = torch.randn(batch_size, noise_size).to(device)
    fake_images = generator(z)
    g_loss = criterion(discriminator(fake_images), real_label)
    g_loss.backward()
    g_optimizer.step()

    ### discriminator training
    z = torch.randn(batch_size, noise_size).to(device)
    fake_images = generator(z)

    fake_prediction = discriminator(fake_images)
    real_prediction = discriminator(real_images)

    fake_loss = criterion(fake_prediction, fake_label)
    real_loss = criterion(real_prediction, real_label)
    d_loss = (fake_loss + real_loss) / 2

    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    if (i+1) % 150 == 0:
      print(f'Epoch [{epoch}/{num_epoch}] Step [{i+1}/{len(data_loader)}] d_loss: {d_loss.item()} g_loss: {g_loss.item()}')

  samples = fake_images.reshape(batch_size, 1, 28, 28)
  save_image(samples, os.path.join(dir_name, f'Gan_fake_samples{epoch+1}.png'))

Epoch [0/1000] Step [150/600] d_loss: 0.15048757195472717 g_loss: 2.387244701385498
Epoch [0/1000] Step [300/600] d_loss: 0.04329875484108925 g_loss: 3.500458240509033
Epoch [0/1000] Step [450/600] d_loss: 0.1515117883682251 g_loss: 3.5689995288848877
Epoch [0/1000] Step [600/600] d_loss: 0.05694755166769028 g_loss: 3.7506794929504395
Epoch [1/1000] Step [150/600] d_loss: 0.08705539256334305 g_loss: 3.467283248901367
Epoch [1/1000] Step [300/600] d_loss: 0.10833865404129028 g_loss: 3.259143829345703
Epoch [1/1000] Step [450/600] d_loss: 0.09509746730327606 g_loss: 3.2585508823394775
Epoch [1/1000] Step [600/600] d_loss: 0.07430216670036316 g_loss: 3.7650530338287354
Epoch [2/1000] Step [150/600] d_loss: 0.43539804220199585 g_loss: 2.534571886062622
Epoch [2/1000] Step [300/600] d_loss: 0.08592623472213745 g_loss: 3.6821696758270264
Epoch [2/1000] Step [450/600] d_loss: 0.10013408958911896 g_loss: 3.3634300231933594
Epoch [2/1000] Step [600/600] d_loss: 0.4433450400829315 g_loss: 2.7995

KeyboardInterrupt: 