# Install and Import Libraries
To begin, we install the required libraries and import the necessary modules for building and training the diffusion model.

In [None]:
# Install required libraries
!pip install diffusers torch torchvision matplotlib

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import inception_v3
from torchvision.models.inception import Inception_V3_Weights
from diffusers import DDPMScheduler, UNet2DModel
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg
import os

# Prepare Dataset

## Data Preprocessing

We preprocess the dataset by resizing all images to 64x64 pixels, normalizing them to the range [-1, 1], and applying transformations to prepare them for training.

## Dataset Creation

We use the ImageFolder class to load the dataset from the specified directory and apply the defined transformations.

In [None]:
# Define data transformations
transform=transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Set the dataset root directory
dataroot = "data/wiki"

# Create the dataset
dataset = ImageFolder(root=dataroot, transform=transform)

# Check the number of samples in the dataset
print(f"Number of images in the dataset: {len(dataset)}")

# Setup Data Loader

## Data Batching

We initialize a DataLoader to batch the dataset for training. The batch_size parameter defines the number of images per batch, and shuffle=True ensures the data is shuffled for better training performance.

In [None]:
# Initialize DataLoader
batch_size = 16  # Define batch size
shuffle = True   # Shuffle the dataset for training

# Create DataLoader for batching
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# Check the number of batches
num_batches = len(dataloader)
print(f"Number of batches: {num_batches}")

# Define Diffusion Model and Trainer

## Model Architecture
We define the diffusion model using the UNet2DModel from the Hugging Face Diffusers library. The model follows a U-Net architecture with hierarchical feature extraction and attention mechanisms for better image generation.

## Noise Scheduler
The DDPMScheduler is used to manage the forward and reverse diffusion processes, with 1000 diffusion steps.

## Optimizer
We use the AdamW optimizer with a learning rate of 1e-4 to train the model.

## Device Setup
The model is moved to the appropriate device (GPU if available, otherwise CPU) for efficient computation.

In [None]:
# Define the diffusion model
model = UNet2DModel(
    sample_size=64,  # Image size
    in_channels=3,   # Number of input channels (RGB)
    out_channels=3,  # Number of output channels (RGB)
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 512),
    down_block_types=(
        "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D"
    ),
    up_block_types=(
        "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"
    )
)

# Define the noise scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training parameters
num_epochs = 15  # Number of epochs

# Training Loop 

## Training Strategy

The training loop iterates over the dataset for a specified number of epochs. For each batch:
1. Gaussian noise is added to the images to simulate the forward diffusion process.
2. The model predicts the added noise.
3. The loss is computed using Mean Squared Error (MSE) between the predicted and actual noise.
4. Backpropagation is performed to update the model's weights using the AdamW optimizer.
5. Progress is monitored by printing the loss every 100 steps.

In [None]:
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    for step, (images, _) in enumerate(dataloader):
        # Move images to the device
        images = images.to(device)

        # Sample noise
        noise = torch.randn_like(images).to(device)

        # Sample random timesteps
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (images.shape[0],), device=device).long()

        # Add noise to the images
        noisy_images = noise_scheduler.add_noise(images, noise, timesteps)

        # Predict the noise
        noise_pred = model(noisy_images, timesteps).sample

        # Compute loss (mean squared error)
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print loss every 100 steps
        if step % 100 == 0:
            print(f"Step {step}/{len(dataloader)}, Loss: {loss.item()}")

# Generate and Display Images

## Image Generation
After training, the model is set to evaluation mode. We generate images by starting with random noise and applying the reverse diffusion process step-by-step.

## Visualization
The generated images are denormalized and displayed in a 4x4 grid using Matplotlib.

In [None]:
# Generate and display images after training using the correct de-noising loop
model.eval()

with torch.no_grad():
    num_images = 16  # Total images to generate
    rows, cols = 4, 4  # 4x4 grid
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    axes = axes.flatten() 
    
    for i in range(num_images):
        noisy_image = torch.randn(1, 3, 64, 64).to(device)
        
        # Reverse diffusion process
        for t in reversed(range(noise_scheduler.num_train_timesteps)):
            # Get noise prediction from the model
            noise_pred = model(noisy_image, t).sample  
            # Perform a de-noising step using the predicted noise
            step_output = noise_scheduler.step(noise_pred, t, noisy_image)
            noisy_image = step_output.prev_sample
        
        # Denormalize and prepare image for display
        generated_image = (noisy_image.squeeze().cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5).clip(0, 1)
        axes[i].imshow(generated_image)
        axes[i].axis("off")
    
    plt.tight_layout()
    plt.show()

# FID Metric

The Fréchet Inception Distance (FID) is used to evaluate the quality of the generated images by comparing their features to those of real images.

## Feature Extraction
We use an InceptionV3 model to extract features from both real and generated images.

## FID Calculation
The FID score is computed based on the mean and covariance of the features from real and generated images. A lower FID score indicates better image quality and diversity.

## Real Image Collection
We collect real images from the dataset using the DataLoader for FID evaluation.

## FID Computation
The FID score is calculated for the diffusion model, providing a quantitative measure of its performance.

In [None]:
class InceptionV3FeatureExtractor:
    def __init__(self, device='cpu'):
        self.device = device
        weights = Inception_V3_Weights.DEFAULT
        self.inception = inception_v3(weights=weights)
        self.inception.eval()
        self.inception.fc = nn.Identity()
        self.inception.to(device)
        self.preprocess = weights.transforms()

    def extract_features(self, images):
        features = []
        with torch.no_grad():
            for img in images:
                inp = self.preprocess(img).unsqueeze(0).to(self.device)
                feature = self.inception(inp)
                features.append(feature.cpu().numpy())
        return np.concatenate(features, axis=0)

def calculate_fid(real_features, fake_features):
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)
    mu_fake = np.mean(fake_features, axis=0)
    sigma_fake = np.cov(fake_features, rowvar=False)
    mean_diff_squared = np.sum((mu_real - mu_fake) ** 2)
    covmean = linalg.sqrtm(sigma_real.dot(sigma_fake))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    trace_term = np.trace(sigma_real + sigma_fake - 2 * covmean)
    fid = mean_diff_squared + trace_term
    return fid

def generate_fake_image(model, noise_scheduler, device):
    with torch.no_grad():
        noisy_image = torch.randn(1, 3, 64, 64).to(device)
        for t in reversed(range(noise_scheduler.num_train_timesteps)):
            noise_pred = model(noisy_image, t).sample
            step_output = noise_scheduler.step(noise_pred, t, noisy_image)
            noisy_image = step_output.prev_sample
    return noisy_image.squeeze(0)

def compute_fid(real_imgs, model, noise_scheduler, device, num_samples=100):
    feature_extractor = InceptionV3FeatureExtractor(device)
    
    # Extract features for real images
    real_features = []
    for img in real_imgs:
        real_features.append(feature_extractor.extract_features([img]))
    real_features = np.concatenate(real_features, axis=0)
    
    # Generate fake images and extract features
    fake_features = []
    for i in range(num_samples):
        fake_img = generate_fake_image(model, noise_scheduler, device)
        fake_img = (fake_img * 0.5 + 0.5).clamp(0,1)
        fake_features.append(feature_extractor.extract_features([fake_img]))
    fake_features = np.concatenate(fake_features, axis=0)
    
    fid_score = calculate_fid(real_features, fake_features)
    return fid_score

# Prepare real images from your dataloader (using the same dataloader from earlier)
def get_all_real_images(dataloader, max_imgs=1000):
    all_images = []
    for _, (imgs, _) in enumerate(dataloader):
        all_images.append(imgs)
        if len(torch.cat(all_images, dim=0)) >= max_imgs:
            break
    return torch.cat(all_images, dim=0)[:max_imgs]


print("Collecting real images for FID evaluation...")
real_imgs = get_all_real_images(dataloader)

print("Computing FID score for diffusion model...")
fid_score = compute_fid(real_imgs, model, noise_scheduler, device, num_samples=min(100, len(real_imgs)))

print(f"Fréchet Inception Distance (FID): {fid_score:.4f}")