# GAN

# EXP-1

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# Generator Network
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # First layer - input is the latent vector Z (nz-dimension)
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.Dropout(0.3),  # Dropout layer to enhance diversity

            # Second layer
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.Dropout(0.3),  # Another dropout layer

            # Third layer
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            # Fourth layer
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            # Output layer - generates the output image with nc channels (number of color channels)
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()  # Tanh activation to output values between -1 and 1
        )

    def forward(self, input):
        return self.net(input)
    
    

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, ndf, nc):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # output. 1 x 1 x 1
        )

    def forward(self, input):
        return self.net(input)

# Hyperparameters
nz = 100  # Size of z latent vector (i.e. size of generator input)
ngf = 64  # Size of feature maps in generator
ndf = 64  # Size of feature maps in discriminator
nc = 3    # Number of channels in the training images. For color images this is 3
#torch.manual_seed(42)

# Create the generator and discriminator
netG = Generator(nz, ngf, nc)
netD = Discriminator(ndf, nc)

# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netG.apply(weights_init)
netD.apply(weights_init)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG.to(device)
netD.to(device)

# Optimizers
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))


# Loss function
criterion = nn.BCELoss()

# Training Data
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

# Training the GAN
num_epochs = 50
real_label = 1
fake_label = 0
G_losses = []
D_losses = []
fixed_noise = torch.randn(64, nz, 1, 1, device=device)  # 64 is the number of images you want to generate
img_list = []

# Training loop corrections
for epoch in range(num_epochs):
    for i, data in enumerate(trainloader, 0):
        ############################
        # (1) Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        netD.zero_grad()

        # Train with real data
        real_data = data[0].to(device)
        batch_size = real_data.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)

        output = netD(real_data).squeeze()  # Flatten the output
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # Train with fake data
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_data = netG(noise)
        label.fill_(fake_label)
        output = netD(fake_data.detach()).squeeze()  # Flatten the output
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()

        errD = errD_real + errD_fake  # Total discriminator loss
        optimizerD.step()

        ############################
        # (2) Update Generator: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # Fake labels are real for generator cost
        output = netD(fake_data).squeeze()  # Flatten the output
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()

        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z1)): %.4f\tD(G(z2)): %.4f'
                  % (epoch, num_epochs, i, len(trainloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

    # Check progress by generating fake image after each epoch
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# EXP-2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.utils as vutils

# Generate a fixed random vector for image generation
fixed_noise = torch.randn(100, nz, 1, 1, device=device)  # 100 for a 10x10 grid

# Generate images from the fixed noise
with torch.no_grad():
    netG.eval()  # Set the generator to evaluation mode
    generated_images = netG(fixed_noise).detach().cpu()

# Create a grid of images
grid = vutils.make_grid(generated_images, nrow=10, normalize=True)

# Function to display the images
plt.figure(figsize=(15, 15))
plt.axis("off")
plt.title("Generated Images Grid")
plt.imshow(np.transpose(grid, (1, 2, 0)))  # Convert from Tensor image
plt.show()


# EXP-3

In [None]:
import torchvision.utils as vutils
from torch import optim, nn

def train_gan(training_ratios, num_epochs=50):
    results = {}
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)

    for ratio in training_ratios:
        netG = Generator(nz, ngf, nc).to(device)
        netD = Discriminator(ndf, nc).to(device)
        netG.apply(weights_init)
        netD.apply(weights_init)

        optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
        optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
        criterion = nn.BCELoss()

        G_losses = []
        D_losses = []
        image_snapshots = []

        for epoch in range(num_epochs):
            for i, data in enumerate(trainloader, 0):
                real_data = data[0].to(device)
                batch_size = real_data.size(0)
                real_label = torch.full((batch_size,), 1, dtype=torch.float, device=device)
                fake_label = torch.full((batch_size,), 0, dtype=torch.float, device=device)

                # Train discriminator 'D' times
                for _ in range(ratio['D']):
                    netD.zero_grad()
                    output_real = netD(real_data).squeeze()
                    errD_real = criterion(output_real, real_label)
                    errD_real.backward()

                    noise = torch.randn(batch_size, nz, 1, 1, device=device)
                    fake_data = netG(noise)
                    output_fake = netD(fake_data.detach()).squeeze()
                    errD_fake = criterion(output_fake, fake_label)
                    errD_fake.backward()

                    optimizerD.step()

                # Train generator 'G' times
                for _ in range(ratio['G']):
                    netG.zero_grad()
                    noise = torch.randn(batch_size, nz, 1, 1, device=device)
                    fake_data = netG(noise)
                    output = netD(fake_data).squeeze()
                    errG = criterion(output, real_label)
                    errG.backward()
                    optimizerG.step()

                # Logging
                G_losses.append(errG.item())
                D_losses.append((errD_real + errD_fake).item())

                if i % 100 == 0 or (i == len(trainloader)-1 and epoch == num_epochs-1):
                    print(f"Ratio: {ratio['label']} | Epoch [{epoch+1}/{num_epochs}] | Batch [{i}/{len(trainloader)}] | "
                          f"Loss_D: {(errD_real + errD_fake).item():.4f} | Loss_G: {errG.item():.4f}")
                    with torch.no_grad():
                        fixed_fake = netG(fixed_noise).detach().cpu()
                        image_snapshots.append(vutils.make_grid(fixed_fake, padding=2, normalize=True))

        # Save results
        results[ratio['label']] = {
            'G_losses': G_losses,
            'D_losses': D_losses,
            'images': image_snapshots,
            'netG': netG
        }

    return results



def plot_losses(results):
    plt.figure(figsize=(12, 8))
    for key, result in results.items():
        G_losses = result['G_losses']
        D_losses = result['D_losses']
        plt.plot(G_losses, label=f'G Loss {key}')
        plt.plot(D_losses, label=f'D Loss {key}')
    plt.title("Generator and Discriminator Loss During Training")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()


# Define training ratios to test
training_ratios = [
    {'D': 1, 'G': 2, 'label': '1:2'},
    {'D': 2, 'G': 1, 'label': '2:1'},
    {'D': 2, 'G': 2, 'label': '2:2'}
]


# Run the experiment
results = train_gan(training_ratios)

# Now calling the plot_losses function to visualize the losses
plot_losses(results)


In [None]:
def plot_generated_images_from_ratios(results, nrow=10, image_size=(18, 18)):
    num_configs = len(results)
    fig, axes = plt.subplots(num_configs, 1, figsize=(image_size[0], image_size[1] * num_configs))

    # Ensure axes is iterable
    if num_configs == 1:
        axes = [axes]

    for ax, (label, result) in zip(axes, results.items()):
        image_snapshots = result['images']
        last_image_batch = image_snapshots[-1]  # Last snapshot

        grid = vutils.make_grid(last_image_batch, nrow=nrow, normalize=True)

        ax.imshow(np.transpose(grid, (1, 2, 0)))
        ax.axis('off')
        ax.set_title(f"Generated Images - Training Ratio {label}")

    plt.tight_layout()
    plt.show()
    
plot_generated_images_from_ratios(results)    


# FID

In [None]:
import os
from torchvision.utils import save_image
from torch.utils.data import DataLoader

os.makedirs("GAN_fid_images/real", exist_ok=True)

real_loader = DataLoader(trainset, batch_size=1, shuffle=True)
for i, (img, _) in enumerate(real_loader):
    if i == 1000:
        break
    save_image(img, f"GAN_fid_images/real/real_{i}.png", normalize=True)


In [None]:
def save_generated_images(netG, save_path, nz=100):
    os.makedirs(save_path, exist_ok=True)
    netG.eval()
    with torch.no_grad():
        for i in range(1000):
            noise = torch.randn(1, nz, 1, 1, device=device)
            fake_img = netG(noise).detach().cpu()
            save_image(fake_img, f"{save_path}/fake_{i}.png", normalize=True)


In [None]:
save_generated_images(results['1:2']['netG'], "GAN_fid_images/fake_1_2")
save_generated_images(results['2:1']['netG'], "GAN_fid_images/fake_2_1")
save_generated_images(results['2:2']['netG'], "GAN_fid_images/fake_2_2")


In [None]:
!python -m pytorch_fid GAN_fid_images/real GAN_fid_images/fake_1_2
!python -m pytorch_fid GAN_fid_images/real GAN_fid_images/fake_2_1
!python -m pytorch_fid GAN_fid_images/real GAN_fid_images/fake_2_2


# EXP-5

In [None]:
def interpolate_multiple_ratios(results, ratios=['1:2', '2:1', '2:2'], nz=100, steps=10):
    fig, axes = plt.subplots(len(ratios), 1, figsize=(steps * 2, len(ratios) * 3))

    if len(ratios) == 1:
        axes = [axes]  # Ensure iterable

    for ax, label in zip(axes, ratios):
        netG = results[label]['netG']
        netG.eval()

        z_start = torch.randn(1, nz, 1, 1, device=device)
        z_end = torch.randn(1, nz, 1, 1, device=device)

        interpolated_z = [(1 - alpha) * z_start + alpha * z_end for alpha in torch.linspace(0, 1, steps)]
        interpolated_z = torch.cat(interpolated_z, dim=0)

        with torch.no_grad():
            interpolated_images = netG(interpolated_z).cpu()

        grid = vutils.make_grid(interpolated_images, nrow=steps, normalize=True)
        ax.imshow(np.transpose(grid, (1, 2, 0)))
        ax.set_title(f"Latent Interpolation - {label}")
        ax.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
interpolate_multiple_ratios(results)


# DDPM

In [None]:
from datasets import load_dataset
import os
from PIL import Image
from tqdm import tqdm

# Load the full AFHQ dataset (train split)
dataset = load_dataset("huggan/AFHQ", split="train")

# Define output root directory
root_dir = "afhq_all"
os.makedirs(root_dir, exist_ok=True)

# Map numerical labels to class names
label_map = {0: 'cat', 1: 'dog', 2: 'wild'}

# Iterate and save images in class-specific subfolders
for i in tqdm(range(len(dataset))):
    item = dataset[i]
    img = item['image']
    label = label_map[item['label']]
    
    class_dir = os.path.join(root_dir, label)
    os.makedirs(class_dir, exist_ok=True)
    
    img_path = os.path.join(class_dir, f"afhq_{i}.jpg")
    img.save(img_path)


In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import utils as vutils
from torch.optim import AdamW
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers import get_cosine_schedule_with_warmup

def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = np.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb

def nonlinearity(x):
    return x*torch.sigmoid(x)


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

In [None]:
class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv

        # Optional convolution after upsampling
        if self.with_conv:
            self.conv = nn.Conv2d(
                in_channels, in_channels,
                kernel_size=3, stride=1, padding=1
            )

    def forward(self, x):
        # Nearest-neighbor upsampling (2x)
        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")

        # Apply conv
        if self.with_conv:
            x = self.conv(x)

        return x

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv

        # convolution for strided downsampling
        if self.with_conv:
            self.conv = nn.Conv2d(
                in_channels, in_channels,
                kernel_size=3, stride=2, padding=0
            )

    def forward(self, x):
        if self.with_conv:
            # Pad input to maintain even spatial dimensions
            pad = (0, 1, 0, 1)  # (left, right, top, bottom)
            x = nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            # Average pooling if conv is not used
            x = nn.functional.avg_pool2d(x, kernel_size=2, stride=2)

        return x

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout, temb_channels=512):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.use_conv_shortcut = conv_shortcut

        # First normalization and convolution
        self.norm1 = Normalize(in_channels)
        self.conv1 = nn.Conv2d(in_channels, self.out_channels, kernel_size=3, stride=1, padding=1)

        # Time embedding projection to match output channels
        self.temb_proj = nn.Linear(temb_channels, self.out_channels)

        # Second normalization and convolution with dropout
        self.norm2 = Normalize(self.out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, stride=1, padding=1)

        # Shortcut projection (1x1 or 3x3 conv if needed)
        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = nn.Conv2d(in_channels, self.out_channels, kernel_size=3, padding=1)
            else:
                self.nin_shortcut = nn.Conv2d(in_channels, self.out_channels, kernel_size=1)

    def forward(self, x, temb):
        h = self.norm1(x)
        h = nonlinearity(h)
        h = self.conv1(h)

        # Add projected time embedding
        h += self.temb_proj(nonlinearity(temb))[:, :, None, None]

        h = self.norm2(h)
        h = nonlinearity(h)
        h = self.dropout(h)
        h = self.conv2(h)

        # Apply shortcut if dimensions don't match
        if self.in_channels != self.out_channels:
            x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)

        return x + h

In [None]:
class Unet(nn.Module):
    def __init__(self, config, verbose=False):
        super().__init__()

        # --- Configuration Parameters ---
        self.config = config
        ch = config['ch']
        out_ch = config['out_ch']
        ch_mult = tuple(config['ch_mult'])
        num_res_blocks = config['num_res_blocks']
        attn_resolutions = config['attn_resolutions']
        dropout = config['dropout']
        in_channels = config['in_channels']
        resolution = config['image_size']
        resamp_with_conv = config['resamp_with_conv']

        self.ch = ch
        self.temb_ch = ch * 4
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # --- Time Embedding ---
        self.temb = nn.Module()
        self.temb.dense = nn.ModuleList([
            nn.Linear(ch, self.temb_ch),
            nn.Linear(self.temb_ch, self.temb_ch)
        ])

        # --- Input Convolution ---
        self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)

        # --- Downsampling Blocks ---
        self.down = nn.ModuleList()
        curr_res = resolution
        in_ch_mult = (1,) + ch_mult

        for i_level in range(self.num_resolutions):
            if verbose:
                print(f'Downsampling - Resolution: {curr_res}')
            block = nn.ModuleList()
            attn = nn.ModuleList()

            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]

            for i_block in range(num_res_blocks):
                block.append(ResnetBlock(
                    in_channels=block_in,
                    out_channels=block_out,
                    temb_channels=self.temb_ch,
                    dropout=dropout
                ))
                block_in = block_out

                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))

            down = nn.Module()
            down.block = block
            down.attn = attn

            if i_level != self.num_resolutions - 1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2

            self.down.append(down)

        # --- Middle Blocks ---
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout
        )
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(
            in_channels=block_in,
            out_channels=block_in,
            temb_channels=self.temb_ch,
            dropout=dropout
        )

        # --- Upsampling Blocks ---
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()

            block_out = ch * ch_mult[i_level]
            skip_in = ch * ch_mult[i_level]

            for i_block in range(num_res_blocks + 1):
                if i_block == num_res_blocks:
                    skip_in = ch * in_ch_mult[i_level]

                block.append(ResnetBlock(
                    in_channels=block_in + skip_in,
                    out_channels=block_out,
                    temb_channels=self.temb_ch,
                    dropout=dropout
                ))
                block_in = block_out

                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))

            up = nn.Module()
            up.block = block
            up.attn = attn

            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2

            self.up.insert(0, up)

        # --- Output Layers ---
        self.norm_out = Normalize(block_in)
        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)


    def forward(self, x, t):
        assert x.shape[2] == x.shape[3] == self.resolution, "Input resolution mismatch."

        # --- Timestep Embedding ---
        temb = get_timestep_embedding(t, self.ch)
        temb = nonlinearity(self.temb.dense[0](temb))
        temb = self.temb.dense[1](temb)

        # --- Downsampling ---
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)

            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))

        # --- Middle ---
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)

        # --- Upsampling ---
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
                if len(self.up[i_level].attn) > 0:
                    h = self.up[i_level].attn[i_block](h)

            if i_level != 0:
                h = self.up[i_level].upsample(h)

        # --- Final Output ---
        h = self.norm_out(h)
        h = nonlinearity(h)
        return self.conv_out(h)


In [None]:
class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.in_channels = in_channels
        self.norm = Normalize(in_channels)

        # Projections for query, key, value
        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        # Output projection
        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        h = self.norm(x)

        # Compute Q, K, V
        q = self.q(h)
        k = self.k(h)
        v = self.v(h)

        b, c, h, w = q.shape
        q = q.reshape(b, c, h * w).permute(0, 2, 1)  # [B, HW, C]
        k = k.reshape(b, c, h * w)                   # [B, C, HW]
        v = v.reshape(b, c, h * w)                   # [B, C, HW]

        # Attention weights: [B, HW, HW]
        attn = torch.bmm(q, k) * (c ** -0.5)
        attn = torch.softmax(attn, dim=2)

        # Apply attention to values
        attn = attn.permute(0, 2, 1)                 # [B, HW, HW]
        h_out = torch.bmm(v, attn).reshape(b, c, h, w)

        return x + self.proj_out(h_out)


In [None]:
class NoiseScheduler:
    def __init__(self, diffusion_config, device):
        self.device = device
        self.eta = diffusion_config['eta'] 
        self.sampling_scheme = diffusion_config['sampling_scheme']
        self.num_timesteps = diffusion_config['num_timesteps']

        # Linear beta schedule from beta_start to beta_end
        self.betas = torch.linspace(
            diffusion_config['beta_start'],
            diffusion_config['beta_end'],
            self.num_timesteps,
            device=device
        )

        # Alpha and its derived forms
        self.alphas = 1.0 - self.betas
        self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1.0 - self.alpha_cum_prod)

    def add_noise(self, original, noise, t):
        """
        Adds noise to the original input `x_0` at timestep `t`:
            x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * noise
        """
        batch_size = original.shape[0]

        # Select alpha_t and (1 - alpha_t) for each sample
        sqrt_alpha = self.sqrt_alpha_cum_prod[t].reshape(batch_size)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alpha_cum_prod[t].reshape(batch_size)

        # Expand to match image shape
        for _ in range(len(original.shape) - 1):
            sqrt_alpha = sqrt_alpha.unsqueeze(-1)
            sqrt_one_minus_alpha = sqrt_one_minus_alpha.unsqueeze(-1)

        return sqrt_alpha * original + sqrt_one_minus_alpha * noise

    def sample_prev_timestep(self, x_t, noise_pred, t_i_minus_1, t_i=None):
        """
        DDPM reverse step: estimate x_{t-1} from x_t and predicted noise.
        """
        # Predict x0 from x_t and predicted noise
        x0_pred = (x_t - self.sqrt_one_minus_alpha_cum_prod[t_i] * noise_pred) / self.sqrt_alpha_cum_prod[t_i]
        x0_pred = torch.clamp(x0_pred, -1.0, 1.0)

        # Compute mean of posterior q(x_{t-1} | x_t, x0)
        mean = x_t - (self.betas[t_i] * noise_pred) / self.sqrt_one_minus_alpha_cum_prod[t_i]
        mean = mean / torch.sqrt(self.alphas[t_i])

        if t_i_minus_1 == 0:
            return mean, x0_pred

        if self.sampling_scheme == 'DDPM':
            # Variance for noise addition
            var_ratio = (1.0 - self.alpha_cum_prod[t_i_minus_1]) / (1.0 - self.alpha_cum_prod[t_i])
            variance = var_ratio * self.betas[t_i]
            sigma = torch.sqrt(variance)

            # Add Gaussian noise
            noise = torch.randn_like(x_t).to(self.device)
            return mean + sigma * noise, x0_pred


In [None]:
# --- Constants ---
img_size = 64
img_channels = 3
batch_size = 16
epochs = 100
lr = 3e-4
Gen_samples = 100

# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Configurations ---
model_config = {
    'type': "simple",
    'in_channels': img_channels,
    'out_ch': img_channels,
    'ch': 128,
    'ch_mult': [1, 1, 2, 2, 4],
    'num_res_blocks': 3,
    'attn_resolutions': [16],
    'dropout': 0.0,
    'var_type': 'fixedsmall',
    'ema_rate': 0.999,
    'ema': True,
    'resamp_with_conv': True,
    'image_size': img_size
}

diffusion_config = {
    'num_timesteps': 1000,
    'beta_start': 0.0001,
    'beta_end': 0.02,
    'sampling_scheme': 'DDPM',
    'eta': 0,
}

train_config = {
    'batch_size': batch_size,
    'num_epochs': epochs,
    'Gen_samples': Gen_samples,
    'num_grid_rows': 10,
    'lr': lr,
    'betas': (0.9, 0.95),
    'max_grad_norm': 1
}

# --- Data Loading ---
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor()
])



# Automatically get the current working directory and dataset path
data_root = os.path.join(os.getcwd(), "afhq_all")
dataset = ImageFolder(root=data_root, transform=transform)


#dataset = ImageFolder(root="/home/adityab/ADRL/A1/Dlcv_ass/afhq_all", transform=transform)
val_size = 1000
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# --- Model and Scheduler ---
#model = Unet(model_config).to(device)
model = torch.compile(Unet(model_config).to(device))
scheduler = NoiseScheduler(diffusion_config, device)
optimizer = AdamW(model.parameters(), lr=train_config['lr'], betas=train_config['betas'], weight_decay=0.002)

num_warmup_steps = len(train_loader) * 2
num_training_steps = len(train_loader) * epochs
opt_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

criterion = nn.MSELoss()




In [None]:
# --- Sampler ---
@torch.no_grad()
def sampler(epoch, model, scheduler, diffusion_config, model_config, train_config, show_img=True):
    model.eval()
    num_samples = train_config['Gen_samples']
    img_size = model_config['image_size']
    channels = model_config['out_ch']
    num_timesteps = diffusion_config['num_timesteps']

    # Start from Gaussian noise
    x = torch.randn(num_samples, channels, img_size, img_size).to(device)

    # Reverse diffusion loop
    for i in reversed(range(num_timesteps)):
        t = torch.full((num_samples,), i, dtype=torch.long).to(device)
        noise_pred = model(x, t)
        x, _ = scheduler.sample_prev_timestep(x, noise_pred, t_i_minus_1=i - 1 if i > 0 else 0, t_i=i)

    # Convert to [0,1] range
    x = (x.clamp(-1, 1) + 1) / 2.0
    grid = vutils.make_grid(x, nrow=train_config['num_grid_rows'])

    if show_img:
        plt.figure(figsize=(10, 10))
        plt.axis('off')
        plt.title(f"Generated Samples (Epoch {epoch})")
        plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
        plt.show()

    return x

In [None]:
timesteps_list = [200, 500, 1000]  # Timesteps to experiment with
all_epoch_losses = {}             # Stores epoch-wise losses for each timestep

for steps in timesteps_list:
    print(f"\n====== Training with {steps} timesteps ======\n")

    # --- Update config ---
    diffusion_config['num_timesteps'] = steps

    # --- Re-init model, scheduler, optimizer, scheduler ---
    #model = Unet(model_config).to(device)
    model = torch.compile(Unet(model_config).to(device))
    scheduler = NoiseScheduler(diffusion_config, device)

    optimizer = AdamW(model.parameters(), lr=train_config['lr'], betas=train_config['betas'], weight_decay=0.002)

    num_warmup_steps = len(train_loader) * 2
    num_training_steps = len(train_loader) * epochs
    opt_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

    criterion = nn.MSELoss()
    epoch_losses = []  # Store average loss per epoch

    # --- Training Loop ---
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0

        for images, _ in tqdm(train_loader, desc=f"[{steps} steps] Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            images = images.to(device) * 2 - 1  # Normalize to [-1, 1]

            noise = torch.randn_like(images).to(device)
            t = torch.randint(0, diffusion_config['num_timesteps'], (images.size(0),), device=device)
            noisy_images = scheduler.add_noise(images, noise, t)

            noise_pred = model(noisy_images, t)
            loss = criterion(noise_pred, noise)

            loss.backward()
            optimizer.step()
            opt_scheduler.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        epoch_losses.append(avg_loss)

        print(f"Epoch [{epoch+1}/{epochs}] | Avg Loss: {avg_loss:.4f} | LR: {opt_scheduler.get_last_lr()[0]:.6f}")

        # --- Show Samples Every 20 Epochs ---
        if (epoch + 1) % 20 == 0:
            sampler(epoch + 1, model, scheduler, diffusion_config, model_config, train_config, show_img=True)

    # Store for plotting
    all_epoch_losses[steps] = epoch_losses
    print(f"\n[✓] Training complete for {steps} timesteps.\n")

# --- Plot all loss curves ---
plt.figure(figsize=(10, 6))
for steps, losses in all_epoch_losses.items():
    plt.plot(range(1, len(losses)+1), losses, label=f"{steps} steps")

plt.title("Average Training Loss Across Timesteps")
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.legend(title="Timesteps")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Note: Due to large cell outputs, the generated samples and logs were not fully visible in the previous cell interface(in VS code). To address this, the following additional code have been provided to extract and saved images (for DDPM 1000-step samples & loss plot) directly from the .ipynb file. (When opened in Google Colab, the full output from the previous cell is visible, but it doesn't appear in Visual Studio.)

In [None]:
import json
import base64
import os

# Automatically get current working directory
notebook_filename = "BenjaminDebbarma3(22554).ipynb"  # Name only, no path
notebook_path = os.path.join(os.getcwd(), notebook_filename)
output_dir = os.path.join(os.getcwd(), "extracted_images_1000")
os.makedirs(output_dir, exist_ok=True)

with open(notebook_path, 'r', encoding='utf-8') as f:
    notebook = json.load(f)

img_count = 0
extract_next_image = False

for cell in notebook.get("cells", []):
    if cell.get("cell_type") == "code":
        for output in cell.get("outputs", []):
            # Check for 1000-step log in stream output
            if output.get("output_type") == "stream" and "1000" in "".join(output.get("text", "")):
                extract_next_image = True

            # Extract image after matching 1000-step indicator
            if extract_next_image and output.get("output_type") == "display_data":
                data = output.get("data", {})
                if "image/png" in data:
                    img_data = base64.b64decode(data["image/png"])
                    out_path = os.path.join(output_dir, f"image_{img_count:03d}.png")
                    with open(out_path, "wb") as img_file:
                        img_file.write(img_data)
                    img_count += 1
                    extract_next_image = False  # reset

print(f"Extracted images associated with 1000 steps")


In [None]:
# --- Plot all loss curves ---
plt.figure(figsize=(10, 6))
for steps, losses in all_epoch_losses.items():
    plt.plot(range(1, len(losses)+1), losses, label=f"{steps} steps")

plt.title("Average Training Loss Across Timesteps")
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.legend(title="Timesteps")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
import os
import torch
from torchvision.utils import save_image
from torchvision import transforms
from tqdm import tqdm
import subprocess

# --- Setup ---
real_dir = "ddpm_fid_images/real"
fake_dir = "ddpm_fid_images/fake"
num_images = 1000  # Total images for FID

os.makedirs(real_dir, exist_ok=True)
os.makedirs(fake_dir, exist_ok=True)

# --- Save real images ---
def save_real_images(dataloader, out_dir, max_images):
    count = 0
    for batch, _ in tqdm(dataloader, desc="Saving Real Images"):
        for img in batch:
            save_image(img, os.path.join(out_dir, f"{count:05d}.png"))
            count += 1
            if count >= max_images:
                return

save_real_images(val_loader, real_dir, num_images)

# --- Save generated images ---
@torch.no_grad()
def save_generated_images(model, scheduler, out_dir, diffusion_config, model_config, max_images):
    model.eval()
    batch_size = 50
    num_batches = max_images // batch_size
    img_size = model_config['image_size']
    channels = model_config['out_ch']
    num_timesteps = diffusion_config['num_timesteps']
    device = next(model.parameters()).device

    count = 0
    for _ in tqdm(range(num_batches), desc="Saving Generated Images"):
        x = torch.randn(batch_size, channels, img_size, img_size).to(device)
        for t in reversed(range(num_timesteps)):
            t_tensor = torch.full((batch_size,), t, dtype=torch.long).to(device)
            noise_pred = model(x, t_tensor)
            x, _ = scheduler.sample_prev_timestep(x, noise_pred, t_i_minus_1=t-1 if t > 0 else 0, t_i=t)
        x = (x.clamp(-1, 1) + 1) / 2.0
        for img in x:
            save_image(img, os.path.join(out_dir, f"{count:05d}.png"))
            count += 1
            if count >= max_images:
                return

save_generated_images(model, scheduler, fake_dir, diffusion_config, model_config, num_images)

In [None]:
# --- Compute FID ---
print("\nComputing FID...")
subprocess.run(["python", "-m", "pytorch_fid", real_dir, fake_dir])