# Scalable Diffusion Models with Transformer (DiT)

This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.

[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)

# 1. Setup

We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once.

In [1]:
import os
# DiT imports:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

  from .autonotebook import tqdm as notebook_tqdm


# Download DiT-XL/2 Models

You can choose between a 512x512 model and a 256x256 model. You can swap-out the LDM VAE, too.

In [2]:
image_size = 256 #@param [256, 512]
dtype = torch.float16 #@param [torch.float16, torch.float32, torch.bfloat16]
num_classes = 1000 #@param [100, 1000]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8

# Load model:
model = DiT_XL_2(
    input_size=latent_size,
    num_classes=num_classes,
    enable_flashattn=True,
    dtype=dtype
).to(device).to(dtype)
state_dict = find_model("C:/Users/wg19671/Downloads/DiT/pretrained_models/DiT-XL-2-256x256.pt")
# state_dict = find_model("C:/Users/wg19671/Downloads/DiT/pretrained_models/0001000.pt")

model.load_state_dict(state_dict)
model.eval() # important!
vae = AutoencoderKL.from_pretrained(vae_model).to(device)

AttributeError: 'PatchEmbed' object has no attribute 'embedding_table'

# 2. Sample from Pre-trained DiT Models

You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a).

In [None]:
# Set user inputs:
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 50 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
# class_labels = tuple(map(int, torch.randint(0, 1000, (8,)).tolist()))
print(class_labels)
samples_per_row = 4 #@param {type:"number"}


# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device).to(dtype)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([num_classes] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)


# Sample images:
samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False, 
    model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
samples = vae.decode(samples / 0.18215).sample

# # Save and display images:
# save_image(samples, "sample.png", nrow=int(samples_per_row), 
#            normalize=True, value_range=(-1, 1))
# samples = Image.open("sample.png")
# display(samples)

(207, 360, 387, 974, 88, 979, 417, 279)


100%|██████████| 50/50 [00:02<00:00, 20.88it/s]


In [None]:
import torch
import torch.nn as nn

num_classes = 1000
hidden_size = 1152
dropout_prob = 0.1
patch_size = 2


from timm.models.vision_transformer import PatchEmbed
from models import TimestepEmbedder, LabelEmbedder


old_y_embedder = LabelEmbedder(num_classes, hidden_size, dropout_prob)
x_embedder = PatchEmbed(
    img_size=32, patch_size=patch_size, in_chans=4, embed_dim=hidden_size, bias=True
)
t_embedder = TimestepEmbedder(hidden_size)
y_embedder = PatchEmbed(
    img_size=32, patch_size=patch_size, in_chans=4, embed_dim=hidden_size, bias=True
)

num_patches = x_embedder.num_patches
pos_embed = nn.Parameter(
    torch.zeros(1, num_patches, hidden_size), requires_grad=False
)


old_y = torch.randint(0, 1000, (128,)) # B
x = torch.rand(128,4,32,32) # B C H W (input)
y = torch.rand(128,4,32,32) # B C H W (label)
t = torch.randint(0, 1000, (x.shape[0],)) # B T  (default 1000 timesteps)

x = (
    x_embedder(x) + pos_embed
)  # (N, T, D), where T = H * W / patch_size ** 2
t = t_embedder(t, dtype=x.dtype)  # (N, D)
y = y_embedder(y)  # (N, D)
# c = t + y  # (N, D)
# for block in blocks:
#     x = block(x, c)  # (N, T, D)
# x = final_layer(x, c)  # (N, T, patch_size ** 2 * out_channels)
# x = unpatchify(x)  # (N, out_channels, H, W)
# x

In [None]:
old_y1 = old_y_embedder(old_y, train=True)
old_y1.shape

torch.Size([128, 1152])

In [None]:
print(x.shape)
print(y.shape)
print(t.shape)

torch.Size([128, 256, 1152])
torch.Size([128, 256, 1152])
torch.Size([128, 1152])


In [None]:
c = t + old_y1
c.shape

torch.Size([128, 1152])

In [None]:
x_embedder(x).shape

ValueError: not enough values to unpack (expected 4, got 3)

In [3]:
import torch
import torch.nn.functional as F
from torch import nn


class PatchEmbed3D(nn.Module):
    """Video to Patch Embedding.

    Args:
        patch_size (int): Patch token size. Default: (2,4,4).
        in_chans (int): Number of input video channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(
        self,
        patch_size=(2, 4, 4),
        in_chans=3,
        embed_dim=96,
        norm_layer=None,
        flatten=True,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.flatten = flatten

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv3d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""
        # padding
        _, _, D, H, W = x.size()
        if W % self.patch_size[2] != 0:
            x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
        if H % self.patch_size[1] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
        if D % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))

        x = self.proj(x)  # (B C T H W)
        if self.norm is not None:
            D, Wh, Ww = x.size(2), x.size(3), x.size(4)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCTHW -> BNC
        return x
    


In [16]:
#VAE Encode = B C Fs H W
vid_label = torch.rand(1,4,10,32,32)

# B C Fs H W
video_label_embedder = PatchEmbed3D(
    patch_size=(2, 4, 4), in_chans=4, embed_dim=1152
)

# B N D
video_label_embedder(vid_label).shape

torch.Size([1, 320, 1152])

In [46]:
t = torch.rand([84, 384])
y = torch.rand([84, 16, 384])

t = torch.swapaxes(t.repeat(y.shape[1], 1, 1), 0, 1) # (N, T, D)
t.shape

torch.Size([84, 16, 384])