In [None]:
from NPHM.models.deepSDF import DeepSDF, DeformationNetwork
from NPHM.models.EnsembledDeepSDF import FastEnsembleDeepSDFMirrored
from NPHM import env_paths
from NPHM.utils.reconstruction import create_grid_points_from_bounds, mesh_from_logits
from NPHM.models.reconstruction import deform_mesh, get_logits, get_logits_backward
from NPHM.models.fitting import inference_iterative_root_finding_joint, inference_identity_space
from NPHM.data.manager import DataManager

import numpy as np
import argparse
import json, yaml
import os
import os.path as osp
import torch
import pyvista as pv
import matplotlib.pyplot as plt

In [None]:
resolution = 35

with open('NPHM/scripts/configs/fitting_nphm.yaml', 'r') as f:
    print('Loading config file from: ' + 'scripts/configs/fitting_nphm.yaml')
    CFG = yaml.safe_load(f)

print(json.dumps(CFG, sort_keys=True, indent=4))

weight_dir_shape = env_paths.EXPERIMENT_DIR + '/{}/'.format(CFG['exp_name_shape'])

# load config files
fname_shape = weight_dir_shape + 'configs.yaml'
with open(fname_shape, 'r') as f:
    print('Loading config file from: ' + fname_shape)
    CFG_shape = yaml.safe_load(f)

In [None]:
device = torch.device("cpu")

In [None]:
print('###########################################################################')
print('####################     Shape Model Configs     #############################')
print('###########################################################################')
print(json.dumps(CFG_shape, sort_keys=True, indent=4))

lm_inds = np.load(env_paths.ANCHOR_INDICES_PATH)
anchors = torch.from_numpy(np.load(env_paths.ANCHOR_MEAN_PATH)).float().unsqueeze(0).unsqueeze(0).to(device)


In [None]:
decoder_shape = FastEnsembleDeepSDFMirrored(
        lat_dim_glob=CFG_shape['decoder']['decoder_lat_dim_glob'],
        lat_dim_loc=CFG_shape['decoder']['decoder_lat_dim_loc'],
        hidden_dim=CFG_shape['decoder']['decoder_hidden_dim'],
        n_loc=CFG_shape['decoder']['decoder_nloc'],
        n_symm_pairs=CFG_shape['decoder']['decoder_nsymm_pairs'],
        anchors=anchors,
        n_layers=CFG_shape['decoder']['decoder_nlayers'],
        pos_mlp_dim=CFG_shape['decoder'].get('pos_mlp_dim', 256),
    )

decoder_shape = decoder_shape.to(device)

path = osp.join(weight_dir_shape, 'checkpoints/checkpoint_epoch_{}.tar'.format(CFG['checkpoint_shape']))
print('Loaded checkpoint from: {}'.format(path))
checkpoint = torch.load(path, map_location=device)
decoder_shape.load_state_dict(checkpoint['decoder_state_dict'], strict=True)

if 'latent_codes_state_dict' in checkpoint:
    n_train_subjects = checkpoint['latent_codes_state_dict']['weight'].shape[0]
    n_val_subjects = checkpoint['latent_codes_val_state_dict']['weight'].shape[0]
    latent_codes_shape = torch.nn.Embedding(n_train_subjects, 512)
    latent_codes_shape_val = torch.nn.Embedding(n_val_subjects, 512)
    
    latent_codes_shape.load_state_dict(checkpoint['latent_codes_state_dict'])
    latent_codes_shape_val.load_state_dict(checkpoint['latent_codes_val_state_dict'])
else:
    print('no latent codes in state dict')
    latent_codes_shape = None
    latent_codes_shape_val = None

decoder_expr = None

In [None]:
lat_mean = torch.from_numpy(np.load(env_paths.ASSETS + 'nphm_lat_mean.npy'))
lat_std = torch.from_numpy(np.load(env_paths.ASSETS + 'nphm_lat_std.npy'))

lat_rep = (torch.randn(lat_mean.shape) * lat_std * 0.85 + lat_mean)
print(lat_rep.shape) #40*32+64

mini = [-.55, -.5, -.95]
maxi = [0.55, 0.75, 0.4]

grid_points = create_grid_points_from_bounds(mini, maxi, resolution)
print(grid_points)
grid_points = torch.from_numpy(grid_points).to(device, dtype=torch.float)
grid_points = torch.reshape(grid_points, (1, len(grid_points), 3)).to(device)
print(grid_points.shape)

In [None]:
logits = get_logits(decoder_shape, lat_rep, grid_points, nbatch_points=100)
print('starting mcubes')

In [None]:
mesh = mesh_from_logits(logits, mini, maxi, resolution)
print('done mcubes')

pl = pv.Plotter(off_screen=True)
pl.add_mesh(mesh)
pl.reset_camera()
pl.camera.position = (0, 0, 3)
pl.camera.zoom(1.4)
pl.set_viewup((0, 1, 0)) #vertical direction of camera = +Y axis
pl.camera.view_plane_normal = (-0, -0, 1) #camera is looking at XY plane
pl.show()
#pl.show(screenshot=out_dir + '/step_{:04d}.png'.format(step))
#mesh.export(out_dir + '/mesh_{:04d}.ply'.format(step))
print(pl.camera)

In [None]:
from render import render

# Define rendering parameters
camera_position = torch.tensor([0.0, 0.0, 2.0])
max_ray_length = 4.

# Define phong model constants
ambient_coeff = 0.1
diffuse_coeff = 0.6
specular_coeff = 0.3
shininess = 32.0

# Define light inputs
light_position = torch.tensor([2.0, 1.0, 3.0])

def sdf_nphm(positions):
    nphm_input = torch.reshape(positions, (1, -1, 3))
    distance, _ = decoder_shape(nphm_input, torch.reshape(lat_rep, (1, 1, -1)), None)
    return distance.squeeze()

In [None]:
image = render(sdf_nphm, 50, camera_position, light_position, ambient_coeff, diffuse_coeff, specular_coeff, shininess)

In [None]:
# Define rendering parameters
res = 50
camera_position = torch.tensor([0.0, 0.0, 3.0])
max_ray_length = 4 - 2.3

# Define phong model constants
ambient_coeff = 0.1
diffuse_coeff = 0.6
specular_coeff = 0.3
shininess = 32.0

# Define light inputs
light_position = torch.tensor([2.0, 1.0, 3.0])

# Create an empty image
image = torch.zeros((res, res, 3))
#image = 0.01 * torch.ones((res, res, 3))

def phong_model(normal, light_dir, view_dir):
    # Normalize all vectors
    normal = normal / torch.norm(normal, dim=-1)
    light_dir = light_dir / torch.norm(light_dir, dim=-1)
    view_dir = view_dir / torch.norm(view_dir, dim=-1)
    
    ambient = ambient_coeff
    diffuse = diffuse_coeff * torch.clamp(torch.sum(light_dir * normal, dim=-1), min=0.0)
    reflect_dir = light_dir - 2 * normal * torch.clamp(torch.sum(light_dir * normal, dim=-1), min=0.0)
    specular = specular_coeff * torch.pow(torch.clamp(torch.sum(reflect_dir * view_dir, dim=-1), min=0.0), shininess)

    return ambient + diffuse + specular

def estimate_normal(sdf, point, epsilon=1e-3):
    # Calculate the SDF value at the given point
    sdf_value = sdf(point)

    # Calculate SDF values at neighboring points
    sdf_dx = sdf(point + torch.tensor([epsilon, 0, 0]))
    sdf_dy = sdf(point + torch.tensor([0, epsilon, 0]))
    sdf_dz = sdf(point + torch.tensor([0, 0, epsilon]))

    # Calculate the gradient using finite differences
    gradient = torch.tensor([sdf_dx - sdf_value, sdf_dy - sdf_value, sdf_dz - sdf_value])

    # Normalize the gradient to obtain the estimated normal
    normal = gradient / torch.norm(gradient, p=2)
    
    return normal

def sdf_sphere(position, radius=0.75):
    return torch.norm(position, dim=-1) - radius

def sdf_nphm(position):
    position = position.unsqueeze(0).unsqueeze(0) # [1, N, 3], lat_rep [lat_dim]
    distance, _ = decoder_shape(position, lat_rep.repeat(1, position.shape[1], 1), None)
    return distance

def ray_march(camera_position, direction, max_length):
    position = camera_position + 2.3 * direction
    step_size = 0.01

    for _ in range(int(max_length / step_size)):
        #distance = sdf_sphere(position)  # Replace with your SDF function
        distance = sdf_nphm(position)
        if distance < 0.01:
            return position  # Ray hits the surface

        position += step_size * direction

    return None  # Ray misses the scene

# Rendering loop
for v in range(res):
    for u in range(res):
        # Normalize the xy value of the current pixel [-1, 1]
        u_norm = (2.0 * (u + 0.5) / res - 1.0)
        v_norm = 1.0 - 2.0 * (v + 0.5) / res
        u_norm = torch.tensor([u_norm])
        v_norm = torch.tensor([v_norm])
         # Calculate the ray direction for the current pixel
        direction_unn = torch.tensor([u_norm, v_norm, -3.0])
        direction = direction_unn / torch.norm(direction_unn, dim=-1)

        # Perform ray marching
        hit_position = ray_march(camera_position, direction, max_ray_length)

        # Color the pixel based on whether the ray hits an object
        if hit_position is not None:
            normal = estimate_normal(sdf_sphere, hit_position)
            light_dir = - (hit_position - light_position) # umdrehen, damit L*V >0
            view_dir = - (camera_position - hit_position) # umdrehen, damit dot product nicht kleienr null?
            reflection = phong_model(normal, light_dir, view_dir)
            # Assign a color for objects
            image[v, u] = reflection * torch.tensor([1.0, 1.0, 1.0])


In [None]:
# Display the image using Matplotlib
plt.imshow(image.detach().numpy())
plt.axis('off')  # Turn off axes
plt.show()
print(image[:, :, 0])