In [28]:
!pip install torch torchvision einops timm tqdm matplotlib



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from typing import List
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [25]:
class SinusoidalEmbeddings(nn.Module):
    """
    Returns a [B, D] vector for each timestep.
    CHANGES:-
      - on-the-fly computation (no precomputed [T, D, 1, 1] table)
      - returns a vector; per-block Linear projects it to channels
    """
    def __init__(self, time_steps: int, embed_dim: int):
        super().__init__()
        self.embed_dim = embed_dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        device = t.device
        half = self.embed_dim // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=device) / max(half - 1, 1))
        arg = t.float().unsqueeze(1) * freqs.unsqueeze(0)          # [B, half]
        emb = torch.cat([torch.sin(arg), torch.cos(arg)], dim=1)   # [B, 2*half]
        if self.embed_dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb

In [26]:
class ResBlock(nn.Module):
    """
    GroupNorm + SiLU, two 3x3 convs, residual skip (1x1 when channels change).
    Time vector is projected (Linear) and added as a channel bias after conv1.
    """
    def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout_prob: float = 0.0):
        super().__init__()
        g1 = 32 if in_channels % 32 == 0 else 1
        g2 = 32 if out_channels % 32 == 0 else 1
        self.norm1 = nn.GroupNorm(g1, in_channels)
        self.act = nn.SiLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.time_proj = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels))  # persistent

        self.norm2 = nn.GroupNorm(g2, out_channels)
        self.dropout = nn.Dropout(dropout_prob)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        self.skip = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor, t_vec: torch.Tensor) -> torch.Tensor:
        h = self.act(self.norm1(x))
        h = self.conv1(h)
        h = h + self.time_proj(t_vec)[..., None, None]    # add time as bias
        h = self.act(self.norm2(h))
        h = self.dropout(h)
        h = self.conv2(h)
        return h + self.skip(x)

In [27]:
class AttentionBlock(nn.Module):
    """Self-attention block for DDPM."""
    def __init__(self, channels: int, num_heads: int = 1):
        super().__init__()
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(32, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, C, H, W]
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h)  # [B, 3*C, H, W]

        # Split channels into 3 parts for q, k, v
        q, k, v = rearrange(qkv, 'B (qkv c) H W -> qkv B c H W', qkv=3, c=C)

        # Reshape to [B, num_heads, C_per_head, H*W] for attention
        q = rearrange(q, 'B (head c_per_head) H W -> B head c_per_head (H W)', head=self.num_heads)
        k = rearrange(k, 'B (head c_per_head) H W -> B head c_per_head (H W)', head=self.num_heads)
        v = rearrange(v, 'B (head c_per_head) H W -> B head c_per_head (H W)', head=self.num_heads)

        # Compute attention scores: [B, num_heads, H*W, H*W]
        attn_scores = torch.einsum('bhid,bhjd->bhij', q, k) * (1.0 / math.sqrt(C // self.num_heads))
        attn = F.softmax(attn_scores, dim=-1)

        # Apply attention to values: [B, num_heads, C_per_head, H*W]
        out = torch.einsum('bhij,bhjd->bhid', attn, v)

        # Reshape back to [B, C, H, W]
        out = rearrange(out, 'B head c_per_head (H W) -> B (head c_per_head) H W', head=self.num_heads, H=H, W=W)

        # Final projection and residual connection
        out = self.proj(out)
        return x + out

class Downsample(nn.Module):
    """Downsampling block."""
    def __init__(self, channels: int):
        super().__init__()
        # Use Conv2d with stride 2 for downsampling
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)

class Upsample(nn.Module):
    """Upsampling block."""
    def __init__(self, channels: int):
        super().__init__()
        # Use ConvTranspose2d with stride 2 for upsampling
        self.conv = nn.ConvTranspose2d(channels, channels, kernel_size=4, stride=2, padding=1)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)

class UNET(nn.Module):
    def __init__(self, base_channels=64, channel_mults=[1, 2, 4, 8], input_channels=1, output_channels=1, time_steps=1000, time_emb_dim=256, dropout_prob=0.1, num_heads=1):
        super().__init__()
        self.time_steps = time_steps
        self.time_emb = SinusoidalEmbeddings(time_steps, time_emb_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim)
        )

        self.initial_conv = nn.Conv2d(input_channels, base_channels, kernel_size=3, padding=1)

        channels = [base_channels * mult for mult in channel_mults]
        # Add base_channels at the beginning for the first input
        all_channels = [base_channels] + channels
        in_out_channels = list(zip(all_channels[:-1], all_channels[1:]))

        # Downsampling path
        self.downs = nn.ModuleList()
        self.downsamples = nn.ModuleList()
        # The paper typically uses two ResBlocks per resolution level, followed by Downsample
        # Attention is placed after the second ResBlock at 16x16 resolution
        for i, (in_ch, out_ch) in enumerate(in_out_channels):
            self.downs.append(ResBlock(in_ch, out_ch, time_emb_dim, dropout_prob))
            self.downs.append(ResBlock(out_ch, out_ch, time_emb_dim, dropout_prob))
            # Add attention block after the second ResBlock at the 16x16 resolution level
            # Assuming the 16x16 resolution corresponds to the second set of blocks (index 1)
            if i == 1: # 32 -> 16 resolution
                self.downs.append(AttentionBlock(out_ch, num_heads))
            else:
                 # Add identity for other levels to maintain list structure
                 self.downs.append(nn.Identity())
            # Add downsample layer except for the last level
            if i < len(in_out_channels) - 1:
                self.downsamples.append(Downsample(out_ch))


        # Bottleneck
        # The bottleneck input should match the output of the last downsampling block
        bottleneck_channels = channels[-1] # This is the output channel of the last downsample
        self.mid_block1 = ResBlock(bottleneck_channels, bottleneck_channels, time_emb_dim, dropout_prob)
        self.mid_attn = AttentionBlock(bottleneck_channels, num_heads)
        self.mid_block2 = ResBlock(bottleneck_channels, bottleneck_channels, time_emb_dim, dropout_prob)

        # Upsampling path
        self.ups = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        # Reverse the channel list for upsampling
        # Need to consider the concatenated skip connection channels
        up_channels_out = list(reversed(all_channels[:-1])) # Output channels for up blocks
        up_channels_in = list(reversed(all_channels)) # Input channels *before* concatenation

        for i in range(len(up_channels_out)):
             # Input channels for up blocks are the current feature map channels + skip connection channels
             # The skip connection comes from the corresponding downsampling level output channel
             # The current feature map channels are the output channels of the previous upsampling/bottleneck block
             # For the first upsample block, the input is from the bottleneck (bottleneck_channels)
             # For subsequent upsample blocks, the input is the output of the previous up block (up_channels_out[i-1])
             current_up_channels = bottleneck_channels if i == 0 else up_channels_out[i-1]
             skip_channels = up_channels_out[i] # Skip connection comes from the same resolution level in down path
             in_ch_up = current_up_channels + skip_channels
             out_ch_up = up_channels_out[i]

             self.ups.append(ResBlock(in_ch_up, out_ch_up, time_emb_dim, dropout_prob))
             self.ups.append(ResBlock(out_ch_up, out_ch_up, time_emb_dim, dropout_prob))
             # Add attention block after the second ResBlock at the 16x16 resolution level
             # Assuming the 16x16 resolution corresponds to the second set of up blocks (index 1)
             if i == len(up_channels_out) - 2: # 4 -> 8 -> 16 resolution (second to last up block)
                  self.ups.append(AttentionBlock(out_ch_up, num_heads))
             else:
                 # Add identity for other levels
                  self.ups.append(nn.Identity())
             # Add upsample layer except for the last level
             if i < len(up_channels_out) - 1:
                 self.upsamples.append(Upsample(out_ch_up))


        self.final_norm = nn.GroupNorm(32, base_channels)
        self.final_act = nn.SiLU()
        self.final_conv = nn.Conv2d(base_channels, output_channels, kernel_size=3, padding=1)


    def forward(self, x, t):
        # x: [B, C, H, W]
        # t: [B]

        t_emb = self.time_emb(t) # [B, time_emb_dim]
        t_emb = self.time_mlp(t_emb) # [B, time_emb_dim]

        h = self.initial_conv(x)
        residuals = [h] # Store initial input for first skip connection

        # Downsampling path
        # Iterate through sets of (ResBlock, ResBlock, Attention, Downsample)
        for i in range(len(self.downsamples)):
            h = self.downs[3*i](h, t_emb) # First ResBlock
            h = self.downs[3*i + 1](h, t_emb) # Second ResBlock
            h = self.downs[3*i + 2](h) # Attention or Identity
            residuals.append(h) # Store output before downsampling for skip connection
            h = self.downsamples[i](h) # Downsample


        # Bottleneck
        h = self.mid_block1(h, t_emb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, t_emb)

        # Upsampling path
        # Need to use residuals from the downsampling path in reverse order
        # The number of upsample blocks is one less than the number of downsample blocks + bottleneck
        # The residuals are from the output of the second ResBlock + Attention before downsampling
        residuals = residuals[::-1] # Reverse residuals

        for i in range(len(self.upsamples)):
            h = self.upsamples[i](h) # Upsample
            # Concatenate with the corresponding residual from the downsampling path
            # The residual index corresponds to the upsampling level
            h = torch.cat([h, residuals[i+1]], dim=1) # Skip connection: current h + residual from down path
            h = self.ups[3*i](h, t_emb) # First ResBlock
            h = self.ups[3*i + 1](h, t_emb) # Second ResBlock
            h = self.ups[3*i + 2](h) # Attention or Identity

        # Final output (after the last upsampling block, which brings it back to original resolution)
        # The last residual connection is from the initial convolution
        # The last up block input should be concatenation of upsampled feature and the initial conv output
        # The last element in reversed residuals is the output of the initial conv.
        h = torch.cat([h, residuals[-1]], dim=1) # Concatenate with the first residual

        h = self.final_norm(h)
        h = self.final_act(h)
        output = self.final_conv(h)

        return output

DDPM uses a U-Net with four resolutions for 32×32 (32→16→8→4 and back), two residual blocks per resolution, and skip connections via concatenation.
Whereas teh given code never downsamples/upsamples and adds skips instead of concatenating.

**Missing self-attention** at 16×16: DDPM paper inserts attention at the 16×16 feature map between conv blocks. Given model has no attention

Time conditioning: Paper feeds Transformer sinusoidal time embedding (via MLP) into each residual block using a learned linear projection. Given ResBlock creates a brand-new Conv2d inside forward() (non-persistent) and injects time differently.

**Normalization/activations**: Paper uses GroupNorm throughout with the PixelCNN++-style U-Net; SiLU/Swish is standard in faithful implementations. Given blocks use ReLU and dynamic group counts that often reduce to group size 1

**SinusoidalEmbeddings**

UPDATED: compute sinusoidal emb on the fly; return [B, D] (vector), not [B, D, 1, 1].

Reason: paper feeds a vector time embedding to every residual block and projects it there.

**ResBlock**

UPDATED: remove the Conv2d created inside forward(); add a persistent nn.Linear(time_dim→out_ch) and add this after conv1 as a bias.
- paper conditions the U-Net on the timestep inside every residual block, which means a learnable projection that turns the time vector into a per-channel bias for that block.

UPDATED: GroupNorm + SiLU activations; proper residual skip with 1×1 when channels change.

**UNET**

UPDATED:  the 32→16→8→4 hierarchy with Downsample/Upsample.

UPDATED: concatenate skip features (not add).

UPDATED: inserted AttentionBlock at 16×16.

UPDATED: added a small MLP over the sinusoidal time embedding and pass the resulting vector to every ResBlock

Implemented the hierarchical structure with explicit Downsample and Upsample layers to achieve the 32->16->8->4 and back resolution changes.

- Introduced AttentionBlocks and placed them at the 16x16 resolution levels in both the downsampling and upsampling paths, as specified in the paper.
Modified the skip connections to use concatenation (torch.cat) instead of addition, combining the feature maps from the downsampling path with the upsampling path.

- Added an MLP (time_mlp) after the SinusoidalEmbeddings to process the time vector before it is fed into the ResBlocks.

- Adjusted the forward pass logic to correctly apply the sequence of ResBlocks, Attention (where applicable), and Downsample/Upsample operations, and to handle the concatenation of skip connections.