In [1]:
import yaml
from data_processing.dataset import FaceDataset
from flame.flame_model_creation import FlameModelCreation
from pytorch3d.io import load_obj
import numpy as np
import pytorch3d.structures
import torch
import flame.flame as imfp
import flame.lbs as lbs
from dreifus.matrix import Pose, Intrinsics
import trimesh
import pytorch3d
from pytorch3d.ops import sample_points_from_meshes
import numpy as np
import k3d
import matplotlib.cm as cm
from gaussian_splatting.utils.sh_utils import SH2RGB
from PIL import Image


def create_flame_model(tracked_flame_params_path, config):
    flame_model = imfp.FlameHead(shape_params=300, expr_params=100)
    flame_data = np.load(tracked_flame_params_path)
    
    shape_params = torch.tensor(np.tile(flame_data['shape'], (128, 1)))
    expression_params = torch.tensor(flame_data['expression'])
    pose_params=torch.tensor(FlameModelCreation._create_pose_param(flame_data['jaw']))
    neck_pose = torch.tensor(flame_data['neck'])
    eye_pose = torch.tensor(flame_data['eyes'])
    jaw = torch.tensor(flame_data['jaw'])
    
    # FLAME MODEL
    i=0 #timestep
    flame_vertices, flame_lms = flame_model.forward3(
        shape=shape_params[[i]],  # We always assume the same shape params for all timesteps
        expr=expression_params[[i]],
        rotation=None,#rotation[[i]],
        neck=neck_pose[[i]], 
        jaw=jaw[[i]],           
        eyes= eye_pose[[i]],
        translation=None,
        pose_params=pose_params[[i]]
    )
    
    # model transformation
    flame_faces = flame_model.faces
    #flame_vertices, flame_faces = upsample_flame(flame_vertices, flame_model.faces, config)
    flame_vertices = flame_vertices.squeeze()
    flame_faces = flame_faces.squeeze()

    flame_face_midpoints =  flame_vertices[flame_faces].mean(dim=1)
    return flame_vertices.cpu().numpy(), flame_faces.cpu().numpy(), flame_face_midpoints.cpu().numpy()

def upsample_flame(flame_vertices, flame_faces, config):
    if not config['data']['upsample_flame_iterations']:
        return flame_vertices
    upsample_flame_iterations = config['data']['upsample_flame_iterations']
    if len(flame_faces.shape) == 2:
        flame_faces = flame_faces.unsqueeze(0)
    p3d_mesh = pytorch3d.structures.Meshes(flame_vertices, flame_faces)
    p3d_mesh_subdivision = pytorch3d.ops.SubdivideMeshes()
    
    for _ in range(upsample_flame_iterations):
        p3d_mesh = p3d_mesh_subdivision(p3d_mesh)
        
    vertices = p3d_mesh.verts_list()[0]
    faces = p3d_mesh.faces_list()[0]
    return vertices.unsqueeze(0), faces.unsqueeze(0)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


  warn(


In [2]:
config = yaml.safe_load(open('configs/settings_final.yaml', 'r'))
val_dataset = FaceDataset(config, data_split='val')
idx = 1

def load_obj_vertices(filename):
    vertices = []
    
    with open(filename, 'r') as file:
        for line in file:
            if line.startswith('v '):  # Lines that define vertices start with 'v'
                parts = line.split()
                vertices.append([float(parts[1]), float(parts[2]), float(parts[3])])
    
    return np.array(vertices)

sequence = val_dataset.sequences[idx]

camera_params = val_dataset.get_camera_params(sequence)

images = val_dataset.get_images(sequence, camera_params)

flame_vertices, flame_faces, flame_midpoints = create_flame_model(sequence['flame_params_path'], config)

val_dataset.scale_and_rotate(flame_vertices, camera_params, sequence)


input_images = {serial: image for serial, image in images.items() if serial in config['data']['input_serials']}

gaussian_colors = val_dataset.get_projected_colors(sequence, camera_params, flame_vertices, input_images, max_dist_factor=1.0)

ENCODER: F
val self.face_ids ['059', '070', '368', '369', '370', '371', '372', '373', '374', '375']


UnboundLocalError: local variable 'flame_faces' referenced before assignment

In [None]:
colors = SH2RGB(gaussian_colors)
color_list = (colors*255).tolist()
color_map = [((int(r) << 16) | (int(g) << 8) | int(b)) for r, g, b in color_list]

# Initialize plot
plot = k3d.plot(grid_visible=False, height=1024)

# Create a mesh with colored faces
#plot += k3d.mesh(flame_vertices, flame_faces, colors=[0x888888]*flame_vertices.shape[0])
plot += k3d.points(flame_vertices, point_size=0.005, colors=[0x888888]*flame_vertices.shape[0])
#plot += k3d.points(flame_vertices, point_size=0.005, colors=color_map)

# Add the mesh to the plot
#plot += mesh

# Display the plot
plot.display()

Output()

In [None]:
# import torch

# def save_obj(vertices, faces, filename):
#     with open(filename, 'w') as f:
#         # Write vertices
#         for vertex in vertices:
#             f.write(f'v {vertex[0]} {vertex[1]} {vertex[2]}\n')
        
#         # Write faces
#         for face in faces:
#             # OBJ format uses 1-based indexing
#             f.write(f'f {face[0] + 1} {face[1] + 1} {face[2] + 1}\n')
            
# # Save to OBJ file
# save_obj(flame_vertices, flame_faces, 'output.obj')