<div
  style="
    background-color: #f0f0f0;
    color:rgb(56, 56, 56);
    padding: 8px;
    display: flex;
    align-items: center;
    gap: 100px;
  "
>
  <img src="./images/brand.svg" style="max-height: 80px;">
  <strong>
    AI Saga: Deep Learning & Generative AI</br>
    2.Lab GAN Explorer: CNNs as Discriminators and Bias Detection</br>
  </strong>
  <emph>
    Student Name: [Complete with your name]</br>
    Date: [Fill up with the submission date]</br>
  </emph>
</div>

## Background

This lab explores how CNNs function as "critics" in generative AI systems. You'll work with a Generative Adversarial Network (GAN), modify its CNN-based discriminator, and analyze the model for biases. This hands-on experience demonstrates the dual nature of CNNs: not just as classifiers, but as quality evaluators guiding generation.

---

## Description

This lab is structured around two parts.

### Part 1: Understanding the GAN Discriminator

You will:
- Load Fashion-MNIST dataset and set up data pipelines
- Review a provided generator architecture
- **Implement** a CNN discriminator from scratch
- **Complete** the adversarial training loop
- **Visualize** training dynamics (loss curves, generated samples, discriminator confidence)
- **Experiment** by modifying the discriminator architecture and observing impact

**Key Deliverables:**
- Working discriminator implementation
- Complete GAN training loop
- Visualizations of training progress
- Before/after comparison of architecture modification
- Answers to 2 reflection questions

### Part 2: Bias Detection and Critical Analysis

You will:
- Analyze Fashion-MNIST class distribution
- Train a simple classifier to label generated images
- **Identify** which clothing categories are over/under-represented in generated samples
- **Document** 2-3 specific bias examples with evidence (e.g., class imbalance, quality variance, mode collapse)
- **Propose and test** one mitigation strategy (e.g., weighted sampling, conditional generation, architecture changes)
- **Reflect** on ethical implications of deploying biased generative models

**Key Deliverables:**
- Class distribution analysis (real vs. generated)
- 2-3 documented bias examples with:
  - Clear description
  - Quantitative evidence
  - Potential source analysis
  - Impact assessment
- One tested mitigation strategy with results
- Ethical reflection (3-4 sentences)

---

## Instructions

1. **Work through sections sequentially** - Part 1 must be completed before Part 2
2. **Complete all code cells marked with `# TODO`** - These are required for full credit
3. **Answer all reflection questions** - Use the designated markdown cells
4. **Document your findings** - Use clear visualizations and written analysis

In [None]:
# ⚠️ IMPORTANT NOTICE FOR STUDENTS ⚠️
#
# Please make sure to check the official instructions for this assignment in Canvas LMS
# as they may have been updated or changed. The instructions above are provided for
# reference only and may not reflect the most current requirements.
#
# Always refer to Canvas LMS for:
# - Latest assignment requirements
# - Due dates
# - Grading criteria
# - Any special instructions
#
# When in doubt, ask your instructor for clarification.

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from collections import defaultdict
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Part 1: Understanding the GAN Discriminator

### 1.1 Load Fashion-MNIST Dataset

Fashion-MNIST contains 70,000 grayscale images (28x28) across 10 clothing categories.

In [None]:
# TODO: Define data transformations
# Hint: Use transforms.ToTensor() and normalize to [-1, 1] range. Why we need to do this?
transform = transforms.Compose(
    [
        # Your code here
    ]
)

# TODO: Load Fashion-MNIST dataset
train_dataset = datasets.FashionMNIST(
    # Your code here
)

# TODO: Create DataLoader with batch_size=128
train_loader = DataLoader(
    # Your code here
)


# TODO: Fashion-MNIST class names
class_names = [
    # Your code here
]

print(f"Training samples: {len(train_dataset)}")
print(f"Number of batches: {len(train_loader)}")

### 1.2 Define the Generator Network

The generator takes random noise (latent vector) and transforms it into a 28x28 image.

**This is provided for you - review the architecture carefully!**

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.model(z)


# Test generator
generator = Generator(latent_dim=100).to(device)
test_noise = torch.randn(4, 100).to(device)
test_output = generator(test_noise)
print(f"Generator output shape: {test_output.shape}")

### 1.3 Define the CNN Discriminator Network

Complete the discriminator implementation.

The discriminator is a CNN that learns to distinguish between real and generated images.

#### Key Requirements:
- Input: 28x28 grayscale image (1 channel)
- Use Conv2d layers with stride=2 for downsampling
- Use LeakyReLU activation (negative_slope=0.2)
- Output: Single probability value (Sigmoid activation)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_layers=3, base_channels=64):
        super(Discriminator, self).__init__()

        # TODO: Build the discriminator architecture
        # Layer 1: 1 -> base_channels (28x28 -> 14x14)
        # Layer 2: base_channels -> base_channels*2 (14x14 -> 7x7)
        # Layer 3: base_channels*2 -> base_channels*4 (7x7 -> 3x3)
        # Final: base_channels*4 -> 1 (3x3 -> 1x1)

        # Hint: Use nn.Sequential to build the model
        self.model = nn.Sequential(
            # Your code here
        )

    def forward(self, img):
        # TODO: Pass image through model and return single probability
        # Hint: Use .view() or .squeeze() to get correct output shape
        pass


# Test discriminator
discriminator = Discriminator(num_layers=3, base_channels=64).to(device)
test_input = torch.randn(4, 1, 28, 28).to(device)
test_output = discriminator(test_input)
print(f"Discriminator output shape: {test_output.shape}")  # Should be [4]
print(f"Sample outputs: {test_output}")

#### REFLECTION QUESTION 1:

Compare this discriminator to a typical CNN classifier (e.g., for MNIST digit recognition):
- What's different about the final layer?
- Why do we use Sigmoid activation at the end instead of Softmax?
- What does the output represent?

**Your Answer:**

_(Write your response here - 3-5 sentences)_

### 1.4 Training Setup

In [None]:
# Hyperparameters
latent_dim = 100
lr = 0.0002
beta1 = 0.5
num_epochs = 10

# TODO: Initialize models
generator = # Your code here
discriminator = # Your code here

# TODO: Define loss function (Binary Cross Entropy)
criterion = # Your code here

# TODO: Define optimizers (Adam with lr=0.0002, betas=(0.5, 0.999))
optimizer_G = # Your code here
optimizer_D = # Your code here

print("Training setup complete!")

### 1.5 GAN Training Loop

Complete the adversarial training loop.

**Adversarial Training Process:**
1. **Train Discriminator**: Learn to distinguish real from fake
2. **Train Generator**: Learn to fool the discriminator
3. Repeat until equilibrium

In [None]:
# Storage for tracking
history = {
    'd_loss': [],
    'g_loss': [],
    'd_real_acc': [],
    'd_fake_acc': [],
    'generated_samples': []
}

# Fixed noise for visualization
fixed_noise = torch.randn(64, latent_dim).to(device)

for epoch in range(num_epochs):
    epoch_d_loss = 0
    epoch_g_loss = 0
    epoch_d_real_acc = 0
    epoch_d_fake_acc = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    
    for i, (real_imgs, _) in enumerate(pbar):
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        
        # TODO: Create labels (real=1, fake=0)
        real_labels = # Your code here
        fake_labels = # Your code here
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        # TODO: Zero gradients
        
        # TODO: Get discriminator output on real images
        real_output = # Your code here
        
        # TODO: Calculate loss on real images
        d_loss_real = # Your code here
        
        # TODO: Generate fake images
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = # Your code here (use generator)
        
        # TODO: Get discriminator output on fake images
        fake_output = # Your code here (use .detach() on fake_imgs!)
        
        # TODO: Calculate loss on fake images
        d_loss_fake = # Your code here
        
        # TODO: Total discriminator loss
        d_loss = # Your code here
        
        # TODO: Backpropagate and update discriminator
        
        # -----------------
        #  Train Generator
        # -----------------
        # TODO: Zero gradients
        
        # TODO: Generate new fake images
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = # Your code here
        
        # TODO: Get discriminator output
        fake_output = # Your code here
        
        # TODO: Generator loss (wants discriminator to think fakes are real!)
        g_loss = # Your code here (hint: use real_labels)
        
        # TODO: Backpropagate and update generator
        
        # Track metrics
        epoch_d_loss += d_loss.item()
        epoch_g_loss += g_loss.item()
        epoch_d_real_acc += (real_output > 0.5).float().mean().item()
        epoch_d_fake_acc += (fake_output < 0.5).float().mean().item()
    
    # Store epoch metrics
    num_batches = len(train_loader)
    history['d_loss'].append(epoch_d_loss / num_batches)
    history['g_loss'].append(epoch_g_loss / num_batches)
    history['d_real_acc'].append(epoch_d_real_acc / num_batches)
    history['d_fake_acc'].append(epoch_d_fake_acc / num_batches)
    
    # Generate samples for visualization
    with torch.no_grad():
        samples = generator(fixed_noise).cpu()
        history['generated_samples'].append(samples)
    
    print(f"Epoch {epoch+1}: D_loss={history['d_loss'][-1]:.4f}, G_loss={history['g_loss'][-1]:.4f}")

print("Training complete!")

### 1.6 Visualize Training Progress

In [None]:
# TODO: Create two subplots showing:
# 1. Discriminator and Generator loss over epochs
# 2. Discriminator accuracy on real vs. fake images

fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Your visualization code here

plt.tight_layout()
plt.show()

#### REFLECTION QUESTION 2:

Looking at the loss curves:
- What would it mean if the discriminator loss goes to zero?
- What would it mean if the generator loss goes to zero?
- Why is it important that both losses stay relatively balanced?

**Your Answer:**

_(Write your response here - 5-7 sentences)_

### 1.7 Visualize Generated Samples at Different Training Stages

In [None]:
def show_generated_images(samples, epoch, n_images=16):
    """Display a grid of generated images."""
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    fig.suptitle(f"Generated Samples - Epoch {epoch}", fontsize=14)

    for idx, ax in enumerate(axes.flat):
        img = samples[idx].squeeze().numpy()
        img = (img + 1) / 2  # Denormalize
        ax.imshow(img, cmap="gray")
        ax.axis("off")

    plt.tight_layout()
    plt.show()


# TODO: Show samples at beginning, middle, and end of training
checkpoints = [0, num_epochs // 2, num_epochs - 1]
for checkpoint in checkpoints:
    show_generated_images(history["generated_samples"][checkpoint], checkpoint + 1)

### 1.8 Discriminator Confidence Analysis

In [None]:
# TODO: Evaluate discriminator on 1000 real and 1000 fake samples
# Create a histogram showing the distribution of discriminator scores

discriminator.eval()
real_scores = []
fake_scores = []

with torch.no_grad():
    # Your code here
    pass

# TODO: Plot histogram
plt.figure(figsize=(10, 5))
# Your plotting code here
plt.show()

print(f"Average discriminator score for REAL images: {np.mean(real_scores):.3f}")
print(f"Average discriminator score for FAKE images: {np.mean(fake_scores):.3f}")

### 1.9 EXPERIMENT: Modify the Discriminator Architecture

Modify the discriminator and observe the impact.

Try ONE of these modifications:
- Change `num_layers` (2 or 4 instead of 3)
- Change `base_channels` (32 or 128 instead of 64)
- Add dropout layers

Train for 5 epochs and compare results.

In [None]:
# TODO: Create modified discriminator and retrain
print("EXPERIMENT: Modified Discriminator")
print("=" * 50)

# Your experimental code here

#### EXPERIMENT Findings:

**Modification Tested:** _(describe what you changed)_

**Observed Impact:**

_(Write 5-7 sentences describing what happened when you modified the discriminator. Compare generation quality, training stability, and loss curves to the original model.)_

---

## Part 2: Bias Detection and Critical Analysis

### 2.1 Dataset Analysis: Class Distribution

In [None]:
# TODO: Count samples per class in Fashion-MNIST
class_counts = defaultdict(int)

# Your code here

# TODO: Create a bar chart showing class distribution
plt.figure(figsize=(12, 5))
# Your plotting code here
plt.show()

### 2.2 Generation Quality Analysis by Class

We'll train a simple classifier to label generated images, then analyze which classes are generated most/least.

In [None]:
# Simple classifier (provided)
class SimpleClassifier(nn.Module):
    def __init__(self):
        super(SimpleClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(), nn.Linear(64 * 7 * 7, 128), nn.ReLU(), nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)


# TODO: Train the classifier for 3 epochs
print("Training classifier...")
classifier = SimpleClassifier().to(device)

# Your training code here

print("Classifier training complete!")

In [None]:
# TODO: Generate 5000 samples and classify them
classifier.eval()
generator.eval()

generated_class_counts = defaultdict(int)

# Your code here

# TODO: Create side-by-side bar charts comparing real vs. generated distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Your plotting code here
plt.show()

# Print comparison table
print("\nGeneration Distribution Analysis:")
print(f"{'Class':<15} {'Real %':>10} {'Generated %':>15} {'Difference':>15}")
print("-" * 60)
# Your code here

### 2.3 Identify Failure Modes and Bias Examples

Find and visualize 2-3 specific examples of bias or quality issues.

In [None]:
# TODO: Generate samples, score them with the discriminator,
# and collect low-quality examples (discriminator score < 0.3)

failure_examples = defaultdict(list)

# Your code here

# TODO: Visualize failure examples for classes with poor generation

### 2.4 Bias Documentation

Document 2-3 specific bias examples with evidence.

#### BIAS EXAMPLE 1:

**Description:**

_(Describe the bias you observed - e.g., class imbalance, quality variance, etc.)_

**Evidence:**

_(Provide specific numbers, percentages, or observations from your analysis)_

**Potential Source:**

_(What might be causing this bias? Consider dataset, architecture, or training dynamics)_

**Impact:**

_(If deployed in a real application, how would this bias affect users or outcomes?)_

#### BIAS EXAMPLE 2:

**Description:**

_(Your analysis here)_

**Evidence:**

_(Your data here)_

**Potential Source:**

_(Your explanation here)_

**Impact:**

_(Your assessment here)_

#### BIAS EXAMPLE 3 (Optional):

**Description:**

**Evidence:**

**Potential Source:**

**Impact:**

### 2.5 Proposed Mitigation Strategy

Choose ONE mitigation approach and test it.

#### MITIGATION STRATEGY:

**Approach:**

_(Describe what mitigation strategy you chose - e.g., weighted sampling, conditional GAN, architecture changes, etc.)_

**Implementation:**

_(Explain how you implemented this mitigation)_

**Expected Impact:**

_(What improvements do you expect this to bring?)_

In [None]:
# TODO: Implement your chosen mitigation strategy
print("MITIGATION TEST")
print("=" * 60)

# Your code here

#### MITIGATION RESULTS:

**Effectiveness:**

_(Did your mitigation work? What improved? What didn't?)_

**Trade-offs:**

_(What are the pros and cons of this approach? Any new problems introduced?)_

### 2.6 Ethical Reflection

#### ETHICAL REFLECTION (3-4 sentences):

**What are the ethical implications of deploying a generative model with the biases you identified?**

_(Consider: Who might be harmed? How could biases be amplified? What responsibilities do developers have?)_

**Your reflection:**



