# Scalable Diffusion Models with Transformer (DiT)

This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.

[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)

In [1]:
!pip install diffusers timm --upgrade

Defaulting to user installation because normal site-packages is not writeable
Collecting diffusers
  Downloading diffusers-0.30.2-py3-none-any.whl.metadata (18 kB)
Collecting timm
  Downloading timm-1.0.9-py3-none-any.whl.metadata (42 kB)
Collecting importlib-metadata (from diffusers)
  Downloading importlib_metadata-8.4.0-py3-none-any.whl.metadata (4.7 kB)
Collecting huggingface-hub>=0.23.2 (from diffusers)
  Downloading huggingface_hub-0.24.6-py3-none-any.whl.metadata (13 kB)
Collecting regex!=2019.12.17 (from diffusers)
  Downloading regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)
Collecting safetensors>=0.3.1 (from diffusers)
  Downloading safetensors-0.4.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Collecting zipp>=0.5 (from importlib-metadata->diffusers)
  Downloading zipp-3.20.1-py3-none-any.whl.metadata (3.7 kB)
Downloading diffusers-0.30.2-py3-none-any.whl (2.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
import torch
# from torchvision.utils import save_image
from diffusion import create_diffusion
from models import DiT_XL_8
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

In [2]:
image_size = 256 #@param [256, 512]
# Load model:
model = DiT_XL_8(input_size=image_size)
state_dict = torch.load('results/004-DiT-XL-8/checkpoints/latest.pt', map_location='cpu')
model.load_state_dict(state_dict['model'])
model = model.to(device)
model.eval() # important!

DiT(
  (x_embedder): PatchEmbed(
    (proj): Conv2d(5, 1152, kernel_size=(8, 8), stride=(8, 8))
    (norm): Identity()
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=1152, bias=True)
      (1): SiLU()
      (2): Linear(in_features=1152, out_features=1152, bias=True)
    )
  )
  (y_embedder): LabelEmbedder(
    (embedding_table): Embedding(1001, 1152)
  )
  (blocks): ModuleList(
    (0-27): 28 x DiTBlock(
      (norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=False)
      (attn): Attention(
        (qkv): Linear(in_features=1152, out_features=3456, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1152, out_features=1152, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=False)
      (mlp): Mlp(
        (fc1): Linear(in_features=11

In [4]:
# Set user inputs:
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 1000 #@param {type:"slider", min:0, max:1000, step:1}
# cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
cfg_scale = 1
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
samples_per_row = 4 #@param {type:"number"}

# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 5, image_size, image_size, device=device)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# Sample images:
samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False, 
    model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples

# Save and display images:
# save_image(samples, "sample.png", nrow=int(samples_per_row), 
#            normalize=True, value_range=(-1, 1))
# samples = Image.open("sample.png")
# display(samples)

points = samples.view(samples.shape[0], samples.shape[1], -1).add(1).mul(.5) # in [0,1]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [5]:
rgb = points[:, :3].clamp(0,1).cpu()
pos = points[:, 3:].clamp(0,1).cpu()

In [None]:
import matplotlib.pyplot as plt

idx = 4
plt.scatter(pos[idx,0], pos[idx,1], c=rgb[idx].T, s=1)

<matplotlib.collections.PathCollection at 0x7f6820631cf0>

In [None]:
import matplotlib.pyplot as plt

coords = samples[0, -2:].view(2,-1).cpu()

plt.scatter(coords[0], coords[1], )