In [1]:
import argparse
import os

os.chdir("..")

import math
import yaml
import logging
import random
import numpy as np
import sys
import imageio
import torch

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from accelerate.logging import get_logger

from data.VideoDataset import VideoDataset 
from torch.utils.data import DataLoader, DistributedSampler

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from model.cap_transformer import CAPVideoXTransformer3DModel

from inference._inference_pipeline import *
from data.VideoDataset import *
from model.flameObj import *

import trimesh

os.getcwd()

'/nfs/horai.dgpsrv/ondemand28/harryscz/diffusion'

In [None]:
flamePath = flamePath = "/scratch/ondemand28/harryscz/head_audio/head/code/flame/flame2023_no_jaw.npz"
sourcePath = "/scratch/ondemand28/harryscz/head_audio/head/data/vfhq-fit"
dataPath = [os.path.join(os.path.join(sourcePath, data), "fit.npz") for data in os.listdir(sourcePath)]
seqPath = "/scratch/ondemand28/harryscz/head/_-91nXXjrVo_00/fit.npz"

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [59]:
head = Flame(flamePath, device="cuda")
head.loadSequence(dataPath[3])
seq = head.LSB(rotation=False)
uvMesh = head.convertUV()
uv = head.get_uv_animation(uvMesh, savePath="flameOut/a.mp4")
sampled = head.sampleFromUV(uv.squeeze(0), "flameOut/b.mp4")
final = head.sampleTo3D(sampledUV=sampled, savePath="flameOut/c.mp4", dist=1.2)



In [60]:
faces = head.faces3d.cpu().numpy()
verts = final[10, ...].cpu().numpy()

trimesh.Trimesh(vertices=verts, faces=faces, process=True).export("flameOut/i.ply")

b'ply\nformat binary_little_endian 1.0\ncomment https://github.com/mikedh/trimesh\nelement vertex 3908\nproperty float x\nproperty float y\nproperty float z\nelement face 9976\nproperty list uchar int vertex_indices\nend_header\n)\x9b\x1e\xbe9g\xa1\xbe"_\xdb\xbf\xd5I\x1e\xbe\xf7_\xa0\xbe\x1aR\xdb\xbf\xff& \xbem\xf1\x9e\xbe\xf6\xb4\xdc\xbf\x12\xd8\x1c\xbe\xea\xd9\x9d\xbefD\xdc\xbf\xb5-\x1b\xbe\xd4^\x9b\xbe|\xf0\xdb\xbf\xe8\x9f\x1e\xbe\xbd\xff\x9b\xbejO\xdc\xbf\xbc\x95[\xbe\x9eJ\xbe\xbe\xbdc\xe7\xbfZqr\xbe\x17\xa2\xbd\xbe-1\xe9\xbf\xecBu\xbe\xe2\x02\xc6\xbe\x86l\xe8\xbf\xd8\xfb`\xbe:\xa0\xc7\xbe\xa7\xe3\xe6\xbf\xde&I\xbelq\xd9\xbe\xdb\x98\xe2\xbfn\xc7N\xbe\xd0\x8b\xd3\xbe~\xc0\xe3\xbfNhS\xbe>#\xdd\xbe\x90\x97\xe4\xbf"\x11H\xbe\xea\x89\xdf\xbe^}\xe2\xbf\xae\x8c]\xbec\x0eq\xbea\xf1\xca\xbf=\x0fT\xbe\xf6\x9bq\xbeT\xc9\xcb\xbf\xd3\xdeT\xbe\x0cj{\xbe \x18\xcb\xbf\xb1\x1a\xc1\xbe22\xa0\xbe\xfb\xfa\xdb\xbf\xac\xf3\xc0\xbe\xf94\xa1\xbe\xb8\x06\xdc\xbfSB\xbf\xbeN\x9b\x9e\xbe\xee\x83\xdd\xbf]\xf9\

In [None]:
uvs = head.batch_uv(dataPath[i:i+1], resolution=256, sample_frames=29, rotation=False)
uvs.permute(0, 1, 4, 2, 3).shape



torch.Size([1, 29, 256, 256, 3])

In [70]:
uvs = head.batch_uv(dataPath[i:i+1], resolution=256, sample_frames=29, rotation=False) # load UVs of shape B, F, C, H, W
ref_frame = uvs[:, 0, :, :, :].unsqueeze(1) # B, 1, C, H, W



In [71]:
sampled = head.sampleFromUV(uvs.squeeze(0), "flameOut/_b.mp4")
final = head.sampleTo3D(sampledUV=sampled, savePath="flameOut/_c.mp4", dist=1.2)

faces = head.faces3d.cpu().numpy()
verts = final[0, ...].cpu().numpy()

trimesh.Trimesh(vertices=verts, faces=faces, process=True).export("flameOut/_i.ply")

b'ply\nformat binary_little_endian 1.0\ncomment https://github.com/mikedh/trimesh\nelement vertex 3911\nproperty float x\nproperty float y\nproperty float z\nelement face 9976\nproperty list uchar int vertex_indices\nend_header\n\n\x08\x1f\xbeP\xa8\xa1\xbe\x06M\xdb\xbfV\xb0\x1e\xbe$\xa2\xa0\xbeR?\xdb\xbf\xa5\x98 \xbe\xb5%\x9f\xbel\x9f\xdc\xbf\x8c?\x1d\xbe\x90\x14\x9e\xbeH.\xdc\xbfh\x86\x1b\xbe)\x9f\x9b\xbe\xe1\xd8\xdb\xbfF\x02\x1f\xbe\x0b;\x9c\xbe\x048\xdc\xbf\x84\x15\\\xbe\xdaq\xbe\xbeB^\xe7\xbf\xd0\xd2r\xbeb\xa8\xbd\xbe\t1\xe9\xbf\xce\xe3u\xbe&\x1c\xc6\xbe\x0cn\xe8\xbf\xce\xd1a\xbee\xdb\xc7\xbe\xe8\xde\xe6\xbf\xdc\xcdJ\xbe<\xcd\xd9\xbe\xcc\x87\xe2\xbfK\x1dP\xbe\t\xdd\xd3\xbe\x1a\xb4\xe3\xbf\xe2*U\xbe\x81\x94\xdd\xbe\xe6\x89\xe4\xbfj\x14J\xbe\x12\xf7\xdf\xbe\x19j\xe2\xbf\xca\\X\xbeU\x81f\xbe\xc4p\xcb\xbff\xf8N\xbe\xa3\x80g\xbefa\xcc\xbfs\xd7O\xbe:\xc5p\xbe\x9c\xc8\xcb\xbf\x0e\xd5\xc0\xbe\x8b5\xa0\xbe\xf8\x0c\xdc\xbf\x8e\xa7\xc0\xbe\xa45\xa1\xbeJ\x19\xdc\xbf\x99\xfe\xbe\xbe\xa5\x8e\x9e

In [73]:
vae = AutoencoderKLCogVideoX.from_pretrained(
    "/scratch/ondemand28/harryscz/model/CogVideoX-2b", subfolder="vae"
)
vae.to("cuda")

AutoencoderKLCogVideoX(
  (encoder): CogVideoXEncoder3D(
    (conv_in): CogVideoXCausalConv3d(
      (conv): CogVideoXSafeConv3d(3, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    )
    (down_blocks): ModuleList(
      (0): CogVideoXDownBlock3D(
        (resnets): ModuleList(
          (0-2): 3 x CogVideoXResnetBlock3D(
            (nonlinearity): SiLU()
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): CogVideoXCausalConv3d(
              (conv): CogVideoXSafeConv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
            )
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): CogVideoXCausalConv3d(
              (conv): CogVideoXSafeConv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
            )
          )
        )
        (downsamplers): ModuleList(
          (0): CogVideoXDownsample3D(
     

In [86]:
def encode_video(vae, video):
    video = video.to(vae.device, dtype=vae.dtype)
    video = video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
    with torch.no_grad(): latent_dist = vae.encode(video).latent_dist.sample() * vae.config.scaling_factor
    return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)

def decode_latents(vae, latents: torch.Tensor) -> torch.Tensor:
    latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
    latents = 1 / vae.config.scaling_factor * latents

    with torch.no_grad(): frames = vae.decode(latents).sample
    return frames

In [92]:
z = encode_video(vae, uvs.permute(0,1,4,2,3).to("cuda"))
y = decode_latents(vae, z).squeeze(0).permute(1,2,3,0)

sampled = head.sampleFromUV(uvs.squeeze(0), "flameOut/__b.mp4")
final = head.sampleTo3D(sampledUV=sampled, savePath="flameOut/__c.mp4", dist=1.2)

faces = head.faces3d.cpu().numpy()
verts = final[0, ...].cpu().numpy()

trimesh.Trimesh(vertices=verts, faces=faces, process=True).export("flameOut/__i.ply")

b'ply\nformat binary_little_endian 1.0\ncomment https://github.com/mikedh/trimesh\nelement vertex 3910\nproperty float x\nproperty float y\nproperty float z\nelement face 9976\nproperty list uchar int vertex_indices\nend_header\n\n\x08\x1f\xbeP\xa8\xa1\xbe\x06M\xdb\xbfV\xb0\x1e\xbe$\xa2\xa0\xbeR?\xdb\xbf\xa5\x98 \xbe\xb5%\x9f\xbel\x9f\xdc\xbf\x8c?\x1d\xbe\x90\x14\x9e\xbeH.\xdc\xbfh\x86\x1b\xbe)\x9f\x9b\xbe\xe1\xd8\xdb\xbfF\x02\x1f\xbe\x0b;\x9c\xbe\x048\xdc\xbf\x84\x15\\\xbe\xdaq\xbe\xbeB^\xe7\xbf\xd0\xd2r\xbeb\xa8\xbd\xbe\t1\xe9\xbf\xce\xe3u\xbe&\x1c\xc6\xbe\x0cn\xe8\xbf\xce\xd1a\xbee\xdb\xc7\xbe\xe8\xde\xe6\xbf\xdc\xcdJ\xbe<\xcd\xd9\xbe\xcc\x87\xe2\xbfK\x1dP\xbe\t\xdd\xd3\xbe\x1a\xb4\xe3\xbf\xe2*U\xbe\x81\x94\xdd\xbe\xe6\x89\xe4\xbfj\x14J\xbe\x12\xf7\xdf\xbe\x19j\xe2\xbf\xca\\X\xbeU\x81f\xbe\xc4p\xcb\xbff\xf8N\xbe\xa3\x80g\xbefa\xcc\xbfs\xd7O\xbe:\xc5p\xbe\x9c\xc8\xcb\xbf\x0e\xd5\xc0\xbe\x8b5\xa0\xbe\xf8\x0c\xdc\xbf\x8e\xa7\xc0\xbe\xa45\xa1\xbeJ\x19\xdc\xbf\x99\xfe\xbe\xbe\xa5\x8e\x9e