In [18]:
%load_ext autoreload
%autoreload 2

import argparse
import os
import math
import yaml
import logging
import random
import numpy as np
import sys
import imageio
import torch

os.chdir("..")
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
def parse_args(arg_list=None):
    parser = argparse.ArgumentParser(
        description="Unconditioned Video Diffusion Inference"
    )
    parser.add_argument(
        "--dataset-path", type=str, required=True,
        help="Directory containing input reference videos."
    )
    parser.add_argument(
        "--pretrained-model-name-or-path", type=str, required=True,
        help="Path or HF ID where transformer/vae/scheduler are stored."
    )
    parser.add_argument(
        "--checkpoint-path", type=str, required=True,
        help="Path to fine‐tuned checkpoint containing transformer state_dict."
    )
    parser.add_argument(
        "--output-dir", type=str, required=True,
        help="Where to write generated videos."
    )
    parser.add_argument(
        "--model-config", type=str, required=True,
        help="YAML file describing model params (height, width, num_reference, num_target, etc.)"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1,
        help="Batch size per device (usually 1 for inference)."
    )
    parser.add_argument(
        "--num-inference-steps", type=int, default=50,
        help="Number of reverse diffusion steps to run."
    )
    parser.add_argument(
        "--mixed-precision", type=str, default="bf16",
        help="Whether to run backbone in 'fp16', 'bf16', or 'fp32'."
    )
    parser.add_argument(
        "--seed", type=int, default=42,
        help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--shuffle", type=int, default=False,
        help="Whether to shuffle dataset. Usually False for inference."
    )
    parser.add_argument(
        "--is-uncond", type=bool, default=False,
        help=""
    )

    # If arg_list is None, argparse picks up sys.argv; 
    # otherwise it treats arg_list as the full argv list.
    return parser.parse_args(arg_list)

args = [
    "--dataset-path", "/scratch/ondemand28/harryscz/head_audio/data/data256/uv",
    "--pretrained-model-name-or-path", "/scratch/ondemand28/harryscz/model/CogVideoX-2b",
    "--checkpoint-path",  "/scratch/ondemand28/harryscz/head_audio/trainOutput/checkpoint-1000.pt",
    "--output-dir",  "/scratch/ondemand28/harryscz/diffusion/videoOut",
    "--model-config",  "/scratch/ondemand28/harryscz/diffusion/train/model_config.yaml",
    "--batch-size",  "1",
    "--num-inference-steps",  "50",
    "--mixed-precision",  "no",
    "--seed",  "42",
    "--shuffle",  "0",
]

args = parse_args(args)

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)


In [20]:
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from accelerate.logging import get_logger

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)
if args.mixed_precision.lower() == "fp16":
    dtype = torch.float16
elif args.mixed_precision.lower() == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float32

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir,
                                                    logging_dir=os.path.join(args.output_dir, "logs"))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(mixed_precision=args.mixed_precision,
                            project_config=accelerator_project_config,
                            kwargs_handlers=[ddp_kwargs])

# 2.4 Set random seed
if args.seed is not None:
    set_seed(args.seed + accelerator.process_index)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

logger = get_logger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info("Accelerator state:", accelerator.state)

--- Logging error ---
Traceback (most recent call last):
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 1083, in emit
    msg = self.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 927, in format
    return fmt.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 663, in format
    record.message = record.getMessage()
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 367, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 87, in

In [21]:
#### Dataset #####
# Video data have shape [B, C, F, H, W]

from data.VideoDataset import VideoDataset 
from torch.utils.data import DataLoader, DistributedSampler

dataset = VideoDataset(
    videos_dir=args.dataset_path,
    num_ref_frames=1,
    num_target_frames=49
)
if args.shuffle:
    sampler = DistributedSampler(
        dataset,
        num_replicas=accelerator.num_processes,
        rank=accelerator.process_index,
        shuffle=True
    )
else:
    sampler = None
data_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    # sampler=sampler,
    collate_fn=lambda x: x[0],   # since dataset returns already‐batched items
    num_workers=2,
    pin_memory=True,
)
logger.info(f"Number of test examples: {len(data_loader)}")

06/11/2025 17:18:31 - INFO - __main__ - Number of test examples: 10


In [22]:
#### Load Model ####
device = "cuda"
dtype = torch.float32

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from model.cap_transformer import CAPVideoXTransformer3DModel

transformer = CAPVideoXTransformer3DModel.from_pretrained(
    args.pretrained_model_name_or_path,
    low_cpu_mem_usage=False,
    device_map=None,
    ignore_mismatched_sizes=True,
    subfolder="transformer",
    torch_dtype=torch.float32,
    cond_in_channels=1,  # only one channel (the ref_mask)
    sample_width=model_config["width"] // 8,
    sample_height=model_config["height"] // 8,
    max_text_seq_length=1,
    max_n_references=model_config["max_n_references"],
    apply_attention_scaling=model_config["use_growth_scaling"],
    use_rotary_positional_embeddings=False,
)
vae = AutoencoderKLCogVideoX.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae"
)
scheduler = CogVideoXDDIMScheduler.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="scheduler",
)

vae.eval().to(dtype)
transformer.eval().to(dtype)

vae = AutoencoderKLCogVideoX.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae"
)
scheduler = CogVideoXDDIMScheduler.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="scheduler",
)

Some weights of the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b were not used when initializing CAPVideoXTransformer3DModel: 
 ['patch_embed.text_proj.bias, patch_embed.text_proj.weight']
Some weights of CAPVideoXTransformer3DModel were not initialized from the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b and are newly initialized: ['patch_embed.cond_proj.bias', 'patch_embed.cond_proj.weight', 'patch_embed.audio_proj.weight', 'patch_embed.audio_proj.bias', 'patch_embed.ref_temp_proj.weight', 'patch_embed.ref_temp_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
ckpt_path = "/scratch/ondemand28/harryscz/head_audio/trainOutput/checkpoint-2000.pt"
ckpt = torch.load(ckpt_path, map_location="cpu")
transformer.load_state_dict(ckpt["state_dict"], strict=False)

  ckpt = torch.load(ckpt_path, map_location="cpu")


[2025-06-11 17:20:27,585] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


06/11/2025 17:20:31 - INFO - root - gcc -pthread -B /scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/include -I/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/include -fPIC -O2 -isystem /scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/include -fPIC -c /tmp/tmp_2kz029o/test.c -o /tmp/tmp_2kz029o/test.o
06/11/2025 17:20:32 - INFO - root - gcc -pthread -B /scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/compiler_compat /tmp/tmp_2kz029o/test.o -laio -o /tmp/tmp_2kz029o/a.out
/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
06/11/2025 17:20:34 - INFO - root - gcc -pthread -B /scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -O2 -Wall -fPIC -O2 -isystem /scra

<All keys matched successfully>

In [8]:
vae, transformer, scheduler, data_loader = accelerator.prepare(vae, transformer, scheduler, data_loader)

In [None]:
for batch_id, batch in enumerate(data_loader):
    print(batch.keys())
    for chunk_id in range(len(batch["video_chunks"])):
        print(batch["video_chunks"][chunk_id].shape)
        print(batch["cond_chunks"].keys())  # Mask for each frames over H and W and channel suggesting which one works as a condition
                                            # list of tensor masks for cond chunks
        print(batch['cond_chunks']['ref_mask'][0].shape)
        print(batch["chunk_is_ref"]) # list of length frame of bool saying which on is a condition 
        print(batch["raw_audio"]) # passed as none
    break

dict_keys(['video_chunks', 'cond_chunks', 'chunk_is_ref', 'raw_audio'])
torch.Size([1, 3, 50, 256, 256])
dict_keys(['ref_mask'])
torch.Size([1, 50, 256, 256, 3])
[tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False],
       device='cuda:0')]
None


In [17]:
%load_ext autoreload
%autoreload 2

from inference.inference_pipeline import *

pipe = VideoDiffusionPipeline(
    vae=vae,
    transformer=transformer,
    scheduler=scheduler
)
batch = next(iter(data_loader))
videos = pipe(batch, num_inference_steps=50)
save_video(videos[0][0], "/scratch/ondemand28/harryscz/diffusion/videoOut/train.mp4")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Inference Progress: 100%|██████████| 50/50 [04:50<00:00,  5.80s/it]
  video_np = (video_np * 255).clip(0, 255).astype(np.uint8)


Saved !
