In [1]:
%env CUDA_VISIBLE_DEVICES=3
import yaml, json
import torch
import os.path as osp
import matplotlib.pyplot as plt

from nphm_tum import env_paths as mono_env_paths
from nphm_tum.models.neural3dmm import construct_n3dmm, load_checkpoint

from utils.render import render
from utils.pipeline import get_optimal_params_color
from utils.pointcloud_fitting import get_latent_from_points

env: CUDA_VISIBLE_DEVICES=3


  from .autonotebook import tqdm as notebook_tqdm


ANCHORS HAVE SHAPE:  torch.Size([1, 1, 65, 3])
creating DeepSDF with...
lat dim 116
hidden_dim 400
Creating DeepSDF with input dim f119, hidden_dim f400 and output_dim 5
Loaded checkpoint from: /home/schmid/Text2Head/MonoNPHM/new_weights_mono//checkpoints/checkpoint_epoch_2500.tar
ANCHORS HAVE SHAPE:  torch.Size([1, 1, 65, 3])
creating DeepSDF with...
lat dim 116
hidden_dim 400
Creating DeepSDF with input dim f119, hidden_dim f400 and output_dim 5
Loaded checkpoint from: /home/schmid/Text2Head/MonoNPHM/new_weights_mono//checkpoints/checkpoint_epoch_2500.tar


In [2]:
grad_vars = ['geo']
weight_dir_shape = mono_env_paths.EXPERIMENT_DIR_REMOTE + '/'
fname_shape = weight_dir_shape + 'configs.yaml'
with open(fname_shape, 'r') as f:
    CFG = yaml.safe_load(f)

# load participant IDs that were used for training
fname_subject_index = f"{weight_dir_shape}/subject_train_index.json"
with open(fname_subject_index, 'r') as f:
    subject_index = json.load(f)

# load expression indices that were used for training
fname_subject_index = f"{weight_dir_shape}/expression_train_index.json"
with open(fname_subject_index, 'r') as f:
    expression_index = json.load(f)


device = torch.device("cuda")
modalities = ['geo', 'exp', 'app']
n_lats = [len(subject_index), len(expression_index), len(subject_index)]

neural_3dmm, latent_codes = construct_n3dmm(
    cfg=CFG,
    modalities=modalities,
    n_latents=n_lats,
    device=device,
    neutral_only=False,
    include_color_branch=True,
    skip_exp_grads= ('exp' not in grad_vars)
)

# load checkpoint from trained NPHM model, including the latent codes
ckpt_path = osp.join(weight_dir_shape, 'checkpoints/checkpoint_epoch_2500.tar')
load_checkpoint(ckpt_path, neural_3dmm, latent_codes)
    
def sdf(sdf_inputs, lat_geo, lat_exp, lat_app):
    dict_in = {
        "queries":sdf_inputs
    }

    cond = {
        "geo": torch.reshape(lat_geo, (1, 1, -1)),
        "exp": torch.reshape(lat_exp, (1, 1, -1)),
        "app": torch.reshape(lat_app, (1, 1, -1))
    }
    dict_out = neural_3dmm(dict_in, cond)
    return dict_out["sdf"], dict_out["color"]

geo_mean = latent_codes.codebook['geo'].embedding.weight.mean(dim=0)
geo_std = latent_codes.codebook['geo'].embedding.weight.std(dim=0)
exp_mean = latent_codes.codebook['exp'].embedding.weight.mean(dim=0)
exp_std = latent_codes.codebook['exp'].embedding.weight.std(dim=0)
app_mean = latent_codes.codebook['app'].embedding.weight.mean(dim=0)
app_std = latent_codes.codebook['app'].embedding.weight.std(dim=0)

ANCHORS HAVE SHAPE:  torch.Size([1, 1, 65, 3])
creating DeepSDF with...
lat dim 116
hidden_dim 400
Creating DeepSDF with input dim f119, hidden_dim f400 and output_dim 5
Loaded checkpoint from: /home/schmid/Text2Head/MonoNPHM/new_weights_mono//checkpoints/checkpoint_epoch_2500.tar


In [4]:
lat_rep_gt = torch.load('lat_rep_3_100_3.pt')

hparams = {
        'resolution': 500,
        'n_iterations': 20,
        'optimizer_lr': 6e-4, 
        'lr_scheduler_factor': 0.5, #0.9
        'lr_scheduler_patience': 3, #5
        'lr_scheduler_min_lr': 1e-6,
    }

camera_params, phong_params, light_params = get_optimal_params_color(hparams)

optimized_lat_rep = get_latent_from_points(lat_rep_gt, hparams, camera_params)

print('Ground Truth Image')
lat_rep_gt = [tensor.to(device) for tensor in lat_rep_gt]
image_gt = render(sdf, lat_rep_gt, camera_params, phong_params, light_params, color=True, mesh_path=None, model_grads=[])
plt.imshow(image_gt.detach().numpy())
plt.axis('off')  # Turn off axes
plt.show()

print('Optimized Image')
image = render(sdf, optimized_lat_rep, camera_params, phong_params, light_params, color=True, mesh_path=None, model_grads=[])
plt.imshow(image.detach().numpy())
plt.axis('off')  # Turn off axes
plt.show()