In [1]:
import sys
sys.path.append('../')

import importlib
from omegaconf import OmegaConf

import torch
import torchvision
from diffusers.models import AutoencoderKL
from diffusion import create_diffusion
from einops import rearrange
from tqdm import tqdm

configs = OmegaConf.load("../configs/mario_t1_v0.yaml")

device = "cuda:5"

  from .autonotebook import tqdm as notebook_tqdm
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")


In [33]:
model = importlib.import_module("models.t1").Model(
    **configs.get("model", {})
)
model.load_state_dict(torch.load("../results/002-t1/checkpoints/0007000.pt", map_location="cpu")["ema"])
model.to(device)

  model.load_state_dict(torch.load("../results/002-t1/checkpoints/0007000.pt", map_location="cpu")["ema"])


Model(
  (x_embedder): PatchEmbed(
    (proj): Conv2d(4, 1152, kernel_size=(2, 2), stride=(2, 2))
    (norm): Identity()
  )
  (t_embedder): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=1152, bias=True)
      (1): SiLU()
      (2): Linear(in_features=1152, out_features=1152, bias=True)
    )
  )
  (blocks): ModuleList(
    (0-27): 28 x DiTBlock(
      (norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=False)
      (attn): Attention(
        (qkv): Linear(in_features=1152, out_features=3456, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1152, out_features=1152, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=False)
      (mlp): Mlp(
        (fc1): Linear(in_features=1152, out_features=4608, bias=True)
        (act): GELU(approximate='tanh')
    

In [5]:
configs.dataset.data_root = "/mnt/store/kmei1/projects/t1/datasets/super-mario-bros-reinforcement-learning/clips/"

In [6]:
%env HF_HOME=/mnt/store/kmei1/HF_HOME/

vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
dataset = importlib.import_module("datasets.action_video").Dataset(
    **configs.get("dataset", {})
)

env: HF_HOME=/mnt/store/kmei1/HF_HOME/


In [32]:
diffusion = create_diffusion(
    timestep_respacing="ddim20",
)

In [39]:
data = dataset[11]
x, _, pos = data

pos = torch.from_numpy(pos[None]).to(device)
x =  torch.from_numpy(x[None]).to(device)
with torch.no_grad():
    T = x.shape[2]
    x = rearrange(x, "N C T H W -> (N T) C H W")
    x = vae.encode(x).latent_dist.sample().mul_(0.13025)
    x = rearrange(x, "(N T) C H W -> N C T H W", T=T)

# inject the last frame
past_frame, x = x[:, :, :-1], x[:, :, -1:]
past_pos, pos = pos[:, :, :-1], pos[:, :, -1:]

bsz = x.shape[0]

In [40]:
all_frames = past_frame
all_poses = past_pos
to_be_generate_frames = 4

memory_frames = 16

for _ in tqdm(range(to_be_generate_frames)):
    z = torch.randn(bsz, 4, 1, 32, 32, device=device)
    model_kwargs = dict(
        pos=pos,
        past_frame=all_frames[:, :, -memory_frames:],
        # past_pos=all_poses[:, :, -memory_frames:]
        past_pos = past_pos
    )

    with torch.no_grad():
        samples = diffusion.p_sample_loop(
            model, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device
        )
    all_frames = torch.cat([all_frames, samples], dim=2)
    all_poses = torch.cat([all_poses, pos], dim=2)

100%|██████████| 4/4 [01:11<00:00, 18.00s/it]


In [41]:
# samples = torch.cat([past_frame, samples], dim=2)
_samples = rearrange(all_frames[:, :, 16:], "N C T H W -> (N T) C H W")
# _samples = rearrange(past_frame, "N C T H W -> (N T) C H W")
with torch.no_grad():
    samples = []
    for frame in _samples:
        samples.append(vae.decode(frame.unsqueeze(0) / 0.13025).sample)
    samples = torch.cat(samples)
samples = torch.clamp(samples, -1, 1)
del _samples

In [42]:
video = 255 * (samples.clip(-1, 1) / 2 + 0.5)

In [43]:
import numpy as np
import PIL.Image as Image

screeshoot = Image.fromarray(np.uint8(rearrange(video.cpu().numpy(), "T C H W -> H (T W) C")))
screeshoot.save("samples.png")

In [44]:
torchvision.io.write_video(
    "samples.mp4",
    video.permute(0, 2, 3, 1).cpu().numpy(),
    fps=8,
    video_codec="h264",
)