In [1]:
import argparse
import os
os.chdir("..")
import math
import yaml
import logging
import random
import numpy as np
import sys
import imageio
from train.trainUtils import *

import torch
import torch.nn.functional as F



In [2]:
def parse_args(arg_list=None):
    parser = argparse.ArgumentParser(
        description="Unconditioned Video Diffusion Inference"
    )
    parser.add_argument(
        "--dataset-path", type=str, required=True,
        help="Directory containing input reference videos."
    )
    parser.add_argument(
        "--pretrained-model-name-or-path", type=str, required=True,
        help="Path or HF ID where transformer/vae/scheduler are stored."
    )
    parser.add_argument(
        "--checkpoint-path", type=str, required=True,
        help="Path to fine‐tuned checkpoint containing transformer state_dict."
    )
    parser.add_argument(
        "--output-dir", type=str, required=True,
        help="Where to write generated videos."
    )
    parser.add_argument(
        "--model-config", type=str, required=True,
        help="YAML file describing model params (height, width, num_reference, num_target, etc.)"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1,
        help="Batch size per device (usually 1 for inference)."
    )
    parser.add_argument(
        "--num-inference-steps", type=int, default=50,
        help="Number of reverse diffusion steps to run."
    )
    parser.add_argument(
        "--mixed-precision", type=str, default="bf16",
        help="Whether to run backbone in 'fp16', 'bf16', or 'fp32'."
    )
    parser.add_argument(
        "--seed", type=int, default=42,
        help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--shuffle", type=int, default=False,
        help="Whether to shuffle dataset. Usually False for inference."
    )
    parser.add_argument(
        "--is-uncond", type=bool, default=False,
        help=""
    )
    parser.add_argument(
        "--learning-rate", type=float, default=1e-4
    )

    # If arg_list is None, argparse picks up sys.argv; 
    # otherwise it treats arg_list as the full argv list.
    return parser.parse_args(arg_list)

args = [
    "--dataset-path", "/scratch/ondemand28/harryscz/head_audio/data/data256/uv",
    "--pretrained-model-name-or-path", "/scratch/ondemand28/harryscz/model/CogVideoX-2b",
    "--checkpoint-path",  "/scratch/ondemand28/harryscz/head_audio/trainOutput/checkpoint-1000.pt",
    "--output-dir",  "/scratch/ondemand28/harryscz/diffusion/videoOut",
    "--model-config",  "/scratch/ondemand28/harryscz/diffusion/train/model_config.yaml",
    "--batch-size",  "1",
    "--num-inference-steps",  "50",
    "--mixed-precision",  "no",
    "--seed",  "42",
    "--shuffle",  "0",
    "--learning-rate", "0.0001"
]

args = parse_args(args)

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)


In [3]:
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from accelerate.logging import get_logger

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)
if args.mixed_precision.lower() == "fp16":
    dtype = torch.float16
elif args.mixed_precision.lower() == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float32

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir,
                                                    logging_dir=os.path.join(args.output_dir, "logs"))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(mixed_precision=args.mixed_precision,
                            project_config=accelerator_project_config,
                            kwargs_handlers=[ddp_kwargs])

# 2.4 Set random seed
if args.seed is not None:
    set_seed(args.seed + accelerator.process_index)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

logger = get_logger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info("Accelerator state:", accelerator.state)

--- Logging error ---
Traceback (most recent call last):
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 1083, in emit
    msg = self.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 927, in format
    return fmt.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 663, in format
    record.message = record.getMessage()
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 367, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 87, in

In [4]:
#### Dataset #####
# Video data have shape [B, C, F, H, W]

from data.VideoDataset import VideoDataset 
from torch.utils.data import DataLoader, DistributedSampler

dataset = VideoDataset(
    videos_dir=args.dataset_path,
    num_ref_frames=1,
    num_target_frames=49
)
if args.shuffle:
    sampler = DistributedSampler(
        dataset,
        num_replicas=accelerator.num_processes,
        rank=accelerator.process_index,
        shuffle=True
    )
else:
    sampler = None
data_loader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    # sampler=sampler,
    collate_fn=lambda x: x[0],   # since dataset returns already‐batched items
    num_workers=2,
    pin_memory=True,
)
logger.info(f"Number of test examples: {len(data_loader)}")

06/13/2025 12:33:50 - INFO - __main__ - Number of test examples: 1


In [5]:
#### Load Model ####
device = "cuda"
weight_dtype = torch.float32

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from model.cap_transformer import CAPVideoXTransformer3DModel
from diffusers.optimization import get_scheduler
from inference.inference_pipeline import *

transformer = CAPVideoXTransformer3DModel.from_pretrained(
    args.pretrained_model_name_or_path,
    low_cpu_mem_usage=False,
    device_map=None,
    ignore_mismatched_sizes=True,
    subfolder="transformer",
    torch_dtype=torch.float32,
    cond_in_channels=1,  # only one channel (the ref_mask)
    sample_width=model_config["width"] // 8,
    sample_height=model_config["height"] // 8,
    max_text_seq_length=1,
    max_n_references=model_config["max_n_references"],
    apply_attention_scaling=model_config["use_growth_scaling"],
    use_rotary_positional_embeddings=False,
)

vae = AutoencoderKLCogVideoX.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae"
)
scheduler = CogVideoXDDIMScheduler.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="scheduler",
    prediction_type="v_prediction"
)

# if args.enable_slicing: vae.enable_slicing()
# if args.enable_tiling:  vae.enable_tiling()

vae.eval().to(dtype)
transformer.eval().to(dtype)

vae, transformer, scheduler, data_loader = accelerator.prepare(vae, transformer, scheduler, data_loader)

Some weights of the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b were not used when initializing CAPVideoXTransformer3DModel: 
 ['patch_embed.text_proj.weight, patch_embed.text_proj.bias']
Some weights of CAPVideoXTransformer3DModel were not initialized from the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b and are newly initialized: ['patch_embed.cond_proj.weight', 'patch_embed.ref_temp_proj.bias', 'patch_embed.audio_proj.bias', 'patch_embed.ref_temp_proj.weight', 'patch_embed.cond_proj.bias', 'patch_embed.audio_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
trainable_parameters =  list(filter(lambda p: p.requires_grad, transformer.parameters()))
params_to_optimize = [{"params": trainable_parameters, "lr": args.learning_rate}]

optimizer = get_optimizer(
    learning_rate=args.learning_rate,
    adam_beta1=0.9, 
    adam_beta2=0.95, 
    adam_epsilon=1e-8, 
    adam_weight_decay=1e-4, 
    params_to_optimize=params_to_optimize, 
    # use_deepspeed=use_deepspeed_optimizer
)

def encode_video(vae, video):
    with torch.no_grad():
        dist = vae.encode(video).latent_dist.sample()
    latent = dist * vae.config.scaling_factor
    return latent.permute(0,2,1,3,4).contiguous()

def unwrap_model(m):
    m = accelerator.unwrap_model(m)
    return m._orig_mod if hasattr(m, "_orig_mod") else m

In [None]:
weight_dtype=torch.float32
transformer.train()

for step, batch in enumerate(data_loader):
    # models_to_accumulate =  [transformer] !!!!!!!!!!!!!!
    # with accelerator.accumulate(models_to_accumulate): !!!!!!!!!!!!!!!
        latent_chunks = []
        ref_mask_chunks = []

        # Initialize necessary data for diffusion
        for i, video in enumerate(batch["video_chunks"]):
            video = video.to(accelerator.device).to(weight_dtype)

            # Encode Video
            latent = encode_video(vae, video) # [B, F, C_z, H, W]
            latent_chunks.append(latent)

            # Ref Mask Chunk, Mask of shape [B, F, H, W, C]
            B, F_z, C, H, W = latent.shape
            rm = torch.zeros((B, F_z, 1, H, W), device=accelerator.device, dtype=weight_dtype)
            rm[:, 0] = 1.0
            ref_mask_chunks.append(rm)

        sequence_infos = [[False, torch.arange(chunk.shape[1])]for chunk in latent_chunks]
        
        # Sample Random Noise
        B, F_z, C_z, H_z, W_z = latent_chunks[0].shape
        timesteps = torch.randint(
            1,
            scheduler.config.num_train_timesteps,
            (B,),
            device=accelerator.device
        ).long()

        noised_latents = []
        for idx, latent in enumerate(latent_chunks):
            noise = torch.randn_like(latent, device=accelerator.device, dtype=weight_dtype)
            noisy_latent = scheduler.add_noise(latent, noise, timesteps)
            noised_latents.append(noisy_latent)

        # Trivial Audio, Text, and Condition
        audio_embeds = torch.zeros((B, F_z, 768), dtype=weight_dtype, device=accelerator.device)
        text_embeds  = torch.zeros((B, 1,
            unwrap_model(transformer).config.attention_head_dim * unwrap_model(transformer).config.num_attention_heads
        ), dtype=weight_dtype, device=accelerator.device)
        B, F_z, C_z, H_z, W_z = noised_latents[0].shape
        zero_cond = [torch.zeros((B, F_z, 1, H_z, W_z), dtype=weight_dtype, device=accelerator.device)] * len(noised_latents)

        # Predict Noise
        model_outputs = transformer(
            hidden_states=noised_latents,
            encoder_hidden_states=text_embeds,
            audio_embeds=audio_embeds,
            condition=zero_cond,
            timestep=timesteps,
            sequence_infos=sequence_infos,
            image_rotary_emb=None,
            return_dict=False
        )[0]

        # ref_mask = torch.cat(ref_mask_chunks, dim=1)
        # non_ref_mask = 1. - ref_mask

        model_output = torch.cat(model_outputs, dim=1)
        model_input = torch.cat(latent_chunks, dim=1)
        noisy_input = torch.cat(noised_latents, dim=1)

        # print("model_output", model_output.min(), model_output.max())
        model_pred = scheduler.get_velocity(model_output, noisy_input, timesteps)

        alpha_bar = scheduler.alphas_cumprod[timesteps].to(weight_dtype)
        sigma_bar = (1 - alpha_bar).sqrt()
        eps = (model_input - alpha_bar.sqrt() * noisy_input) / sigma_bar
        v_true = alpha_bar.sqrt() * eps - sigma_bar * model_input
        loss = F.mse_loss(model_pred, v_true)

        accelerator.backward(loss)