# VisionGen-Comparative-Study (Part 1)

_A Comparative Analysis of CGANs and Diffusion Models for Conditional Image Synthesis_

---

## Motivation

Generative models are transforming the field of computer vision, enabling AI systems to generate realistic images from simple inputs like class labels or text prompts. This project investigates:
- How do CGANs and Diffusion Models perform on classic and challenging datasets?
- What are the differences in sample quality, training stability, and controllability?
- Which model is more suitable for high-fidelity, class-conditional image synthesis?

## Approach

- **CGANs:**  
  Implemented and trained on MNIST (digits), Oxford-102 Flowers, and CUB-200 Birds datasets, using fully connected and convolutional architectures.

- **Diffusion Models:**  
  Integrated modern diffusion pipelines (e.g., DDPM, Stable Diffusion) using open-source libraries. Evaluated on the same datasets for direct comparison.

- **Comparative Analysis:**  
  Results evaluated both quantitatively (FID, Inception Score) and qualitatively (side-by-side image grids, training dynamics).

## Key Features

- End-to-end PyTorch and HuggingFace-based implementation
- Clean code, modular design, and ready-to-run Jupyter notebooks
- Direct, apples-to-apples comparison of CGANs and Diffusion Models
- Sample outputs, evaluation metrics, and visualizations included

---

## Author

**Ayushman Mishra**  
[GitHub: frMishR](https://github.com/frMishR)

---

## Research Inspiration

This project draws inspiration from foundational work in generative modeling, especially:

- **Conditional Generative Adversarial Nets (Mirza & Osindero, 2014):**  
  [arXiv:1411.1784](https://arxiv.org/abs/1411.1784)  
  The original CGAN paper provided the conceptual and technical basis for the CGAN implementations.

Recent advances in diffusion models and open-source research also guided the design of experiments and code.

---

## Project Origin

This project originated as a group submission for:

- **Course:** EEE 598: Generative AI – Theory and Practice  
- **Professor:** Dr. Lalitha Sankar (Arizona State University)  
- **Semester:** Spring 2025

### Original Contributors

- **Ayushman Mishra** (Solo Upgrade & Comparative Study, 2025)
- **Snavya Sai Munti Mudugu Badri Prasad** [GitHub: snavya0309](https://github.com/snavya0309)
- **Sushma Niresh** [GitHub: SushmaNiresh](https://github.com/SushmaNiresh)

> *Note: This repository represents a major solo upgrade and extension by Ayushman Mishra. All diffusion modeling, comparative analysis, and new documentation were developed independently by Ayushman Mishra (July–September 2025). The original CGAN implementation and dataset preparation were developed collaboratively by the above group.*

---

# VisionGen – Conditional DCGAN on Multi-Class Datasets

> **Notebook Focus**: This notebook implements and trains a **Conditional DCGAN (cDCGAN)** — a class-conditional variant of the Deep Convolutional GAN architecture — on curated image datasets such as **MNIST**, **Oxford 102 Flowers**, and **CUB-200 Birds**.

---

## Objective

This notebook explores **conditional image generation** using GANs. Unlike standard GANs that generate images without control, **Conditional GANs (cGANs)** allow us to **guide the generation process using class labels** (e.g., “rose,” “eagle,” “digit 3”).

We specifically adopt a **DCGAN-style architecture**, integrating label conditioning into both:
- The **Generator**, which takes noise + label as input and generates class-specific images
- The **Discriminator**, which receives both an image + label and learns to classify whether the image is real or fake **for that label**

---

## Key Concepts Covered

- **Label conditioning**: Injecting class information into both G and D using one-hot embeddings
- **DCGAN Architecture**: Convolutional generator and discriminator for stable image synthesis
- **Multi-class Datasets**: Trained on MNIST (digits), Oxford Flowers (102 classes), and CUB Birds (200 fine-grained classes)
- **Loss Functions**: Binary Cross Entropy (BCE) for adversarial loss
- **Visualization**: Generate and compare outputs for specific class labels
- **Training Dynamics**: Track loss curves, sample quality, and class-conditional accuracy visually

---

## Notebook Structure

1. **Imports & Setup** – Libraries, device config, seeds
2. **Dataset Handling** – Class-wise folder structure, transformations, label encoding
3. **Conditional DCGAN Architecture**
   - `Generator`: Accepts noise + class embedding → image
   - `Discriminator`: Accepts image + class embedding → real/fake score
4. **Training Loop** – Adversarial training with label guidance
5. **Sample Generation** – Periodic image sampling for visual inspection
6. **(Optional/Planned)**: FID / Inception Score metrics, Diffusion model comparison

---

## Why this matters

This notebook demonstrates:
- A full **end-to-end GAN training pipeline** with class-conditioning
- How to **scale to large multi-class datasets**
- The **visual power of generative models** in robotics, vision, and simulation use-cases

It sets the foundation for the upcoming **Diffusion model extension**, which will be added to the same notebook for comparative study.

---

### Setting up GPU.

In [None]:
import torch
print("Device:", torch.cuda.get_device_name(0))

### Imports & Setup.

- **Python**: 3.11.8
- **Key libraries detected**: PIL, difflib, math, matplotlib, numpy, os, random, scipy, sklearn, tarfile, torch, torchmetrics, torchvision, tqdm, urllib

In [None]:
!pip install torch torchvision numpy matplotlib pillow tqdm scikit-learn --quiet

In [None]:
import os, random, json, numpy as np, matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid, save_image

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

# reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

## 1. Flower Dataset (Oxford-102)

### Oxford-102 Configurations

In [None]:
dataset_name = "Oxford102"

# Hyperparameters
batch_size  = 64            # Flowers are heavier than MNIST, so batch smaller
z_dim       = 100           # Latent vector dimension
y_dim       = 102           # 102 flower categories
img_size    = 64            # Resize all to 64x64
img_channels= 3             # RGB images
img_dim     = img_size * img_size * img_channels
lr          = 0.0002
epochs      = 200           # Same as MNIST

save_dir = f"./outputs/{dataset_name}"
os.makedirs(save_dir, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

### Oxford-102 Dataset Download

In [None]:
import os
import urllib.request
import tarfile
import scipy.io

# Create target directory
oxford102_root = "./data/Oxford102"
os.makedirs(oxford102_root, exist_ok=True)

images_url = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz"
labels_url = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat"

# Target paths
images_tar_path = os.path.join(oxford102_root, "102flowers.tgz")
labels_mat_path = os.path.join(oxford102_root, "imagelabels.mat")

# Download images archive if not already downloaded
if not os.path.exists(images_tar_path):
    print("Downloading 102flowers.tgz...")
    urllib.request.urlretrieve(images_url, images_tar_path)
    print("Downloaded images archive!")

# Extract images if not already extracted
jpg_folder = os.path.join(oxford102_root, "jpg")
if not os.path.exists(jpg_folder):
    print("Extracting images...")
    with tarfile.open(images_tar_path, "r:gz") as tar:
        tar.extractall(path=oxford102_root)
    print("Extraction complete!")

# Download labels if not already downloaded
if not os.path.exists(labels_mat_path):
    print("Downloading imagelabels.mat...")
    urllib.request.urlretrieve(labels_url, labels_mat_path)
    print("Downloaded labels file!")

print("Oxford-102 dataset downloaded.")

### Oxford-102 Dataset Loader

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import scipy.io

# Define transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Paths
oxford102_root = "./data/Oxford102"
image_dir = os.path.join(oxford102_root, "jpg")
labels_mat_path = os.path.join(oxford102_root, "imagelabels.mat")

# Custom Dataset Class
class Oxford102Dataset(Dataset):
    def __init__(self, image_dir, labels_mat_path, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        
        # Load labels
        labels_data = scipy.io.loadmat(labels_mat_path)
        self.labels = labels_data["labels"][0]  # (8189,)
        
        # Load image file names
        self.image_files = sorted(os.listdir(image_dir))
        
        assert len(self.labels) == len(self.image_files), "Mismatch between number of labels and images!"
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx] - 1  # MATLAB indexing starts from 1 and not 0
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Instantiate dataset and dataloader
oxford_dataset = Oxford102Dataset(
    image_dir=image_dir,
    labels_mat_path=labels_mat_path,
    transform=transform
)

oxford_loader = DataLoader(
    oxford_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=0,
    pin_memory=True,
)

print(f"Oxford-102 dataset ready: {len(oxford_dataset)} images across {y_dim} classes.")

### The D & G Setup (Generator64 and Discriminator64) for Oxford-102
###### (Concept stays loyal to Original Mirza's cGAN, just upgraded Convolutional layers as images are 64×64×3, not 28×28×1)

In [None]:
import torch.nn as nn
import torch

# Utility function to expand one-hot labels spatially
def expand_y(y, h, w):
    """Expand y: (B,C) -> (B,C,h,w) by spatial tiling."""
    return y.view(y.size(0), y.size(1), 1, 1).expand(-1, -1, h, w)

# Generator for 64x64 images
class Generator64(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(z_dim + y_dim, 4 * 4 * 512),
            nn.BatchNorm1d(4 * 4 * 512),
            nn.ReLU(True)
        )
        self.conv_blocks = nn.ModuleList([
            nn.ConvTranspose2d(512 + y_dim, 256, 4, 2, 1, bias=False),  # 8x8
            nn.ConvTranspose2d(256 + y_dim, 128, 4, 2, 1, bias=False),  # 16x16
            nn.ConvTranspose2d(128 + y_dim, 64, 4, 2, 1, bias=False),   # 32x32
            nn.ConvTranspose2d(64 + y_dim, 3, 4, 2, 1, bias=False)      # 64x64
        ])
        self.bns = nn.ModuleList([
            nn.BatchNorm2d(256),
            nn.BatchNorm2d(128),
            nn.BatchNorm2d(64)
        ])

    def forward(self, z, y):  # z: (B, 100), y: (B, 102)
        x = torch.cat([z, y], dim=1)  # (B, 202)
        x = self.fc(x).view(-1, 512, 4, 4)  # (B, 512, 4, 4)

        for i, conv in enumerate(self.conv_blocks):
            h, w = x.shape[2], x.shape[3]
            y_exp = expand_y(y, h, w)
            x = torch.cat([x, y_exp], dim=1)
            x = conv(x)
            if i < 3:
                x = self.bns[i](x)
                x = nn.ReLU(True)(x)
            else:
                x = torch.tanh(x)
        return x

# Discriminator for 64x64 images
class Discriminator64(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_blocks = nn.ModuleList([
            nn.Conv2d(3 + y_dim, 64, 4, 2, 1, bias=True),   # 32x32
            nn.Conv2d(64 + y_dim, 128, 4, 2, 1, bias=True), # 16x16
            nn.Conv2d(128 + y_dim, 256, 4, 2, 1, bias=True),# 8x8
            nn.Conv2d(256 + y_dim, 512, 4, 2, 1, bias=True) # 4x4
        ])
        self.fc = nn.Linear(4 * 4 * 512, 1)  # Final linear layer for logits

    def forward(self, x, y):  # x: (B, 3, 64, 64), y: (B, 102)
        for conv in self.conv_blocks:
            h, w = x.shape[2], x.shape[3]
            y_exp = expand_y(y, h, w)
            x = torch.cat([x, y_exp], dim=1)
            x = conv(x)
            x = nn.LeakyReLU(0.2, inplace=True)(x)
        
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)             # Logits (use BCEWithLogitsLoss later)
        return x

### Optimizers + Loss Setup

In [None]:
import torch.optim as optim
import torch.nn as nn

# Initialize models
G = Generator64().to(DEVICE)
D = Discriminator64().to(DEVICE)

# Optimizers
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss Function
criterion = nn.BCEWithLogitsLoss()

print("Optimizers and Loss function = READY")

### One-Hot Encoding Helper

In [None]:
def make_one_hot(labels, num_classes):
    """
    Converts integer labels into one-hot encoded vectors.
    labels: (batch_size,) --> returns (batch_size, num_classes)
    """
    return torch.zeros(labels.size(0), num_classes, device=labels.device).scatter_(1, labels.unsqueeze(1), 1)

### Save Image Grid (Oxford-102, RGB)

In [None]:
import matplotlib.pyplot as plt

def save_image_grid(images, labels, epoch, save_dir, nrow=8):
    """
    Save a grid of generated images grouped by their labels.
    """
    images = images.detach().cpu()
    images = (images + 1) / 2  # Denormalize from [-1,1] to [0,1]
    
    batch_size = images.size(0)
    fig, axes = plt.subplots(nrows=(batch_size // nrow) + 1, ncols=nrow, figsize=(nrow * 2, (batch_size // nrow) * 2))
    axes = axes.flatten()

    for img, ax in zip(images, axes):
        img = img.permute(1, 2, 0)  # (C,H,W) --> (H,W,C)
        ax.imshow(img)
        ax.axis('off')
    
    for ax in axes[batch_size:]:
        ax.axis('off')
        
    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, f"epoch_{epoch:03d}.png"))
    plt.close()

## Training Loop for Oxford-102 (Please 'DO NOT RUN' unless willing to RE-TRAIN)

In [None]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

# Fixed noise and labels for consistent evaluation
fixed_z = torch.randn(y_dim * 2, z_dim, device=DEVICE)
fixed_labels = torch.arange(y_dim, device=DEVICE).repeat_interleave(2)
fixed_labels_onehot = make_one_hot(fixed_labels, y_dim)

D_losses = []
G_losses = []

for epoch in range(1, epochs + 1):
    loop = tqdm(oxford_loader, leave=False, desc=f"Epoch {epoch}/{epochs}")
    for real_imgs, labels in loop:
        real_imgs, labels = real_imgs.to(DEVICE), labels.to(DEVICE)
        bsize = real_imgs.size(0)
        
        real_targets = torch.ones(bsize, 1, device=DEVICE)
        fake_targets = torch.zeros(bsize, 1, device=DEVICE)
        
        labels_onehot = make_one_hot(labels.long(), y_dim)  

        # Train Discriminator
        D_optimizer.zero_grad()
        
        D_real = D(real_imgs, labels_onehot)
        D_loss_real = criterion(D_real, real_targets)
        
        z = torch.randn(bsize, z_dim, device=DEVICE)
        fake_imgs = G(z, labels_onehot)
        D_fake = D(fake_imgs.detach(), labels_onehot)
        D_loss_fake = criterion(D_fake, fake_targets)
        
        D_loss = D_loss_real + D_loss_fake
        D_loss.backward()
        D_optimizer.step()
        
        # Train Generator
        G_optimizer.zero_grad()
        
        z = torch.randn(bsize, z_dim, device=DEVICE)
        fake_imgs = G(z, labels_onehot)
        D_fake = D(fake_imgs, labels_onehot)
        G_loss = criterion(D_fake, real_targets)
        
        G_loss.backward()
        G_optimizer.step()
        
        D_losses.append(D_loss.item())
        G_losses.append(G_loss.item())
        
    
    print(f"Epoch [{epoch}/{epochs}]  D_loss: {D_loss.item():.4f}  G_loss: {G_loss.item():.4f}")
    
    # Save and Show Fake Samples every 10 epochs
    if epoch % 10 == 0 or epoch == 1:
        with torch.no_grad():
            fake_imgs_fixed = G(fixed_z, fixed_labels_onehot)
        
        # Save to file
        save_image_grid(fake_imgs_fixed, fixed_labels, epoch, save_dir)
        
        # Show in notebook
        fig, axes = plt.subplots(4, 8, figsize=(16, 8))
        axes = axes.flatten()
        fake_imgs_fixed = fake_imgs_fixed.detach().cpu()
        fake_imgs_fixed = (fake_imgs_fixed + 1) / 2  # denormalize

        for img, ax in zip(fake_imgs_fixed, axes):
            img = img.permute(1, 2, 0)
            ax.imshow(img)
            ax.axis('off')

        for ax in axes[len(fake_imgs_fixed):]:
            ax.axis('off')

        plt.suptitle(f"Generated Samples at Epoch {epoch}", fontsize=16)
        plt.tight_layout()
        plt.show()


np.save(os.path.join(save_dir, "D_losses.npy"), np.array(D_losses))
np.save(os.path.join(save_dir, "G_losses.npy"), np.array(G_losses))
torch.save(G.state_dict(), os.path.join(save_dir, "generator.pth"))
torch.save(D.state_dict(), os.path.join(save_dir, "discriminator.pth"))

print("Training done and all models and losses saved.")

### Load Trained Generator + Generate Flowers + Map Label to Flower Name

In [None]:
import torch
import matplotlib.pyplot as plt
import os

# Load trained Generator
G = Generator64().to(DEVICE)
G.load_state_dict(torch.load(os.path.join(save_dir, "generator.pth"), map_location=DEVICE))
G.eval()  # Set to evaluation mode

# Prepare Correct Class Names (Oxford-102 flowers)
class_names = [
    'pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold',
    'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon',
    "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower',
    'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary',
    'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke',
    'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly',
    'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy',
    'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup',
    'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium',
    'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata',
    'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus',
    'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple',
    'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus',
    'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', 'hippeastrum ',
    'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia',
    'blanket flower', 'trumpet creeper', 'blackberry lily'
]

# Generate samples
fixed_z = torch.randn(y_dim, z_dim, device=DEVICE)  # 1 sample per class
fixed_labels = torch.arange(y_dim, device=DEVICE)
fixed_labels_onehot = make_one_hot(fixed_labels, y_dim)

with torch.no_grad():
    fake_imgs = G(fixed_z, fixed_labels_onehot)

# Show and map flowers
fig, axes = plt.subplots(8, 13, figsize=(18, 12))
axes = axes.flatten()

fake_imgs = (fake_imgs + 1) / 2  # Denormalize from [-1,1] to [0,1]

for idx, (img, ax) in enumerate(zip(fake_imgs, axes)):
    img = img.detach().cpu().permute(1, 2, 0)
    ax.imshow(img)
    if idx < len(class_names):
        ax.set_title(f"{class_names[idx]}", fontsize=6)
    ax.axis('off')

# Hide extra axes if any
for ax in axes[len(fake_imgs):]:
    ax.axis('off')

plt.tight_layout()
plt.show()

### Markdown list of 102 Flower Classes (Oxford-102)

1. Pink primrose
2. Hard-leaved pocket orchid
3. Canterbury bells
4. Sweet pea
5. English marigold
6. Tiger lily
7. Moon orchid
8. Bird of paradise
9. Monkshood
10. Globe thistle
11. Snapdragon
12. Colt’s foot
13. King protea
14. Spear thistle
15. Yellow iris
16. Globe-flower
17. Purple coneflower
18. Peruvian lily
19. Ball moss
20. Mexican petunia
21. Bromelia
22. Blanket flower
23. Trumpet creeper
24. Black-eyed susan
25. Pontederia
26. Bolero deep blue
27. Bougainvillea
28. Camellia
29. Mallow
30. Mexican aster
31. Alpine sea holly
32. Ruby-lipped cattleya
33. Cape flower
34. Great masterwort
35. Siam tulip
36. Lenten rose
37. Barbeton daisy
38. Daffodil
39. Sword lily
40. Poinsettia
41. Gaura
42. Geranium
43. Orange dahlia
44. Pink-yellow dahlia
45. Cautleya spicata
46. Japanese anemone
47. Blackberry lily
48. Tree poppy
49. Gazania
50. Azalea
51. Water lily
52. Rose
53. Thorn apple
54. Morning glory
55. Passion flower
56. Lotus
57. Toad lily
58. Anthurium
59. Frangipani
60. Clematis
61. Hibiscus
62. Columbine
63. Desert-rose
64. Tree mallow
65. Magnolia
66. Cyclamen
67. Watercress
68. Canna lily
69. Hippeastrum
70. Bee balm
71. Balloon flower
72. Oxeye daisy
73. Fire lily
74. Pincushion flower
75. Fritillary
76. Red ginger
77. Grape hyacinth
78. Corn poppy
79. Prince of wales feathers
80. Stemless gentian
81. Artichoke
82. Sweet william
83. Carnation
84. Garden phlox
85. Love in the mist
86. Mexican sunflower
87. Wild pansy
88. Primula
89. Sunflower
90. Pelargonium
91. Bishop of llandaff
92. Gaillardia
93. Buttercup
94. Oxlip
95. Tiger flower
96. Rose mallow
97. Snapdragon
98. Columbine
99. Colt’s foot
100. King protea
101. Spear thistle
102. Globe thistle

In [None]:
import os
import torch
import random
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import scipy.io

# Paths
real_image_folder = os.path.join(oxford102_root, "jpg")
labels_mat_path = os.path.join(oxford102_root, "imagelabels.mat")

class_names = [
    'pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold',
    'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon',
    "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower',
    'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary',
    'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke',
    'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly',
    'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy',
    'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup',
    'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium',
    'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata',
    'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus',
    'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple',
    'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus',
    'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', 'hippeastrum ',
    'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia',
    'blanket flower', 'trumpet creeper', 'blackberry lily'
]

#### Example : Corn Poppy

In [None]:
# Selecting Flower
selected_flower_name = "corn poppy"
selected_label = class_names.index(selected_flower_name)

print(f"Selected Flower: {selected_flower_name} (Label {selected_label})")

# Load labels
labels_data = scipy.io.loadmat(labels_mat_path)
real_labels = labels_data["labels"][0] - 1  # MATLAB .mat files are 1-indexed and not zero.

# Find all images matching selected label
image_files = sorted(os.listdir(real_image_folder))
matching_indices = [i for i, lbl in enumerate(real_labels) if lbl == selected_label]

# Randomly pick one real sample
real_idx = random.choice(matching_indices)
real_img_path = os.path.join(real_image_folder, image_files[real_idx])

# Load real image
real_img = Image.open(real_img_path).convert("RGB")
real_img = real_img.resize((img_size, img_size))

# Generate fake image
z = torch.randn(1, z_dim, device=DEVICE)
label_tensor = torch.tensor([selected_label], device=DEVICE)
label_onehot = make_one_hot(label_tensor, y_dim)

with torch.no_grad():
    fake_img = G(z, label_onehot)

fake_img = (fake_img + 1) / 2  # Denormalize
fake_img = fake_img.squeeze(0).permute(1, 2, 0).cpu().numpy()

# side-by-side
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Real image
axes[0].imshow(real_img)
axes[0].set_title(f"Real: {selected_flower_name}")
axes[0].axis('off')

# Generated image
axes[1].imshow(fake_img)
axes[1].set_title(f"Generated: {selected_flower_name}")
axes[1].axis('off')

plt.tight_layout()
plt.show()

#### Example : English Marigold

In [None]:
# Selecting Flower
selected_flower_name = "english marigold"
selected_label = class_names.index(selected_flower_name)

print(f"Selected Flower: {selected_flower_name} (Label {selected_label})")

# Load labels
labels_data = scipy.io.loadmat(labels_mat_path)
real_labels = labels_data["labels"][0] - 1  

# Find all images matching selected label
image_files = sorted(os.listdir(real_image_folder))
matching_indices = [i for i, lbl in enumerate(real_labels) if lbl == selected_label]

# Randomly pick one real sample
real_idx = random.choice(matching_indices)
real_img_path = os.path.join(real_image_folder, image_files[real_idx])

# Load real image
real_img = Image.open(real_img_path).convert("RGB")
real_img = real_img.resize((img_size, img_size))

# Generate fake image
z = torch.randn(1, z_dim, device=DEVICE)
label_tensor = torch.tensor([selected_label], device=DEVICE)
label_onehot = make_one_hot(label_tensor, y_dim)

with torch.no_grad():
    fake_img = G(z, label_onehot)

fake_img = (fake_img + 1) / 2  # Denormalize
fake_img = fake_img.squeeze(0).permute(1, 2, 0).cpu().numpy()

# side-by-side
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Real image
axes[0].imshow(real_img)
axes[0].set_title(f"Real: {selected_flower_name}")
axes[0].axis('off')

# Generated image
axes[1].imshow(fake_img)
axes[1].set_title(f"Generated: {selected_flower_name}")
axes[1].axis('off')

plt.tight_layout()
plt.show()

#### Example : Snapdragon

In [None]:
# Selecting Flower
selected_flower_name = "snapdragon"
selected_label = class_names.index(selected_flower_name)

print(f"Selected Flower: {selected_flower_name} (Label {selected_label})")

# Load labels
labels_data = scipy.io.loadmat(labels_mat_path)
real_labels = labels_data["labels"][0] - 1  

# Find all images matching selected label
image_files = sorted(os.listdir(real_image_folder))
matching_indices = [i for i, lbl in enumerate(real_labels) if lbl == selected_label]

# Randomly pick one real sample
real_idx = random.choice(matching_indices)
real_img_path = os.path.join(real_image_folder, image_files[real_idx])

# Load real image
real_img = Image.open(real_img_path).convert("RGB")
real_img = real_img.resize((img_size, img_size))

# Generate fake image
z = torch.randn(1, z_dim, device=DEVICE)
label_tensor = torch.tensor([selected_label], device=DEVICE)
label_onehot = make_one_hot(label_tensor, y_dim)

with torch.no_grad():
    fake_img = G(z, label_onehot)

fake_img = (fake_img + 1) / 2  # Denormalize
fake_img = fake_img.squeeze(0).permute(1, 2, 0).cpu().numpy()

# side-by-side
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

# Real image
axes[0].imshow(real_img)
axes[0].set_title(f"Real: {selected_flower_name}")
axes[0].axis('off')

# Generated image
axes[1].imshow(fake_img)
axes[1].set_title(f"Generated: {selected_flower_name}")
axes[1].axis('off')

plt.tight_layout()
plt.show()

### FID / IS for Oxford-102 dataset

In [None]:
!pip install torch-fidelity

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

In [None]:
import torch
from torch.utils.data import DataLoader
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import numpy as np

# number of samples to evaluate
num_samples = 1000
batch_eval  = 64  

# Generate all fake images in CPU-friendly batches
fake_uint8_list = []
G.eval()
with torch.no_grad():
    for i in range(0, num_samples, batch_eval):
        z_batch = torch.randn(batch_eval, z_dim, device=DEVICE)
        labels = torch.randint(0, y_dim, (batch_eval,), device=DEVICE)
        oh = make_one_hot(labels, y_dim)
        imgs = G(z_batch, oh).cpu()                 # (B,3,64,64) float in [-1,1]
        imgs = (imgs + 1) / 2                       # to [0,1]
        imgs = (imgs * 255).clamp(0,255).to(torch.uint8)
        fake_uint8_list.append(imgs)
fake_uint8 = torch.cat(fake_uint8_list, dim=0)

# Sample real images from the dataset in batches
real_uint8_list = []
real_loader = DataLoader(
    oxford_dataset, batch_size=batch_eval, shuffle=True, drop_last=True, num_workers=0
)
needed = num_samples
for real_imgs, _ in real_loader:
    if needed <= 0:
        break
    real = real_imgs[:needed].cpu()              # (B,3,64,64) normalized [-1,1]? no, ours is [–1,1]
    real = (real * 0.5 + 0.5)                    # to [0,1]
    real = (real * 255).clamp(0,255).to(torch.uint8)
    real_uint8_list.append(real)
    needed -= real.shape[0]
real_uint8 = torch.cat(real_uint8_list, dim=0)

# Instantiate metrics (compute_on_cpu to avoid GPU OOM)
fid = FrechetInceptionDistance(feature=64, compute_on_cpu=True).to('cpu')
is_scorer = InceptionScore(compute_on_cpu=True).to('cpu')

# Update FID in small chunks
for i in range(0, num_samples, batch_eval):
    fid.update(real_uint8[i:i+batch_eval], real=True)
    fid.update(fake_uint8[i:i+batch_eval], real=False)
fid_score = fid.compute().item()

# Update Inception Score on fake images
for i in range(0, num_samples, batch_eval):
    is_scorer.update(fake_uint8[i:i+batch_eval])
is_score, is_std = is_scorer.compute()


print(f"FID (real vs generated): {fid_score:.2f}")
print(f"Inception Score: {is_score:.2f} ± {is_std:.2f}")

### Loss Plot for Oxford-102

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os

def moving_average(x, window_size=200):
    return np.convolve(x, np.ones(window_size)/window_size, mode='valid')

# Path to saved losses
save_dir = "./outputs/Oxford102"
D_losses = np.load(os.path.join(save_dir, "D_losses.npy"))
G_losses = np.load(os.path.join(save_dir, "G_losses.npy"))

# Plot smoothed loss curves
plt.figure(figsize=(12, 6))
plt.plot(moving_average(D_losses), label="Discriminator Loss", color="tab:blue", alpha=0.9)
plt.plot(moving_average(G_losses), label="Generator Loss", color="tab:orange", alpha=0.9)
plt.xlabel("Training Iteration")
plt.ylabel("Loss")
plt.title("Oxford-102 cGAN Training Loss Curves (Smoothed)")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "oxford102_cgan_loss_curve.png"))
plt.show()

## 2. Bird Dataset (CUB-200)

### Configuration & Setup

In [None]:
import os
import torch

dataset_name = "CUB200"

# Hyperparameters
batch_size  = 64
z_dim       = 100
y_dim       = 200          # 200 bird species
img_size    = 128
img_channels= 3
img_dim     = img_size * img_size * img_channels
lr          = 0.0002
epochs      = 200

save_dir = f"./outputs/{dataset_name}"
os.makedirs(save_dir, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

### Dataset Download and Extraction

In [None]:
import urllib.request
import tarfile

cub_root = "./data/CUB_200_2011"
cub_tar_path = "./data/CUB_200_2011.tgz"
cub_url = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz"

os.makedirs("./data", exist_ok=True)

# Download dataset if not present
if not os.path.exists(cub_tar_path):
    print("Downloading CUB-200-2011...")
    urllib.request.urlretrieve(cub_url, cub_tar_path)
    print("Downloaded!")

# Extract dataset if not already extracted
if not os.path.exists(cub_root):
    print("Extracting CUB-200-2011...")
    with tarfile.open(cub_tar_path, "r:gz") as tar:
        tar.extractall(path="./data")
    print("Extracted!")

print("CUB-200 dataset ready.")

### Dataset Loader

In [None]:
import scipy.io
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# Define transforms
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Custom Dataset
class CUBDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        
        # Read image paths
        with open(os.path.join(root_dir, "images.txt")) as f:
            self.image_paths = [line.strip().split()[1] for line in f.readlines()]
        
        # Read labels
        with open(os.path.join(root_dir, "image_class_labels.txt")) as f:
            self.labels = [int(line.strip().split()[1]) - 1 for line in f.readlines()]  # readme indexing

        assert len(self.image_paths) == len(self.labels), "Mismatch between images and labels"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, "images", self.image_paths[idx])
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Instantiate
cub_dataset = CUBDataset(cub_root, transform=transform)
cub_loader = DataLoader(cub_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0, pin_memory=True)

print(f"CUB-200 Dataset ready: {len(cub_dataset)} images across {y_dim} classes.")

### Generator128 and Discriminator128

In [None]:
import torch
import torch.nn as nn

def expand_y(y, h, w):
    return y.view(y.size(0), y.size(1), 1, 1).expand(-1, -1, h, w)

class Generator128(nn.Module):
    def __init__(self):
        super().__init__()
        # project z+y → 8×8×512
        self.fc = nn.Sequential(
            nn.Linear(z_dim + y_dim, 8 * 8 * 512),
            nn.BatchNorm1d(8 * 8 * 512),
            nn.ReLU(True)
        )
        self.conv_blocks = nn.ModuleList([
            nn.ConvTranspose2d(512 + y_dim, 256, 4, 2, 1, bias=False),  # 8→16
            nn.ConvTranspose2d(256 + y_dim, 128, 4, 2, 1, bias=False),  # 16→32
            nn.ConvTranspose2d(128 + y_dim,  64, 4, 2, 1, bias=False),  # 32→64
            nn.ConvTranspose2d( 64 + y_dim,   3, 4, 2, 1, bias=False),  # 64→128
        ])
        self.bns = nn.ModuleList([
            nn.BatchNorm2d(256),
            nn.BatchNorm2d(128),
            nn.BatchNorm2d(64),
        ])

    def forward(self, z, y):
        # z: (B, z_dim), y: (B, y_dim)
        x = torch.cat([z, y], dim=1)
        x = self.fc(x).view(-1, 512, 8, 8)
        for i, conv in enumerate(self.conv_blocks):
            h, w = x.shape[2:]
            y_exp = expand_y(y, h, w)
            x = torch.cat([x, y_exp], dim=1)
            x = conv(x)
            if i < len(self.bns):
                x = self.bns[i](x)
                x = nn.ReLU(True)(x)
            else:
                x = torch.tanh(x)
        return x

class Discriminator128(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_blocks = nn.ModuleList([
            nn.Conv2d(3 + y_dim,   64, 4, 2, 1, bias=True),  # 128→64
            nn.Conv2d(64 + y_dim, 128, 4, 2, 1, bias=True),  # 64→32
            nn.Conv2d(128 + y_dim,256, 4, 2, 1, bias=True),  # 32→16
            nn.Conv2d(256 + y_dim,512, 4, 2, 1, bias=True),  # 16→8
        ])
        self.fc = nn.Linear(8 * 8 * 512, 1)

    def forward(self, x, y):
        # x: (B, 3, 128,128), y: (B, y_dim)
        for conv in self.conv_blocks:
            h, w = x.shape[2:]
            y_exp = expand_y(y, h, w)
            x = torch.cat([x, y_exp], dim=1)
            x = conv(x)
            x = nn.LeakyReLU(0.2, inplace=True)(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

### Optimizers and Loss

In [None]:
import torch.optim as optim

G = Generator128().to(DEVICE)
D = Discriminator128().to(DEVICE)

G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCEWithLogitsLoss()

print("Success!!")

### Helper Functions (make_one_hot, save_image_grid)

In [None]:
def make_one_hot(labels, num_classes):
    return torch.zeros(labels.size(0), num_classes, device=labels.device).scatter_(1, labels.unsqueeze(1), 1)

def save_image_grid(images, labels, epoch, save_dir, nrow=8):
    import matplotlib.pyplot as plt
    images = images.detach().cpu()
    images = (images + 1) / 2
    batch_size = images.size(0)
    fig, axes = plt.subplots(nrows=(batch_size // nrow) + 1, ncols=nrow, figsize=(nrow*2, (batch_size//nrow)*2))
    axes = axes.flatten()

    for img, ax in zip(images, axes):
        img = img.permute(1, 2, 0)
        ax.imshow(img)
        ax.axis('off')

    for ax in axes[batch_size:]:
        ax.axis('off')

    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, f"epoch_{epoch:03d}.png"))
    plt.close()

## Training Loop for CUB-200 (Please 'DO NOT RUN' unless willing to RE-TRAIN)

In [None]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

fixed_z = torch.randn(y_dim * 2, z_dim, device=DEVICE)
fixed_labels = torch.arange(y_dim, device=DEVICE).repeat_interleave(2)
fixed_labels_onehot = make_one_hot(fixed_labels, y_dim)

D_losses = []
G_losses = []

for epoch in range(1, epochs + 1):
    loop = tqdm(cub_loader, leave=False, desc=f"Epoch {epoch}/{epochs}")
    for real_imgs, labels in loop:
        real_imgs, labels = real_imgs.to(DEVICE), labels.to(DEVICE)
        bsize = real_imgs.size(0)

        real_targets = torch.ones(bsize, 1, device=DEVICE)
        fake_targets = torch.zeros(bsize, 1, device=DEVICE)

        labels_onehot = make_one_hot(labels, y_dim)

        D_optimizer.zero_grad()
        D_real = D(real_imgs, labels_onehot)
        D_loss_real = criterion(D_real, real_targets)

        z = torch.randn(bsize, z_dim, device=DEVICE)
        fake_imgs = G(z, labels_onehot)
        D_fake = D(fake_imgs.detach(), labels_onehot)
        D_loss_fake = criterion(D_fake, fake_targets)

        D_loss = D_loss_real + D_loss_fake
        D_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        z = torch.randn(bsize, z_dim, device=DEVICE)
        fake_imgs = G(z, labels_onehot)
        D_fake = D(fake_imgs, labels_onehot)
        G_loss = criterion(D_fake, real_targets)

        G_loss.backward()
        G_optimizer.step()

        D_losses.append(D_loss.item())
        G_losses.append(G_loss.item())

    print(f"Epoch [{epoch}/{epochs}]  D_loss: {D_loss.item():.4f}  G_loss: {G_loss.item():.4f}")

    if epoch % 10 == 0 or epoch == 1:
        with torch.no_grad():
            fake_imgs_fixed = G(fixed_z, fixed_labels_onehot)

        save_image_grid(fake_imgs_fixed, fixed_labels, epoch, save_dir)

        fig, axes = plt.subplots(4, 8, figsize=(16, 8))
        axes = axes.flatten()
        fake_imgs_fixed = fake_imgs_fixed.detach().cpu()
        fake_imgs_fixed = (fake_imgs_fixed + 1) / 2

        for img, ax in zip(fake_imgs_fixed, axes):
            img = img.permute(1, 2, 0)
            ax.imshow(img)
            ax.axis('off')

        for ax in axes[len(fake_imgs_fixed):]:
            ax.axis('off')

        plt.suptitle(f"Generated Samples at Epoch {epoch}", fontsize=16)
        plt.tight_layout()
        plt.show()

np.save(os.path.join(save_dir, "D_losses.npy"), np.array(D_losses))
np.save(os.path.join(save_dir, "G_losses.npy"), np.array(G_losses))
torch.save(G.state_dict(), os.path.join(save_dir, "generator.pth"))
torch.save(D.state_dict(), os.path.join(save_dir, "discriminator.pth"))

print("training done and models/losses saved.")

### One image for EACH CLASS with name shown

In [None]:
import torch
import matplotlib.pyplot as plt
import math

# Load the trained Generator
G = Generator128().to(DEVICE)
G.load_state_dict(torch.load(os.path.join(save_dir, "generator.pth")))
G.eval()

# Load class names
cub_classes = []
with open(os.path.join(cub_root, "classes.txt")) as f:
    for line in f:
        cub_classes.append(line.strip().split(" ", 1)[1])

# Generate one bird per class
def generate_bird_for_each_class():
    z = torch.randn(y_dim, z_dim, device=DEVICE)  # 200 random noise vectors
    labels = torch.arange(y_dim, device=DEVICE)   # Labels 0 to 199
    labels_onehot = make_one_hot(labels, y_dim)

    with torch.no_grad():
        fake_imgs = G(z, labels_onehot)

    fake_imgs = (fake_imgs + 1) / 2  # Denormalize
    fake_imgs = fake_imgs.detach().cpu()

    # Plotting
    ncols = 10
    nrows = math.ceil(y_dim / ncols)
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
    axes = axes.flatten()

    for img, label, ax in zip(fake_imgs, labels, axes):
        img = img.permute(1, 2, 0)
        ax.imshow(img)
        ax.set_title(cub_classes[label], fontsize=5)
        ax.axis('off')

    for ax in axes[len(fake_imgs):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# Call the function
generate_bird_for_each_class()

### All CUB-200 Bird Class Names

1. Black_Footed_Albatross  
2. Laysan_Albatross  
3. Sooty_Albatross  
4. Groove_Billed_Ani  
5. Crested_Auklet  
6. Least_Auklet  
7. Parakeet_Auklet  
8. Rhinoceros_Auklet  
9. Brewer_Blackbird  
10. Red_Winged_Blackbird  
11. Rusty_Blackbird  
12. Yellow_Headed_Blackbird  
13. Bobolink  
14. Indigo_Bunting  
15. Lazuli_Bunting  
16. Painted_Bunting  
17. Cardinal  
18. Spotted_Catbird  
19. Gray_Catbird  
20. Yellow_Breasted_Chat  
21. Eastern_Towhee  
22. Chuck_Will_Widow  
23. Brandt_Cormorant  
24. Red_Faced_Cormorant  
25. Pelagic_Cormorant  
26. Bronzed_Cowbird  
27. Shiny_Cowbird  
28. Brown_Creeper  
29. American_Crow  
30. Fish_Crow  
31. Black_Billed_Cuckoo  
32. Mangrove_Cuckoo  
33. Yellow_Billed_Cuckoo  
34. Gray_Crowned_Rosy_Finch  
35. Purple_Finch  
36. Northern_Flicker  
37. Acadian_Flycatcher  
38. Great_Crested_Flycatcher  
39. Least_Flycatcher  
40. Olive_Sided_Flycatcher  
41. Scissor_Tailed_Flycatcher  
42. Vermilion_Flycatcher  
43. Yellow_Bellied_Flycatcher  
44. Frigatebird  
45. Northern_Fulmar  
46. Gadwall  
47. American_Goldfinch  
48. European_Goldfinch  
49. Boat_Tailed_Grackle  
50. Eared_Grebe  
51. Horned_Grebe  
52. Pied_Billed_Grebe  
53. Western_Grebe  
54. Blue_Grosbeak  
55. Evening_Grosbeak  
56. Pine_Grosbeak  
57. Rose_Breasted_Grosbeak  
58. Pigeon_Guillemot  
59. California_Gull  
60. Glaucous_Winged_Gull  
61. Heermann_Gull  
62. Herring_Gull  
63. Ivory_Gull  
64. Ring_Billed_Gull  
65. Slaty_Backed_Gull  
66. Western_Gull  
67. Anna_Hummingbird  
68. Ruby_Throated_Hummingbird  
69. Rufous_Hummingbird  
70. Green_Violetear  
71. Long_Tailed_Jaeger  
72. Pomarine_Jaeger  
73. Blue_Jay  
74. Florida_Jay  
75. Green_Jay  
76. Dark_Eyed_Junco  
77. Northern_Jacana  
78. Green_Kingfisher  
79. Pied_Kingfisher  
80. Ringed_Kingfisher  
81. Belted_Kingfisher  
82. White_Breasted_Kingfisher  
83. Red_Legged_Kittiwake  
84. Horned_Lark  
85. Pacific_Loon  
86. Mallard  
87. Western_Meadowlark  
88. Hooded_Merganser  
89. Red_Breasted_Merganser  
90. Mockingbird  
91. Nighthawk  
92. Clark_Nutcracker  
93. White_Breasted_Nuthatch  
94. Baltimore_Oriole  
95. Hooded_Oriole  
96. Orchard_Oriole  
97. Scott_Oriole  
98. Ovenbird  
99. Brown_Pelican  
100. White_Pelican  
101. Western_Wood_Pewee  
102. Sayornis  
103. American_Pipit  
104. Whip_Poor_Will  
105. Horned_Puffin  
106. Common_Raven  
107. White_Necked_Raven  
108. American_Redstart  
109. Geococcyx  
110. Loggerhead_Shrike  
111. Great_Grey_Shrike  
112. Baird_Sparrow  
113. Black_Throated_Sparrow  
114. Brewer_Sparrow  
115. Chipping_Sparrow  
116. Clay_Colored_Sparrow  
117. House_Sparrow  
118. Field_Sparrow  
119. Fox_Sparrow  
120. Grasshopper_Sparrow  
121. Harris_Sparrow  
122. Henslow_Sparrow  
123. Le_Conte_Sparrow  
124. Lincoln_Sparrow  
125. Nelson_Sharp_Tailed_Sparrow  
126. Savannah_Sparrow  
127. Seaside_Sparrow  
128. Song_Sparrow  
129. Tree_Sparrow  
130. Vesper_Sparrow  
131. White_Crowned_Sparrow  
132. White_Throated_Sparrow  
133. Cape_Glossy_Starling  
134. Bank_Swallow  
135. Barn_Swallow  
136. Cliff_Swallow  
137. Tree_Swallow  
138. Scarlet_Tanager  
139. Summer_Tanager  
140. Artic_Tern  
141. Black_Tern  
142. Caspian_Tern  
143. Common_Tern  
144. Elegant_Tern  
145. Forsters_Tern  
146. Least_Tern  
147. Green_Tailed_Towhee  
148. Brown_Thrasher  
149. Sage_Thrasher  
150. Black_Capped_Vireo  
151. Blue_Headed_Vireo  
152. Philadelphia_Vireo  
153. Red_Eyed_Vireo  
154. Warbling_Vireo  
155. White_Eyed_Vireo  
156. Yellow_Throated_Vireo  
157. Bay_Breasted_Warbler  
158. Black_And_White_Warbler  
159. Black_Throated_Blue_Warbler  
160. Blue_Winged_Warbler  
161. Canada_Warbler  
162. Cape_May_Warbler  
163. Cerulean_Warbler  
164. Chestnut_Sided_Warbler  
165. Golden_Winged_Warbler  
166. Hooded_Warbler  
167. Kentucky_Warbler  
168. Magnolia_Warbler  
169. Mourning_Warbler  
170. Myrtle_Warbler  
171. Nashville_Warbler  
172. Orange_Crowned_Warbler  
173. Palm_Warbler  
174. Pine_Warbler  
175. Prairie_Warbler  
176. Prothonotary_Warbler  
177. Swainson_Warbler  
178. Tennessee_Warbler  
179. Wilson_Warbler  
180. Worm_Eating_Warbler  
181. Yellow_Warbler  
182. Northern_Waterthrush  
183. Louisiana_Waterthrush  
184. Bohemian_Waxwing  
185. Cedar_Waxwing  
186. American_Three_Toed_Woodpecker  
187. Downy_Woodpecker  
188. Hairy_Woodpecker  
189. Red_Bellied_Woodpecker  
190. Red_Cockaded_Woodpecker  
191. Pileated_Woodpecker  
192. Red_Headed_Woodpecker  
193. White_Breasted_Woodpecker  
194. American_Three_Toed_Woodpecker  
195. Bewick_Wren  
196. Cactus_Wren  
197. Carolina_Wren  
198. House_Wren  
199. Marsh_Wren  
200. Rock_Wren

##### We will be entering (target label = (corresponding number - 1)) for the bird we want to generate.

#### Example : Yellow_Warbler (Label = 181)

In [None]:
import torch
import matplotlib.pyplot as plt
import random
from PIL import Image

# Load trained Generator
G = Generator128().to(DEVICE)
G.load_state_dict(torch.load(os.path.join(save_dir, "generator.pth"), map_location=DEVICE))
G.eval()

# Load class names
class_names = []
with open(os.path.join(cub_root, "classes.txt")) as f:
    class_names = [line.strip().split(" ", 1)[1] for line in f.readlines()]

target_label = 181

# 1. Pick a random real image
real_indices = [i for i, lbl in enumerate(cub_dataset.labels) if lbl == target_label]
random_real_idx = random.choice(real_indices)
real_img_path = os.path.join(cub_root, "images", cub_dataset.image_paths[random_real_idx])
real_img = Image.open(real_img_path).convert("RGB")
real_img = transform(real_img)  # Apply same transform as training
real_img = (real_img + 1) / 2   # Denormalize to [0,1]
real_img = real_img.permute(1, 2, 0).cpu()

# 2. Generate a fake image
z = torch.randn(1, z_dim, device=DEVICE)
label_tensor = torch.tensor([target_label], device=DEVICE)
label_onehot = make_one_hot(label_tensor, y_dim)

with torch.no_grad():
    fake_img = G(z, label_onehot)

fake_img = (fake_img + 1) / 2  # Denormalize
fake_img = fake_img.squeeze(0).detach().cpu()
fake_img = fake_img.permute(1, 2, 0)

# Plot Real vs Fake
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

axes[0].imshow(real_img)
axes[0].set_title(f"Real: {class_names[target_label]}")
axes[0].axis('off')

axes[1].imshow(fake_img)
axes[1].set_title(f"Fake: {class_names[target_label]}")
axes[1].axis('off')

plt.suptitle(f"Comparison: Real vs Fake - {class_names[target_label]}", fontsize=14)
plt.tight_layout()
plt.show()

#### Example : Blue_Winged_Warbler (Label = 160)

In [None]:
import torch
import matplotlib.pyplot as plt
import random
from PIL import Image

# Load trained Generator
G = Generator128().to(DEVICE)
G.load_state_dict(torch.load(os.path.join(save_dir, "generator.pth"), map_location=DEVICE))
G.eval()

# Load class names
class_names = []
with open(os.path.join(cub_root, "classes.txt")) as f:
    class_names = [line.strip().split(" ", 1)[1] for line in f.readlines()]

target_label = 160

# 1. Pick a random real image
real_indices = [i for i, lbl in enumerate(cub_dataset.labels) if lbl == target_label]
random_real_idx = random.choice(real_indices)
real_img_path = os.path.join(cub_root, "images", cub_dataset.image_paths[random_real_idx])
real_img = Image.open(real_img_path).convert("RGB")
real_img = transform(real_img)  # Apply same transform as training
real_img = (real_img + 1) / 2   # Denormalize to [0,1]
real_img = real_img.permute(1, 2, 0).cpu()

# 2. Generate a fake image
z = torch.randn(1, z_dim, device=DEVICE)
label_tensor = torch.tensor([target_label], device=DEVICE)
label_onehot = make_one_hot(label_tensor, y_dim)

with torch.no_grad():
    fake_img = G(z, label_onehot)

fake_img = (fake_img + 1) / 2  # Denormalize
fake_img = fake_img.squeeze(0).detach().cpu()
fake_img = fake_img.permute(1, 2, 0)

# Plot Real vs Fake
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

axes[0].imshow(real_img)
axes[0].set_title(f"Real: {class_names[target_label]}")
axes[0].axis('off')

axes[1].imshow(fake_img)
axes[1].set_title(f"Fake: {class_names[target_label]}")
axes[1].axis('off')

plt.suptitle(f"Comparison: Real vs Fake - {class_names[target_label]}", fontsize=14)
plt.tight_layout()
plt.show()

#### Example : Black_Throated_Sparrow (Label = 113)

In [None]:
import torch
import matplotlib.pyplot as plt
import random
from PIL import Image

# Load trained Generator
G = Generator128().to(DEVICE)
G.load_state_dict(torch.load(os.path.join(save_dir, "generator.pth"), map_location=DEVICE))
G.eval()

# Load class names
class_names = []
with open(os.path.join(cub_root, "classes.txt")) as f:
    class_names = [line.strip().split(" ", 1)[1] for line in f.readlines()]

target_label = 113

# 1. Pick a random real image
real_indices = [i for i, lbl in enumerate(cub_dataset.labels) if lbl == target_label]
random_real_idx = random.choice(real_indices)
real_img_path = os.path.join(cub_root, "images", cub_dataset.image_paths[random_real_idx])
real_img = Image.open(real_img_path).convert("RGB")
real_img = transform(real_img)  # Apply same transform as training
real_img = (real_img + 1) / 2   # Denormalize to [0,1]
real_img = real_img.permute(1, 2, 0).cpu()

# 2. Generate a fake image
z = torch.randn(1, z_dim, device=DEVICE)
label_tensor = torch.tensor([target_label], device=DEVICE)
label_onehot = make_one_hot(label_tensor, y_dim)

with torch.no_grad():
    fake_img = G(z, label_onehot)

fake_img = (fake_img + 1) / 2  # Denormalize
fake_img = fake_img.squeeze(0).detach().cpu()
fake_img = fake_img.permute(1, 2, 0)

# Plot Real vs Fake
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

axes[0].imshow(real_img)
axes[0].set_title(f"Real: {class_names[target_label]}")
axes[0].axis('off')

axes[1].imshow(fake_img)
axes[1].set_title(f"Fake: {class_names[target_label]}")
axes[1].axis('off')

plt.suptitle(f"Comparison: Real vs Fake - {class_names[target_label]}", fontsize=14)
plt.tight_layout()
plt.show()

#### Example : Florida_Jay (Label = 74)

In [None]:
import torch
import matplotlib.pyplot as plt
import random
from PIL import Image

# Load trained Generator
G = Generator128().to(DEVICE)
G.load_state_dict(torch.load(os.path.join(save_dir, "generator.pth"), map_location=DEVICE))
G.eval()

# Load class names
class_names = []
with open(os.path.join(cub_root, "classes.txt")) as f:
    class_names = [line.strip().split(" ", 1)[1] for line in f.readlines()]

target_label = 73

# 1. Pick a random real image
real_indices = [i for i, lbl in enumerate(cub_dataset.labels) if lbl == target_label]
random_real_idx = random.choice(real_indices)
real_img_path = os.path.join(cub_root, "images", cub_dataset.image_paths[random_real_idx])
real_img = Image.open(real_img_path).convert("RGB")
real_img = transform(real_img)  # Apply same transform as training
real_img = (real_img + 1) / 2   # Denormalize to [0,1]
real_img = real_img.permute(1, 2, 0).cpu()

# 2. Generate a fake image
z = torch.randn(1, z_dim, device=DEVICE)
label_tensor = torch.tensor([target_label], device=DEVICE)
label_onehot = make_one_hot(label_tensor, y_dim)

with torch.no_grad():
    fake_img = G(z, label_onehot)

fake_img = (fake_img + 1) / 2  # Denormalize
fake_img = fake_img.squeeze(0).detach().cpu()
fake_img = fake_img.permute(1, 2, 0)

# Plot Real vs Fake
fig, axes = plt.subplots(1, 2, figsize=(8, 4))

axes[0].imshow(real_img)
axes[0].set_title(f"Real: {class_names[target_label]}")
axes[0].axis('off')

axes[1].imshow(fake_img)
axes[1].set_title(f"Fake: {class_names[target_label]}")
axes[1].axis('off')

plt.suptitle(f"Comparison: Real vs Fake - {class_names[target_label]}", fontsize=14)
plt.tight_layout()
plt.show()

### Generating using Class Labels

In [None]:
import difflib

name_to_label = {}
class_names = []
with open(os.path.join(cub_root, "classes.txt")) as f:
    for idx, line in enumerate(f):
        name = line.strip().split(" ", 1)[1]
        norm = name.lower().replace("_", " ")
        class_names.append(name)
        name_to_label[norm] = idx

def prompt_to_label(prompt):
    """Fuzzy‐match your text prompt to the closest CUB-200 class."""
    key = prompt.lower().strip()
    key = key.replace("_", " ")
    if key in name_to_label:
        return name_to_label[key]
    # otherwise use difflib to pick the closest match
    guesses = difflib.get_close_matches(key, name_to_label.keys(), n=1, cutoff=0.6)
    if not guesses:
        raise ValueError(f"No close match found for '{prompt}'")
    return name_to_label[guesses[0]]

def generate_from_text(prompt):
    """Generate a single fake bird given its English name."""
    label = prompt_to_label(prompt)
    z = torch.randn(1, z_dim, device=DEVICE)
    oh = make_one_hot(torch.tensor([label], device=DEVICE), y_dim)
    with torch.no_grad():
        img = G(z, oh)
    img = (img + 1)/2
    img = img.squeeze(0).cpu().permute(1,2,0)
    plt.figure(figsize=(4,4))
    plt.imshow(img)
    plt.title(f"{class_names[label]} ({label})", fontsize=10)
    plt.axis("off")
    plt.show()

In [None]:
generate_from_text("Yellow Warbler")

In [None]:
generate_from_text("blue jay")

In [None]:
generate_from_text("American Three Toed Woodpecker")

### FID / IS for CUB-200

In [None]:
import torch
import numpy as np
import random
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torch.utils.data import DataLoader

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

num_gen = 1000
batch_eval = 64


G = Generator128().to(DEVICE)
G.load_state_dict(torch.load(os.path.join(save_dir, "generator.pth"), map_location=DEVICE))
G.eval()

def make_one_hot(labels, num_classes):
    return torch.zeros(labels.size(0), num_classes, device=labels.device).scatter_(1, labels.unsqueeze(1), 1)


fake_imgs_list = []

with torch.no_grad():
    for i in range(0, num_gen, batch_eval):
        z = torch.randn(batch_eval, z_dim, device=DEVICE)
        labels = torch.randint(0, y_dim, (batch_eval,), device=DEVICE)
        one_hot_labels = make_one_hot(labels, y_dim)
        fake_imgs = G(z, one_hot_labels).cpu()
        fake_imgs = (fake_imgs + 1) / 2  # [-1,1] to [0,1]
        fake_imgs = (fake_imgs * 255).clamp(0, 255).to(torch.uint8)
        fake_imgs_list.append(fake_imgs)

fake_imgs_uint8 = torch.cat(fake_imgs_list, dim=0)


real_imgs_list = []
real_loader = DataLoader(cub_dataset, batch_size=batch_eval, shuffle=True, drop_last=True, num_workers=0)

needed = num_gen
for real_imgs, _ in real_loader:
    if needed <= 0:
        break
    real = real_imgs[:needed]
    real = (real * 0.5 + 0.5)  # [-1,1] to [0,1]
    real = (real * 255).clamp(0, 255).to(torch.uint8)
    real_imgs_list.append(real)
    needed -= real.shape[0]

real_imgs_uint8 = torch.cat(real_imgs_list, dim=0)

fid = FrechetInceptionDistance(feature=64, normalize=True).to('cpu')
for i in range(0, num_gen, batch_eval):
    fid.update(real_imgs_uint8[i:i+batch_eval], real=True)
    fid.update(fake_imgs_uint8[i:i+batch_eval], real=False)
fid_score = fid.compute().item()

is_scorer = InceptionScore(normalize=True).to('cpu')
for i in range(0, num_gen, batch_eval):
    is_scorer.update(fake_imgs_uint8[i:i+batch_eval])
is_score, is_std = is_scorer.compute()

print(f"\nFID (real vs generated): {fid_score:.2f}")
print(f"Inception Score: {is_score:.2f} ± {is_std:.2f}")

### Loss Plot for CUB-200

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os

def moving_average(x, window_size=200):
    return np.convolve(x, np.ones(window_size)/window_size, mode='valid')

# Path to saved losses
save_dir = "./outputs/CUB200"
D_losses = np.load(os.path.join(save_dir, "D_losses.npy"))
G_losses = np.load(os.path.join(save_dir, "G_losses.npy"))

# Plot smoothed loss curves
plt.figure(figsize=(12, 6))
plt.plot(moving_average(D_losses), label="Discriminator Loss", color="tab:blue", alpha=0.9)
plt.plot(moving_average(G_losses), label="Generator Loss", color="tab:orange", alpha=0.9)
plt.xlabel("Training Iteration")
plt.ylabel("Loss")
plt.title("CUB-200 cGAN Training Loss Curves (Smoothed)")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "cub200_cgan_loss_curve.png"))
plt.show()

# Completed.