In [None]:
import argparse
import os

os.chdir("..")

import math
import yaml
import logging
import random
import numpy as np
import sys
import imageio
import torch

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from accelerate.logging import get_logger

from data.VideoDataset import VideoDataset 
from torch.utils.data import DataLoader, DistributedSampler

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler
from model.cap_transformer import CAPVideoXTransformer3DModel

from inference._inference_pipeline import *
from data.VideoDataset import *
from model.flameObj import *
from wf_vae.model import *
from model.causalvideovae.model import *
import trimesh

os.getcwd()

'/nfs/horai.dgpsrv/ondemand28/harryscz/diffusion'

In [2]:
flamePath = flamePath = "/scratch/ondemand28/harryscz/head_audio/head/code/flame/flame2023_no_jaw.npz"
sourcePath = "/scratch/ondemand28/harryscz/head_audio/head/data/vfhq-fit"
dataPath = [os.path.join(os.path.join(sourcePath, data), "fit.npz") for data in os.listdir(sourcePath)]
seqPath = "/scratch/ondemand28/harryscz/head/_-91nXXjrVo_00/fit.npz"

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
head = Flame(flamePath, device="cuda")
head.loadSequence(dataPath[3])
seq = head.LSB(rotation=False)
uvMesh = head.convertUV()
uv = head.get_uv_animation(uvMesh, savePath="flameOut/a.mp4")
sampled = head.sampleFromUV(uv.squeeze(0), "flameOut/b.mp4")
final = head.sampleTo3D(sampledUV=sampled, savePath="flameOut/c.mp4", dist=1.2)



In [6]:
i = 10
uvs = head.batch_uv(dataPath[i:i+1], resolution=256, sample_frames=100, rotation=False)
uvs.permute(0, 1, 4, 2, 3).shape

torch.Size([1, 100, 3, 256, 256])

In [8]:
# vae = AutoencoderKLCogVideoX.from_pretrained(
#     "/scratch/ondemand28/harryscz/model/CogVideoX-2b", subfolder="vae"
# )
# vae.to("cuda")

model_cls = ModelRegistry.get_model("WFVAE")
vae = model_cls.from_pretrained("/scratch/ondemand28/harryscz/other/WF-VAE/weight")

ckpt = torch.load("/scratch/ondemand28/harryscz/diffusion/modelOut/vae/WFVAE-lr1.00e-05-bs1-rs256-sr1-fr25/checkpoint-44000.ckpt", map_location="cpu")
vae.load_state_dict(ckpt["state_dict"],  strict=False)

print("WF VAE checkpoint loaded!")

  ckpt = torch.load("/scratch/ondemand28/harryscz/diffusion/modelOut/vae/WFVAE-lr1.00e-05-bs1-rs256-sr1-fr25/checkpoint-44000.ckpt", map_location="cpu")


WF VAE checkpoint loaded!


In [None]:
def encode_video(vae, video):
    video = video.to(vae.device, dtype=vae.dtype)
    video = video.permute(0, 2, 1, 3, 4)  # [B, C, F, H, W]
    with torch.no_grad(): latent_dist = vae.encode(video).latent_dist.sample() 
    return latent_dist.permute(0, 2, 1, 3, 4).to(memory_format=torch.contiguous_format)

def decode_latents(vae, latents: torch.Tensor) -> torch.Tensor:
    latents = latents.permute(0, 2, 1, 3, 4)  # [batch_size, num_channels, num_frames, height, width]
    latents = ·latents

    with torch.no_grad(): frames = vae.decode(latents).sample
    return frames

In [92]:
z = encode_video(vae, uvs.permute(0,1,4,2,3).to("cuda"))
y = decode_latents(vae, z).squeeze(0).permute(1,2,3,0)

sampled = head.sampleFromUV(uvs.squeeze(0), "flameOut/__b.mp4")
final = head.sampleTo3D(sampledUV=sampled, savePath="flameOut/__c.mp4", dist=1.2)

faces = head.faces3d.cpu().numpy()
verts = final[0, ...].cpu().numpy()

trimesh.Trimesh(vertices=verts, faces=faces, process=True).export("flameOut/__i.ply")

b'ply\nformat binary_little_endian 1.0\ncomment https://github.com/mikedh/trimesh\nelement vertex 3910\nproperty float x\nproperty float y\nproperty float z\nelement face 9976\nproperty list uchar int vertex_indices\nend_header\n\n\x08\x1f\xbeP\xa8\xa1\xbe\x06M\xdb\xbfV\xb0\x1e\xbe$\xa2\xa0\xbeR?\xdb\xbf\xa5\x98 \xbe\xb5%\x9f\xbel\x9f\xdc\xbf\x8c?\x1d\xbe\x90\x14\x9e\xbeH.\xdc\xbfh\x86\x1b\xbe)\x9f\x9b\xbe\xe1\xd8\xdb\xbfF\x02\x1f\xbe\x0b;\x9c\xbe\x048\xdc\xbf\x84\x15\\\xbe\xdaq\xbe\xbeB^\xe7\xbf\xd0\xd2r\xbeb\xa8\xbd\xbe\t1\xe9\xbf\xce\xe3u\xbe&\x1c\xc6\xbe\x0cn\xe8\xbf\xce\xd1a\xbee\xdb\xc7\xbe\xe8\xde\xe6\xbf\xdc\xcdJ\xbe<\xcd\xd9\xbe\xcc\x87\xe2\xbfK\x1dP\xbe\t\xdd\xd3\xbe\x1a\xb4\xe3\xbf\xe2*U\xbe\x81\x94\xdd\xbe\xe6\x89\xe4\xbfj\x14J\xbe\x12\xf7\xdf\xbe\x19j\xe2\xbf\xca\\X\xbeU\x81f\xbe\xc4p\xcb\xbff\xf8N\xbe\xa3\x80g\xbefa\xcc\xbfs\xd7O\xbe:\xc5p\xbe\x9c\xc8\xcb\xbf\x0e\xd5\xc0\xbe\x8b5\xa0\xbe\xf8\x0c\xdc\xbf\x8e\xa7\xc0\xbe\xa45\xa1\xbeJ\x19\xdc\xbf\x99\xfe\xbe\xbe\xa5\x8e\x9e