### Latent Action Model Demo

1. Clone the repository and install dependencies

In [None]:
!git clone https://github.com/microsoft/villa-x.git
!curl -LsSf https://astral.sh/uv/install.sh | sh
!sudo apt-get install -y build-essential zlib1g-dev libffi-dev libssl-dev libbz2-dev libreadline-dev libsqlite3-dev liblzma-dev libncurses-dev tk-dev python3-dev ffmpeg -y
!cd villa-x && uv sync

2. Download villa-X checkpoint

In [None]:
!git clone https://huggingface.co/microsoft/villa-x villax_checkpoint

3. Load the latent action model

In [1]:
from lam import IgorModel

lam = IgorModel.from_pretrained("villax_checkpoint").cuda()

4. Extract latent actions from a video

In [2]:
import torch


def read_video(fp: str):
    from torchvision.io import read_video

    video, *_ = read_video(fp, pts_unit="sec")
    return video.permute(0, 3, 1, 2)


def read_image(fp: str):
    import numpy as np
    import torch
    from PIL import Image

    image = Image.open(fp).convert("RGB")

    return torch.tensor(np.array(image)).permute(2, 0, 1)


def save_video(frames, output_path, fps=30):
    from torchvision.io import write_video

    write_video(output_path, frames, fps=fps)


video = read_video("example_01.mp4").cuda()  # Load your video here
latent_action = lam.idm(video)

4. Use image FDM to generate reconstructed frames

In [3]:
recon = [video[0]]
for i in range(0, len(latent_action[0])):
    reconstructed_frame = lam.apply_latent_action(video[i], latent_action[0][i])
    recon.append(reconstructed_frame)

save_video(
    torch.cat([video, torch.stack(recon)], dim=3).permute(0, 2, 3, 1),
    "recon_video.mp4",
    fps=5,
)

5. Iteratively generate future frames from latent actions using image FDM. (To be replaced with world model)

In [None]:
cur_frame = read_image("example_target_01.png").cuda()  # Load your target frame here

frames = [cur_frame]
for la in latent_action[0]:
    cur_frame = lam.apply_latent_action(cur_frame, la)
    frames.append(cur_frame)

frames = torch.cat([video, torch.stack(frames)], dim=3).permute(0, 2, 3, 1)
save_video(frames, "iterative_video.mp4", fps=5)