In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from nphm_tum import env_paths as mono_env_paths
from nphm_tum.models.neural3dmm import construct_n3dmm, load_checkpoint
from utils.pipeline import get_image_clip_embedding, get_latent_from_text
import json, yaml

In [None]:
torch.cuda.empty_cache()
device = "cuda" if torch.cuda.is_available() else "cpu"

weight_dir_shape = mono_env_paths.EXPERIMENT_DIR_REMOTE + '/'
fname_shape = weight_dir_shape + 'configs.yaml'
with open(fname_shape, 'r') as f:
    CFG = yaml.safe_load(f)

# load participant IDs that were used for training
fname_subject_index = f"{weight_dir_shape}/subject_train_index.json"
with open(fname_subject_index, 'r') as f:
    subject_index = json.load(f)

# load expression indices that were used for training
fname_subject_index = f"{weight_dir_shape}/expression_train_index.json"
with open(fname_subject_index, 'r') as f:
    expression_index = json.load(f)

device = torch.device("cuda")
modalities = ['geo', 'exp', 'app']
n_lats = [len(subject_index), len(expression_index), len(subject_index)]

_, latent_codes = construct_n3dmm(
    cfg=CFG,
    modalities=modalities,
    n_latents=n_lats,
    device=device,
    include_color_branch=True
    )

def get_latent_mean():
    geo_mean = latent_codes.codebook['geo'].embedding.weight.mean(dim=0)
    geo_std = latent_codes.codebook['geo'].embedding.weight.std(dim=0)
    exp_mean = latent_codes.codebook['exp'].embedding.weight.mean(dim=0)
    exp_std = latent_codes.codebook['exp'].embedding.weight.std(dim=0)
    app_mean = latent_codes.codebook['app'].embedding.weight.mean(dim=0).detach()
    app_std = latent_codes.codebook['app'].embedding.weight.std(dim=0).detach()

    lat_rep = [geo_mean, exp_mean, app_mean]

    print('mean', app_mean.shape)
    print('std', app_std.shape)
    print('geo', geo_mean.shape)
    print('exp', exp_mean.shape)

    return lat_rep

Enter prompt here

In [None]:
prompt = 'Kate Winslet'

hparams = {
        'exp_name': 'test',
        'resolution': 180,
        'n_iterations': 50,
        'lambda_geo': 0.6,
        'lambda_app': 0.6,
        'gamma_geo': 0., 
        'gamma_app': 0.,
        'alpha': 0.1,
        'color_prob': 0.3,
        'optimizer_lr': 0.2,  
        'optimizer_lr_app': 0.2, 
        'optimizer_beta1': 0.9,
        'batch_size': 10,
        'lr_scheduler_factor': 0.95,
        'lr_scheduler_patience': 2, 
        'lr_scheduler_min_lr': 0.01,
    }

lat_mean = get_latent_mean()

In [None]:
best_latent, _ = get_latent_from_text(prompt, hparams, init_lat=lat_mean, CLIP_gt=None, DINO_gt=None)

Visualize results

In [None]:
resolution = 700

camera_params = {
            "camera_distance": 0.21 * 2.57,
            "camera_angle_rho": 0.,
            "camera_angle_theta": 0.,
            "focal_length": 2.57,
            "max_ray_length": 3.5,
            # Image
            "resolution_y": resolution,
            "resolution_x": resolution
        }

phong_params = {
            "ambient_coeff": 0.32,
            "diffuse_coeff": 0.85,
            "specular_coeff": 0.34,
            "shininess": 25,
            # Colors
            "background_color": torch.tensor([1., 1., 1.])
        }

light_params = {
            "amb_light_color": torch.tensor([0.65, 0.65, 0.65]),
            # light 1
            "light_intensity_1": 1.69,
            "light_color_1": torch.tensor([1., 1., 1.]),
            "light_dir_1": torch.tensor([0, -0.18, -0.8]),
            # light p
            "light_intensity_p": 0.52,
            "light_color_p": torch.tensor([1., 1., 1.]),
            "light_pos_p": torch.tensor([0.17, 2.77, -2.25])
    }

_, image = get_image_clip_embedding(best_latent, camera_params, phong_params, light_params, with_app_grad=False, color=True)
plt.imshow(image.detach().numpy())
plt.axis('off')  # Turn off axes
plt.show()
