# GAN for Minist

In [2]:
# import libraries
import os
import imageio
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [3]:
# Ensure the output directories exist
output_dir = 'output/minist_GAN'
checkpoint_dir = 'checkpoint/minist_GAN'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

In [4]:
# Parameters
epochs = 300
batch_size = 64
learning_rate = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
img_size = 28
channels = 1 # gray

img_shape = (channels, img_size, img_size) 

In [6]:
# Check CUDA's presence
cuda_is_present = True if torch.cuda.is_available() else False
print(cuda_is_present)

False


In [7]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        def layer_block(input_size, output_size, normalize=True):
            layers = [nn.Linear(input_size, output_size)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_size, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *layer_block(latent_dim, 128, normalize=False),
            *layer_block(128, 256),
            *layer_block(256, 512),
            *layer_block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        verdict = self.model(img_flat)
        return verdict

In [9]:
# Initialize models and loss function
generator = Generator()
discriminator = Discriminator()
adversarial_loss = torch.nn.BCELoss()

if cuda_is_present:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

In [10]:
# Load MNIST dataset
os.makedirs('datasets/mnist', exist_ok=True)
data_loader = DataLoader(
    datasets.MNIST('datasets/mnist', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])),
    batch_size=batch_size, shuffle=True)

In [12]:
Tensor = torch.cuda.FloatTensor if cuda_is_present else torch.FloatTensor

optimizer_generator = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(b1, b2))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(b1, b2))

losses = []
images_for_gif = []
for epoch in range(1, epochs + 1):
    for i, (images, _) in enumerate(data_loader):
        real_images = Variable(images.type(Tensor))
        real_output = Variable(Tensor(images.size(0), 1).fill_(1.0), requires_grad=False)
        fake_output = Variable(Tensor(images.size(0), 1).fill_(0.0), requires_grad=False)

        # Training Generator
        optimizer_generator.zero_grad()
        z = Variable(Tensor(np.random.normal(0, 1, (images.shape[0], latent_dim))))
        generated_images = generator(z)
        generator_loss = adversarial_loss(discriminator(generated_images), real_output)
        generator_loss.backward()
        optimizer_generator.step()

        # Training Discriminator
        optimizer_discriminator.zero_grad()
        discriminator_loss_real = adversarial_loss(discriminator(real_images), real_output)
        discriminator_loss_fake = adversarial_loss(discriminator(generated_images.detach()), fake_output)
        discriminator_loss = (discriminator_loss_real + discriminator_loss_fake) / 2
        discriminator_loss.backward()
        optimizer_discriminator.step()
        
    print(f"[Epoch {epoch:=4d}/{epochs}] [Batch {i:=4d}/{len(data_loader)}] ---> "
        f"[D Loss: {discriminator_loss.item():.6f}] [G Loss: {generator_loss.item():.6f}]")

    losses.append((generator_loss.item(), discriminator_loss.item()))
    if epoch % 10 == 0:
        image_filename = f'{output_dir}/images/epoch_{epoch}.png'
        os.makedirs(f'{output_dir}/images', exist_ok=True)
        save_image(generated_images.data[:25], image_filename, nrow=5, normalize=True)
        images_for_gif.append(imageio.imread(image_filename))
        # Save model checkpoints
        checkpoint_path = os.path.join(checkpoint_dir, f'generator_epoch_{epoch}.pth')
        torch.save(generator.state_dict(), checkpoint_path)
        checkpoint_path = os.path.join(checkpoint_dir, f'discriminator_epoch_{epoch}.pth')
        torch.save(discriminator.state_dict(), checkpoint_path)

[Epoch    1/300] [Batch    0/938] ---> [D Loss: 0.493361] [G Loss: 2.044260]
[Epoch    1/300] [Batch    1/938] ---> [D Loss: 0.943768] [G Loss: 0.189669]
[Epoch    1/300] [Batch    2/938] ---> [D Loss: 0.410894] [G Loss: 0.778165]
[Epoch    1/300] [Batch    3/938] ---> [D Loss: 0.471762] [G Loss: 1.880143]
[Epoch    1/300] [Batch    4/938] ---> [D Loss: 0.407336] [G Loss: 1.502279]
[Epoch    1/300] [Batch    5/938] ---> [D Loss: 0.456592] [G Loss: 0.780123]
[Epoch    1/300] [Batch    6/938] ---> [D Loss: 0.415473] [G Loss: 0.846547]
[Epoch    1/300] [Batch    7/938] ---> [D Loss: 0.395354] [G Loss: 1.188578]
[Epoch    1/300] [Batch    8/938] ---> [D Loss: 0.452313] [G Loss: 1.138787]
[Epoch    1/300] [Batch    9/938] ---> [D Loss: 0.530754] [G Loss: 0.798984]
[Epoch    1/300] [Batch   10/938] ---> [D Loss: 0.526021] [G Loss: 0.870470]
[Epoch    1/300] [Batch   11/938] ---> [D Loss: 0.512091] [G Loss: 0.984320]
[Epoch    1/300] [Batch   12/938] ---> [D Loss: 0.535481] [G Loss: 0.789073]

In [None]:
# Visualizing the losses at every epoch
losses = np.array(losses)
plt.plot(losses.T[0], label='Generator')
plt.plot(losses.T[1], label='Discriminator')
plt.title("Training Losses")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f'{output_dir}/loss_plot.png')

In [None]:
# Creating a gif of generated images at every epoch
imageio.mimwrite(f'{output_dir}/generated_images.gif', images_for_gif, fps=len(images_for_gif)/5)

# Load the final trained models for generating MNIST images
generator.load_state_dict(torch.load(os.path.join(checkpoint_dir, f'generator_epoch_{epochs}.pth')))
generator.eval()

# Generate some MNIST images
num_images = 10
z = Variable(Tensor(np.random.normal(0, 1, (num_images, latent_dim))))
generated_images = generator(z)

# Save generated images
generated_image_filename = f'{output_dir}/final_generated_images.png'
save_image(generated_images.data, generated_image_filename, nrow=num_images, normalize=True)
print(f'Generated images saved to {generated_image_filename}')