# Generative Adversarial Network (GAN) in PyTorch

This notebook implements a basic GAN using PyTorch to generate handwritten digits similar to the MNIST dataset.

## What is a GAN?
A Generative Adversarial Network consists of two neural networks:
- **Generator**: Creates fake images from random noise
- **Discriminator**: Distinguishes between real and fake images

These networks compete against each other, improving through adversarial training.

## 1. Import Required Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Define Hyperparameters

In [None]:
# Hyperparameters
latent_dim = 100  # Dimension of the noise vector
hidden_dim = 256  # Hidden layer dimension
image_dim = 28 * 28  # MNIST images are 28x28 pixels
batch_size = 64
num_epochs = 50
learning_rate = 0.0002

## 3. Load and Prepare MNIST Dataset

In [None]:
# Transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

# Download and load MNIST dataset
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

print(f'Number of batches: {len(train_loader)}')
print(f'Total images: {len(train_dataset)}')

## 4. Define the Generator Network

The Generator takes random noise as input and generates fake images.

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, hidden_dim, image_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim * 4, image_dim),
            nn.Tanh()  # Output in range [-1, 1]
        )
    
    def forward(self, x):
        return self.model(x)

## 5. Define the Discriminator Network

The Discriminator takes images as input and predicts whether they are real or fake.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, image_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(image_dim, hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Output probability [0, 1]
        )
    
    def forward(self, x):
        return self.model(x)

## 6. Initialize Networks and Optimizers

In [None]:
# Initialize Generator and Discriminator
generator = Generator(latent_dim, hidden_dim, image_dim).to(device)
discriminator = Discriminator(image_dim, hidden_dim).to(device)

# Loss function
criterion = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

print("Generator Architecture:")
print(generator)
print("\nDiscriminator Architecture:")
print(discriminator)

## 7. Training Loop

Train the GAN by alternating between:
1. Training the Discriminator to distinguish real from fake images
2. Training the Generator to fool the Discriminator

In [None]:
# Lists to store losses for plotting
G_losses = []
D_losses = []

# Fixed noise for visualization
fixed_noise = torch.randn(64, latent_dim).to(device)

print("Starting Training...")
for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(train_loader):
        batch_size = real_images.size(0)
        real_images = real_images.view(-1, image_dim).to(device)
        
        # Create labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # ==================== Train Discriminator ====================
        # Train on real images
        outputs = discriminator(real_images)
        d_loss_real = criterion(outputs, real_labels)
        
        # Train on fake images
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        
        # Backpropagation for Discriminator
        d_loss = d_loss_real + d_loss_fake
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()
        
        # ==================== Train Generator ====================
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        
        # Generator tries to fool discriminator
        g_loss = criterion(outputs, real_labels)
        
        # Backpropagation for Generator
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()
    
    # Store losses
    G_losses.append(g_loss.item())
    D_losses.append(d_loss.item())
    
    # Print progress
    if (epoch + 1) % 5 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

print("Training completed!")

## 8. Plot Training Losses

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(G_losses, label='Generator Loss')
plt.plot(D_losses, label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('GAN Training Losses')
plt.legend()
plt.grid(True)
plt.show()

## 9. Generate and Visualize Images

In [None]:
# Generate fake images
generator.eval()
with torch.no_grad():
    fake_images = generator(fixed_noise).cpu().view(-1, 28, 28)

# Plot generated images
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    if i < len(fake_images):
        ax.imshow(fake_images[i], cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Images from GAN', fontsize=16)
plt.tight_layout()
plt.show()

## 10. Compare Real vs Generated Images

In [None]:
# Get some real images
real_batch = next(iter(train_loader))
real_images_sample = real_batch[0][:32]

# Generate fake images
with torch.no_grad():
    noise = torch.randn(32, latent_dim).to(device)
    fake_images_sample = generator(noise).cpu().view(-1, 28, 28)

# Plot comparison
fig, axes = plt.subplots(4, 8, figsize=(12, 6))
fig.suptitle('Top 2 rows: Real Images | Bottom 2 rows: Generated Images', fontsize=14)

for i in range(16):
    axes[i // 8, i % 8].imshow(real_images_sample[i].squeeze(), cmap='gray')
    axes[i // 8, i % 8].axis('off')

for i in range(16):
    axes[(i // 8) + 2, i % 8].imshow(fake_images_sample[i], cmap='gray')
    axes[(i // 8) + 2, i % 8].axis('off')

plt.tight_layout()
plt.show()

## 11. Save the Trained Models (Optional)

In [None]:
# Save models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
print("Models saved successfully!")

## Conclusion

This notebook demonstrates a basic GAN implementation in PyTorch. The Generator learns to create realistic handwritten digits, while the Discriminator learns to distinguish between real and fake images.

### Key Takeaways:
- GANs consist of two competing networks
- Training involves alternating updates to both networks
- The Generator improves by trying to fool the Discriminator
- The Discriminator improves by better distinguishing real from fake

### Possible Improvements:
- Use deeper convolutional networks (DCGAN)
- Implement techniques like batch normalization
- Try different datasets (CelebA, CIFAR-10)
- Experiment with different architectures (WGAN, StyleGAN)