# Base Channels Tuning

This file contains the results of tuning the `base_channels` hyperparameter in the UNet architecture. The `base_channels` parameter controls the number of channels in the first layer of the encoder, and all subsequent layers scale proportionally. We are testing the following values: 8, 16, 32, 64, 96, and 128. The results are found below.


## Base Channels = 8

This cell implements the UNet architecture with **`base_channels = 8`**. The `base_channels` parameter determines the number of feature channels in the first convolutional layer of the encoder, and all subsequent layers scale proportionally.

**Architecture Details:**

With `base_channels = 8`, the channel progression through the UNet encoder is:
- Initial layer: $8$ channels
- Encoder layer 1: $8 \rightarrow 16$ channels
- Encoder layer 2: $16 \rightarrow 32$ channels  
- Encoder layer 3: $32 \rightarrow 64$ channels
- Bottleneck: $64$ channels
- Decoder layers: $64 \rightarrow 32 \rightarrow 16 \rightarrow 8$ channels

**Model Characteristics:**
- **Total Parameters:** Smallest model size (~few hundred thousand parameters)
- **Memory Usage:** Lowest memory footprint
- **Training Speed:** Fastest training and inference
- **Representation Capacity:** Limited - may struggle with complex patterns

This is the most lightweight configuration, suitable for quick experimentation or resource-constrained environments.


In [None]:
# ============================================================
# UT Zappos50K Shoe Diffusion Model - base_channels = 8
# ============================================================

import os
import math
import glob
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt

# -----------------------------
# 0. Config & Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# 1. Dataset Loading (EXACT REPLICATION)
# -----------------------------
import kagglehub

# Download latest version
path = kagglehub.dataset_download("aryashah2k/large-shoe-dataset-ut-zappos50k")

BASE_DIR = "/root/.cache/kagglehub/datasets/aryashah2k/large-shoe-dataset-ut-zappos50k/versions/1"
IMAGE_ROOT = os.path.join(BASE_DIR, "ut-zap50k-images-square")  # we'll use the square images

if not os.path.exists(IMAGE_ROOT):
    raise FileNotFoundError(f"IMAGE_ROOT '{IMAGE_ROOT}' does not exist. Check your base path.")

print("Image root:", IMAGE_ROOT)
print("Contents:", os.listdir(IMAGE_ROOT)[:10])

# Collect all image paths recursively from IMAGE_ROOT
extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff", "*.webp")
image_paths = []
for ext in extensions:
    image_paths.extend(glob.glob(os.path.join(IMAGE_ROOT, "**", ext), recursive=True))

if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {IMAGE_ROOT} with extensions {extensions}.")

print(f"Found {len(image_paths)} image files.")

# Transform: resize & normalize (NO cropping, just warp to 64x64)
image_size = 64
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),                     # [0,1]
    transforms.Normalize([0.5]*3, [0.5]*3),    # [-1,1]
])

class ZapposImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        from PIL import Image
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        # unconditional -> we don't care about labels, return dummy
        return img, 0

dataset = ZapposImageDataset(image_paths, transform=transform)

batch_size = 64
num_workers = 2

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

print("Dataset and DataLoader ready.")

# -----------------------------
# 2. Time Embedding (Sinusoidal)
# -----------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) int or float timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        exponents = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        t = t.float().unsqueeze(1)  # (B,1)
        angles = t * exponents[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

# -----------------------------
# 3. UNet Building Blocks
# -----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_mlp = nn.Linear(time_dim, out_channels)

        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, t_emb):
        """
        x: (B, C, H, W)
        t_emb: (B, time_dim)
        """
        h = self.conv1(x)
        t_added = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_added
        h = self.act(self.norm1(h))

        h = self.conv2(h)
        h = self.act(self.norm2(h))

        return h + self.res_conv(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)
        self.down = nn.Conv2d(out_channels, out_channels, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_dim):
        """
        in_channels: from below
        out_channels: after upsample
        skip_channels: from skip connection
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
        self.res1 = ResidualBlock(out_channels + skip_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        # fix size mismatch if any (shouldn't happen with 64x64 -> 8x8 pyramid, but safe)
        if x.size(-1) != skip.size(-1):
            diff = skip.size(-1) - x.size(-1)
            x = F.pad(x, (diff // 2, diff - diff // 2, diff // 2, diff - diff // 2))
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        return x

# -----------------------------
# 4. UNet for 64x64 RGB Shoes (Unconditional)
# -----------------------------
class ShoeUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_dim=256):
        super().__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.time_dim = time_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim),
        )

        # Initial conv
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        chs = [
            base_channels,
            base_channels * 2,
            base_channels * 4,
            base_channels * 8,
        ]  # [8,16,32,64]

        self.down1 = DownBlock(chs[0], chs[1], time_dim)  # 8->16
        self.down2 = DownBlock(chs[1], chs[2], time_dim)  # 16->32
        self.down3 = DownBlock(chs[2], chs[3], time_dim)  # 32->64

        # Bottleneck
        self.bottleneck = ResidualBlock(chs[3], chs[3], time_dim)

        # Decoder: 8x8 -> 16x16 -> 32x32 -> 64x64
        self.up3 = UpBlock(chs[3], chs[2], skip_channels=chs[3], time_dim=time_dim)
        self.up2 = UpBlock(chs[2], chs[1], skip_channels=chs[2], time_dim=time_dim)
        self.up1 = UpBlock(chs[1], chs[0], skip_channels=chs[1], time_dim=time_dim)

        # Final conv: predict noise ε
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        """
        x: (B, 3, H, W)
        t: (B,) timesteps
        """
        t_emb = self.time_mlp(t)  # (B, time_dim)

        x = self.init_conv(x)

        x, skip1 = self.down1(x, t_emb)  # 64x64 -> 32x32
        x, skip2 = self.down2(x, t_emb)  # 32x32 -> 16x16
        x, skip3 = self.down3(x, t_emb)  # 16x16 -> 8x8

        x = self.bottleneck(x, t_emb)

        x = self.up3(x, skip3, t_emb)    # 8x8 -> 16x16
        x = self.up2(x, skip2, t_emb)    # 16x16 -> 32x32
        x = self.up1(x, skip1, t_emb)    # 32x32 -> 64x64

        x = self.out_conv(x)             # predict noise
        return x

# -----------------------------
# 5. DDPM Diffusion Utilities
# -----------------------------
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.num_steps = num_steps

        self.betas = torch.linspace(beta_start, beta_end, num_steps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alpha_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), self.alpha_cumprod[:-1]], dim=0
        )

        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)

        self.posterior_variance = (
            self.betas
            * (1.0 - self.alpha_cumprod_prev)
            / (1.0 - self.alpha_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_steps, (batch_size,), device=self.device).long()

    def q_sample(self, x0, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0)
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alpha_bar_t = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t):
        """
        Reverse step: p(x_{t-1} | x_t)
        """
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=self.device, dtype=torch.long)

        eps_theta = model(x_t, t_batch)

        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(alpha_t)

        model_mean = sqrt_recip_alpha_t * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * eps_theta
        )

        if t == 0:
            return model_mean
        else:
            posterior_var_t = self.posterior_variance[t]
            noise = torch.randn_like(x_t)
            return model_mean + torch.sqrt(posterior_var_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=8):
        """
        Sample x_0 from the model by starting from pure noise.
        """
        model.eval()
        x = torch.randn(batch_size, 3, image_size, image_size, device=self.device)

        for t in reversed(range(self.num_steps)):
            x = self.p_sample(model, x, t)

        return x

# -----------------------------
# 6. Instantiate Model & Diffusion
# -----------------------------
time_dim = 256
model = ShoeUNet(
    in_channels=3,
    base_channels=8,
    time_dim=time_dim
).to(device)

diffusion = Diffusion(
    num_steps=1000,
    beta_start=1e-4,
    beta_end=0.02,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# -----------------------------
# 7. Training Loop
# -----------------------------
num_epochs = 30  # increase for better visuals

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x, _ in train_loader:   # ignore labels
        x = x.to(device)        # already normalized to [-1,1]

        b = x.size(0)
        t = diffusion.sample_timesteps(b)

        noise = torch.randn_like(x)
        x_t = diffusion.q_sample(x, t, noise=noise)

        pred_noise = model(x_t, t)
        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * b

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.6f}")

print("Training complete.")

# -----------------------------
# 8. Sampling & Visualization
# -----------------------------
@torch.no_grad()
def show_samples(model, diffusion, n=8):
    """
    Generate and show shoe samples from the diffusion model.
    """
    model.eval()
    samples = diffusion.sample(model, image_size=image_size, batch_size=n)
    samples = (samples.clamp(-1, 1) + 1) / 2.0  # back to [0,1]

    grid = utils.make_grid(samples, nrow=min(n, 4))
    plt.figure(figsize=(4 * (n // 4 + 1), 4))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.show()

print("\nBase Channels = 8 Results:")
show_samples(model, diffusion, n=8)


## Base Channels = 16

This cell implements the UNet architecture with **`base_channels = 16`**.

**Architecture Details:**

With `base_channels = 16`, the channel progression through the UNet encoder is:
- Initial layer: $16$ channels
- Encoder layer 1: $16 \rightarrow 32$ channels
- Encoder layer 2: $32 \rightarrow 64$ channels  
- Encoder layer 3: $64 \rightarrow 128$ channels
- Bottleneck: $128$ channels
- Decoder layers: $128 \rightarrow 64 \rightarrow 32 \rightarrow 16$ channels

**Model Characteristics:**
- **Total Parameters:** Small model size (~1-2 million parameters)
- **Memory Usage:** Low memory footprint
- **Training Speed:** Fast training and inference
- **Representation Capacity:** Moderate - good balance for many tasks

This configuration provides a good starting point for experimentation with reasonable capacity while maintaining efficiency.


In [None]:
# ============================================================
# UT Zappos50K Shoe Diffusion Model - base_channels = 16
# ============================================================

import os
import math
import glob
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt

# -----------------------------
# 0. Config & Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# 1. Dataset Loading (EXACT REPLICATION)
# -----------------------------
import kagglehub

# Download latest version
path = kagglehub.dataset_download("aryashah2k/large-shoe-dataset-ut-zappos50k")

BASE_DIR = "/root/.cache/kagglehub/datasets/aryashah2k/large-shoe-dataset-ut-zappos50k/versions/1"
IMAGE_ROOT = os.path.join(BASE_DIR, "ut-zap50k-images-square")  # we'll use the square images

if not os.path.exists(IMAGE_ROOT):
    raise FileNotFoundError(f"IMAGE_ROOT '{IMAGE_ROOT}' does not exist. Check your base path.")

print("Image root:", IMAGE_ROOT)
print("Contents:", os.listdir(IMAGE_ROOT)[:10])

# Collect all image paths recursively from IMAGE_ROOT
extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff", "*.webp")
image_paths = []
for ext in extensions:
    image_paths.extend(glob.glob(os.path.join(IMAGE_ROOT, "**", ext), recursive=True))

if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {IMAGE_ROOT} with extensions {extensions}.")

print(f"Found {len(image_paths)} image files.")

# Transform: resize & normalize (NO cropping, just warp to 64x64)
image_size = 64
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),                     # [0,1]
    transforms.Normalize([0.5]*3, [0.5]*3),    # [-1,1]
])

class ZapposImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        from PIL import Image
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        # unconditional -> we don't care about labels, return dummy
        return img, 0

dataset = ZapposImageDataset(image_paths, transform=transform)

batch_size = 64
num_workers = 2

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

print("Dataset and DataLoader ready.")

# -----------------------------
# 2. Time Embedding (Sinusoidal)
# -----------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) int or float timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        exponents = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        t = t.float().unsqueeze(1)  # (B,1)
        angles = t * exponents[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

# -----------------------------
# 3. UNet Building Blocks
# -----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_mlp = nn.Linear(time_dim, out_channels)

        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, t_emb):
        """
        x: (B, C, H, W)
        t_emb: (B, time_dim)
        """
        h = self.conv1(x)
        t_added = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_added
        h = self.act(self.norm1(h))

        h = self.conv2(h)
        h = self.act(self.norm2(h))

        return h + self.res_conv(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)
        self.down = nn.Conv2d(out_channels, out_channels, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_dim):
        """
        in_channels: from below
        out_channels: after upsample
        skip_channels: from skip connection
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
        self.res1 = ResidualBlock(out_channels + skip_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        # fix size mismatch if any (shouldn't happen with 64x64 -> 8x8 pyramid, but safe)
        if x.size(-1) != skip.size(-1):
            diff = skip.size(-1) - x.size(-1)
            x = F.pad(x, (diff // 2, diff - diff // 2, diff // 2, diff - diff // 2))
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        return x

# -----------------------------
# 4. UNet for 64x64 RGB Shoes (Unconditional)
# -----------------------------
class ShoeUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_dim=256):
        super().__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.time_dim = time_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim),
        )

        # Initial conv
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        chs = [
            base_channels,
            base_channels * 2,
            base_channels * 4,
            base_channels * 8,
        ]  # [16,32,64,128]

        self.down1 = DownBlock(chs[0], chs[1], time_dim)  # 16->32
        self.down2 = DownBlock(chs[1], chs[2], time_dim)  # 32->64
        self.down3 = DownBlock(chs[2], chs[3], time_dim)  # 64->128

        # Bottleneck
        self.bottleneck = ResidualBlock(chs[3], chs[3], time_dim)

        # Decoder: 8x8 -> 16x16 -> 32x32 -> 64x64
        self.up3 = UpBlock(chs[3], chs[2], skip_channels=chs[3], time_dim=time_dim)
        self.up2 = UpBlock(chs[2], chs[1], skip_channels=chs[2], time_dim=time_dim)
        self.up1 = UpBlock(chs[1], chs[0], skip_channels=chs[1], time_dim=time_dim)

        # Final conv: predict noise ε
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        """
        x: (B, 3, H, W)
        t: (B,) timesteps
        """
        t_emb = self.time_mlp(t)  # (B, time_dim)

        x = self.init_conv(x)

        x, skip1 = self.down1(x, t_emb)  # 64x64 -> 32x32
        x, skip2 = self.down2(x, t_emb)  # 32x32 -> 16x16
        x, skip3 = self.down3(x, t_emb)  # 16x16 -> 8x8

        x = self.bottleneck(x, t_emb)

        x = self.up3(x, skip3, t_emb)    # 8x8 -> 16x16
        x = self.up2(x, skip2, t_emb)    # 16x16 -> 32x32
        x = self.up1(x, skip1, t_emb)    # 32x32 -> 64x64

        x = self.out_conv(x)             # predict noise
        return x

# -----------------------------
# 5. DDPM Diffusion Utilities
# -----------------------------
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.num_steps = num_steps

        self.betas = torch.linspace(beta_start, beta_end, num_steps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alpha_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), self.alpha_cumprod[:-1]], dim=0
        )

        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)

        self.posterior_variance = (
            self.betas
            * (1.0 - self.alpha_cumprod_prev)
            / (1.0 - self.alpha_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_steps, (batch_size,), device=self.device).long()

    def q_sample(self, x0, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0)
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alpha_bar_t = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t):
        """
        Reverse step: p(x_{t-1} | x_t)
        """
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=self.device, dtype=torch.long)

        eps_theta = model(x_t, t_batch)

        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(alpha_t)

        model_mean = sqrt_recip_alpha_t * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * eps_theta
        )

        if t == 0:
            return model_mean
        else:
            posterior_var_t = self.posterior_variance[t]
            noise = torch.randn_like(x_t)
            return model_mean + torch.sqrt(posterior_var_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=8):
        """
        Sample x_0 from the model by starting from pure noise.
        """
        model.eval()
        x = torch.randn(batch_size, 3, image_size, image_size, device=self.device)

        for t in reversed(range(self.num_steps)):
            x = self.p_sample(model, x, t)

        return x

# -----------------------------
# 6. Instantiate Model & Diffusion
# -----------------------------
time_dim = 256
model = ShoeUNet(
    in_channels=3,
    base_channels=16,
    time_dim=time_dim
).to(device)

diffusion = Diffusion(
    num_steps=1000,
    beta_start=1e-4,
    beta_end=0.02,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# -----------------------------
# 7. Training Loop
# -----------------------------
num_epochs = 30  # increase for better visuals

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x, _ in train_loader:   # ignore labels
        x = x.to(device)        # already normalized to [-1,1]

        b = x.size(0)
        t = diffusion.sample_timesteps(b)

        noise = torch.randn_like(x)
        x_t = diffusion.q_sample(x, t, noise=noise)

        pred_noise = model(x_t, t)
        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * b

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.6f}")

print("Training complete.")

# -----------------------------
# 8. Sampling & Visualization
# -----------------------------
@torch.no_grad()
def show_samples(model, diffusion, n=8):
    """
    Generate and show shoe samples from the diffusion model.
    """
    model.eval()
    samples = diffusion.sample(model, image_size=image_size, batch_size=n)
    samples = (samples.clamp(-1, 1) + 1) / 2.0  # back to [0,1]

    grid = utils.make_grid(samples, nrow=min(n, 4))
    plt.figure(figsize=(4 * (n // 4 + 1), 4))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.show()

print("\nBase Channels = 16 Results:")
show_samples(model, diffusion, n=8)


## Base Channels = 32

This cell implements the UNet architecture with **`base_channels = 32`**.

**Architecture Details:**

With `base_channels = 32`, the channel progression through the UNet encoder is:
- Initial layer: $32$ channels
- Encoder layer 1: $32 \rightarrow 64$ channels
- Encoder layer 2: $64 \rightarrow 128$ channels  
- Encoder layer 3: $128 \rightarrow 256$ channels
- Bottleneck: $256$ channels
- Decoder layers: $256 \rightarrow 128 \rightarrow 64 \rightarrow 32$ channels

**Model Characteristics:**
- **Total Parameters:** Medium model size (~3-5 million parameters)
- **Memory Usage:** Moderate memory footprint
- **Training Speed:** Moderate training and inference speed
- **Representation Capacity:** Good - suitable for most image generation tasks

This configuration offers a solid balance between model capacity and computational efficiency.


In [None]:
# ============================================================
# UT Zappos50K Shoe Diffusion Model - base_channels = 32
# ============================================================

import os
import math
import glob
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt

# -----------------------------
# 0. Config & Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# 1. Dataset Loading (EXACT REPLICATION)
# -----------------------------
import kagglehub

# Download latest version
path = kagglehub.dataset_download("aryashah2k/large-shoe-dataset-ut-zappos50k")

BASE_DIR = "/root/.cache/kagglehub/datasets/aryashah2k/large-shoe-dataset-ut-zappos50k/versions/1"
IMAGE_ROOT = os.path.join(BASE_DIR, "ut-zap50k-images-square")  # we'll use the square images

if not os.path.exists(IMAGE_ROOT):
    raise FileNotFoundError(f"IMAGE_ROOT '{IMAGE_ROOT}' does not exist. Check your base path.")

print("Image root:", IMAGE_ROOT)
print("Contents:", os.listdir(IMAGE_ROOT)[:10])

# Collect all image paths recursively from IMAGE_ROOT
extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff", "*.webp")
image_paths = []
for ext in extensions:
    image_paths.extend(glob.glob(os.path.join(IMAGE_ROOT, "**", ext), recursive=True))

if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {IMAGE_ROOT} with extensions {extensions}.")

print(f"Found {len(image_paths)} image files.")

# Transform: resize & normalize (NO cropping, just warp to 64x64)
image_size = 64
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),                     # [0,1]
    transforms.Normalize([0.5]*3, [0.5]*3),    # [-1,1]
])

class ZapposImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        from PIL import Image
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        # unconditional -> we don't care about labels, return dummy
        return img, 0

dataset = ZapposImageDataset(image_paths, transform=transform)

batch_size = 64
num_workers = 2

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

print("Dataset and DataLoader ready.")

# -----------------------------
# 2. Time Embedding (Sinusoidal)
# -----------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) int or float timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        exponents = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        t = t.float().unsqueeze(1)  # (B,1)
        angles = t * exponents[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

# -----------------------------
# 3. UNet Building Blocks
# -----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_mlp = nn.Linear(time_dim, out_channels)

        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, t_emb):
        """
        x: (B, C, H, W)
        t_emb: (B, time_dim)
        """
        h = self.conv1(x)
        t_added = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_added
        h = self.act(self.norm1(h))

        h = self.conv2(h)
        h = self.act(self.norm2(h))

        return h + self.res_conv(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)
        self.down = nn.Conv2d(out_channels, out_channels, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_dim):
        """
        in_channels: from below
        out_channels: after upsample
        skip_channels: from skip connection
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
        self.res1 = ResidualBlock(out_channels + skip_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        # fix size mismatch if any (shouldn't happen with 64x64 -> 8x8 pyramid, but safe)
        if x.size(-1) != skip.size(-1):
            diff = skip.size(-1) - x.size(-1)
            x = F.pad(x, (diff // 2, diff - diff // 2, diff // 2, diff - diff // 2))
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        return x

# -----------------------------
# 4. UNet for 64x64 RGB Shoes (Unconditional)
# -----------------------------
class ShoeUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_dim=256):
        super().__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.time_dim = time_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim),
        )

        # Initial conv
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        chs = [
            base_channels,
            base_channels * 2,
            base_channels * 4,
            base_channels * 8,
        ]  # [32,64,128,256]

        self.down1 = DownBlock(chs[0], chs[1], time_dim)  # 32->64
        self.down2 = DownBlock(chs[1], chs[2], time_dim)  # 64->128
        self.down3 = DownBlock(chs[2], chs[3], time_dim)  # 128->256

        # Bottleneck
        self.bottleneck = ResidualBlock(chs[3], chs[3], time_dim)

        # Decoder: 8x8 -> 16x16 -> 32x32 -> 64x64
        self.up3 = UpBlock(chs[3], chs[2], skip_channels=chs[3], time_dim=time_dim)
        self.up2 = UpBlock(chs[2], chs[1], skip_channels=chs[2], time_dim=time_dim)
        self.up1 = UpBlock(chs[1], chs[0], skip_channels=chs[1], time_dim=time_dim)

        # Final conv: predict noise ε
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        """
        x: (B, 3, H, W)
        t: (B,) timesteps
        """
        t_emb = self.time_mlp(t)  # (B, time_dim)

        x = self.init_conv(x)

        x, skip1 = self.down1(x, t_emb)  # 64x64 -> 32x32
        x, skip2 = self.down2(x, t_emb)  # 32x32 -> 16x16
        x, skip3 = self.down3(x, t_emb)  # 16x16 -> 8x8

        x = self.bottleneck(x, t_emb)

        x = self.up3(x, skip3, t_emb)    # 8x8 -> 16x16
        x = self.up2(x, skip2, t_emb)    # 16x16 -> 32x32
        x = self.up1(x, skip1, t_emb)    # 32x32 -> 64x64

        x = self.out_conv(x)             # predict noise
        return x

# -----------------------------
# 5. DDPM Diffusion Utilities
# -----------------------------
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.num_steps = num_steps

        self.betas = torch.linspace(beta_start, beta_end, num_steps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alpha_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), self.alpha_cumprod[:-1]], dim=0
        )

        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)

        self.posterior_variance = (
            self.betas
            * (1.0 - self.alpha_cumprod_prev)
            / (1.0 - self.alpha_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_steps, (batch_size,), device=self.device).long()

    def q_sample(self, x0, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0)
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alpha_bar_t = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t):
        """
        Reverse step: p(x_{t-1} | x_t)
        """
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=self.device, dtype=torch.long)

        eps_theta = model(x_t, t_batch)

        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(alpha_t)

        model_mean = sqrt_recip_alpha_t * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * eps_theta
        )

        if t == 0:
            return model_mean
        else:
            posterior_var_t = self.posterior_variance[t]
            noise = torch.randn_like(x_t)
            return model_mean + torch.sqrt(posterior_var_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=8):
        """
        Sample x_0 from the model by starting from pure noise.
        """
        model.eval()
        x = torch.randn(batch_size, 3, image_size, image_size, device=self.device)

        for t in reversed(range(self.num_steps)):
            x = self.p_sample(model, x, t)

        return x

# -----------------------------
# 6. Instantiate Model & Diffusion
# -----------------------------
time_dim = 256
model = ShoeUNet(
    in_channels=3,
    base_channels=32,
    time_dim=time_dim
).to(device)

diffusion = Diffusion(
    num_steps=1000,
    beta_start=1e-4,
    beta_end=0.02,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# -----------------------------
# 7. Training Loop
# -----------------------------
num_epochs = 30  # increase for better visuals

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x, _ in train_loader:   # ignore labels
        x = x.to(device)        # already normalized to [-1,1]

        b = x.size(0)
        t = diffusion.sample_timesteps(b)

        noise = torch.randn_like(x)
        x_t = diffusion.q_sample(x, t, noise=noise)

        pred_noise = model(x_t, t)
        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * b

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.6f}")

print("Training complete.")

# -----------------------------
# 8. Sampling & Visualization
# -----------------------------
@torch.no_grad()
def show_samples(model, diffusion, n=8):
    """
    Generate and show shoe samples from the diffusion model.
    """
    model.eval()
    samples = diffusion.sample(model, image_size=image_size, batch_size=n)
    samples = (samples.clamp(-1, 1) + 1) / 2.0  # back to [0,1]

    grid = utils.make_grid(samples, nrow=min(n, 4))
    plt.figure(figsize=(4 * (n // 4 + 1), 4))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.show()

print("\nBase Channels = 32 Results:")
show_samples(model, diffusion, n=8)


## Base Channels = 64

This cell implements the UNet architecture with **`base_channels = 64`** (the baseline configuration from the original model).

**Architecture Details:**

With `base_channels = 64`, the channel progression through the UNet encoder is:
- Initial layer: $64$ channels
- Encoder layer 1: $64 \rightarrow 128$ channels
- Encoder layer 2: $128 \rightarrow 256$ channels  
- Encoder layer 3: $256 \rightarrow 512$ channels
- Bottleneck: $512$ channels
- Decoder layers: $512 \rightarrow 256 \rightarrow 128 \rightarrow 64$ channels

**Model Characteristics:**
- **Total Parameters:** Large model size (~10-15 million parameters)
- **Memory Usage:** High memory footprint
- **Training Speed:** Slower training and inference
- **Representation Capacity:** Excellent - strong capacity for complex patterns

This is the baseline configuration used in the original model, providing a good balance between quality and computational cost for most applications.


In [None]:
# ============================================================
# UT Zappos50K Shoe Diffusion Model - base_channels = 64
# ============================================================

import os
import math
import glob
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt

# -----------------------------
# 0. Config & Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# 1. Dataset Loading (EXACT REPLICATION)
# -----------------------------
import kagglehub

# Download latest version
path = kagglehub.dataset_download("aryashah2k/large-shoe-dataset-ut-zappos50k")

BASE_DIR = "/root/.cache/kagglehub/datasets/aryashah2k/large-shoe-dataset-ut-zappos50k/versions/1"
IMAGE_ROOT = os.path.join(BASE_DIR, "ut-zap50k-images-square")  # we'll use the square images

if not os.path.exists(IMAGE_ROOT):
    raise FileNotFoundError(f"IMAGE_ROOT '{IMAGE_ROOT}' does not exist. Check your base path.")

print("Image root:", IMAGE_ROOT)
print("Contents:", os.listdir(IMAGE_ROOT)[:10])

# Collect all image paths recursively from IMAGE_ROOT
extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff", "*.webp")
image_paths = []
for ext in extensions:
    image_paths.extend(glob.glob(os.path.join(IMAGE_ROOT, "**", ext), recursive=True))

if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {IMAGE_ROOT} with extensions {extensions}.")

print(f"Found {len(image_paths)} image files.")

# Transform: resize & normalize (NO cropping, just warp to 64x64)
image_size = 64
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),                     # [0,1]
    transforms.Normalize([0.5]*3, [0.5]*3),    # [-1,1]
])

class ZapposImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        from PIL import Image
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        # unconditional -> we don't care about labels, return dummy
        return img, 0

dataset = ZapposImageDataset(image_paths, transform=transform)

batch_size = 64
num_workers = 2

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

print("Dataset and DataLoader ready.")

# -----------------------------
# 2. Time Embedding (Sinusoidal)
# -----------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) int or float timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        exponents = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        t = t.float().unsqueeze(1)  # (B,1)
        angles = t * exponents[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

# -----------------------------
# 3. UNet Building Blocks
# -----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_mlp = nn.Linear(time_dim, out_channels)

        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, t_emb):
        """
        x: (B, C, H, W)
        t_emb: (B, time_dim)
        """
        h = self.conv1(x)
        t_added = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_added
        h = self.act(self.norm1(h))

        h = self.conv2(h)
        h = self.act(self.norm2(h))

        return h + self.res_conv(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)
        self.down = nn.Conv2d(out_channels, out_channels, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_dim):
        """
        in_channels: from below
        out_channels: after upsample
        skip_channels: from skip connection
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
        self.res1 = ResidualBlock(out_channels + skip_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        # fix size mismatch if any (shouldn't happen with 64x64 -> 8x8 pyramid, but safe)
        if x.size(-1) != skip.size(-1):
            diff = skip.size(-1) - x.size(-1)
            x = F.pad(x, (diff // 2, diff - diff // 2, diff // 2, diff - diff // 2))
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        return x

# -----------------------------
# 4. UNet for 64x64 RGB Shoes (Unconditional)
# -----------------------------
class ShoeUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_dim=256):
        super().__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.time_dim = time_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim),
        )

        # Initial conv
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        chs = [
            base_channels,
            base_channels * 2,
            base_channels * 4,
            base_channels * 8,
        ]  # [64,128,256,512]

        self.down1 = DownBlock(chs[0], chs[1], time_dim)  # 64->128
        self.down2 = DownBlock(chs[1], chs[2], time_dim)  # 128->256
        self.down3 = DownBlock(chs[2], chs[3], time_dim)  # 256->512

        # Bottleneck
        self.bottleneck = ResidualBlock(chs[3], chs[3], time_dim)

        # Decoder: 8x8 -> 16x16 -> 32x32 -> 64x64
        self.up3 = UpBlock(chs[3], chs[2], skip_channels=chs[3], time_dim=time_dim)
        self.up2 = UpBlock(chs[2], chs[1], skip_channels=chs[2], time_dim=time_dim)
        self.up1 = UpBlock(chs[1], chs[0], skip_channels=chs[1], time_dim=time_dim)

        # Final conv: predict noise ε
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        """
        x: (B, 3, H, W)
        t: (B,) timesteps
        """
        t_emb = self.time_mlp(t)  # (B, time_dim)

        x = self.init_conv(x)

        x, skip1 = self.down1(x, t_emb)  # 64x64 -> 32x32
        x, skip2 = self.down2(x, t_emb)  # 32x32 -> 16x16
        x, skip3 = self.down3(x, t_emb)  # 16x16 -> 8x8

        x = self.bottleneck(x, t_emb)

        x = self.up3(x, skip3, t_emb)    # 8x8 -> 16x16
        x = self.up2(x, skip2, t_emb)    # 16x16 -> 32x32
        x = self.up1(x, skip1, t_emb)    # 32x32 -> 64x64

        x = self.out_conv(x)             # predict noise
        return x

# -----------------------------
# 5. DDPM Diffusion Utilities
# -----------------------------
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.num_steps = num_steps

        self.betas = torch.linspace(beta_start, beta_end, num_steps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alpha_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), self.alpha_cumprod[:-1]], dim=0
        )

        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)

        self.posterior_variance = (
            self.betas
            * (1.0 - self.alpha_cumprod_prev)
            / (1.0 - self.alpha_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_steps, (batch_size,), device=self.device).long()

    def q_sample(self, x0, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0)
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alpha_bar_t = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t):
        """
        Reverse step: p(x_{t-1} | x_t)
        """
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=self.device, dtype=torch.long)

        eps_theta = model(x_t, t_batch)

        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(alpha_t)

        model_mean = sqrt_recip_alpha_t * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * eps_theta
        )

        if t == 0:
            return model_mean
        else:
            posterior_var_t = self.posterior_variance[t]
            noise = torch.randn_like(x_t)
            return model_mean + torch.sqrt(posterior_var_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=8):
        """
        Sample x_0 from the model by starting from pure noise.
        """
        model.eval()
        x = torch.randn(batch_size, 3, image_size, image_size, device=self.device)

        for t in reversed(range(self.num_steps)):
            x = self.p_sample(model, x, t)

        return x

# -----------------------------
# 6. Instantiate Model & Diffusion
# -----------------------------
time_dim = 256
model = ShoeUNet(
    in_channels=3,
    base_channels=64,
    time_dim=time_dim
).to(device)

diffusion = Diffusion(
    num_steps=1000,
    beta_start=1e-4,
    beta_end=0.02,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# -----------------------------
# 7. Training Loop
# -----------------------------
num_epochs = 30  # increase for better visuals

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x, _ in train_loader:   # ignore labels
        x = x.to(device)        # already normalized to [-1,1]

        b = x.size(0)
        t = diffusion.sample_timesteps(b)

        noise = torch.randn_like(x)
        x_t = diffusion.q_sample(x, t, noise=noise)

        pred_noise = model(x_t, t)
        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * b

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.6f}")

print("Training complete.")

# -----------------------------
# 8. Sampling & Visualization
# -----------------------------
@torch.no_grad()
def show_samples(model, diffusion, n=8):
    """
    Generate and show shoe samples from the diffusion model.
    """
    model.eval()
    samples = diffusion.sample(model, image_size=image_size, batch_size=n)
    samples = (samples.clamp(-1, 1) + 1) / 2.0  # back to [0,1]

    grid = utils.make_grid(samples, nrow=min(n, 4))
    plt.figure(figsize=(4 * (n // 4 + 1), 4))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.show()

print("\nBase Channels = 64 Results:")
show_samples(model, diffusion, n=8)


## Base Channels = 96

This cell implements the UNet architecture with **`base_channels = 96`**.

**Architecture Details:**

With `base_channels = 96`, the channel progression through the UNet encoder is:
- Initial layer: $96$ channels
- Encoder layer 1: $96 \rightarrow 192$ channels
- Encoder layer 2: $192 \rightarrow 384$ channels  
- Encoder layer 3: $384 \rightarrow 768$ channels
- Bottleneck: $768$ channels
- Decoder layers: $768 \rightarrow 384 \rightarrow 192 \rightarrow 96$ channels

**Model Characteristics:**
- **Total Parameters:** Very large model size (~20-30 million parameters)
- **Memory Usage:** Very high memory footprint
- **Training Speed:** Slow training and inference
- **Representation Capacity:** Excellent - high capacity for capturing fine details

This configuration provides increased model capacity beyond the baseline, potentially improving generation quality at the cost of computational resources.


In [None]:
# ============================================================
# UT Zappos50K Shoe Diffusion Model - base_channels = 96
# ============================================================

import os
import math
import glob
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt

# -----------------------------
# 0. Config & Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# 1. Dataset Loading (EXACT REPLICATION)
# -----------------------------
import kagglehub

# Download latest version
path = kagglehub.dataset_download("aryashah2k/large-shoe-dataset-ut-zappos50k")

BASE_DIR = "/root/.cache/kagglehub/datasets/aryashah2k/large-shoe-dataset-ut-zappos50k/versions/1"
IMAGE_ROOT = os.path.join(BASE_DIR, "ut-zap50k-images-square")  # we'll use the square images

if not os.path.exists(IMAGE_ROOT):
    raise FileNotFoundError(f"IMAGE_ROOT '{IMAGE_ROOT}' does not exist. Check your base path.")

print("Image root:", IMAGE_ROOT)
print("Contents:", os.listdir(IMAGE_ROOT)[:10])

# Collect all image paths recursively from IMAGE_ROOT
extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff", "*.webp")
image_paths = []
for ext in extensions:
    image_paths.extend(glob.glob(os.path.join(IMAGE_ROOT, "**", ext), recursive=True))

if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {IMAGE_ROOT} with extensions {extensions}.")

print(f"Found {len(image_paths)} image files.")

# Transform: resize & normalize (NO cropping, just warp to 64x64)
image_size = 64
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),                     # [0,1]
    transforms.Normalize([0.5]*3, [0.5]*3),    # [-1,1]
])

class ZapposImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        from PIL import Image
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        # unconditional -> we don't care about labels, return dummy
        return img, 0

dataset = ZapposImageDataset(image_paths, transform=transform)

batch_size = 64
num_workers = 2

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

print("Dataset and DataLoader ready.")

# -----------------------------
# 2. Time Embedding (Sinusoidal)
# -----------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) int or float timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        exponents = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        t = t.float().unsqueeze(1)  # (B,1)
        angles = t * exponents[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

# -----------------------------
# 3. UNet Building Blocks
# -----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_mlp = nn.Linear(time_dim, out_channels)

        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, t_emb):
        """
        x: (B, C, H, W)
        t_emb: (B, time_dim)
        """
        h = self.conv1(x)
        t_added = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_added
        h = self.act(self.norm1(h))

        h = self.conv2(h)
        h = self.act(self.norm2(h))

        return h + self.res_conv(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)
        self.down = nn.Conv2d(out_channels, out_channels, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_dim):
        """
        in_channels: from below
        out_channels: after upsample
        skip_channels: from skip connection
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
        self.res1 = ResidualBlock(out_channels + skip_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        # fix size mismatch if any (shouldn't happen with 64x64 -> 8x8 pyramid, but safe)
        if x.size(-1) != skip.size(-1):
            diff = skip.size(-1) - x.size(-1)
            x = F.pad(x, (diff // 2, diff - diff // 2, diff // 2, diff - diff // 2))
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        return x

# -----------------------------
# 4. UNet for 64x64 RGB Shoes (Unconditional)
# -----------------------------
class ShoeUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_dim=256):
        super().__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.time_dim = time_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim),
        )

        # Initial conv
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        chs = [
            base_channels,
            base_channels * 2,
            base_channels * 4,
            base_channels * 8,
        ]  # [96,192,384,768]

        self.down1 = DownBlock(chs[0], chs[1], time_dim)  # 96->192
        self.down2 = DownBlock(chs[1], chs[2], time_dim)  # 192->384
        self.down3 = DownBlock(chs[2], chs[3], time_dim)  # 384->768

        # Bottleneck
        self.bottleneck = ResidualBlock(chs[3], chs[3], time_dim)

        # Decoder: 8x8 -> 16x16 -> 32x32 -> 64x64
        self.up3 = UpBlock(chs[3], chs[2], skip_channels=chs[3], time_dim=time_dim)
        self.up2 = UpBlock(chs[2], chs[1], skip_channels=chs[2], time_dim=time_dim)
        self.up1 = UpBlock(chs[1], chs[0], skip_channels=chs[1], time_dim=time_dim)

        # Final conv: predict noise ε
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        """
        x: (B, 3, H, W)
        t: (B,) timesteps
        """
        t_emb = self.time_mlp(t)  # (B, time_dim)

        x = self.init_conv(x)

        x, skip1 = self.down1(x, t_emb)  # 64x64 -> 32x32
        x, skip2 = self.down2(x, t_emb)  # 32x32 -> 16x16
        x, skip3 = self.down3(x, t_emb)  # 16x16 -> 8x8

        x = self.bottleneck(x, t_emb)

        x = self.up3(x, skip3, t_emb)    # 8x8 -> 16x16
        x = self.up2(x, skip2, t_emb)    # 16x16 -> 32x32
        x = self.up1(x, skip1, t_emb)    # 32x32 -> 64x64

        x = self.out_conv(x)             # predict noise
        return x

# -----------------------------
# 5. DDPM Diffusion Utilities
# -----------------------------
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.num_steps = num_steps

        self.betas = torch.linspace(beta_start, beta_end, num_steps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alpha_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), self.alpha_cumprod[:-1]], dim=0
        )

        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)

        self.posterior_variance = (
            self.betas
            * (1.0 - self.alpha_cumprod_prev)
            / (1.0 - self.alpha_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_steps, (batch_size,), device=self.device).long()

    def q_sample(self, x0, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0)
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alpha_bar_t = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t):
        """
        Reverse step: p(x_{t-1} | x_t)
        """
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=self.device, dtype=torch.long)

        eps_theta = model(x_t, t_batch)

        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(alpha_t)

        model_mean = sqrt_recip_alpha_t * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * eps_theta
        )

        if t == 0:
            return model_mean
        else:
            posterior_var_t = self.posterior_variance[t]
            noise = torch.randn_like(x_t)
            return model_mean + torch.sqrt(posterior_var_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=8):
        """
        Sample x_0 from the model by starting from pure noise.
        """
        model.eval()
        x = torch.randn(batch_size, 3, image_size, image_size, device=self.device)

        for t in reversed(range(self.num_steps)):
            x = self.p_sample(model, x, t)

        return x

# -----------------------------
# 6. Instantiate Model & Diffusion
# -----------------------------
time_dim = 256
model = ShoeUNet(
    in_channels=3,
    base_channels=96,
    time_dim=time_dim
).to(device)

diffusion = Diffusion(
    num_steps=1000,
    beta_start=1e-4,
    beta_end=0.02,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# -----------------------------
# 7. Training Loop
# -----------------------------
num_epochs = 30  # increase for better visuals

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x, _ in train_loader:   # ignore labels
        x = x.to(device)        # already normalized to [-1,1]

        b = x.size(0)
        t = diffusion.sample_timesteps(b)

        noise = torch.randn_like(x)
        x_t = diffusion.q_sample(x, t, noise=noise)

        pred_noise = model(x_t, t)
        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * b

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.6f}")

print("Training complete.")

# -----------------------------
# 8. Sampling & Visualization
# -----------------------------
@torch.no_grad()
def show_samples(model, diffusion, n=8):
    """
    Generate and show shoe samples from the diffusion model.
    """
    model.eval()
    samples = diffusion.sample(model, image_size=image_size, batch_size=n)
    samples = (samples.clamp(-1, 1) + 1) / 2.0  # back to [0,1]

    grid = utils.make_grid(samples, nrow=min(n, 4))
    plt.figure(figsize=(4 * (n // 4 + 1), 4))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.show()

print("\nBase Channels = 96 Results:")
show_samples(model, diffusion, n=8)


## Base Channels = 128

This cell implements the UNet architecture with **`base_channels = 128`**.

**Architecture Details:**

With `base_channels = 128`, the channel progression through the UNet encoder is:
- Initial layer: $128$ channels
- Encoder layer 1: $128 \rightarrow 256$ channels
- Encoder layer 2: $256 \rightarrow 512$ channels  
- Encoder layer 3: $512 \rightarrow 1024$ channels
- Bottleneck: $1024$ channels
- Decoder layers: $1024 \rightarrow 512 \rightarrow 256 \rightarrow 128$ channels

**Model Characteristics:**
- **Total Parameters:** Largest model size (~40-60 million parameters)
- **Memory Usage:** Very high memory footprint (may require GPU with significant VRAM)
- **Training Speed:** Slowest training and inference
- **Representation Capacity:** Maximum - highest capacity for capturing complex patterns and fine details

This is the largest configuration tested, providing maximum model capacity at the cost of significant computational resources. The bottleneck layer reaches 1024 channels, which is a common maximum in many high-capacity vision models.


In [None]:
# ============================================================
# UT Zappos50K Shoe Diffusion Model - base_channels = 128
# ============================================================

import os
import math
import glob
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt

# -----------------------------
# 0. Config & Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# 1. Dataset Loading (EXACT REPLICATION)
# -----------------------------
import kagglehub

# Download latest version
path = kagglehub.dataset_download("aryashah2k/large-shoe-dataset-ut-zappos50k")

BASE_DIR = "/root/.cache/kagglehub/datasets/aryashah2k/large-shoe-dataset-ut-zappos50k/versions/1"
IMAGE_ROOT = os.path.join(BASE_DIR, "ut-zap50k-images-square")  # we'll use the square images

if not os.path.exists(IMAGE_ROOT):
    raise FileNotFoundError(f"IMAGE_ROOT '{IMAGE_ROOT}' does not exist. Check your base path.")

print("Image root:", IMAGE_ROOT)
print("Contents:", os.listdir(IMAGE_ROOT)[:10])

# Collect all image paths recursively from IMAGE_ROOT
extensions = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tif", "*.tiff", "*.webp")
image_paths = []
for ext in extensions:
    image_paths.extend(glob.glob(os.path.join(IMAGE_ROOT, "**", ext), recursive=True))

if len(image_paths) == 0:
    raise RuntimeError(f"No image files found under {IMAGE_ROOT} with extensions {extensions}.")

print(f"Found {len(image_paths)} image files.")

# Transform: resize & normalize (NO cropping, just warp to 64x64)
image_size = 64
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),                     # [0,1]
    transforms.Normalize([0.5]*3, [0.5]*3),    # [-1,1]
])

class ZapposImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        from PIL import Image
        img = Image.open(path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        # unconditional -> we don't care about labels, return dummy
        return img, 0

dataset = ZapposImageDataset(image_paths, transform=transform)

batch_size = 64
num_workers = 2

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

print("Dataset and DataLoader ready.")

# -----------------------------
# 2. Time Embedding (Sinusoidal)
# -----------------------------
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        """
        t: (B,) int or float timesteps
        returns: (B, dim)
        """
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        exponents = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        t = t.float().unsqueeze(1)  # (B,1)
        angles = t * exponents[None, :]  # (B, half_dim)
        emb = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

# -----------------------------
# 3. UNet Building Blocks
# -----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        self.time_mlp = nn.Linear(time_dim, out_channels)

        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.res_conv = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.res_conv = nn.Identity()

    def forward(self, x, t_emb):
        """
        x: (B, C, H, W)
        t_emb: (B, time_dim)
        """
        h = self.conv1(x)
        t_added = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_added
        h = self.act(self.norm1(h))

        h = self.conv2(h)
        h = self.act(self.norm2(h))

        return h + self.res_conv(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.res1 = ResidualBlock(in_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)
        self.down = nn.Conv2d(out_channels, out_channels, 4, stride=2, padding=1)

    def forward(self, x, t_emb):
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, time_dim):
        """
        in_channels: from below
        out_channels: after upsample
        skip_channels: from skip connection
        """
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1)
        self.res1 = ResidualBlock(out_channels + skip_channels, out_channels, time_dim)
        self.res2 = ResidualBlock(out_channels, out_channels, time_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)
        # fix size mismatch if any (shouldn't happen with 64x64 -> 8x8 pyramid, but safe)
        if x.size(-1) != skip.size(-1):
            diff = skip.size(-1) - x.size(-1)
            x = F.pad(x, (diff // 2, diff - diff // 2, diff // 2, diff - diff // 2))
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        x = self.res2(x, t_emb)
        return x

# -----------------------------
# 4. UNet for 64x64 RGB Shoes (Unconditional)
# -----------------------------
class ShoeUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_dim=256):
        super().__init__()
        self.in_channels = in_channels
        self.base_channels = base_channels
        self.time_dim = time_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim),
        )

        # Initial conv
        self.init_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        chs = [
            base_channels,
            base_channels * 2,
            base_channels * 4,
            base_channels * 8,
        ]  # [128,256,512,1024]

        self.down1 = DownBlock(chs[0], chs[1], time_dim)  # 128->256
        self.down2 = DownBlock(chs[1], chs[2], time_dim)  # 256->512
        self.down3 = DownBlock(chs[2], chs[3], time_dim)  # 512->1024

        # Bottleneck
        self.bottleneck = ResidualBlock(chs[3], chs[3], time_dim)

        # Decoder: 8x8 -> 16x16 -> 32x32 -> 64x64
        self.up3 = UpBlock(chs[3], chs[2], skip_channels=chs[3], time_dim=time_dim)
        self.up2 = UpBlock(chs[2], chs[1], skip_channels=chs[2], time_dim=time_dim)
        self.up1 = UpBlock(chs[1], chs[0], skip_channels=chs[1], time_dim=time_dim)

        # Final conv: predict noise ε
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, x, t):
        """
        x: (B, 3, H, W)
        t: (B,) timesteps
        """
        t_emb = self.time_mlp(t)  # (B, time_dim)

        x = self.init_conv(x)

        x, skip1 = self.down1(x, t_emb)  # 64x64 -> 32x32
        x, skip2 = self.down2(x, t_emb)  # 32x32 -> 16x16
        x, skip3 = self.down3(x, t_emb)  # 16x16 -> 8x8

        x = self.bottleneck(x, t_emb)

        x = self.up3(x, skip3, t_emb)    # 8x8 -> 16x16
        x = self.up2(x, skip2, t_emb)    # 16x16 -> 32x32
        x = self.up1(x, skip1, t_emb)    # 32x32 -> 64x64

        x = self.out_conv(x)             # predict noise
        return x

# -----------------------------
# 5. DDPM Diffusion Utilities
# -----------------------------
class Diffusion:
    def __init__(self, num_steps=1000, beta_start=1e-4, beta_end=0.02, device="cpu"):
        self.device = device
        self.num_steps = num_steps

        self.betas = torch.linspace(beta_start, beta_end, num_steps, device=device)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alpha_cumprod_prev = torch.cat(
            [torch.tensor([1.0], device=device), self.alpha_cumprod[:-1]], dim=0
        )

        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_alpha_cumprod = torch.sqrt(self.alpha_cumprod)
        self.sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - self.alpha_cumprod)

        self.posterior_variance = (
            self.betas
            * (1.0 - self.alpha_cumprod_prev)
            / (1.0 - self.alpha_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_steps, (batch_size,), device=self.device).long()

    def q_sample(self, x0, t, noise=None):
        """
        Forward diffusion: q(x_t | x_0)
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_alpha_bar_t = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t):
        """
        Reverse step: p(x_{t-1} | x_t)
        """
        B = x_t.size(0)
        t_batch = torch.full((B,), t, device=self.device, dtype=torch.long)

        eps_theta = model(x_t, t_batch)

        beta_t = self.betas[t]
        alpha_t = self.alphas[t]
        alpha_cumprod_t = self.alpha_cumprod[t]
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alpha_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(alpha_t)

        model_mean = sqrt_recip_alpha_t * (
            x_t - (beta_t / sqrt_one_minus_alpha_cumprod_t) * eps_theta
        )

        if t == 0:
            return model_mean
        else:
            posterior_var_t = self.posterior_variance[t]
            noise = torch.randn_like(x_t)
            return model_mean + torch.sqrt(posterior_var_t) * noise

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=8):
        """
        Sample x_0 from the model by starting from pure noise.
        """
        model.eval()
        x = torch.randn(batch_size, 3, image_size, image_size, device=self.device)

        for t in reversed(range(self.num_steps)):
            x = self.p_sample(model, x, t)

        return x

# -----------------------------
# 6. Instantiate Model & Diffusion
# -----------------------------
time_dim = 256
model = ShoeUNet(
    in_channels=3,
    base_channels=128,
    time_dim=time_dim
).to(device)

diffusion = Diffusion(
    num_steps=1000,
    beta_start=1e-4,
    beta_end=0.02,
    device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

# -----------------------------
# 7. Training Loop
# -----------------------------
num_epochs = 30  # increase for better visuals

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x, _ in train_loader:   # ignore labels
        x = x.to(device)        # already normalized to [-1,1]

        b = x.size(0)
        t = diffusion.sample_timesteps(b)

        noise = torch.randn_like(x)
        x_t = diffusion.q_sample(x, t, noise=noise)

        pred_noise = model(x_t, t)
        loss = F.mse_loss(pred_noise, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * b

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.6f}")

print("Training complete.")

# -----------------------------
# 8. Sampling & Visualization
# -----------------------------
@torch.no_grad()
def show_samples(model, diffusion, n=8):
    """
    Generate and show shoe samples from the diffusion model.
    """
    model.eval()
    samples = diffusion.sample(model, image_size=image_size, batch_size=n)
    samples = (samples.clamp(-1, 1) + 1) / 2.0  # back to [0,1]

    grid = utils.make_grid(samples, nrow=min(n, 4))
    plt.figure(figsize=(4 * (n // 4 + 1), 4))
    plt.axis("off")
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.show()

print("\nBase Channels = 128 Results:")
show_samples(model, diffusion, n=8)
