In [5]:
import os
from pathlib import Path
import random     
import inspect
import math
import sys

from accelerate import Accelerator

from diffusers import DiffusionPipeline, CogVideoXDDIMScheduler
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [6]:
if 'flameObj' in sys.modules:
    del sys.modules['flameObj']

from flameObj import *

flamePath = flamePath = "/scratch/ondemand28/harryscz/head_audio/head/code/flame/flame2023_no_jaw.npz"
sourcePath = "/scratch/ondemand28/harryscz/head_audio/head/data/vfhq-fit"
dataPath = [os.path.join(os.path.join(sourcePath, data), "fit.npz") for data in os.listdir(sourcePath)]
seqPath = "/scratch/ondemand28/harryscz/head_aduiohead/_-91nXXjrVo_00/fit.npz"

head = Flame(flamePath, device='cuda')



In [19]:
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler
from cap_transformer import CAPVideoXTransformer3DModel
import yaml

weight_dtype=torch.float32
device="cuda"

In [20]:
model_config="/scratch/ondemand28/harryscz/diffusion/model_config.yaml"
pretrained_model_name_or_path = "/scratch/ondemand28/harryscz/model/CogVideoX-2b"
ckpt_path = "/scratch/ondemand28/harryscz/head_audio/trainOutput/checkpoint-1000.pt"

with open(model_config) as f: model_config_yaml = yaml.safe_load(f)

vae = AutoencoderKLCogVideoX.from_pretrained(
    pretrained_model_name_or_path, subfolder="vae"
)

scheduler = CogVideoXDDIMScheduler.from_pretrained(
    pretrained_model_name_or_path, subfolder="scheduler"
)

transformer = CAPVideoXTransformer3DModel.from_pretrained(
    pretrained_model_name_or_path,
    low_cpu_mem_usage=False,
    device_map=None,
    ignore_mismatched_sizes=True,
    subfolder="transformer",
    torch_dtype=torch.float32,
    cond_in_channels=1,  # only one channel (the ref_mask)
    sample_width=model_config_yaml["width"] // 8,
    sample_height=model_config_yaml["height"] // 8,
    max_text_seq_length=1,
    max_n_references=model_config_yaml["max_n_references"],
    apply_attention_scaling=model_config_yaml["use_growth_scaling"],
    use_rotary_positional_embeddings=False,
)

ckpt = torch.load(ckpt_path, map_location="cpu")

if "state_dict" in ckpt:
    raw_state_dict = ckpt["state_dict"]
elif "model_state_dict" in ckpt:
    raw_state_dict = ckpt["model_state_dict"]
else:
    # If the .pt is literally just a pure state_dict, do this:
    raw_state_dict = ckpt

clean_state_dict = {}
for key, val in raw_state_dict.items():
    new_key = key
    # e.g. if your keys start with "module.", remove it:
    if key.startswith("module."):
        new_key = key[len("module."):]
    # or if saved under "model.", do:
    # if key.startswith("model."):
    #     new_key = key[len("model."):]
    clean_state_dict[new_key] = val

missing, unexpected = transformer.load_state_dict(clean_state_dict, strict=False)

print("==> Missing keys (these will be randomly initialized because they weren't in the checkpoint):")
for k in missing:
    print("   ", k)
print("==> Unexpected keys (these were in the checkpoint but didn't match any parameter in your model):")
for k in unexpected:
    print("   ", k)

vae.eval().to(device)
transformer.eval().to(device)

Some weights of the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b were not used when initializing CAPVideoXTransformer3DModel: 
 ['patch_embed.text_proj.bias, patch_embed.text_proj.weight']
Some weights of CAPVideoXTransformer3DModel were not initialized from the model checkpoint at /scratch/ondemand28/harryscz/model/CogVideoX-2b and are newly initialized: ['patch_embed.ref_temp_proj.bias', 'patch_embed.cond_proj.bias', 'patch_embed.audio_proj.weight', 'patch_embed.cond_proj.weight', 'patch_embed.audio_proj.bias', 'patch_embed.ref_temp_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  ckpt = torch.load(ckpt_path, map_location="cpu")


==> Missing keys (these will be randomly initialized because they weren't in the checkpoint):
==> Unexpected keys (these were in the checkpoint but didn't match any parameter in your model):


CAPVideoXTransformer3DModel(
  (patch_embed): CAPPatchEmbed(
    (proj): Conv2d(16, 1920, kernel_size=(2, 2), stride=(2, 2))
    (cond_proj): Conv2d(1, 1920, kernel_size=(2, 2), stride=(2, 2))
    (audio_proj): Linear(in_features=3072, out_features=1920, bias=True)
    (ref_temp_proj): Linear(in_features=2, out_features=480, bias=True)
  )
  (embedding_dropout): Dropout(p=0.0, inplace=False)
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=1920, out_features=512, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=512, out_features=512, bias=True)
  )
  (transformer_blocks): ModuleList(
    (0-29): 30 x CogVideoXBlock(
      (norm1): CogVideoXLayerNormZero(
        (silu): SiLU()
        (linear): Linear(in_features=512, out_features=11520, bias=True)
        (norm): LayerNorm((1920,), eps=1e-05, elementwise_affine=True)
      )
      (attn1): Attention(
        (norm_q): LayerNorm((64,), eps=1e-06, elementwise_affine=True

In [None]:
from torchvision import transforms

def preprocess_frame(pil_frame, target_resolution=(model_config_yaml["height"], model_config_yaml["width"])):
    """
    1) Resize the input PIL frame
    2) Convert to floatTensor in [-1, +1]
    3) Return a (1, 3, H, W) tensor on CPU
    """
    preprocess = transforms.Compose([
        transforms.Resize(target_resolution, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),            # → [0,1]
        transforms.Normalize([0.5]*3, [0.5]*3),  # → [-1, +1]
    ])
    x = preprocess(pil_frame)  # shape: (3, H, W), CPU float32
    return x.unsqueeze(0)      # shape: (1, 3, H, W)

def encode_single_frame(vae, frame_tensor):
    """
    frame_tensor: (1, 3, H, W), in [-1, +1], CPU or device
    returns: latent of shape (1, C_z, F_lat=1, h_lat, w_lat)
    """
    frame_tensor = frame_tensor.to(device=device, dtype=weight_dtype)
    with torch.no_grad():
        # Note: CogVideoX’s VAE expects a 5D “video” tensor: (B, 3, F, H, W).
        # For a single frame, F=1.
        video_input = frame_tensor.unsqueeze(2)  # → (1, 3, 1, H, W)
        z = vae.encode(video_input).latent_dist.sample()  # (1, C_z, 1, h_lat, w_lat), float32
        z = z * vae.config.scaling_factor
    return z.contiguous()  

preprocess = transforms.Compose([
    transforms.Resize((model_config_yaml["height"], model_config_yaml["width"]),
                      interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),  # → [-1,+1]
])

def preprocess_frame_to_tensor(pil_img):
    x = preprocess(pil_img)         # (3, H, W) in [-1,+1]
    return x.unsqueeze(0)           # (1, 3, H, W)

@torch.no_grad()
def encode_single_frame(vae, frame_tensor):
    # frame_tensor: (1,3,H,W) in [-1,+1], CPU or already on device
    f = frame_tensor.to(device, dtype=weight_dtype)
    video_in = f.unsqueeze(2)       # (1,3,1,H,W)
    with torch.no_grad():
        z = vae.encode(video_in).latent_dist.sample()  # (1, C_z, 1, h_lat, w_lat)
        z = z * vae.config.scaling_factor
    return z.contiguous()           # (1, C_z, 1, h_lat, w_lat)

# (C) Build the masks, embeddings & sequence_infos for inference

def build_masks_and_embeddings(latent_ref):
    """
    latent_ref: (1, C_z, 1, h_lat, w_lat)
    returns:
      cond_chunks:      [ (1, 1, 1, h_lat, w_lat) ]
      sequence_infos:   [ (False, tensor([0], device=device)) ]
      fake_text_embeds: (1, 1, inner_dim)
      fake_audio_embeds: (1, 1, 768)
    """
    B, C_z, F_lat, h_lat, w_lat = latent_ref.shape
    assert B == 1 and F_lat == 1

    # C)1) Reference mask = all‐ones
    ref_mask_latent = torch.ones((B, 1, F_lat, h_lat, w_lat), device=device, dtype=weight_dtype)
    cond_mask_chunk = ref_mask_latent.permute(0, 2, 1, 3, 4)  # (1,1,1,h_lat,w_lat)
    cond_chunks = [cond_mask_chunk]

    # C)2) sequence_infos
    seq_idx = torch.arange(0, F_lat, device=device)  # → tensor([0])
    sequence_infos = [(False, seq_idx)]

    # C)3) fake_text_embeds
    inner_dim = transformer.config.num_attention_heads * transformer.config.attention_head_dim
    fake_text_embeds = torch.zeros((B, 1, inner_dim), device=device, dtype=weight_dtype)

    # C)4) fake_audio_embeds
    audio_feature_dim = 768
    fake_audio_embeds = torch.zeros((B, F_lat, audio_feature_dim), device=device, dtype=weight_dtype)

    return cond_chunks, sequence_infos, fake_text_embeds, fake_audio_embeds

@torch.no_grad()
def denoise_from_single_frame(
    vae,
    transformer,
    scheduler,
    latent_ref,            # (1, C_z, 1, h_lat, w_lat)
    cond_chunks,           # [ (1, 1, 1, h_lat, w_lat) ]
    sequence_infos,        # [ (False, tensor([0])) ]
    fake_text_embeds,      # (1, 1, inner_dim)
    fake_audio_embeds,     # (1, 1, 768)
    num_inference_steps=50,
):
    B, C_z, F_lat, h_lat, w_lat = latent_ref.shape
    assert B == 1 and F_lat == 1

    # D.1) Start from pure noise in latent space
    current_latent = torch.randn_like(latent_ref, device=device, dtype=weight_dtype)

    # D.2) Prepare timesteps (linear spacing from T-1 down to 0)
    scheduler.set_format("pt")
    T = scheduler.config.num_train_timesteps
    timesteps = torch.linspace(T - 1, 0, num_inference_steps, dtype=torch.long, device=device)

    # D.3) Reverse‐diffusion loop
    for t in timesteps:
        # (1) Permute current_latent into the shape (B, F_lat, C_z, h_lat, w_lat)
        noised_chunk = current_latent.permute(0, 2, 1, 3, 4)  # → (1,1,C_z,h_lat,w_lat)

        # (2) Infer velocity with the transformer
        out_list = transformer(
            hidden_states=[noised_chunk],
            condition=cond_chunks,
            sequence_infos=sequence_infos,
            timestep=torch.tensor([t], device=device),
            audio_embeds=fake_audio_embeds,
            encoder_hidden_states=fake_text_embeds,
            image_rotary_emb=None,
            return_dict=False,
        )[0]
        # out_list is a list of length 1; element[0].shape = (1,1,C_z,h_lat,w_lat)

        pred_velocity = torch.cat(out_list, dim=1)  # → (1,1,C_z,h_lat,w_lat)

        # (3) One scheduler.step call → produce the “previous sample” at t-1
        step_output = scheduler.step(pred_velocity, t, current_latent)
        current_latent = step_output.prev_sample

    # D.4) Invert the VAE scaling to get the final “clean” latent
    final_latent = current_latent / vae.config.scaling_factor

    # D.5) Decode with VAE
    with torch.no_grad():
        decoded = vae.decode(final_latent).sample  # (1,3,1,H,W)
        decoded_image = decoded.squeeze(2)          # (1,3,H,W)

    return decoded_image
