In [9]:
# Install dependencies
!pip install timm diffusers



In [10]:
# IMPORTANT: This notebook is designed for YOUR CUSTOM DiT fork with x2 modifications.
# 
# Option 1: If you've pushed your code to GitHub, replace the URL below:
# !git clone https://github.com/YOUR_USERNAME/YOUR_REPO.git
# %cd YOUR_REPO
#
# Option 2: For now, we clone the base DiT repo and overwrite with your custom files:
!git clone https://github.com/facebookresearch/DiT.git
%cd DiT

# The next cells will overwrite models.py (with x2 modifications) and add new scripts

Cloning into 'DiT'...
remote: Enumerating objects: 102, done.[K
remote: Total 102 (delta 0), reused 0 (delta 0), pack-reused 102 (from 1)[K
Receiving objects: 100% (102/102), 6.37 MiB | 56.19 MiB/s, done.
Resolving deltas: 100% (55/55), done.
/content/DiT/DiT


In [11]:
%%writefile models.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------

import torch
import torch.nn as nn
import numpy as np
import math
import timm
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp, VisionTransformer, Block


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


#################################################################################
#               Embedding Layers for Timesteps and Class Labels                 #
#################################################################################

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels)
        return embeddings


#################################################################################
#                                 Core DiT Model                                #
#################################################################################

class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x


class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class DiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
    ):
        super().__init__()
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.patch_size = patch_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads

        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        self.x2_embedder = PatchEmbed(input_size, patch_size * 2, in_channels, hidden_size, bias=True)
        self.t_embedder = TimestepEmbedder(hidden_size)
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        num_patches = self.x_embedder.num_patches
        # Will use fixed sin-cos embedding:
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])
        # ViT block for processing x2
        # Will be initialized with target dimensions, but may be replaced with pretrained block if dimensions differ
        self.x2_vit_block = Block(
            dim=hidden_size,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=True
        )
        # Projection layers for adapting dimensions if needed (will be created if dimensions don't match)
        self.x2_vit_proj_in = None
        self.x2_vit_proj_out = None
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        # Initialize (and freeze) pos_embed by sin-cos embedding:
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)

        # Initialize label embedding table:
        nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)

        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in DiT blocks:
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)
        
        # Load pre-trained timm ViT weights for x2_vit_block
        self.load_pretrained_vit_weights()

    def load_pretrained_vit_weights(self, vit_model_name='vit_large_patch16_224'):
        """
        Load pre-trained timm ViT weights for x2_vit_block.
        If dimensions don't match, uses projection layers to adapt.
        """
        try:
            # Load pre-trained ViT model
            print(f"[DiT] Loading pre-trained ViT model: {vit_model_name}", flush=True)
            pretrained_vit = timm.create_model(vit_model_name, pretrained=True)
            pretrained_vit.eval()
            
            # Extract the last block
            pretrained_block = pretrained_vit.blocks[-1]
            pretrained_dim = pretrained_vit.embed_dim
            # Get num_heads from the attention module or infer from qkv weight shape
            pretrained_num_heads = None
            # Try multiple ways to get num_heads
            if hasattr(pretrained_block.attn, 'num_heads'):
                pretrained_num_heads = pretrained_block.attn.num_heads
            elif hasattr(pretrained_vit, 'num_heads'):
                pretrained_num_heads = pretrained_vit.num_heads
            else:
                # Infer from qkv weight shape
                # qkv weight shape is (3 * num_heads * head_dim, embed_dim)
                # where embed_dim = num_heads * head_dim
                # So: qkv_out_dim = 3 * embed_dim, and head_dim = embed_dim / num_heads
                # Therefore: qkv_out_dim = 3 * num_heads * (embed_dim / num_heads) = 3 * embed_dim
                # This means we can't directly get num_heads from qkv_out_dim alone
                # But we can calculate: num_heads = embed_dim / head_dim
                # And head_dim = qkv_out_dim / (3 * num_heads) = embed_dim / num_heads
                # So: qkv_out_dim = 3 * embed_dim (always true for standard ViT)
                # We need to infer head_dim. Standard head_dim values: 64 (most common)
                qkv_out_dim = pretrained_block.attn.qkv.weight.shape[0]
                # Calculate head_dim from qkv: head_dim = qkv_out_dim / (3 * num_heads)
                # Since we don't know num_heads, let's use common defaults
                # Common ViT configs: embed_dim 768 -> 12 heads (head_dim=64), 1024 -> 16 heads (head_dim=64)
                if pretrained_dim == 768:
                    pretrained_num_heads = 12
                elif pretrained_dim == 1024:
                    pretrained_num_heads = 16
                elif pretrained_dim == 1280:
                    pretrained_num_heads = 16
                else:
                    # Calculate: assume head_dim = 64 (most common)
                    pretrained_num_heads = pretrained_dim // 64
                    if pretrained_num_heads <= 0 or pretrained_num_heads > 32:
                        pretrained_num_heads = 16  # fallback to common value
                print(f"[DiT] Inferred num_heads={pretrained_num_heads} from embed_dim={pretrained_dim}", flush=True)
            
            print(f"[DiT] Pre-trained ViT block: dim={pretrained_dim}, num_heads={pretrained_num_heads}", flush=True)
            print(f"[DiT] Target x2_vit_block: dim={self.x2_vit_block.norm1.normalized_shape[0]}, num_heads={self.num_heads}", flush=True)
            
            # Check if dimensions match
            if pretrained_dim == self.hidden_size and pretrained_num_heads == self.num_heads:
                # Direct weight loading if dimensions match
                print(f"[DiT] Dimensions match! Loading weights directly...", flush=True)
                self.x2_vit_block.load_state_dict(pretrained_block.state_dict(), strict=True)
                print(f"[DiT] ✓ Successfully loaded pre-trained ViT weights for x2_vit_block", flush=True)
            else:
                # Dimensions don't match - replace block with pretrained dimensions and use projection layers
                print(f"[DiT] Dimensions don't match. Creating block with pretrained dimensions and projection layers...", flush=True)
                
                # Create a new block with pretrained dimensions
                pretrained_mlp_ratio = pretrained_vit.mlp_ratio if hasattr(pretrained_vit, 'mlp_ratio') else 4.0
                self.x2_vit_block = Block(
                    dim=pretrained_dim,
                    num_heads=pretrained_num_heads,
                    mlp_ratio=pretrained_mlp_ratio,
                    qkv_bias=True
                )
                
                # Load pre-trained weights into the new block
                self.x2_vit_block.load_state_dict(pretrained_block.state_dict(), strict=True)
                
                # Create projection layers
                self.x2_vit_proj_in = nn.Linear(self.hidden_size, pretrained_dim)
                self.x2_vit_proj_out = nn.Linear(pretrained_dim, self.hidden_size)
                # Initialize projection layers
                nn.init.xavier_uniform_(self.x2_vit_proj_in.weight)
                nn.init.constant_(self.x2_vit_proj_in.bias, 0)
                nn.init.xavier_uniform_(self.x2_vit_proj_out.weight)
                nn.init.constant_(self.x2_vit_proj_out.bias, 0)
                
                print(f"[DiT] ✓ Successfully loaded pre-trained ViT weights into block with dim={pretrained_dim}", flush=True)
                print(f"[DiT] Using projection layers to adapt dimensions: {self.hidden_size} -> {pretrained_dim} -> {self.hidden_size}", flush=True)
                
        except Exception as e:
            print(f"[DiT] Warning: Could not load pre-trained ViT weights: {e}", flush=True)
            print(f"[DiT] Using randomly initialized weights for x2_vit_block", flush=True)

    def unpatchify(self, x):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0]
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def forward(self, x, t, y):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # Log forward pass execution
        print(f"[DiT Forward] Batch: {x.shape}, Timesteps: [{t.min().item():.0f}-{t.max().item():.0f}], Classes: {y[:min(4,len(y))].tolist()}", flush=True)
        
        # skip = self.x_embedder(x)                                 # preserve pre-block representation
        x2 = self.x2_embedder(x)
        # print(f"[DiT Forward] skip: {skip.shape}", flush=True)
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        t = self.t_embedder(t)                   # (N, D)
        y = self.y_embedder(y, self.training)    # (N, D)
        c = t + y                                # (N, D)
        # Pool x2 from shape (N, 4T, D) to (N, T, D) to match x
        # Need to transpose for avg_pool1d which expects (N, C, L) format
        x2 = x2.transpose(1, 2)  # (N, D, 4T)
        print(f"[DiT Forward] x2 after transpose: {x2.shape}", flush=True)
        # x2 = torch.avg_pool1d(x2, kernel_size=4, stride=4)  # (N, D, T)
        # print(f"[DiT Forward] x2 after AvgPooling: {x2.shape}", flush=True)
        x2 = torch.nn.functional.interpolate(x2, size=256, mode='linear', align_corners=False)
        print(f"[DiT Forward] x2 after interpolate: {x2.shape}", flush=True)
        x2 = x2.transpose(1, 2)  # (N, T, D)
        print(f"[DiT Forward] x2 after transpose back: {x2.shape}", flush=True)
        # Apply ViT block to x2 to ensure same processing as x
        # Use projection layers if dimensions don't match
        if self.x2_vit_proj_in is not None:
            x2 = self.x2_vit_proj_in(x2)  # Project to pre-trained ViT dimension
        x2 = self.x2_vit_block(x2)  # (N, T, D) - processed by pre-trained ViT block
        if self.x2_vit_proj_out is not None:
            x2 = self.x2_vit_proj_out(x2)  # Project back to hidden_size
        print(f"[DiT Forward] x2 after ViT block: {x2.shape}", flush=True)
        print(f"[DiT Forward] x: {x.shape}", flush=True)
        print(f"[DiT Forward] x after addition: {x.shape}", flush=True)
        for block in self.blocks:
            x = block(x, c)
            # if i == 0:
            # x = x + x2
        # for i, block in enumerate(self.blocks):
        #     x = block(x, c)
        #     if i == 0:
        #         x = x + x2
            # (N, T, D)
        # x = x + skip                             # skip connection across all blocks
        x = x + x2
        
        # print(f"[DiT Forward] ✓ Skip org connection applied across {len(self.blocks)} blocks", flush=True)
        
        x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        return x

    def forward_with_cfg(self, x, t, y, cfg_scale):
        """
        Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
        """
        print(f"\n{'='*60}")
        print(f"[CFG] Input: {x.shape}, CFG scale: {cfg_scale}")
        
        # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        
        model_out = self.forward(combined, t, y)
        
        # For exact reproducibility reasons, we apply classifier-free guidance on only
        # three channels by default. The standard approach to cfg applies it to all channels.
        # This can be done by uncommenting the following line and commenting-out the line following that.
        # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        
        print(f"[CFG] ✓ Guidance applied, output: {torch.cat([eps, rest], dim=1).shape}")
        print(f"{'='*60}\n")
        
        return torch.cat([eps, rest], dim=1)


#################################################################################
#                   Sine/Cosine Positional Embedding Functions                  #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


#################################################################################
#                                   DiT Configs                                  #
#################################################################################

def DiT_XL_2(**kwargs):
    print("Creating DiT-XL/2 model...")
    return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)

def DiT_XL_4(**kwargs):
    print("Creating DiT-XL/4 model...")
    return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)

def DiT_XL_8(**kwargs):
    return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)

def DiT_L_2(**kwargs):
    return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)

def DiT_L_4(**kwargs):
    return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)

def DiT_L_8(**kwargs):
    return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)

def DiT_B_2(**kwargs):
    return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)

def DiT_B_4(**kwargs):
    return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)

def DiT_B_8(**kwargs):
    return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)

def DiT_S_2(**kwargs):
    return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)

def DiT_S_4(**kwargs):
    return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)

def DiT_S_8(**kwargs):
    return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)


DiT_models = {
    'DiT-XL/2': DiT_XL_2,  'DiT-XL/4': DiT_XL_4,  'DiT-XL/8': DiT_XL_8,
    'DiT-L/2':  DiT_L_2,   'DiT-L/4':  DiT_L_4,   'DiT-L/8':  DiT_L_8,
    'DiT-B/2':  DiT_B_2,   'DiT-B/4':  DiT_B_4,   'DiT-B/8':  DiT_B_8,
    'DiT-S/2':  DiT_S_2,   'DiT-S/4':  DiT_S_4,   'DiT-S/8':  DiT_S_8,
}


Overwriting models.py


In [12]:
%%writefile train_x2_finetune.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
A minimal training script for DiT using PyTorch DDP.
Modified for Fine-tuning ONLY the x2 block on a subset of classes.
"""
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Subset
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
from collections import OrderedDict
from PIL import Image
from copy import deepcopy
from glob import glob
from time import time
import argparse
import logging
import os

from models import DiT_models
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL


#################################################################################
#                             Training Helper Functions                         #
#################################################################################

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def cleanup():
    """
    End DDP training.
    """
    dist.destroy_process_group()


def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    if dist.get_rank() == 0:  # real logger
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())
    return logger


def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


#################################################################################
#                                  Training Loop                                #
#################################################################################

def main(args):
    """
    Trains a new DiT model.
    """
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."

    # Setup DDP:
    dist.init_process_group("nccl")
    assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")

    # Setup an experiment folder:
    if rank == 0:
        os.makedirs(args.results_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
        experiment_index = len(glob(f"{args.results_dir}/*"))
        model_string_name = args.model.replace("/", "-")  # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
        experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}-x2-finetune"  # Create an experiment folder
        checkpoint_dir = f"{experiment_dir}/checkpoints"  # Stores saved model checkpoints
        os.makedirs(checkpoint_dir, exist_ok=True)
        logger = create_logger(experiment_dir)
        logger.info(f"Experiment directory created at {experiment_dir}")
    else:
        logger = create_logger(None)

    # Create model:
    assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
    latent_size = args.image_size // 8
    model = DiT_models[args.model](
        input_size=latent_size,
        num_classes=args.num_classes
    )
    # Note that parameter initialization is done within the DiT constructor
    
    # --- FREEZING LOGIC START ---
    logger.info("Freezing main DiT model...")
    for p in model.parameters():
        p.requires_grad = False
    
    logger.info("Unfreezing x2 components...")
    # Unfreeze x2_embedder
    for p in model.x2_embedder.parameters():
        p.requires_grad = True
    
    # Unfreeze x2_vit_block
    for p in model.x2_vit_block.parameters():
        p.requires_grad = True
        
    # Unfreeze projections if they exist
    if model.x2_vit_proj_in is not None:
        logger.info("Unfreezing x2_vit_proj_in...")
        for p in model.x2_vit_proj_in.parameters():
            p.requires_grad = True
            
    if model.x2_vit_proj_out is not None:
        logger.info("Unfreezing x2_vit_proj_out...")
        for p in model.x2_vit_proj_out.parameters():
            p.requires_grad = True
            
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Trainable Parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})")
    # --- FREEZING LOGIC END ---

    ema = deepcopy(model).to(device)  # Create an EMA of the model for use after training
    requires_grad(ema, False)
    model = DDP(model.to(device), device_ids=[rank])
    diffusion = create_diffusion(timestep_respacing="")  # default: 1000 steps, linear noise schedule
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)

    # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
    # Only pass trainable parameters to the optimizer
    opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=0)

    # Setup data:
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    
    # --- DATASET FILTERING START ---
    full_dataset = ImageFolder(args.data_path, transform=transform)
    
    # Filter dataset for selected classes
    logger.info(f"Filtering dataset for classes: {args.classes}")
    selected_indices = [i for i, label in enumerate(full_dataset.targets) if label in args.classes]
    
    if len(selected_indices) == 0:
        raise ValueError(f"No images found for classes {args.classes}. Check your dataset or class indices.")
        
    dataset = Subset(full_dataset, selected_indices)
    logger.info(f"Filtered Dataset contains {len(dataset):,} images (from {len(full_dataset)} total)")
    # --- DATASET FILTERING END ---

    sampler = DistributedSampler(
        dataset,
        num_replicas=dist.get_world_size(),
        rank=rank,
        shuffle=True,
        seed=args.global_seed
    )
    loader = DataLoader(
        dataset,
        batch_size=int(args.global_batch_size // dist.get_world_size()),
        shuffle=False,
        sampler=sampler,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )

    # Prepare models for training:
    update_ema(ema, model.module, decay=0)  # Ensure EMA is initialized with synced weights
    model.train()  # important! This enables embedding dropout for classifier-free guidance
    ema.eval()  # EMA model should always be in eval mode

    # Variables for monitoring/logging purposes:
    train_steps = 0
    log_steps = 0
    running_loss = 0
    start_time = time()

    logger.info(f"Training for {args.epochs} epochs...")
    for epoch in range(args.epochs):
        sampler.set_epoch(epoch)
        logger.info(f"Beginning epoch {epoch}...")
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            with torch.no_grad():
                # Map input images to latent space + normalize latents:
                x = vae.encode(x).latent_dist.sample().mul_(0.18215)
            t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
            model_kwargs = dict(y=y)
            loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
            loss = loss_dict["loss"].mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            update_ema(ema, model.module)

            # Log loss values:
            running_loss += loss.item()
            log_steps += 1
            train_steps += 1
            if train_steps % args.log_every == 0:
                # Measure training speed:
                torch.cuda.synchronize()
                end_time = time()
                steps_per_sec = log_steps / (end_time - start_time)
                # Reduce loss history over all processes:
                avg_loss = torch.tensor(running_loss / log_steps, device=device)
                dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
                avg_loss = avg_loss.item() / dist.get_world_size()
                logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
                # Reset monitoring variables:
                running_loss = 0
                log_steps = 0
                start_time = time()

            # Save DiT checkpoint:
            if train_steps % args.ckpt_every == 0 and train_steps > 0:
                if rank == 0:
                    # Only save weights that were trainable to save space
                    trainable_keys = [k for k, p in model.module.named_parameters() if p.requires_grad]
                    model_state = {k: v for k, v in model.module.state_dict().items() if k in trainable_keys}
                    
                    checkpoint = {
                        "model": model_state, 
                        # We save full EMA for now to be safe, or we could also just save partial EMA if needed
                        # But EMA usually keeps track of everything. To be safe for inference, let's keep full EMA or just trainable parts.
                        # For simplicity in this specialized script, let's stick to full EMA so inference scripts don't break,
                        # UNLESS the user explicitly wants lightweight.
                        # Given the user asked for LoRA-like "efficient" storage, best to save only what changed.
                        # But standard inference scripts expect full state dict. 
                        # Compromise: Save full EMA (for immediate use) but partial model (for resume/space).
                        "ema": ema.state_dict(),
                        "opt": opt.state_dict(),
                        "args": args,
                        "x2_finetune_only": True
                    }
                    checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint to {checkpoint_path} (Model contains only trainable params)")
                dist.barrier()

    model.eval()  # important! This disables randomized embedding dropout
    # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...

    logger.info("Done!")
    cleanup()


if __name__ == "__main__":
    # Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", type=str, required=True)
    parser.add_argument("--results-dir", type=str, default="results")
    parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--epochs", type=int, default=1400)
    parser.add_argument("--global-batch-size", type=int, default=256)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")  # Choice doesn't affect training
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--log-every", type=int, default=100)
    parser.add_argument("--ckpt-every", type=int, default=50_000)
    
    # New arguments
    parser.add_argument("--classes", type=int, nargs="+", default=list(range(10)), help="List of ImageNet class indices to train on (default: 0-9)")

    args = parser.parse_args()
    main(args)


Writing train_x2_finetune.py


In [13]:
%%writefile test_freezing.py

import torch
import sys
import os

# Add repo root to path
sys.path.append(os.getcwd())

from models import DiT_models

def verify_freezing():
    print("Initializing model...")
    model = DiT_models['DiT-XL/2'](
        input_size=32,
        num_classes=1000
    )
    
    print("Applying freezing logic...")
    # --- FREEZING LOGIC COPIED FROM SCRIPT ---
    for p in model.parameters():
        p.requires_grad = False
    
    # Check x2_embedder
    for p in model.x2_embedder.parameters():
        p.requires_grad = True
    
    # Check x2_vit_block
    for p in model.x2_vit_block.parameters():
        p.requires_grad = True
        
    if model.x2_vit_proj_in is not None:
        for p in model.x2_vit_proj_in.parameters():
            p.requires_grad = True
            
    if model.x2_vit_proj_out is not None:
        for p in model.x2_vit_proj_out.parameters():
            p.requires_grad = True
    # -----------------------------------------

    print("Verifying parameters...")
    
    # 1. Verify backbone is frozen
    frozen_params = [
        model.x_embedder.proj.weight,
        model.t_embedder.mlp[0].weight,
        model.blocks[0].attn.qkv.weight,
        model.final_layer.linear.weight
    ]
    for p in frozen_params:
        assert not p.requires_grad, f"Backbone parameter {p.shape} should be frozen!"
        
    # 2. Verify x2 branch is unfrozen
    unfrozen_params = [
        model.x2_embedder.proj.weight,
        model.x2_vit_block.norm1.weight
    ]
    if model.x2_vit_proj_in is not None:
        unfrozen_params.append(model.x2_vit_proj_in.weight)

    for p in unfrozen_params:
        assert p.requires_grad, f"x2 parameter {p.shape} should be unfrozen!"
        
    print("SUCCESS: Freezing logic verified correctly.")
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Params: {total_params:,}")
    print(f"Trainable Params: {trainable_params:,} ({trainable_params/total_params:.2%})")

if __name__ == "__main__":
    verify_freezing()


Writing test_freezing.py


In [14]:
# Verify that the freezing logic works correctly
!python test_freezing.py

Initializing model...
Creating DiT-XL/2 model...
[DiT] Loading pre-trained ViT model: vit_large_patch16_224
[DiT] Pre-trained ViT block: dim=1024, num_heads=16
[DiT] Target x2_vit_block: dim=1152, num_heads=16
[DiT] Dimensions don't match. Creating block with pretrained dimensions and projection layers...
[DiT] ✓ Successfully loaded pre-trained ViT weights into block with dim=1024
[DiT] Using projection layers to adapt dimensions: 1152 -> 1024 -> 1152
Applying freezing logic...
Verifying parameters...
SUCCESS: Freezing logic verified correctly.
Total Params: 690,162,208
Trainable Params: 15,032,576 (2.18%)


In [15]:
# Download ImageNet Validation Dataset (ILSVRC2012)
# This will download ~6.3GB and organize it into class folders

import os
import tarfile
from pathlib import Path
import xml.etree.ElementTree as ET

# Download validation set
print('Downloading ImageNet validation set (~6.3GB)...')
!wget -nc https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

# Download validation ground truth annotations
print('Downloading validation ground truth...')
!wget -nc https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz

# Extract validation images
print('Extracting validation images...')
os.makedirs('imagenet_val', exist_ok=True)
!tar -xf ILSVRC2012_img_val.tar -C imagenet_val

# Extract devkit to get ground truth
print('Extracting devkit...')
!tar -xzf ILSVRC2012_devkit_t12.tar.gz

# Parse ground truth from devkit
print('Parsing validation ground truth...')
gt_file = 'ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt'
with open(gt_file, 'r') as f:
    # Ground truth file has 1-indexed class IDs, convert to 0-indexed
    labels = [int(line.strip()) - 1 for line in f]

# Organize into class folders
print('Organizing into class folders...')
val_dir = Path('imagenet_val')
organized_dir = Path('imagenet_val_organized')

# Create class directories
for class_id in set(labels):
    (organized_dir / str(class_id)).mkdir(parents=True, exist_ok=True)

# Move images to class folders
val_images = sorted(val_dir.glob('ILSVRC2012_val_*.JPEG'))
for idx, img_path in enumerate(val_images):
    class_id = labels[idx]
    target_path = organized_dir / str(class_id) / img_path.name
    img_path.rename(target_path)

print(f'✓ ImageNet validation set organized into {len(set(labels))} class folders')
print(f'Total images: {len(val_images)}')
print('Dataset ready at: ./imagenet_val_organized')

Downloading ImageNet validation set (~6.3GB)...
--2025-12-07 20:14:12--  https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
Resolving image-net.org (image-net.org)... 171.64.68.16
Connecting to image-net.org (image-net.org)|171.64.68.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6744924160 (6.3G) [application/x-tar]
Saving to: ‘ILSVRC2012_img_val.tar’


2025-12-07 20:18:53 (22.9 MB/s) - ‘ILSVRC2012_img_val.tar’ saved [6744924160/6744924160]

Downloading validation ground truth...
--2025-12-07 20:18:53--  https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
Resolving image-net.org (image-net.org)... 171.64.68.16
Connecting to image-net.org (image-net.org)|171.64.68.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2568145 (2.4M) [application/x-gzip]
Saving to: ‘ILSVRC2012_devkit_t12.tar.gz’


2025-12-07 20:18:54 (18.3 MB/s) - ‘ILSVRC2012_devkit_t12.tar.gz’ saved [2568145/2568145]

Extracting validation i

In [35]:
# Run Fine-Tuning
# Using ImageNet validation set organized above
!torchrun --nnodes=1 --nproc_per_node=1 train_x2_finetune.py \
    --model DiT-XL/2 \
    --data-path ./imagenet_val_organized \
    --classes 972 973 974 975 976 \
    --epochs 2 \
    --global-batch-size 4 \
    --log-every 1   # Changed from 10 to 1

  self.setter(val)
2025-12-07 20:34:12.797882: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765139652.819410    7704 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765139652.825587    7704 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765139652.840981    7704 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765139652.841006    7704 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765139652.841010    7704 computation_placer.cc:177] com

In [40]:
# # Check if checkpoints exist
# !ls  results/*/checkpoints/

# # Check training logs
!cat results/002-DiT-XL-2-x2-finetune/log.txt
!ls results/002-DiT-XL-2-x2-finetune/

checkpoints  log.txt
