In [1]:
# Imports and device setup
import torch
from dit import DiT_models
from vae import VAE_models
from torchvision.io import read_video, write_video
from utils import load_prompt, load_actions, sigmoid_beta_schedule
from tqdm import tqdm
from einops import rearrange
from torch import autocast
from safetensors.torch import load_model
import os

assert torch.cuda.is_available(), "CUDA is required for this notebook"
device = 'cuda:0'

torch.manual_seed(0)



  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7b1c641fe650>

In [2]:
# User-defined parameters (editable)
oasis_ckpt = '/root/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/oasis500m.safetensors'
vae_ckpt = '/root/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors'
prompt_path = 'sample_data/sample_image_0.png'
actions_path = 'sample_data/sample_actions_0.one_hot_actions.pt'
video_offset = None  # or set an integer
n_prompt_frames = 1  # number of frames to condition on
num_frames = 32  # total frames to generate
output_path = 'video-sparsity-100.mp4'
fps = 20
ddim_steps = 10


In [3]:
# Load DiT model
torch.manual_seed(0)
torch.cuda.manual_seed(0)
model = DiT_models['DiT-S/2'](streaming=True)
model.half()

print(f"Loading Oasis-500M checkpoint from {oasis_ckpt}...")
if oasis_ckpt.endswith('.pt'):
    ckpt = torch.load(oasis_ckpt, weights_only=True)
    model.load_state_dict(ckpt, strict=False)
else:
    load_model(model, oasis_ckpt)
model = model.to(device).eval()


Loading Oasis-500M checkpoint from /root/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/oasis500m.safetensors...


In [4]:
# Load VAE model
vae = VAE_models['vit-l-20-shallow-encoder']()
vae.half()

print(f"Loading ViT-VAE-L/20 checkpoint from {vae_ckpt}...")
if vae_ckpt.endswith('.pt'):
    vae_state = torch.load(vae_ckpt, weights_only=True)
    vae.load_state_dict(vae_state)
else:
    load_model(vae, vae_ckpt)
vae = vae.to(device).eval()


Loading ViT-VAE-L/20 checkpoint from /root/.cache/huggingface/hub/models--Etched--oasis-500m/snapshots/4ca7d2d811f4f0c6fd1d5719bf83f14af3446c0c/vit-l-20.safetensors...


In [5]:
model = torch.compile(model)
vae = torch.compile(vae)

In [6]:
# Prepare sampling and noise schedules
max_noise_level = 1000
noise_range = torch.linspace(-1, max_noise_level - 1, ddim_steps + 1)
noise_abs_max = 20
stabilization_level = 15


In [7]:
n_prompt_frames

1

In [8]:
# Load prompt and actions, move to device
x = load_prompt(prompt_path, video_offset=video_offset, n_prompt_frames=n_prompt_frames)
actions = load_actions(actions_path, action_offset=video_offset)[:, :num_frames]
x = x.to(device).half()
actions = actions.to(device).half()

# VAE encoding
B, _, C, H, W = x.shape[0], x.shape[1], x.shape[2], x.shape[3], x.shape[4]
scaling_factor = 0.07843137255
x = rearrange(x, 'b t c h w -> (b t) c h w')
with autocast('cuda', dtype=torch.half):
    x = vae.encode(x * 2 - 1).mean * scaling_factor
x = rearrange(x, '(b t) (h w) c -> b t c h w', t=n_prompt_frames, h=H // vae.patch_size, w=W // vae.patch_size)
x = x[:, :n_prompt_frames]


prompt is image; ignoring video_offset and n_prompt_frames


In [9]:
# Prepare alpha schedules
betas = sigmoid_beta_schedule(max_noise_level).float().to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod = rearrange(alphas_cumprod, 'T -> T 1 1 1')


In [10]:
t = torch.full((B, 1), stabilization_level - 1, dtype=torch.long, device=device)
x_curr = x.clone()

with torch.no_grad():
    with autocast('cuda', dtype=torch.half):
        v = model(x_curr, t, actions[:, :1], last_only=True, last_frame=True)
        # v = model(x_curr, t, last_only=True, last_frame=True)

In [11]:
# Sampling loop
for i in tqdm(range(n_prompt_frames, num_frames)):
    # Initialize noise chunk
    chunk = torch.randn((B, 1, *x.shape[-3:]), device=device, dtype=torch.float16)
    chunk = torch.clamp(chunk, -noise_abs_max, noise_abs_max)
    x = torch.cat([x, chunk], dim=1)
    start_frame = max(0, i + 1 - model.max_frames)

    for noise_idx in reversed(range(1, ddim_steps + 1)):
        # Build timesteps
        t_ctx = torch.full((B, i), stabilization_level - 1, dtype=torch.long, device=device)
        t = torch.full((B, 1), noise_range[noise_idx], dtype=torch.long, device=device)
        t_next = torch.full((B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device)
        t_next = torch.where(t_next < 0, t, t_next)
        t = torch.cat([t_ctx, t], dim=1)
        t_next = torch.cat([t_ctx, t_next], dim=1)

        # Prepare model inputs
        x_curr = x.clone()[:, start_frame:]
        t = t[:, start_frame:]
        t_next = t_next[:, start_frame:]

        # Predict
        with torch.no_grad():
            # if model.streaming:
            #     v = model(x_curr[:, -1:], t[:, -1:], last_only=True, last_frame=False)
            # else:
            #     v = model(x_curr, t, last_only=True, last_frame=False)
            if model.streaming:
                v = model(x_curr[:, -1:], t[:, -1:], actions[:, i:i+1], last_only=True, last_frame=False)
            else:
                v = model(x_curr, t, actions[:, :i+1], last_only=True, last_frame=False)

        # DDIM noise update
        x_start = alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
        x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) / (1 / alphas_cumprod[t] - 1).sqrt()

        # Compute next frame
        alpha_next = alphas_cumprod[t_next]
        alpha_next[:, :-1] = 1
        if noise_idx == 1:
            alpha_next[:, -1:] = 1
        x_pred = alpha_next.sqrt() * x_start[:, -1:] + x_noise[:, -1:] * (1 - alpha_next).sqrt()
        x[:, -1:] = x_pred[:, -1:]
        
    t = torch.full((B, 1), stabilization_level - 1, dtype=torch.long, device=device)
        
    # if model.streaming:
    #     # Streaming model only needs the very last frame for the cache update
    #     v = model(x[:, -1:].clone(), t, last_only=True, last_frame=True)
    if model.streaming:
        v = model(x[:, -1:], t[:, -1:], actions[:, i:i+1], last_only=True, last_frame=True)


  0%|          | 0/31 [00:00<?, ?it/s]W0611 22:38:27.201000 41259 site-packages/torch/_dynamo/convert_frame.py:844] [1/8] torch._dynamo hit config.cache_size_limit (8)
W0611 22:38:27.201000 41259 site-packages/torch/_dynamo/convert_frame.py:844] [1/8]    function: 'rearrange' (/open-oasis/.conda/lib/python3.11/site-packages/einops/einops.py:545)
W0611 22:38:27.201000 41259 site-packages/torch/_dynamo/convert_frame.py:844] [1/8]    last reason: 1/0: tensor 'L['tensor']' dispatch key set mismatch. expected DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA, AutocastCUDA), actual DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA)
W0611 22:38:27.201000 41259 site-packages/torch/_dynamo/convert_frame.py:844] [1/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0611 22:38:27.201000 41259 site-packages/torch/_dynamo/convert_frame.py:844] [1/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html

In [12]:
# VAE decoding and save output video
x = rearrange(x, 'b t c h w -> (b t) (h w) c')
with torch.no_grad():
    x = (vae.decode(x / scaling_factor) + 1) / 2
x = rearrange(x, '(b t) c h w -> b t h w c', t=num_frames)
x = torch.clamp(x, 0, 1)
x = (x * 255).byte()
write_video(output_path, x[0].cpu(), fps=fps)
print(f"Generation saved to {output_path}.")


Generation saved to video-sparsity-100.mp4.


In [13]:
model

OptimizedModule(
  (_orig_mod): DiT(
    (x_embedder): PatchEmbed(
      (proj): Conv2d(16, 1024, kernel_size=(2, 2), stride=(2, 2))
      (norm): Identity()
    )
    (t_embedder): TimestepEmbedder(
      (mlp): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): SiLU()
        (2): Linear(in_features=1024, out_features=1024, bias=True)
      )
    )
    (spatial_rotary_emb): RotaryEmbedding()
    (temporal_rotary_emb): RotaryEmbedding()
    (external_cond): Linear(in_features=25, out_features=1024, bias=True)
    (blocks): ModuleList(
      (0-15): 16 x SpatioTemporalDiTBlock(
        (s_norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=False)
        (s_attn): SpatialAxialAttention(
          (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
          (to_out): Linear(in_features=1024, out_features=1024, bias=True)
          (rotary_emb): RotaryEmbedding()
        )
        (s_norm2): LayerNorm((1024,), eps=1e-06, elemen

In [14]:
from IPython.display import HTML
from base64 import b64encode

mp4 = open(output_path, "rb").read()
b64 = b64encode(mp4).decode()

# Display HTML video
HTML(f"""
<video width="640" height="480" controls>
  <source src="data:video/mp4;base64,{b64}" type="video/mp4">
  Your browser does not support the video tag.
</video>
""")