# Training GAN on MNIST Dataset

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import sys

In [2]:
# Experiment settings
batch_size = 64 # Batch size
learning_rate = 0.0002 # Learning rate
z_dim = 100 # Dimension of the noise vector
max_epochs = 200 # Number of epochs to train the model
mnist_data_root = './datasets' # Root directory for the MNIST dataset
saved_models_dir = './saved_models' # Model weight save directory
experiment_name = 'gan_mnist_pytorch' # Name of the experiment
output_dir = './data/' + experiment_name + '_output' # Output directory for generated images
tensorboard_log_dir = './runs/' + experiment_name # TensorBoard log directory

In [3]:
# Check if a GPU is available and set the device accordingly
if torch.accelerator.is_available():
    device = torch.accelerator.current_accelerator().type
else:
    device = "cpu"
print(f"Using {device} device")

Using mps device


In [4]:
# MNIST Dataset
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to 1 channel
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Adjust for 1 channel
])

# Dataset for MNIST training images
train_dataset = datasets.MNIST(root=mnist_data_root, train=True, transform=transform, download=False)

# DataLoader for MNIST training images
batch_size = 32
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)


In [5]:
# Generator model
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))

In [6]:
# Discriminator model
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [7]:
# build network
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2) 

# Initialize the generator and discriminator
generator_model = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
discriminator_model = Discriminator(mnist_dim).to(device)



In [None]:
# Show the generator and discriminator model architectures
print(generator_model)
print(discriminator_model) 
sys.stdout.flush()

Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)
Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)


In [9]:
# Loss function
criterion = nn.BCELoss() 

# Optimizers
generator_optimizer = optim.Adam(generator_model.parameters(), lr = learning_rate)
discriminator_optimizer = optim.Adam(discriminator_model.parameters(), lr = learning_rate)

In [10]:
# Function to train the discriminator
def D_train(x):
    # Reset gradients of model
    discriminator_model.zero_grad()

    # Take the real images
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(batch_size, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    # Let the discriminator predict the real images
    D_output = discriminator_model(x_real)
    D_real_loss = criterion(D_output, y_real)

    # Let the discriminator predict the fake images 
    z = Variable(torch.randn(batch_size, z_dim).to(device))
    x_fake, y_fake = generator_model(z), Variable(torch.zeros(batch_size, 1).to(device))
    D_output = discriminator_model(x_fake)
    D_fake_loss = criterion(D_output, y_fake)

    # Calculate the total discriminator loss, backpropagate, and update the parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    discriminator_optimizer.step()
    
    # Return the discriminator loss
    return  D_loss.data.item()

In [None]:
# Function to train the generator
def G_train(x):
    # Reset gradients of model
    generator_model.zero_grad()

    # Generate random noise with label 1 (fake)
    z = Variable(torch.randn(batch_size, z_dim).to(device))
    y = Variable(torch.ones(batch_size, 1).to(device))

    # Let the generator generate fake images
    G_output = generator_model(z)

    # Let the discriminator predict the fake images
    D_output = discriminator_model(G_output)

    # Calculate the generator loss
    G_loss = criterion(D_output, y)

    # Gradient backpropagation and optimization of the generator's parameters
    G_loss.backward()
    generator_optimizer.step()
        
    # Return the generator loss
    return G_loss.data.item()

In [None]:
# Training loop
writer = SummaryWriter(tensorboard_log_dir)
for epoch in range(1, max_epochs+1):    
    # Lists to store losses       
    D_losses, G_losses = [], []

    # Iterate through the training data 
    for batch_idx, (x, _) in enumerate(train_loader):
        # Train the discriminator and store the loss 
        D_losses.append(D_train(x))
        # Train the generator and store the loss
        G_losses.append(G_train(x))

    # Print the average losses for the epoch
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), max_epochs, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    sys.stdout.flush()

    # Log the losses to TensorBoard
    writer.add_scalars('Discriminator vs Generator Loss', {'Discriminator': torch.mean(torch.FloatTensor(D_losses)), 'Generator': torch.mean(torch.FloatTensor(G_losses))}, epoch)
    writer.flush()

    # Save the generator and discriminator models for every epoch
    torch.save(generator_model.state_dict(), saved_models_dir + '/' + experiment_name + '_G_epoch_' + str(epoch) + '.pth')
    torch.save(discriminator_model.state_dict(), saved_models_dir + '/' + experiment_name + '_D_epoch_' + str(epoch) + '.pth')

    # Save image for every epoch
    with torch.no_grad():
        test_z = Variable(torch.randn(batch_size, z_dim).to(device))
        generated = generator_model(test_z)
        save_image(generated.view(generated.size(0), 1, 28, 28),  output_dir + '/' + experiment_name + '_output_' + str(epoch) + '.png')


[1/200]: loss_d: 0.847, loss_g: 2.852
[2/200]: loss_d: 0.538, loss_g: 2.849
[3/200]: loss_d: 0.648, loss_g: 2.356
[4/200]: loss_d: 0.767, loss_g: 1.967
[5/200]: loss_d: 0.854, loss_g: 1.749
[6/200]: loss_d: 0.922, loss_g: 1.573
[7/200]: loss_d: 0.946, loss_g: 1.511
[8/200]: loss_d: 1.045, loss_g: 1.295
[9/200]: loss_d: 1.048, loss_g: 1.311
[10/200]: loss_d: 1.077, loss_g: 1.235
[11/200]: loss_d: 1.099, loss_g: 1.204
[12/200]: loss_d: 1.091, loss_g: 1.225
[13/200]: loss_d: 1.110, loss_g: 1.186
[14/200]: loss_d: 1.144, loss_g: 1.114
[15/200]: loss_d: 1.170, loss_g: 1.065
[16/200]: loss_d: 1.187, loss_g: 1.034
[17/200]: loss_d: 1.190, loss_g: 1.032
