In [10]:
%load_ext autoreload
%autoreload 2

import argparse
import os
import math
import yaml
import logging
import random
import numpy as np
import sys
import imageio
import torch

os.chdir("..")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def parse_args(arg_list=None):
    parser = argparse.ArgumentParser(
        description="Unconditioned Video Diffusion Inference"
    )
    parser.add_argument(
        "--dataset-path", type=str, required=True,
        help="Directory containing input reference videos."
    )
    parser.add_argument(
        "--pretrained-model-name-or-path", type=str, required=True,
        help="Path or HF ID where transformer/vae/scheduler are stored."
    )
    parser.add_argument(
        "--checkpoint-path", type=str, required=True,
        help="Path to fine‐tuned checkpoint containing transformer state_dict."
    )
    parser.add_argument(
        "--output-dir", type=str, required=True,
        help="Where to write generated videos."
    )
    parser.add_argument(
        "--model-config", type=str, required=True,
        help="YAML file describing model params (height, width, num_reference, num_target, etc.)"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1,
        help="Batch size per device (usually 1 for inference)."
    )
    parser.add_argument(
        "--num-inference-steps", type=int, default=50,
        help="Number of reverse diffusion steps to run."
    )
    parser.add_argument(
        "--mixed-precision", type=str, default="bf16",
        help="Whether to run backbone in 'fp16', 'bf16', or 'fp32'."
    )
    parser.add_argument(
        "--seed", type=int, default=42,
        help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--shuffle", type=int, default=False,
        help="Whether to shuffle dataset. Usually False for inference."
    )
    parser.add_argument(
        "--is-uncond", type=bool, default=False,
        help=""
    )

    # If arg_list is None, argparse picks up sys.argv; 
    # otherwise it treats arg_list as the full argv list.
    return parser.parse_args(arg_list)

args = [
    "--dataset-path", "/scratch/ondemand28/harryscz/head_audio/data/data256/uv",
    "--pretrained-model-name-or-path", "/scratch/ondemand28/harryscz/model/CogVideoX-2b",
    "--checkpoint-path",  "/scratch/ondemand28/harryscz/head_audio/trainOutput/checkpoint-1000.pt",
    "--output-dir",  "/scratch/ondemand28/harryscz/diffusion/videoOut",
    "--model-config",  "/scratch/ondemand28/harryscz/diffusion/train/model_config.yaml",
    "--batch-size",  "1",
    "--num-inference-steps",  "50",
    "--mixed-precision",  "no",
    "--seed",  "42",
    "--shuffle",  "0",
]

args = parse_args(args)

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)


In [4]:
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from accelerate.logging import get_logger

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)
if args.mixed_precision.lower() == "fp16":
    dtype = torch.float16
elif args.mixed_precision.lower() == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float32

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir,
                                                    logging_dir=os.path.join(args.output_dir, "logs"))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(mixed_precision=args.mixed_precision,
                            project_config=accelerator_project_config,
                            kwargs_handlers=[ddp_kwargs])

# 2.4 Set random seed
if args.seed is not None:
    set_seed(args.seed + accelerator.process_index)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

logger = get_logger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info("Accelerator state:", accelerator.state)

--- Logging error ---
Traceback (most recent call last):
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 1083, in emit
    msg = self.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 927, in format
    return fmt.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 663, in format
    record.message = record.getMessage()
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 367, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 87, in

In [5]:
#### Dataset #####
# Video data have shape [B, C, F, H, W]

from data.VideoDataset import VideoDataset 
from torch.utils.data import DataLoader, DistributedSampler

dataset = VideoDataset(
    videos_dir=args.dataset_path,
    num_ref_frames=1,
    num_target_frames=49
)
if args.shuffle:
    sampler = DistributedSampler(
        dataset,
        num_replicas=accelerator.num_processes,
        rank=accelerator.process_index,
        shuffle=True
    )
else:
    sampler = None
data_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    # sampler=sampler,
    collate_fn=lambda x: x[0],   # since dataset returns already‐batched items
    num_workers=2,
    pin_memory=True,
)
logger.info(f"Number of test examples: {len(data_loader)}")

06/10/2025 14:37:17 - INFO - __main__ - Number of test examples: 10


In [6]:
#### Load Model ####
device = "cuda"
dtype = torch.float32

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from model.cap_transformer import CAPVideoXTransformer3DModel

transformer = CAPVideoXTransformer3DModel.from_pretrained(
    args.pretrained_model_name_or_path,
    low_cpu_mem_usage=False,
    device_map=None,
    ignore_mismatched_sizes=True,
    subfolder="transformer",
    torch_dtype=torch.float32,
    cond_in_channels=1,  # only one channel (the ref_mask)
    sample_width=model_config["width"] // 8,
    sample_height=model_config["height"] // 8,
    max_text_seq_length=1,
    max_n_references=model_config["max_n_references"],
    apply_attention_scaling=model_config["use_growth_scaling"],
    use_rotary_positional_embeddings=False,
)

vae = AutoencoderKLCogVideoX.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae"
)
scheduler = CogVideoXDDIMScheduler.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="scheduler",
)

vae.eval().to(dtype)
transformer.eval().to(dtype)

vae, transformer, scheduler, data_loader = accelerator.prepare(vae, transformer, scheduler, data_loader)

Some weights of the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b were not used when initializing CAPVideoXTransformer3DModel: 
 ['patch_embed.text_proj.weight, patch_embed.text_proj.bias']
Some weights of CAPVideoXTransformer3DModel were not initialized from the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b and are newly initialized: ['patch_embed.audio_proj.weight', 'patch_embed.cond_proj.bias', 'patch_embed.ref_temp_proj.bias', 'patch_embed.audio_proj.bias', 'patch_embed.cond_proj.weight', 'patch_embed.ref_temp_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
for batch_id, batch in enumerate(data_loader):
    print(batch.keys())
    for chunk_id in range(len(batch["video_chunks"])):
        print(batch["video_chunks"][chunk_id].shape)
        print(batch["cond_chunks"].keys())  # Mask for each frames over H and W and channel suggesting which one works as a condition
                                            # list of tensor masks for cond chunks
        print(batch['cond_chunks']['ref_mask'][0].shape)
        print(batch["chunk_is_ref"]) # list of length frame of bool saying which on is a condition 
        print(batch["raw_audio"]) # passed as none
    break

dict_keys(['video_chunks', 'cond_chunks', 'chunk_is_ref', 'raw_audio'])
torch.Size([1, 3, 50, 256, 256])
dict_keys(['ref_mask'])
torch.Size([1, 50, 256, 256, 3])
[tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False],
       device='cuda:0')]
None


In [39]:
%load_ext autoreload
%autoreload 2

from inference.inference_pipeline import *

pipe = VideoDiffusionPipeline(
    vae=vae,
    transformer=transformer,
    scheduler=scheduler
)
batch = next(iter(data_loader))
videos = pipe(batch, num_inference_steps=50)
save_video(videos[0][0], "/scratch/ondemand28/harryscz/diffusion/videoOut/try.mp4")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Inference Progress: 100%|██████████| 50/50 [01:27<00:00,  1.75s/it]


Saved !


In [None]:
from typing import Dict, List, Optional, Union, Tuple
from diffusers import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from diffusers.video_processor import VideoProcessor

def save_video(video : torch.tensor, save_path : str, fps : int = 16):
    video_np = video.permute(1, 2, 3, 0).cpu().numpy()  # [49, 256, 256, 3]

    video_np = (video_np * 255).clip(0, 255).astype(np.uint8)

    imageio.mimsave(save_path, video_np, fps=fps)
    
    print("Saved !")

class VideoDiffusionPipeline(DiffusionPipeline):
    """
    A custom diffusion pipeline that mirrors your manual inference loop,
    but inherits from DiffusionPipeline to leverage no-grad, mixed-precision,
    and buffer reuse for maximum efficiency.
    """
    def __init__(
        self,
        vae,
        transformer,
        scheduler,
    ):
        super().__init__()
        self.register_modules(vae=vae, transformer=transformer, scheduler=scheduler)

        # Scale factors for spatial/temporal axes
        self.vae_scale_factor_spatial = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.vae_scale_factor_temporal = getattr(self.vae.config, "temporal_compression_ratio", 1)

        # Video post-processor
        self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)

    def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
        latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
        latents = 1 / self.vae.config.scaling_factor * latents

        frames = self.vae.decode(latents).sample
        return frames

    @torch.no_grad()
    def __call__(
        self,
        batch: Dict[str, Union[List[torch.FloatTensor], Dict[str, List[torch.FloatTensor]]]],
        num_inference_steps: int = 50,
        guidance_scale: float = 1.0,
        seed: int = 42,
        output_type: str = "pil",
        return_dict: bool = False,
        return_latent: bool = False,
        return_pil: bool = False,
        return_decode_latent : bool = True,
    ) -> Union[List, Dict]:
        device = self._execution_device
        dtype = self.transformer.dtype
        generator = torch.Generator(device=device)
        generator.manual_seed(seed+1)

        # 1) Extract & encode
        latent_chunks: List[torch.Tensor] = []
        ref_mask_chunks: List[torch.Tensor] = []
        sequence_infos: List[tuple] = []

        for i, video in enumerate(batch["video_chunks"]):
            # video: [B, C, F, H, W]
            video = video.to(device=device, dtype=dtype)
            # Initialize first frame and set rest as random noise
            video[:, :, 1:, :, :] = torch.randn(video[:, :, 1:, :, :].shape, generator=generator, device=device)
            with torch.no_grad(): dist = self.vae.encode(video).latent_dist.sample()
            latent = dist * self.vae.config.scaling_factor
            latent = latent.permute(0, 2, 1, 3, 4).contiguous()  # [B, F, C_z, h, w]
            latent_chunks.append(latent)

            # mask: batch["cond_chunks"]["ref_mask"][i] shape [B, F, H, W, C_mask]
            rm = batch["cond_chunks"]["ref_mask"][i]
            rm = rm.to(device=device, dtype=dtype).permute(0, 4, 1, 2, 3)
            # now [B, C_mask, F, H, W]
            ref_mask_chunks.append(rm)

            # sequence info
            is_ref = batch.get("chunk_is_ref", [False] * len(latent_chunks))[i]
            seq = torch.arange(0, latent.shape[1], device=device)
            sequence_infos.append((is_ref, seq))

        # 2) Build 2× for classifier-free guidance
        latents = latent_chunks
        masks   = [torch.cat([m, torch.zeros_like(m)], dim=0) for m in ref_mask_chunks]
        # keep ref_latents for mixing
        ref_latents = [torch.cat([z, torch.zeros_like(z)], dim=0) for z in latent_chunks]

        # 3) dummy audio/text embeddings (adjust if you have real ones)
        B2 = latents[0].shape[0] * 2
        total_F = sum(z.shape[1] for z in latents)
        audio_embeds = torch.zeros((B2, total_F, 768), dtype=dtype, device=device)
        text_embeds  = torch.zeros((B2, 1,
            self.transformer.config.attention_head_dim * self.transformer.config.num_attention_heads
        ), dtype=dtype, device=device)

        # 4) timesteps
        timesteps, _ = retrieve_timesteps(self.scheduler, num_inference_steps, device=device)

        # 5) optional fuse QKV once
        # try:
        #     self.transformer.fuse_qkv_projections()
        # except Exception:
        #     pass

        # 6) denoising loop
        old_pred_original_samples = [None] * len(latents)
        for i, t in enumerate(tqdm(timesteps, desc="Inference Progress")):
            latent_model_inputs = [torch.cat([chunks] * 2, dim=0)for chunks in latents]
            B2, F, C, H, W = latent_model_inputs[0].shape
            # one zero condition tensor
            zero_cond = torch.zeros((B2, F, 1, H, W), dtype=dtype, device=device)

            # single forward
            noise_preds = self.transformer(
                hidden_states=latent_model_inputs,
                encoder_hidden_states=text_embeds,
                audio_embeds=audio_embeds,
                condition=[zero_cond] * len(latent_model_inputs),
                sequence_infos=[[False, torch.arange(chunk.shape[1])]for chunk in latents],
                timestep=t.expand(B2),
                image_rotary_emb=None,
                return_dict=False,
            )[0]

            # apply guidance, scheduler.step, then mixing
            new_latents = []
            new_old_pred_original_samples = []

            for noise_pred, old_pred_original_sample, latent in zip(noise_preds, old_pred_original_samples, latents):
                noise_pred, noise_pred_uncond = noise_pred.chunk(2, dim=0)
                # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
            
                latent, old_pred_original_sample = scheduler.step(
                    noise_pred,
                    old_pred_original_sample,
                    t,
                    timesteps[i - 1] if i > 0 else None,
                    latent,
                    eta=0.0,
                    generator=generator
                )

                new_latents.append(latent)
                new_old_pred_original_samples.append(old_pred_original_sample)

            latents = new_latents
            old_pred_original_samples = list(new_old_pred_original_samples)

        # 7) decode to videos
        videos = []
        if return_latent:
            return latent
        
        if return_decode_latent:
            return self.decode_latents(latent)
        
        for latent in latents:
            dec = latent.permute(0, 2, 1, 3, 4) / self.vae.config.scaling_factor
            frames = self.vae.decode(dec).sample
            video = self.video_processor.postprocess_video(video=frames, output_type=output_type)
            videos.append(video)

        return {"frames": videos} if return_dict else videos

In [None]:
generator = torch.Generator(device=device)
generator.manual_seed(args.seed)
pipe = VideoDiffusionPipeline(vae, transformer, scheduler)
pipe = pipe.to(device).to(torch.float32)   
batch = next(iter(data_loader))

In [None]:
videos = pipe(batch, num_inference_steps=50)

Inference Progress: 100%|██████████| 50/50 [01:29<00:00,  1.79s/it]


In [None]:
save_video(videos[0], "/scratch/ondemand28/harryscz/diffusion/videoOut/try.mp4")

Saved !
