#Configuring project directory

In [None]:
# --- SETUP: Work Locally but Import Existing Models ---
from google.colab import drive
import os
import shutil

# Mount Google Drive
drive.mount('/content/drive')

# Work locally
os.chdir('/content/')

# Copy existing trained models FROM Drive TO local
drive_models_path = '/content/drive/MyDrive/589-mini-diffusion/MNISTMiniDiffusionResults' #change this to the location where existing models are
existing_models = [
    "ddpm_mnist_baseline_epoch_50.pth",
    "mnist_classifier.pth",
    "ddpm_mnist_pruned_finetuned_epoch_20.pth",
    "ddpm_mnist_reduced_channels_epoch_50.pth",
    "ddpm_mnist_shallow_network_epoch_50.pth",
    "ddpm_mnist_reduced_channels_finetuned_epoch_10.pth",
    "ddpm_mnist_shallow_network_finetuned_epoch_20.pth",
]

print("Copying existing models from Drive to local...")
for model_file in existing_models:
    drive_path = f"{drive_models_path}/{model_file}"
    if os.path.exists(drive_path):
        shutil.copy2(drive_path, f"/content/{model_file}")
        print(f"Copied: {model_file}")
    else:
        print(f"Not found: {model_file} (will train from scratch)")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Copying existing models from Drive to local...
Copied: ddpm_mnist_baseline_epoch_50.pth
Copied: mnist_classifier.pth
Copied: ddpm_mnist_pruned_finetuned_epoch_20.pth
Copied: ddpm_mnist_reduced_channels_epoch_50.pth
Copied: ddpm_mnist_shallow_network_epoch_50.pth
Copied: ddpm_mnist_reduced_channels_finetuned_epoch_10.pth
Copied: ddpm_mnist_shallow_network_finetuned_epoch_20.pth


# Imports and Installs

In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->torchmetrics)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->torchmetrics)
  D

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import math
import os
import gc
import time
import torch.nn.utils.prune as prune
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM

# Baseline Model

In [None]:
# --- 1. Configuration ---
class Config:
    image_size = 28
    batch_size = 128
    num_epochs = 50  # Baseline training epochs
    fine_tune_epochs = 20 # Fine-tuning epochs after pruning
    learning_rate = 1e-3
    pruning_structured_amount = 0.1 # 10% structured (filter) pruning
    pruning_unstructured_amount = 0.3 # 30% unstructured (weight) pruning
    timesteps = 200
    beta_start = 1e-4
    beta_end = 0.02
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_generated_samples_for_eval = 10000 # Number of samples for FID/SSIM evaluation

config = Config()

# --- 2. Data Preparation ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1]
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

# Create a DataLoader for the entire test set (real images) for evaluation metrics
# This is needed for FID and SSIM comparison with generated images
all_real_images = []
for images, _ in test_loader:
    all_real_images.append(images)
all_real_images = torch.cat(all_real_images, dim=0).to(config.device)


# --- 3. Diffusion Utilities (Noise Schedule) ---
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps, beta_start, beta_end):
    return torch.linspace(beta_start, beta_end, timesteps)

betas = cosine_beta_schedule(timesteps=config.timesteps).to(config.device)
alphas = (1. - betas).to(config.device)
alphas_cumprod = torch.cumprod(alphas, axis=0).to(config.device)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=config.device), alphas_cumprod[:-1]]).to(config.device)

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(config.device)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod).to(config.device)

posterior_variance = (betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)).to(config.device)

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = betas[t_index]
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t_index]
    sqrt_alphas_t = sqrt_alphas_cumprod[t_index]

    pred_noise = model(x, t)

    pred_x0 = (x - pred_noise * sqrt_one_minus_alphas_cumprod_t) / sqrt_alphas_t

    mean = (x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t) / torch.sqrt(alphas[t_index])

    variance = posterior_variance[t_index]

    noise = torch.randn_like(x) if t_index > 0 else 0.
    return mean + torch.sqrt(variance) * noise

@torch.no_grad()
def p_sample_loop(model, shape):
    img = torch.randn(shape, device=config.device)
    for i in tqdm(reversed(range(0, config.timesteps)), desc='sampling loop time step', total=config.timesteps):
        t = torch.full((img.shape[0],), i, device=config.device, dtype=torch.long)
        img = p_sample(model, img, t, i)
    return img

@torch.no_grad()
def p_sample_loop_chunked(model, total_samples, batch_size=100, image_shape=(1, 28, 28)):
    """Generate samples in clean chunks without corruption"""
    model.eval()
    all_samples = []

    num_batches = (total_samples + batch_size - 1) // batch_size

    for batch_idx in range(num_batches):
        current_batch_size = min(batch_size, total_samples - batch_idx * batch_size)

        # Generate this batch independently
        img = torch.randn((current_batch_size,) + image_shape, device=config.device)

        for i in tqdm(reversed(range(0, config.timesteps)),
                     desc=f'Batch {batch_idx+1}/{num_batches}',
                     total=config.timesteps):
            t = torch.full((img.shape[0],), i, device=config.device, dtype=torch.long)
            img = p_sample(model, img, t, i)

        all_samples.append(img)

        print(f"Completed batch {batch_idx+1}/{num_batches} ({current_batch_size} samples)")

        # Clear cache between batches
        torch.cuda.empty_cache()

    # Concatenate all batches
    return torch.cat(all_samples, dim=0)

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

# --- 4. Model Definition (Simplified U-Net) ---
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        )
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

        if in_channels != out_channels:
            self.residual_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.residual_conv = nn.Identity()

    def forward(self, x, t):
        h = self.conv1(x)
        h = self.bn1(h)
        h = self.relu(h)
        time_emb = self.mlp(t)
        h = h + time_emb[:, :, None, None]
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu(h)
        return h + self.residual_conv(x)


class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)

class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class SimpleUnet(nn.Module):
    def __init__(self,
                 image_channels=1,
                 channels=(32, 64, 128), # C0, C1, C2
                 out_channels=1,
                 time_emb_dim=256):
        super().__init__()

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU()
        )

        self.conv0 = nn.Conv2d(image_channels, channels[0], 3, padding=1)

        self.downs = nn.ModuleList([])
        for i in range(len(channels) - 1):
            in_c = channels[i]
            out_c = channels[i+1]
            self.downs.append(nn.ModuleList([
                Block(in_c, out_c, time_emb_dim),
                Block(out_c, out_c, time_emb_dim),
                Downsample(out_c)
            ]))

        self.mid_block1 = Block(channels[-1], channels[-1], time_emb_dim)
        self.mid_block2 = Block(channels[-1], channels[-1], time_emb_dim)

        self.ups = nn.ModuleList([])
        for i in reversed(range(len(channels) - 1)):
            current_level_channels = channels[i+1]
            output_channels_for_this_stage = channels[i]

            self.ups.append(nn.ModuleList([
                Block(current_level_channels + current_level_channels, output_channels_for_this_stage, time_emb_dim),
                Block(output_channels_for_this_stage, output_channels_for_this_stage, time_emb_dim),
                Upsample(current_level_channels)
            ]))

        self.final_conv = nn.Conv2d(channels[0], out_channels, 1)

    def forward(self, x, timestep):
        t_emb = self.time_mlp(timestep)

        x = self.conv0(x)
        h = [] # Stores skip connections

        # Downsampling
        for block1, block2, downsample_layer in self.downs:
            x = block1(x, t_emb)
            x = block2(x, t_emb)
            h.append(x)
            x = downsample_layer(x)

        # Bottleneck
        x = self.mid_block1(x, t_emb)
        x = self.mid_block2(x, t_emb)

        # Upsampling
        for block1, block2, upsample_layer in self.ups:
            skip_connection = h.pop()

            x = upsample_layer(x)

            x = torch.cat((x, skip_connection), dim=1)

            x = block1(x, t_emb)
            x = block2(x, t_emb)

        return self.final_conv(x)

# --- 5. Training/Fine-tuning Function ---
def get_noise_pred(model, x_0, t):
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t]
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t]

    noise = torch.randn_like(x_0)

    x_t = sqrt_alphas_cumprod_t[:, None, None, None] * x_0 + \
          sqrt_one_minus_alphas_cumprod_t[:, None, None, None] * noise

    predicted_noise = model(x_t, t.float())

    return noise, predicted_noise

def train_model(model, loader, num_epochs, learning_rate, model_save_path_prefix, is_finetuning=False):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    mode_str = "Fine-tuning" if is_finetuning else "Training"
    print(f"Starting {mode_str} on {config.device} for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        model.train()
        pbar = tqdm(loader, desc=f"{mode_str} Epoch {epoch+1}/{num_epochs}")
        for batch_idx, (images, _) in enumerate(pbar):
            images = images.to(config.device)
            optimizer.zero_grad()

            t = torch.randint(0, config.timesteps, (images.shape[0],), device=config.device).long()

            noise, predicted_noise = get_noise_pred(model, images, t)

            loss = criterion(predicted_noise, noise)
            loss.backward()
            optimizer.step()

            pbar.set_postfix(loss=loss.item())

        # Corrected: Use full path for saving models
        current_model_save_path = f"{model_save_path_prefix}_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), current_model_save_path)
        print(f"Epoch {epoch+1} finished, Loss: {loss.item():.4f}. Model saved to {current_model_save_path}.")

        # Generate samples periodically
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            print(f"Generating samples from {mode_str.lower()} model...")
            model.eval()
            with torch.no_grad():
                sample_batch_size = 16
                generated_samples = p_sample_loop(model, shape=(sample_batch_size, 1, config.image_size, config.image_size))
                final_samples_display = (generated_samples + 1) * 0.5 # Denormalize for display

                # Corrected: Use full path for saving generated samples
                current_samples_save_path = f"{model_save_path_prefix}_generated_samples_epoch_{epoch+1}.png"
                torchvision.utils.save_image(final_samples_display, current_samples_save_path, nrow=4)
                print(f"Generated samples saved to {current_samples_save_path}.")

    print(f"{mode_str} complete!")

100%|██████████| 9.91M/9.91M [00:00<00:00, 39.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.03MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.39MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.77MB/s]


# Pruning

In [None]:
def apply_unstructured_pruning(model, amount=0.2):
    """
    Applies unstructured pruning to all convolutional and linear layers
    in the model.
    """
    print(f"\nApplying unstructured pruning with amount: {amount}")
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            prune.l1_unstructured(module, name="weight", amount=amount)
            # Make pruning permanent (removes the reparameterization)
            prune.remove(module, "weight")
            # print(f"  Applied unstructured pruning to: {name}")

    print("Unstructured pruning applied.")

def apply_structured_pruning(model, amount=0.2):
    """
    Applies structured pruning (filter pruning) to Conv2d layers.
    """
    print(f"\nApplying structured pruning with amount: {amount}")
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            # dim=0 prunes output channels (filters)
            prune.ln_structured(module, name="weight", amount=amount, n=1, dim=0)
            prune.remove(module, "weight")
            # print(f"  Applied structured pruning to: {name}")

    print("Structured pruning applied.")

#Architecture Modifications

In [None]:
def train_reduced_channels_model(train_loader, do_finetuning=False):
    print("\n--- Training Reduced Channels Model ---")
    print("Description: 50% reduction in channel dimensions")
    print("Channels: (16, 32, 64)")
    print("Time embedding dim: 128")

    model = SimpleUnet(
        image_channels=1,
        channels=(16, 32, 64),
        out_channels=1,
        time_emb_dim=128
    ).to(config.device)

    if do_finetuning:
        initial_path = f"ddpm_mnist_reduced_channels_epoch_{config.num_epochs}.pth"
        print(f"Loading initial model from {initial_path} for fine-tuning")
        model.load_state_dict(torch.load(initial_path, map_location=config.device))

        epochs = 10
        lr = config.learning_rate / 100
        model_path = f"ddpm_mnist_reduced_channels_finetuned_epoch_{epochs}.pth"
        prefix = "ddpm_mnist_reduced_channels_finetuned"
    else:
        epochs = config.num_epochs
        lr = config.learning_rate
        model_path = f"ddpm_mnist_reduced_channels_epoch_{epochs}.pth"
        prefix = "ddpm_mnist_reduced_channels"

    if os.path.exists(model_path):
        print(f"Loading existing model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=config.device))
    else:
        print(f"Training reduced channels model...")
        train_model(model, train_loader, epochs, lr, prefix, is_finetuning=do_finetuning)
        model.load_state_dict(torch.load(model_path, map_location=config.device))

    return model

def train_shallow_network_model(train_loader, do_finetuning=False):
    print("\n--- Training Shallow Network Model ---")
    print("Description: Reduced network depth (2 levels vs 3)")
    print("Channels: (64, 128)")
    print("Time embedding dim: 256")

    model = SimpleUnet(
        image_channels=1,
        channels=(64, 128),
        out_channels=1,
        time_emb_dim=256
    ).to(config.device)

    if do_finetuning:
        initial_path = f"ddpm_mnist_shallow_network_epoch_{config.num_epochs}.pth"
        print(f"Loading initial model from {initial_path} for fine-tuning")
        model.load_state_dict(torch.load(initial_path, map_location=config.device))

        epochs = config.fine_tune_epochs
        lr = config.learning_rate / 10
        model_path = f"ddpm_mnist_shallow_network_finetuned_epoch_{epochs}.pth"
        prefix = "ddpm_mnist_shallow_network_finetuned"
    else:
        epochs = config.num_epochs
        lr = config.learning_rate
        model_path = f"ddpm_mnist_shallow_network_epoch_{epochs}.pth"
        prefix = "ddpm_mnist_shallow_network"

    if os.path.exists(model_path):
        print(f"Loading existing model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=config.device))
    else:
        print(f"Training shallow network model...")
        train_model(model, train_loader, epochs, lr, prefix, is_finetuning=do_finetuning)
        model.load_state_dict(torch.load(model_path, map_location=config.device))

    return model

#Evaluation

In [None]:
# Helper to get number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Helper to get model size
def get_model_size(model, filename_prefix="temp_model"):
    # Sanitize filename for model_name
    sanitized_name = filename_prefix.lower().replace(' ', '_').replace('&', '').replace('-', '')
    filename = f"{sanitized_name}.pth"
    torch.save(model.state_dict(), filename)
    size_mb = os.path.getsize(filename) / (1024 * 1024)
    os.remove(filename) # Clean up
    return size_mb

# --- MNIST Classifier for Pseudo-FID & Classification Accuracy ---
class SimpleMNISTClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # 14x14
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), # 7x7
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU()
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Output 1x1
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def get_features(self, x):
        return self.avgpool(self.features(x)).flatten(1) # Return flattened features

def train_classifier(classifier_model, data_loader, num_epochs=10, lr=1e-3, device="cpu"):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier_model.parameters(), lr=lr)
    classifier_model.to(device)
    classifier_model.train()
    print(f"Training MNIST Classifier on {device}...")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(data_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = classifier_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Classifier Epoch {epoch+1}, Loss: {running_loss / len(data_loader):.4f}")
    print("MNIST Classifier training complete.")

def evaluate_classifier_accuracy(classifier_model, images, labels):
    classifier_model.eval()
    with torch.no_grad():
        outputs = classifier_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total = labels.size(0)
        correct = (predicted == labels).sum().item()
    return correct / total

def calculate_pseudo_fid(classifier_model, real_images, generated_images, batch_size=128):
    classifier_model.eval()
    classifier_model.to(config.device)

    all_real_features = []
    all_gen_features = []

    # Process real images
    for i in tqdm(range(0, real_images.shape[0], batch_size), desc="Extracting real features"):
        batch = real_images[i:i + batch_size]
        features = classifier_model.get_features(batch)
        all_real_features.append(features.cpu())
    real_features = torch.cat(all_real_features, dim=0)

    # Process generated images
    for i in tqdm(range(0, generated_images.shape[0], batch_size), desc="Extracting gen features"):
        batch = generated_images[i:i + batch_size]
        features = classifier_model.get_features(batch)
        all_gen_features.append(features.cpu())
    gen_features = torch.cat(all_gen_features, dim=0)

    # Calculate mean and covariance for real and generated features
    mu_real, sigma_real = real_features.mean(dim=0), torch.cov(real_features.T)
    mu_gen, sigma_gen = gen_features.mean(dim=0), torch.cov(gen_features.T)

    # Add a small epsilon to the diagonal for numerical stability (to ensure positive definite)
    eps = 1e-6 * torch.eye(sigma_real.shape[0], device=sigma_real.device)
    sigma_real_reg = sigma_real + eps
    sigma_gen_reg = sigma_gen + eps

    # Calculate Frechet Distance using eigenvalue decomposition for matrix square root
    # Based on: https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/image/fid.py#L225
    # and the original FID paper
    diff = mu_real - mu_gen

    # Calculate matrix square root of (sigma_real @ sigma_gen)
    # Using eigenvalue decomposition as torch.linalg.sqrtm might not be available
    # For symmetric matrices, A = V @ D @ V.T, then sqrt(A) = V @ sqrt(D) @ V.T
    cov_product = sigma_real_reg @ sigma_gen_reg

    # Ensure symmetry for eigh
    cov_product = (cov_product + cov_product.T) / 2.0

    # Eigenvalue decomposition
    eigvals, eigvecs = torch.linalg.eigh(cov_product)

    # Filter out tiny negative eigenvalues and take square root
    eigvals = torch.clamp(eigvals, min=0.0) # Ensure non-negative
    sqrt_eigvals = torch.sqrt(eigvals)

    # Construct the square root of the covariance product
    cov_sqrt = eigvecs @ torch.diag_embed(sqrt_eigvals) @ eigvecs.T

    # Handle potential complex numbers if `torch.linalg.sqrtm` returns them due to numerical issues
    if cov_sqrt.is_complex():
        cov_sqrt = cov_sqrt.real # Take only the real part if it's slightly off

    fid_score = torch.sum(diff**2) + torch.trace(sigma_real + sigma_gen - 2 * cov_sqrt)

    return fid_score.item()

def evaluate_model(model, model_name, classifier_model, all_real_images, test_loader):
    print(f"\n--- Evaluating {model_name} ---")
    model.eval()
    model.to(config.device)

    # 1. Model Efficiency Metrics
    param_count = count_parameters(model)
    print(f"  Parameters: {param_count:,}")

    model_size_mb = get_model_size(model, model_name) # Pass model_name as prefix
    print(f"  Model Size: {model_size_mb:.2f} MB")

    start_time = time.time()
    generated_samples_eval = p_sample_loop(model, shape=(config.num_generated_samples_for_eval, 1, config.image_size, config.image_size))
    # generated_samples_eval = p_sample_loop_chunked(model, config.num_generated_samples_for_eval, batch_size=100)
    end_time = time.time()
    inference_latency = end_time - start_time
    print(f"  Inference Latency ({config.num_generated_samples_for_eval} samples): {inference_latency:.4f} seconds")

    # Denormalize generated samples for SSIM/FID (expected range 0-1)
    generated_samples_eval_denorm = (generated_samples_eval + 1) * 0.5

    # Ensure all_real_images also matches the number of generated samples for fair comparison
    # We take a random subset of real images if needed to match the num_generated_samples_for_eval
    if all_real_images.shape[0] > config.num_generated_samples_for_eval:
        indices = torch.randperm(all_real_images.shape[0])[:config.num_generated_samples_for_eval]
        real_images_for_eval_denorm = (all_real_images[indices] + 1) * 0.5
    else:
        # If fewer real images than generated, use all real images and pad generated if necessary
        # For simplicity, we assume num_generated_samples_for_eval <= total real images
        real_images_for_eval_denorm = (all_real_images + 1) * 0.5


    # 2. Image Quality Metrics
    ssim_metric = SSIM(data_range=1.0, reduction='elementwise_mean').to(config.device)
    ssim_score = ssim_metric(generated_samples_eval_denorm, real_images_for_eval_denorm).item()
    print(f"  Average SSIM: {ssim_score:.4f}")

    pseudo_fid_score = calculate_pseudo_fid(classifier_model, real_images_for_eval_denorm, generated_samples_eval_denorm)
    print(f"  Pseudo-FID: {pseudo_fid_score:.4f}")

    # 3. Accuracy (Classification Accuracy of generated digits)
    # Evaluate classifier on generated samples
    generated_labels_pred = classifier_model(generated_samples_eval).argmax(dim=1)
    unique_classes, counts = torch.unique(generated_labels_pred, return_counts=True)
    print(f"  Generated Digit Distribution (from classifier): {dict(zip(unique_classes.tolist(), counts.tolist()))}")

    # Classifier accuracy on REAL MNIST test data (to confirm classifier itself is good)
    test_images_for_classifier = torch.cat([images for images, _ in test_loader], dim=0).to(config.device)
    test_labels_for_classifier = torch.cat([labels for _, labels in test_loader], dim=0).to(config.device)
    classifier_test_accuracy = evaluate_classifier_accuracy(classifier_model, test_images_for_classifier, test_labels_for_classifier)
    print(f"  Classifier Accuracy on Real MNIST Test Set: {classifier_test_accuracy:.4f}")

    print(f"--- Evaluation of {model_name} Complete ---\n")
    return {
        "params": param_count,
        "size_mb": model_size_mb,
        "latency_s": inference_latency,
        "ssim": ssim_score,
        "pseudo_fid": pseudo_fid_score,
        "classifier_test_accuracy": classifier_test_accuracy
    }

#Full Execution Flow

In [None]:
# --- Step 1: Train/Load MNIST Classifier ---
mnist_classifier = SimpleMNISTClassifier().to("cpu")
classifier_path = "mnist_classifier.pth"

if os.path.exists(classifier_path):
    print(f"Loading existing MNIST classifier from {classifier_path}")
    mnist_classifier.load_state_dict(torch.load(classifier_path, map_location=config.device))
else:
    print("Training MNIST Classifier (for evaluation metrics)...")
    classifier_train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    train_classifier(mnist_classifier, classifier_train_loader, num_epochs=10, lr=1e-3, device=config.device)
    torch.save(mnist_classifier.state_dict(), classifier_path)
    print(f"MNIST Classifier saved to {classifier_path}")

# --- Step 2: Train Baseline Model ---
baseline_model = SimpleUnet(image_channels=1,
                           channels=(32, 64, 128),
                           out_channels=1,
                           time_emb_dim=256).to(config.device)
baseline_model_path = f"ddpm_mnist_baseline_epoch_{config.num_epochs}.pth"

if os.path.exists(baseline_model_path):
    print(f"Loading existing baseline model from {baseline_model_path}")
    baseline_model.load_state_dict(torch.load(baseline_model_path, map_location=config.device))
else:
    print("Training baseline model...")
    train_model(baseline_model, train_loader, config.num_epochs, config.learning_rate, "ddpm_mnist_baseline")
    baseline_model.load_state_dict(torch.load(baseline_model_path, map_location=config.device))

# --- Step 3: Evaluate Baseline Model ---
baseline_eval_results = evaluate_model(baseline_model, "Baseline Model", mnist_classifier, all_real_images, test_loader)
del baseline_model
torch.cuda.empty_cache()
gc.collect()

# --- Step 4: Train Architecture Variants ---
print("\n" + "="*80)
print("TRAINING ARCHITECTURE VARIANTS")
print("="*80)

# Reduced Channels
reduced_channels_finetuned_path = f"ddpm_mnist_reduced_channels_finetuned_epoch_10.pth"
reduced_channels_initial_path = f"ddpm_mnist_reduced_channels_epoch_{config.num_epochs}.pth"

if os.path.exists(reduced_channels_finetuned_path):
    print("Loading existing fine-tuned reduced channels model...")
    reduced_channels_finetuned = train_reduced_channels_model(train_loader, do_finetuning=True)
elif os.path.exists(reduced_channels_initial_path):
    print("Initial model exists, fine-tuning reduced channels model...")
    reduced_channels_finetuned = train_reduced_channels_model(train_loader, do_finetuning=True)
else:
    print("Training initial then fine-tuning reduced channels model...")
    train_reduced_channels_model(train_loader)  # Train initial
    reduced_channels_finetuned = train_reduced_channels_model(train_loader, do_finetuning=True)  # Fine-tune

reduced_channels_finetuned_results = evaluate_model(reduced_channels_finetuned, "Reduced Channels Fine-tuned", mnist_classifier, all_real_images, test_loader)
del reduced_channels_finetuned
torch.cuda.empty_cache()
gc.collect()

# Shallow Network
shallow_network_finetuned_path = f"ddpm_mnist_shallow_network_finetuned_epoch_{config.fine_tune_epochs}.pth"
shallow_network_initial_path = f"ddpm_mnist_shallow_network_epoch_{config.num_epochs}.pth"

if os.path.exists(shallow_network_finetuned_path):
    print("Loading existing fine-tuned shallow network model...")
    shallow_network_finetuned = train_shallow_network_model(train_loader, do_finetuning=True)
elif os.path.exists(shallow_network_initial_path):
    print("Initial model exists, fine-tuning shallow network model...")
    shallow_network_finetuned = train_shallow_network_model(train_loader, do_finetuning=True)
else:
    print("Training initial then fine-tuning shallow network model...")
    train_shallow_network_model(train_loader)  # Train initial
    shallow_network_finetuned = train_shallow_network_model(train_loader, do_finetuning=True)  # Fine-tune

shallow_network_finetuned_results = evaluate_model(shallow_network_finetuned, "Shallow Network Fine-tuned", mnist_classifier, all_real_images, test_loader)
del shallow_network_finetuned
torch.cuda.empty_cache()
gc.collect()

# --- Step 5: Prepare and Prune Model ---
pruned_model = SimpleUnet(image_channels=1,
                         channels=(32, 64, 128),
                         out_channels=1,
                         time_emb_dim=256).to(config.device)
pruned_model.load_state_dict(torch.load(baseline_model_path, map_location=config.device))

print("\n--- Applying Pruning ---")
apply_structured_pruning(pruned_model, amount=config.pruning_structured_amount)
apply_unstructured_pruning(pruned_model, amount=config.pruning_unstructured_amount)
torch.save(pruned_model.state_dict(), "ddpm_mnist_pruned_initial.pth")
print("Initial pruned model saved (before fine-tuning).")

# --- Step 7: Fine-tune Pruned Model ---
pruned_model_path = f"ddpm_mnist_pruned_finetuned_epoch_{config.fine_tune_epochs}.pth"

if os.path.exists(pruned_model_path):
    print(f"Loading existing pruned and fine-tuned model from {pruned_model_path}")
    pruned_model.load_state_dict(torch.load(pruned_model_path, map_location=config.device))
else:
    print("Fine-tuning pruned model...")
    fine_tune_lr = config.learning_rate / 10
    train_model(pruned_model, train_loader, config.fine_tune_epochs, fine_tune_lr, "ddpm_mnist_pruned_finetuned", is_finetuning=True)
    pruned_model.load_state_dict(torch.load(pruned_model_path, map_location=config.device))

# --- Step 8: Evaluate Pruned and Fine-tuned Model ---
pruned_eval_results = evaluate_model(pruned_model, "Pruned & Fine-tuned Model", mnist_classifier, all_real_images, test_loader)

# --- Step 9: Summary of All Results ---
print("\n" + "="*80)
print("EXPERIMENT SUMMARY")
print("="*80)

all_results = {
    "Baseline": baseline_eval_results,
    "Reduced Channels": reduced_channels_finetuned_results,
    "Shallow Network": shallow_network_finetuned_results,
    "Pruned & Fine-tuned": pruned_eval_results
}

print("Model Comparison Summary:")
print(f"{'Model':<20} {'Pseudo-FID':<10} {'SSIM':<8} {'Params':<10} {'Size(MB)':<8}")
print("-" * 70)
for model_name, results in all_results.items():
    print(f"{model_name:<20} "
          f"{results['pseudo_fid']:<10.2f} "
          f"{results['ssim']:<8.4f} "
          f"{results['params']:<10,} "
          f"{results['size_mb']:<8.1f}")

Loading existing MNIST classifier from mnist_classifier.pth
Loading existing baseline model from ddpm_mnist_baseline_epoch_50.pth

--- Evaluating Baseline Model ---
  Parameters: 2,529,217
  Model Size: 9.72 MB


sampling loop time step: 100%|██████████| 200/200 [00:57<00:00,  3.45it/s]


  Inference Latency (10000 samples): 57.9701 seconds
  Average SSIM: 0.3475


Extracting real features: 100%|██████████| 79/79 [00:00<00:00, 1247.22it/s]
Extracting gen features: 100%|██████████| 79/79 [00:00<00:00, 1489.09it/s]


  Pseudo-FID: 10.7524
  Generated Digit Distribution (from classifier): {0: 1616, 1: 567, 2: 898, 3: 1009, 4: 760, 5: 873, 6: 1062, 7: 794, 8: 937, 9: 1484}
  Classifier Accuracy on Real MNIST Test Set: 0.9787
--- Evaluation of Baseline Model Complete ---


TRAINING ARCHITECTURE VARIANTS
Loading existing fine-tuned reduced channels model...

--- Training Reduced Channels Model ---
Description: 50% reduction in channel dimensions
Channels: (16, 32, 64)
Time embedding dim: 128
Loading initial model from ddpm_mnist_reduced_channels_epoch_50.pth for fine-tuning
Loading existing model from ddpm_mnist_reduced_channels_finetuned_epoch_10.pth

--- Evaluating Reduced Channels Fine-tuned ---
  Parameters: 634,081
  Model Size: 2.49 MB


sampling loop time step: 100%|██████████| 200/200 [00:31<00:00,  6.39it/s]


  Inference Latency (10000 samples): 31.2825 seconds
  Average SSIM: 0.3688


Extracting real features: 100%|██████████| 79/79 [00:00<00:00, 1646.12it/s]
Extracting gen features: 100%|██████████| 79/79 [00:00<00:00, 1779.08it/s]

  Pseudo-FID: 1.7628
  Generated Digit Distribution (from classifier): {0: 1060, 1: 1169, 2: 967, 3: 1196, 4: 824, 5: 804, 6: 817, 7: 1127, 8: 820, 9: 1216}





  Classifier Accuracy on Real MNIST Test Set: 0.9787
--- Evaluation of Reduced Channels Fine-tuned Complete ---

Loading existing fine-tuned shallow network model...

--- Training Shallow Network Model ---
Description: Reduced network depth (2 levels vs 3)
Channels: (64, 128)
Time embedding dim: 256
Loading initial model from ddpm_mnist_shallow_network_epoch_50.pth for fine-tuning
Loading existing model from ddpm_mnist_shallow_network_finetuned_epoch_20.pth

--- Evaluating Shallow Network Fine-tuned ---
  Parameters: 2,148,097
  Model Size: 8.24 MB


sampling loop time step: 100%|██████████| 200/200 [01:21<00:00,  2.45it/s]


  Inference Latency (10000 samples): 81.5870 seconds
  Average SSIM: 0.3659


Extracting real features: 100%|██████████| 79/79 [00:00<00:00, 1939.78it/s]
Extracting gen features: 100%|██████████| 79/79 [00:00<00:00, 1866.07it/s]

  Pseudo-FID: 2.3936
  Generated Digit Distribution (from classifier): {0: 993, 1: 1113, 2: 856, 3: 1345, 4: 818, 5: 802, 6: 992, 7: 1185, 8: 750, 9: 1146}





  Classifier Accuracy on Real MNIST Test Set: 0.9787
--- Evaluation of Shallow Network Fine-tuned Complete ---


--- Applying Pruning ---

Applying structured pruning with amount: 0.1
Structured pruning applied.

Applying unstructured pruning with amount: 0.3
Unstructured pruning applied.
Initial pruned model saved (before fine-tuning).
Loading existing pruned and fine-tuned model from ddpm_mnist_pruned_finetuned_epoch_20.pth

--- Evaluating Pruned & Fine-tuned Model ---
  Parameters: 2,529,217
  Model Size: 9.72 MB


sampling loop time step: 100%|██████████| 200/200 [00:57<00:00,  3.47it/s]


  Inference Latency (10000 samples): 57.6236 seconds
  Average SSIM: 0.3674


Extracting real features: 100%|██████████| 79/79 [00:00<00:00, 1965.17it/s]
Extracting gen features: 100%|██████████| 79/79 [00:00<00:00, 1595.41it/s]

  Pseudo-FID: 2.7686
  Generated Digit Distribution (from classifier): {0: 837, 1: 1062, 2: 1139, 3: 1400, 4: 859, 5: 879, 6: 799, 7: 1001, 8: 875, 9: 1149}





  Classifier Accuracy on Real MNIST Test Set: 0.9787
--- Evaluation of Pruned & Fine-tuned Model Complete ---


EXPERIMENT SUMMARY
Model Comparison Summary:
Model                Pseudo-FID SSIM     Params     Size(MB)
----------------------------------------------------------------------
Baseline             10.75      0.3475   2,529,217  9.7     
Reduced Channels     1.76       0.3688   634,081    2.5     
Shallow Network      2.39       0.3659   2,148,097  8.2     
Pruned & Fine-tuned  2.77       0.3674   2,529,217  9.7     


In [None]:
import os
source_directory = '/content/'
destination_directory = '/content/drive/MyDrive/589-project/'

# Create the destination directory if it doesn't exist
os.makedirs(destination_directory, exist_ok=True)
print(f"Destination directory ensured: {destination_directory}")

# 3. List the files you want to copy (adjust as needed)
# You can list specific files or copy an entire directory
files_to_copy = [
    "ddpm_mnist_baseline_epoch_50.pth",
    "ddpm_mnist_baseline_generated_samples_epoch_10.png",
    "ddpm_mnist_baseline_generated_samples_epoch_20.png",
    "ddpm_mnist_baseline_generated_samples_epoch_30.png",
    "ddpm_mnist_baseline_generated_samples_epoch_40.png",
    "ddpm_mnist_baseline_generated_samples_epoch_50.png",
    "ddpm_mnist_pruned_initial.pth",
    "ddpm_mnist_pruned_finetuned_epoch_20.pth",
    "ddpm_mnist_pruned_finetuned_generated_samples_epoch_10.png",
    "ddpm_mnist_pruned_finetuned_generated_samples_epoch_20.png",
    "ddpm_mnist_reduced_channels_epoch_50.pth",
    "ddpm_mnist_reduced_channels_finetuned_epoch_10.pth",
    "ddpm_mnist_reduced_channels_generated_samples_epoch_10.png",
    "ddpm_mnist_reduced_channels_generated_samples_epoch_20.png",
    "ddpm_mnist_reduced_channels_generated_samples_epoch_30.png",
    "ddpm_mnist_reduced_channels_generated_samples_epoch_40.png",
    "ddpm_mnist_reduced_channels_generated_samples_epoch_50.png",
    "ddpm_mnist_reduced_channels_finetuned_generated_samples_epoch_10.png",
    "ddpm_mnist_shallow_network_epoch_50.pth",
    "ddpm_mnist_shallow_network_generated_samples_epoch_10.png",
    "ddpm_mnist_shallow_network_generated_samples_epoch_20.png",
    "ddpm_mnist_shallow_network_generated_samples_epoch_30.png",
    "ddpm_mnist_shallow_network_generated_samples_epoch_40.png",
    "ddpm_mnist_shallow_network_generated_samples_epoch_50.png",
    "ddpm_mnist_shallow_network_finetuned_epoch_20.pth",
    "ddpm_mnist_shallow_network_finetuned_generated_samples_epoch_10.png",
    "ddpm_mnist_shallow_network_finetuned_generated_samples_epoch_20.png",
    "mnist_classifier.pth"
]

print("\nCopying files to Google Drive...")
for filename in files_to_copy:
    source_path = os.path.join(source_directory, filename)
    destination_path = os.path.join(destination_directory, filename)

    if os.path.exists(source_path):
        try:
            shutil.copy2(source_path, destination_path)
            print(f"Copied '{filename}' to '{destination_directory}'")
        except Exception as e:
            print(f"Error copying '{filename}': {e}")
    else:
        print(f"Warning: Source file '{filename}' not found at '{source_path}'")

print("\nCopying complete!")

Destination directory ensured: /content/drive/MyDrive/589-project/

Copying files to Google Drive...
Copied 'ddpm_mnist_baseline_epoch_50.pth' to '/content/drive/MyDrive/589-project/'
Copied 'ddpm_mnist_pruned_initial.pth' to '/content/drive/MyDrive/589-project/'
Copied 'ddpm_mnist_pruned_finetuned_epoch_20.pth' to '/content/drive/MyDrive/589-project/'
Copied 'ddpm_mnist_reduced_channels_epoch_50.pth' to '/content/drive/MyDrive/589-project/'
Copied 'ddpm_mnist_reduced_channels_finetuned_epoch_10.pth' to '/content/drive/MyDrive/589-project/'
Copied 'ddpm_mnist_shallow_network_epoch_50.pth' to '/content/drive/MyDrive/589-project/'
Copied 'ddpm_mnist_shallow_network_finetuned_epoch_20.pth' to '/content/drive/MyDrive/589-project/'
Copied 'mnist_classifier.pth' to '/content/drive/MyDrive/589-project/'

Copying complete!
