# Scalable Diffusion Models with Transformer (DiT)

This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.

[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)

# 1. Setup

We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once.

In [1]:
!git clone https://github.com/facebookresearch/DiT.git
import DiT, os
os.chdir('DiT')
os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
!pip install diffusers timm --upgrade
# DiT imports:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

Cloning into 'DiT'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (79/79), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 102 (delta 55), reused 33 (delta 33), pack-reused 23[K
Receiving objects: 100% (102/102), 6.37 MiB | 10.69 MiB/s, done.
Resolving deltas: 100% (55/55), done.
Collecting timm
  Downloading timm-1.0.3-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Downloading timm-1.0.3-py3-none-any.whl (2.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 0.9.16
    Uninstalling timm-0.9.16:
      Successfully uninstalled timm-0.9.16
Successfully installed timm-1.0.3
GPU not found. Using CPU instead.


  return torch._C._cuda_getDeviceCount() > 0


# Download DiT-XL/2 Models

You can choose between a 512x512 model and a 256x256 model. You can swap-out the LDM VAE, too.

In [2]:
image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
latent_size = int(image_size) // 8
# Load model:
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval() # important!
vae = AutoencoderKL.from_pretrained(vae_model).to(device)

Downloading https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt to pretrained_models/DiT-XL-2-256x256.pt


  0%|          | 0/2700611775 [00:00<?, ?it/s]

# 2. Sample from Pre-trained DiT Models

You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a).

In [3]:
# Set user inputs:
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
samples_per_row = 4 #@param {type:"number"}

# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))

# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)

# Sample images:
samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False, 
    model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
samples = vae.decode(samples / 0.18215).sample

# Save and display images:
save_image(samples, "sample.png", nrow=int(samples_per_row), 
           normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)

  0%|          | 0/250 [00:00<?, ?it/s]

PYTORCH

In [2]:
import torch as th
import torch.nn.functional as F

In [11]:
a = th.rand(1,2,3,4)
# print(a)
# b = th.argmax(F.softmax(a,dim=2),axis=2)
# print(b)
print(a)
k = F.softmax(a,dim=2)
# k = th.sum(b,dim=2)
print(k)
k = th.sum(k,dim=2)
print(k)

tensor([[[[0.5582, 0.8384, 0.8436, 0.2354],
          [0.2947, 0.9504, 0.3292, 0.8884],
          [0.1030, 0.4064, 0.5709, 0.3271]],

         [[0.3816, 0.7789, 0.2523, 0.6280],
          [0.1720, 0.6519, 0.8004, 0.7060],
          [0.1095, 0.9498, 0.1172, 0.2554]]]])
tensor([[[[0.4162, 0.3613, 0.4239, 0.2489],
          [0.3198, 0.4041, 0.2534, 0.4783],
          [0.2640, 0.2346, 0.3227, 0.2728]],

         [[0.3887, 0.3261, 0.2775, 0.3610],
          [0.3152, 0.2872, 0.4801, 0.3903],
          [0.2961, 0.3868, 0.2424, 0.2487]]]])
tensor([[[1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000]]])


In [10]:
N = 2
T = 4
D = 3
t1 = th.rand(N,T,D)
t2 = th.randn(N,T,D)
t3 = th.zeros(N,T,D)
t4 = th.ones(N,T,D)
t5 = th.full((N,T,D),6)
t6 = 3*t4
t1,t2,t3,t4,t5,t6

(tensor([[[0.3357, 0.0761, 0.1998],
          [0.6062, 0.0993, 0.9665],
          [0.9869, 0.7946, 0.7680],
          [0.7404, 0.5994, 0.1234]],
 
         [[0.7745, 0.2094, 0.2938],
          [0.5635, 0.9947, 0.3598],
          [0.3647, 0.1045, 0.3568],
          [0.2566, 0.8713, 0.9351]]]),
 tensor([[[-0.1619,  1.4103, -0.6718],
          [-1.4412, -0.5914,  1.9748],
          [-0.0396, -0.1573,  0.8775],
          [ 0.1963,  1.4566,  2.2303]],
 
         [[ 1.8923,  0.1595,  0.4503],
          [-0.0341, -0.5649,  0.4817],
          [ 0.9184, -0.6458,  1.2508],
          [ 0.0972, -0.6601, -1.7134]]]),
 tensor([[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]),
 tensor([[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 

In [24]:
t1,t1[:,-4,:].shape,t1[:,-4,:]

(tensor([[[0.3357, 0.0761, 0.1998],
          [0.6062, 0.0993, 0.9665],
          [0.9869, 0.7946, 0.7680],
          [0.7404, 0.5994, 0.1234]],
 
         [[0.7745, 0.2094, 0.2938],
          [0.5635, 0.9947, 0.3598],
          [0.3647, 0.1045, 0.3568],
          [0.2566, 0.8713, 0.9351]]]),
 torch.Size([2, 3]),
 tensor([[0.3357, 0.0761, 0.1998],
         [0.7745, 0.2094, 0.2938]]))

In [5]:
output_size = 6
input_size = 4
l = nn.Linear(6,4,bias=True)
l.weight

Parameter containing:
tensor([[-0.2622,  0.3692,  0.0302,  0.0108, -0.3587,  0.2882],
        [ 0.1102,  0.2648, -0.1780, -0.0627, -0.3972, -0.3387],
        [-0.3600, -0.2968, -0.0796, -0.0439,  0.1984,  0.2068],
        [ 0.1410, -0.3960,  0.0128,  0.2042, -0.0569,  0.1763]],
       requires_grad=True)

In [2]:
import numpy as np

In [7]:
a = np.array([[1,2,3],[4,5,6],[7,8,9]])
b = np.array([[1,2,3],[4,5,6],[7,8,9]])
c = np.concatenate([a,b],axis = 1)
c.reshape(-1)
d = np.random.rand(3,2,4)
print(d)
d = d.reshape(-1)
d


[[[0.99653848 0.74497918 0.74118636 0.55491099]
  [0.6008518  0.50982715 0.80114502 0.49731264]]

 [[0.1201585  0.5096922  0.25460918 0.82141725]
  [0.09198959 0.21342022 0.55442237 0.70965507]]

 [[0.76007615 0.65889005 0.84394497 0.17147846]
  [0.91701943 0.90906091 0.97572201 0.6849114 ]]]


array([0.99653848, 0.74497918, 0.74118636, 0.55491099, 0.6008518 ,
       0.50982715, 0.80114502, 0.49731264, 0.1201585 , 0.5096922 ,
       0.25460918, 0.82141725, 0.09198959, 0.21342022, 0.55442237,
       0.70965507, 0.76007615, 0.65889005, 0.84394497, 0.17147846,
       0.91701943, 0.90906091, 0.97572201, 0.6849114 ])

In [12]:
import random


data_paths = ["path1", "path2", "path3", "path4", "path5"]
print(data_paths)
random.seed(42)
a = [k for k in data_paths]
print(a)
random.shuffle(a)
print(a)
b= [k for k in data_paths]
print(b)
random.seed(42)
random.shuffle(b)
print(b)




['path1', 'path2', 'path3', 'path4', 'path5']
['path1', 'path2', 'path3', 'path4', 'path5']
['path4', 'path2', 'path3', 'path5', 'path1']
['path1', 'path2', 'path3', 'path4', 'path5']
['path4', 'path2', 'path3', 'path5', 'path1']


In [1]:

import numpy as np
a = np.random.rand(480,512,512)
def pad_array(t):
   if t.shape[0] != 640:
    s = (640 - t.shape[0])//2
    padding = ((s,s), (0, 0), (0, 0))
    t = np.pad(t, padding, mode='constant', constant_values=0)
   return t

b = pad_array(a)
b.shape

(640, 512, 512)