In [None]:
# Install the necessary libraries
!pip install -q diffusers transformers accelerate ftfy

print("Libraries installed successfully!")

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Subset

# Define the transformation pipeline
preprocess = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

# Download and load the FULL training data
full_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=preprocess)

# --- SPEEDUP CHANGE: Use a subset of the data ---
# We'll use only the first 10,000 images for faster training
num_images_for_subset = 10000
subset_indices = list(range(num_images_for_subset))
subset_dataset = Subset(full_dataset, subset_indices)

print(f"Using a subset of {len(subset_dataset)} images instead of {len(full_dataset)}.")

# --- Use the subset_dataset for the dataloader ---
dataloader = torch.utils.data.DataLoader(subset_dataset, batch_size=128, shuffle=True)

# Visualization remains the same
dataiter = iter(dataloader)
images, labels = next(dataiter)
img_grid = torchvision.utils.make_grid(images[:32])
img_grid = img_grid / 2 + 0.5
npimg = img_grid.numpy()
plt.figure(figsize=(15, 7))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
plt.title("Sample Images from CIFAR-10 Dataset Subset")
plt.show()

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
from dataclasses import dataclass
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from torch.optim import AdamW
from diffusers.optimization import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
import os

# --- 1. Configuration ---
@dataclass
class TrainingConfig:
    image_size = 64
    train_batch_size = 128
    eval_batch_size = 16
    num_epochs = 15  
    learning_rate = 2e-4
    lr_warmup_steps = 500
    save_image_epochs = 2
    output_dir = "ddpm-cifar10-64-fast"
    seed = 0

config = TrainingConfig()

# --- 2. Dataloader ---
preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize((config.image_size, config.image_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),
])
dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=preprocess)
# Note: DataLoader batch_size is now the total size that will be split
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=2)

# --- 3. Setup Device and Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet2DModel(
    sample_size=config.image_size, in_channels=3, out_channels=3, layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
)

# WRAP MODEL WITH nn.DataParallel
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model.to(device) # Move the model to the primary GPU

# --- 4. Scheduler and Optimizer ---
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(dataloader) * config.num_epochs),
)

# --- 5. Training Loop ---
for epoch in range(config.num_epochs):
    progress_bar = tqdm(total=len(dataloader))
    progress_bar.set_description(f"Epoch {epoch + 1}/{config.num_epochs}")

    for step, batch in enumerate(dataloader):
        clean_images = batch[0].to(device) # Move data to the device
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]

        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # Predict the noise residual
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        # Backpropagation
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        progress_bar.set_postfix(**logs)

    # --- 6. Generate and Save Images ---
    if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
        # Access the original model through .module attribute
        unwrapped_model = model.module if isinstance(model, nn.DataParallel) else model

        pipeline = DDPMPipeline(unet=unwrapped_model, scheduler=noise_scheduler)
        generator = torch.manual_seed(config.seed)
        images = pipeline(generator=generator, batch_size=config.eval_batch_size, output_type="numpy").images

        print(f"\nGenerating {config.eval_batch_size} sample images at epoch {epoch+1}:")
        image_grid = torchvision.utils.make_grid(torch.from_numpy(images).permute(0, 3, 1, 2))
        npimg = image_grid.numpy() / 2 + 0.5
        plt.figure(figsize=(8, 8))
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.axis("off")
        plt.show()

print("\nDiffusion model training complete!")

# --- 7. Save final model and store generated images ---
final_unwrapped_model = model.module if isinstance(model, nn.DataParallel) else model
final_pipeline = DDPMPipeline(unet=final_unwrapped_model, scheduler=noise_scheduler)

# Create output directory if it doesn't exist
if not os.path.exists(config.output_dir):
    os.makedirs(config.output_dir)

print(f"Saving final model pipeline to {config.output_dir}...")
final_pipeline.save_pretrained(config.output_dir)

# Store the final generated images for later comparison
diffusion_images = final_pipeline(
    generator=torch.manual_seed(config.seed),
    batch_size=config.eval_batch_size,
    output_type="numpy"
).images

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTForImageClassification
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torchvision

# --- 1. GAN Configuration ---
latent_dim = 100
lr_gan = 0.0002
beta1 = 0.5
gan_epochs = 30 
gan_batch_size = 128


# --- 2. Model Definitions ---

# Generator
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False), nn.Tanh()
        )
    def forward(self, z):
        return self.model(z)

# The Discriminator using a pre-trained ViT
class ViTDiscriminator(nn.Module):
    def __init__(self):
        super(ViTDiscriminator, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=1, ignore_mismatched_sizes=True)
        self.vit.classifier = nn.Linear(self.vit.config.hidden_size, 1)

    def forward(self, img):
        # Before passing the image to the ViT, we resize it from 64x64 to 224x224 for pretrained vit compatibility
        upsampled_img = F.interpolate(img, size=(224, 224), mode='bilinear', align_corners=False)
        return self.vit(upsampled_img).logits


# --- 3. Initialization ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fixed_noise = torch.randn(64, latent_dim, 1, 1, device=device)

netG = Generator(latent_dim)
netD = ViTDiscriminator()

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for ViT-GAN!")
    netG = nn.DataParallel(netG)
    netD = nn.DataParallel(netD)

netG.to(device)
netD.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_gan, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_gan, betas=(beta1, 0.999))

gan_dataloader = torch.utils.data.DataLoader(subset_dataset, batch_size=gan_batch_size, shuffle=True, num_workers=2)

# --- 4. GAN Training Loop ---
print("Starting ViT-GAN Training...")
for epoch in range(gan_epochs):
    progress_bar = tqdm(total=len(gan_dataloader))
    progress_bar.set_description(f"Epoch {epoch + 1}/{gan_epochs}")
    
    for i, data in enumerate(gan_dataloader, 0):
        ## Train Discriminator ##
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        real_label = torch.full((b_size,), 1., dtype=torch.float, device=device)
        
        output_real = netD(real_cpu).view(-1)
        errD_real = criterion(output_real, real_label)
        errD_real.backward()

        noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
        fake = netG(noise)
        fake_label = torch.full((b_size,), 0., dtype=torch.float, device=device)
        
        output_fake = netD(fake.detach()).view(-1)
        errD_fake = criterion(output_fake, fake_label)
        errD_fake.backward()
        
        errD = errD_real + errD_fake
        optimizerD.step()

        ## Train Generator ##
        netG.zero_grad()
        real_label.fill_(1.)
        output = netD(fake).view(-1)
        errG = criterion(output, real_label)
        errG.backward()
        optimizerG.step()
        
        progress_bar.update(1)
        progress_bar.set_postfix({"Loss_D": errD.item(), "Loss_G": errG.item()})

    with torch.no_grad():
        unwrapped_G = netG.module if isinstance(netG, nn.DataParallel) else netG
        fake_samples = unwrapped_G(fixed_noise).detach().cpu()

    print(f"\nGenerating samples after epoch {epoch+1}:")
    grid = torchvision.utils.make_grid(fake_samples, padding=2, normalize=True)
    plt.figure(figsize=(8,8))
    plt.axis("off")
    plt.title(f"ViT-GAN Generated Images - Epoch {epoch+1}")
    plt.imshow(np.transpose(grid,(1,2,0)))
    plt.show()

print("\nViT-GAN training complete!")

# --- 5. Store final generated images for comparison ---
with torch.no_grad():
    unwrapped_G = netG.module if isinstance(netG, nn.DataParallel) else netG
    # Generate the same number of images as the diffusion model for a fair comparison
    vit_gan_images_tensor = unwrapped_G(torch.randn(config.eval_batch_size, latent_dim, 1, 1, device=device)).detach().cpu()
    
    vit_gan_images = vit_gan_images_tensor.permute(0, 2, 3, 1).numpy()
    vit_gan_images = (vit_gan_images * 0.5) + 0.5

In [None]:
 # Install the library for calculating FID score
!pip install -q torch-fidelity

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torch_fidelity import calculate_metrics
from PIL import Image
import os

# --- 1. Helper function to save images ---
def save_images_to_folder(images_numpy, folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    
    for i, img_np in enumerate(images_numpy):
        img_uint8 = (img_np * 255).astype(np.uint8)
        img = Image.fromarray(img_uint8)
        img.save(os.path.join(folder_path, f"image_{i}.png"))



# --- 2. Qualitative (Visual) Comparison ---

print("--- VISUAL COMPARISON ---")

# Display Diffusion Model Images
print("\nImages from Diffusion Model:")
diffusion_grid = torchvision.utils.make_grid(torch.from_numpy(diffusion_images).permute(0, 3, 1, 2))
np_diffusion_grid = diffusion_grid.numpy()
plt.figure(figsize=(8, 8))
plt.imshow(np.transpose(np_diffusion_grid, (1, 2, 0)))
plt.axis("off")
plt.title("Diffusion Model Generated Images")
plt.show()


# Display ViT-GAN Images
print("\nImages from ViT-GAN Model:")

# We convert the numpy array (B, H, W, C) to a torch tensor and
# permute it to (B, C, H, W) which is the format `make_grid` expects.
vit_gan_tensor_for_grid = torch.from_numpy(vit_gan_images).permute(0, 3, 1, 2)
vit_gan_grid = torchvision.utils.make_grid(vit_gan_tensor_for_grid)
np_vit_gan_grid = vit_gan_grid.numpy()
plt.figure(figsize=(8, 8))
# Now we need to transpose it back for matplotlib, just like the diffusion grid
plt.imshow(np.transpose(np_vit_gan_grid, (1, 2, 0)))
plt.axis("off")
plt.title("ViT-GAN Generated Images")
plt.show()


# --- 3. Quantitative (FID Score) Comparison ---

print("\n--- QUANTITATIVE COMPARISON (FID SCORE) ---")
print("A lower FID score is better.")

REAL_IMAGES_DIR = './data/cifar10_real_for_fid'
DIFFUSION_IMAGES_DIR = './gen_images_diffusion'
VIT_GAN_IMAGES_DIR = './gen_images_vit_gan'

save_images_to_folder(diffusion_images, DIFFUSION_IMAGES_DIR)
save_images_to_folder(vit_gan_images, VIT_GAN_IMAGES_DIR)

print("\nPreparing a folder of real images for comparison...")
if not os.path.exists(REAL_IMAGES_DIR):
    real_images_for_fid = []
    for i, (img, _) in enumerate(dataset):
        if i >= 1000:
            break
        img_np = img.permute(1, 2, 0).numpy() * 0.5 + 0.5
        real_images_for_fid.append(img_np)
    save_images_to_folder(np.array(real_images_for_fid), REAL_IMAGES_DIR)
else:
    print("Real images folder already exists.")


print("\nCalculating FID for Diffusion Model... (this may take a minute)")
metrics_dict_diffusion = calculate_metrics(
    input1=DIFFUSION_IMAGES_DIR, 
    input2=REAL_IMAGES_DIR, 
    cuda=True, 
    fid=True,
    verbose=False
)
print(f"Diffusion Model FID: {metrics_dict_diffusion['frechet_inception_distance']:.2f}")

print("\nCalculating FID for ViT-GAN Model...")
metrics_dict_vit_gan = calculate_metrics(
    input1=VIT_GAN_IMAGES_DIR, 
    input2=REAL_IMAGES_DIR, 
    cuda=True, 
    fid=True,
    verbose=False
)
print(f"ViT-GAN FID: {metrics_dict_vit_gan['frechet_inception_distance']:.2f}")