In [1]:
%load_ext autoreload
%autoreload 2

import argparse
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '8'
os.chdir("..")

import math
import yaml
import logging
import random
import numpy as np
import sys
import imageio
import torch

os.getcwd()

'/nfs/horai.dgpsrv/ondemand28/harryscz/diffusion'

In [2]:
def parse_args(arg_list=None):
    parser = argparse.ArgumentParser(
        description="Unconditioned Video Diffusion Inference"
    )
    parser.add_argument(
        "--dataset-path", type=str, required=True,
        help="Directory containing input reference videos."
    )
    parser.add_argument(
        "--pretrained-model-name-or-path", type=str, required=True,
        help="Path or HF ID where transformer/vae/scheduler are stored."
    )
    parser.add_argument(
        "--checkpoint-path", type=str, required=True,
        help="Path to fine‐tuned checkpoint containing transformer state_dict."
    )
    parser.add_argument(
        "--output-dir", type=str, required=True,
        help="Where to write generated videos."
    )
    parser.add_argument(
        "--model-config", type=str, required=True,
        help="YAML file describing model params (height, width, num_reference, num_target, etc.)"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1,
        help="Batch size per device (usually 1 for inference)."
    )
    parser.add_argument(
        "--num-inference-steps", type=int, default=50,
        help="Number of reverse diffusion steps to run."
    )
    parser.add_argument(
        "--mixed-precision", type=str, default="bf16",
        help="Whether to run backbone in 'fp16', 'bf16', or 'fp32'."
    )
    parser.add_argument(
        "--seed", type=int, default=42,
        help="Random seed for reproducibility."
    )
    parser.add_argument(
        "--shuffle", type=int, default=False,
        help="Whether to shuffle dataset. Usually False for inference."
    )
    parser.add_argument(
        "--is-uncond", type=bool, default=False,
        help=""
    )
    parser.add_argument(
        "--sample-frames", type=int, default=50
    )

    # If arg_list is None, argparse picks up sys.argv; 
    # otherwise it treats arg_list as the full argv list.
    return parser.parse_args(arg_list)

args = [
    "--dataset-path", "/scratch/ondemand28/harryscz/head_audio/head/data/vfhq-fit",
    "--pretrained-model-name-or-path", "/scratch/ondemand28/harryscz/model/CogVideoX-2b",
    "--checkpoint-path",  "/scratch/ondemand28/harryscz/head_audio/trainOutput/checkpoint-6000.pt",
    "--output-dir",  "/scratch/ondemand28/harryscz/diffusion/videoOut",
    "--model-config",  "/scratch/ondemand28/harryscz/diffusion/train/model_config.yaml",
    "--batch-size",  "1",
    "--num-inference-steps",  "50",
    "--mixed-precision",  "no",
    "--seed",  "42",
    "--shuffle",  "0",
    "--sample-frames", "29"
]

args = parse_args(args)

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)

In [3]:
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from accelerate.logging import get_logger

with open(args.model_config, "r") as f: model_config = yaml.safe_load(f)
if args.mixed_precision.lower() == "fp16":
    dtype = torch.float16
elif args.mixed_precision.lower() == "bf16":
    dtype = torch.bfloat16
else:
    dtype = torch.float32

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir,
                                                    logging_dir=os.path.join(args.output_dir, "logs"))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(mixed_precision=args.mixed_precision,
                            project_config=accelerator_project_config,
                            kwargs_handlers=[ddp_kwargs])

# 2.4 Set random seed
if args.seed is not None:
    set_seed(args.seed + accelerator.process_index)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

logger = get_logger(__name__)
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info("Accelerator state:", accelerator.state)

from data.VideoDataset import *
from torch.utils.data import DataLoader, DistributedSampler

dataset = VideoPathDataset(
    source_dir=args.dataset_path,
)
if args.shuffle:
    sampler = DistributedSampler(
        dataset,
        num_replicas=accelerator.num_processes,
        rank=accelerator.process_index,
        shuffle=True
    )
else:
    sampler = None
data_loader = DataLoader(
    dataset,
    batch_size=2,
    # sampler=sampler,
    collate_fn=lambda x: x,   
    num_workers=2,
    pin_memory=True,
)
logger.info(f"Number of test examples: {len(data_loader)}")

--- Logging error ---
Traceback (most recent call last):
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 1083, in emit
    msg = self.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 927, in format
    return fmt.format(record)
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 663, in format
    record.message = record.getMessage()
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/logging/__init__.py", line 367, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/scratch/ondemand28/harryscz/anaconda3/envs/pytorch3d/lib/python3.9/runpy.py", line 87, in

In [4]:
def encode_video(vae, video, grad=False):
    video = video.to(accelerator.device, dtype=vae.dtype)
    video = video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]

    if grad:
        latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
    else: 
        with torch.no_grad(): latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
    return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)

def decode_latents(vae, latents: torch.Tensor, grad=False) -> torch.Tensor:
    latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
    latents = 1 / vae.config.scaling_factor * latents

    if grad:
        rames = vae.decode(latents).sample
    else:
        with torch.no_grad(): frames = vae.decode(latents).sample
    return frames.permute(0,2,1,3,4)

In [5]:
device = "cuda"
dtype = torch.float32

from diffusers import AutoencoderKLCogVideoX

vae = AutoencoderKLCogVideoX.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae"
)
vae.eval().to(dtype)

AutoencoderKLCogVideoX(
  (encoder): CogVideoXEncoder3D(
    (conv_in): CogVideoXCausalConv3d(
      (conv): CogVideoXSafeConv3d(3, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    )
    (down_blocks): ModuleList(
      (0): CogVideoXDownBlock3D(
        (resnets): ModuleList(
          (0-2): 3 x CogVideoXResnetBlock3D(
            (nonlinearity): SiLU()
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): CogVideoXCausalConv3d(
              (conv): CogVideoXSafeConv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
            )
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): CogVideoXCausalConv3d(
              (conv): CogVideoXSafeConv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
            )
          )
        )
        (downsamplers): ModuleList(
          (0): CogVideoXDownsample3D(
     

In [6]:
vae, data_loader = accelerator.prepare(vae, data_loader)

In [7]:
from model.flameObj import *

flamePath = "/scratch/ondemand28/harryscz/head_audio/head/code/flame/flame2023_no_jaw.npz"
sourcePath = "/scratch/ondemand28/harryscz/head_audio/head/data/vfhq-fit"
dataPath = [os.path.join(os.path.join(sourcePath, data), "fit.npz") for data in os.listdir(sourcePath)]
seqPath = "/scratch/ondemand28/harryscz/head/_-91nXXjrVo_00/fit.npz"

head = Flame(flamePath, device="cuda")



In [8]:
batch = next(iter(data_loader))
batch

['/scratch/ondemand28/harryscz/head_audio/head/data/vfhq-fit/g1eIAelVFq4_02/fit.npz',
 '/scratch/ondemand28/harryscz/head_audio/head/data/vfhq-fit/jJ62WENxR78_01/fit.npz']

In [9]:
def encode_video(vae, video, grad=False):
    video = video.to(accelerator.device, dtype=vae.dtype)
    video = video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]

    if grad:
        latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
    else: 
        with torch.no_grad(): latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
    return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)

def decode_latents(vae, latents: torch.Tensor, grad=False) -> torch.Tensor:
    latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
    latents = 1 / vae.config.scaling_factor * latents

    if grad:
        frames = vae.decode(latents).sample
    else:
        with torch.no_grad(): frames = vae.decode(latents).sample
    return frames.permute(0,2,1,3,4)

In [10]:
uvs = head.batch_uv(batch[:1], resolution=256, rotation=False, sample_frames=29).permute(0,1,4,2,3) 

uvs = uvs.to(accelerator.device, dtype=vae.dtype)
uvs = uvs.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
z = vae.encode(uvs)
latents = z.latent_dist.sample() * vae.config.scaling_factor
latents = 1 / vae.config.scaling_factor * latents
recon = vae.decode(latents).sample

OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 47.52 GiB of which 32.25 MiB is free. Including non-PyTorch memory, this process has 47.40 GiB memory in use. Of the allocated memory 46.28 GiB is allocated by PyTorch, and 324.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
def compute_kl_loss(z):
    mu = z.latent_dist.mean
    logvar = z.latent_dist.logvar
    # kl per latent unit
    kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    return kl

kl = compute_kl_loss(z).sum()
l1 = F.l1_loss(uvs, recon, reduction='mean')

tensor(146.7351, device='cuda:0', grad_fn=<MulBackward0>)

In [None]:
z = encode_video(vae, uvs)

In [37]:
y = decode_latents(vae, z)

In [48]:
from tqdm import tqdm

In [None]:
epoch = 20
batch_size = 2
train_steps = epoch * (len(dataset) // batch_size)

progress_bar = tqdm(
    range(0, train_steps),
    initial=1,
    desc="Steps",
)

for ep in epoch:
    for i, batch in enumerate(data_loader):
        uvs = head.batch_uv(batch, resolution=256, rotation=False, sample_frames=150).permute(0,1,4,2,3) 

        uvs = uvs.to(accelerator.device, dtype=vae.dtype)
        uvs = uvs.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
        latent_dist = vae.encode(uvs).latent_dist.sample() * vae.config.scaling_factor
        
        y = decode_latents(vae, z)

        recon_loss = F.mse_loss(recon_x, x)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.numel()
        loss = recon_loss + beta * kl_loss

OutOfMemoryError: CUDA out of memory. Tried to allocate 320.00 MiB. GPU 0 has a total capacity of 47.52 GiB of which 190.25 MiB is free. Including non-PyTorch memory, this process has 47.24 GiB memory in use. Of the allocated memory 45.79 GiB is allocated by PyTorch, and 673.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [14]:
torch.cuda.empty_cache()

NameError: name 'uvs' is not defined