In [None]:
import yaml, json
import torch
import os.path as osp

from nphm_tum import env_paths as mono_env_paths
from nphm_tum.models.neural3dmm import construct_n3dmm, load_checkpoint

from NPHM import env_paths
from NPHM.models.EnsembledDeepSDF import FastEnsembleDeepSDFMirrored

import numpy as np

from utils.render import render

device = "cuda"

In [None]:
with open('../NPHM/scripts/configs/fitting_nphm.yaml', 'r') as f:
    CFG = yaml.safe_load(f)
        
weight_dir_shape = env_paths.EXPERIMENT_DIR + '/{}/'.format(CFG['exp_name_shape'])
fname_shape = weight_dir_shape + 'configs.yaml'
with open(fname_shape, 'r') as f:
    CFG_shape = yaml.safe_load(f)
    
lm_inds = np.load(env_paths.ANCHOR_INDICES_PATH)
anchors = torch.from_numpy(np.load(env_paths.ANCHOR_MEAN_PATH)).float().unsqueeze(0).unsqueeze(0).to(device)
decoder_shape = FastEnsembleDeepSDFMirrored(
        lat_dim_glob=CFG_shape['decoder']['decoder_lat_dim_glob'],
        lat_dim_loc=CFG_shape['decoder']['decoder_lat_dim_loc'],
        hidden_dim=CFG_shape['decoder']['decoder_hidden_dim'],
        n_loc=CFG_shape['decoder']['decoder_nloc'],
        n_symm_pairs=CFG_shape['decoder']['decoder_nsymm_pairs'],
        anchors=anchors,
        n_layers=CFG_shape['decoder']['decoder_nlayers'],
        pos_mlp_dim=CFG_shape['decoder'].get('pos_mlp_dim', 256),
    )
decoder_shape = decoder_shape.to(device)
path = osp.join(weight_dir_shape, 'checkpoints/checkpoint_epoch_{}.tar'.format(CFG['checkpoint_shape']))
model_checkpoint = torch.load(path, map_location=device)
decoder_shape.load_state_dict(model_checkpoint['decoder_state_dict'], strict=True)

def sdf(sdf_inputs, lat_rep):
    lat_rep_in = torch.reshape(lat_rep[0], (1, 1, -1))
    return decoder_shape(sdf_inputs, lat_rep_in, None)[0]

In [None]:
weight_dir_shape = mono_env_paths.EXPERIMENT_DIR_REMOTE + '/'
fname_shape = weight_dir_shape + 'configs.yaml'
with open(fname_shape, 'r') as f:
    CFG = yaml.safe_load(f)

    # load participant IDs that were used for training
    fname_subject_index = f"{weight_dir_shape}/subject_train_index.json"
    with open(fname_subject_index, 'r') as f:
        subject_index = json.load(f)

    # load expression indices that were used for training
    fname_subject_index = f"{weight_dir_shape}/expression_train_index.json"
    with open(fname_subject_index, 'r') as f:
        expression_index = json.load(f)


    # construct the NPHM models and latent codebook
    device = torch.device("cuda")
    neural_3dmm, latent_codes = construct_n3dmm(
                                  cfg = CFG,
                                  modalities=['geo', 'exp'],
                                  n_latents=[len(subject_index), len(expression_index)],
                                  device=device,
                                  )

    # load checkpoint from trained NPHM model, including the latent codes
    ckpt_path = osp.join(weight_dir_shape, 'checkpoints/checkpoint_epoch_6500.tar')
    load_checkpoint(ckpt_path, neural_3dmm, latent_codes)
    
    def mono_sdf(sdf_inputs, lat_rep):
        dict_in = {
            "queries":sdf_inputs
        }
        cond = {
            "geo": torch.reshape(lat_rep[0], (1, 1, -1)),
            "exp": torch.reshape(lat_rep[1], (1, 1, -1))
        }
        return neural_3dmm(dict_in, cond)["sdf"]

In [None]:
geo_mean = latent_codes.codebook['geo'].embedding.weight.mean(dim=0)
geo_std = latent_codes.codebook['geo'].embedding.weight.std(dim=0)
exp_mean = latent_codes.codebook['exp'].embedding.weight.mean(dim=0)
exp_std = latent_codes.codebook['exp'].embedding.weight.std(dim=0)

In [None]:
lat_rep = torch.stack([geo_mean, exp_mean], axis=0)

camera_params = {
        "camera_distance": 0.21 * 2.57,
        "camera_angle": 45.,
        "focal_length": 2.57,
        "max_ray_length": 3,
        # Image
        "resolution_y": 200,
        "resolution_x": 200
    }
phong_params = {
        "ambient_coeff": 0.51,
        "diffuse_coeff": 0.75,
        "specular_coeff": 0.64,
        "shininess": 0.5,
        # Colors
        "object_color": torch.tensor([0.53, 0.24, 0.64]),
        "background_color": torch.tensor([0.36, 0.77, 0.29])
    }

light_params = {
        "amb_light_color": torch.tensor([0.9, 0.16, 0.55]),
        # light 1
        "light_intensity_1": 1.42,
        "light_color_1": torch.tensor([0.8, 0.97, 0.89]),
        "light_dir_1": torch.tensor([-0.6, -0.4, -0.67]),
        # light p
        "light_intensity_p": 0.62,
        "light_color_p": torch.tensor([0.8, 0.97, 0.89]),
        "light_pos_p": torch.tensor([1.19, -1.27, 2.24])
    }

In [None]:
geo_mean.shape

In [None]:
lat_rep.shape

In [None]:
img = render(sdf, lat_rep, camera_params, phong_params, light_params)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(img)

In [None]:
in_dict = {
    "queries":torch.zeros((1, 6, 3)).cuda()
}

cond = {
    "geo": geo_mean,
    "exp": exp_mean
}

In [None]:
neural_3dmm(in_dict, cond)["sdf"]

In [None]:
res