In [None]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

# Hyperparameters
latent_dim = 100
lr = 0.0002
batch_size = 128
image_size = 28*28
num_epochs = 50

## Download Dataset

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

## Models

In [None]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, image_size),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)

## Training Stage

In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

# Loss function
criterion = nn.BCELoss()

# Training
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):

        real_images = images.to(device)                         # Move the real images to the device (GPU or CPU)
        real_labels = torch.ones(batch_size, 1).to(device)      # Create real labels (1s) with the current batch size
        fake_labels = torch.zeros(batch_size, 1).to(device)     # Create fake labels (0s) with the current batch size

        # # --------- Train the Discriminator --------- #
        d_optimizer.zero_grad()                                 # Zero the gradients of the discriminator optimizer
        outputs = discriminator(real_images)                    # Pass real images through the discriminator
        d_real_loss = criterion(outputs, real_labels)           # Calculate the loss for real images (how well does it recognize real images)
        z = torch.randn(batch_size, latent_dim).to(device)      # Generate random noise to produce fake images
        fake_images = generator(z)                              # Generate fake images from noise
        outputs = discriminator(fake_images.detach())           # Pass the fake images through the discriminator
        d_fake_loss = criterion(outputs, fake_labels)           # Calculate the loss for fake images (how well does it recognize fake images)
        d_loss = d_real_loss + d_fake_loss                      # Calculate total discriminator loss
        d_loss.backward()                                       # Update discriminator weights
        d_optimizer.step()

        # --------- Train the Generator --------- #
        g_optimizer.zero_grad()                                 # Zero the gradients of the generator optimizer
        outputs = discriminator(fake_images)                    # Pass the fake images to the discriminator again
        g_loss = criterion(outputs, real_labels)                # Calculate the generator loss (how well does it fool the discriminator)
        g_loss.backward()                                       # Update generate weights
        g_optimizer.step()

        if (i+1) % 400 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}')

    # Save generated images every epoch
    save_image(fake_images.reshape(fake_images.size(0), 1, 28, 28), './data/gan/fake_image-%03d.png' % (epoch+1))

print("Training complete.")