This notebook will serve as a testbed for custom shape-targeting forward and backward development and testing.

Every cell done below would able to run independently after first import cell

In [None]:
# First, add a custom function to deformable.h and deformable.cpp that make sure compiles and can modify data
import sys
sys.path.append('../')
import os

from pathlib import Path
import time
import numpy as np
import scipy.optimize
import pickle
import matplotlib.pyplot as plt

from py_diff_pd.common.common import ndarray, create_folder
from py_diff_pd.common.common import print_info, print_ok, print_error, print_warning
from py_diff_pd.common.grad_check import check_gradients
from py_diff_pd.common.display import export_gif
from py_diff_pd.core.py_diff_pd_core import StdRealVector
from py_diff_pd.env.soft_starfish_env_3d import SoftStarfishEnv3d
from py_diff_pd.common.project_path import root_path
from py_diff_pd.core.py_diff_pd_core import HexMesh3d, HexDeformable, StdRealVector
import py_diff_pd.common.hex_mesh as hex

visualize

In [None]:
# visualize the hex mesh
from py_diff_pd.common.renderer import PbrtRenderer
def render_quasi_starfish(mesh_file, png_file):
    options = {
        'file_name': png_file,
        'light_map': 'uffizi-large.exr',
        'sample': 4,
        'max_depth': 2,
        'camera_pos': (2, 3, 5),
        'camera_lookat': (1, -1, 0), # roughly the center of starfish obj
        
    }
    renderer = PbrtRenderer(options)
    
    mesh = HexMesh3d()
    mesh.Initialize(mesh_file)
    renderer.add_hex_mesh(mesh, render_voxel_edge=True, color=(.3, .7, .5), transforms=[
        ('r', [90, 1, 0, 0]),  # Rotate 90 degrees around the x-axis
        ('t', [0, 0, 0]),
        ])
    # renderer.add_tri_mesh(Path(root_path) / 'asset/mesh/flat_ground.obj',
    #         texture_img='chkbd_24_0.7', transforms=[
    #             ('s', 4),
    #             ('t', [0, 0, -1]),
    #             ])
    
    
    renderer.render()   



ShapeTarget Forward test.

In [None]:

# global parameters
obj_num = 30
asset_folder = Path('/mnt/e/muscleCode/sample_muscle_data/starfish')
default_hex_bin_str = str(asset_folder / 'starfish_demo_voxel.bin')
gt_folder = Path('/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish/init_ground_truth')
gt_json = str(gt_folder)+ '/'+ 'starfish_' + str(obj_num) + '_init_ground_truth.json'
render_folder = Path('/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish')
render_bin_path = render_folder / str('starfish_voxel_'+str(obj_num)+'.bin')
render_bin_str = str(render_bin_path)
# Deformable param
youngs_modulus = 5e5
poissons_ratio = 0.45
la = youngs_modulus * poissons_ratio / ((1 + poissons_ratio) * (1 - 2 * poissons_ratio))
mu = youngs_modulus / (2 * (1 + poissons_ratio))

density = 1e3
thread_ct = 20
dt = 1e-2    
options = {
        'max_pd_iter': 1000,
        'thread_ct': 20,
        'abs_tol': 1e-6,
        'rel_tol': 1e-6,
        'verbose': 0,
        'use_bfgs': 1,
        'bfgs_history_size': 10,
        'max_ls_iter': 10,
        
    }
default_hex_mesh = HexMesh3d()
default_hex_mesh.Initialize(default_hex_bin_str)  
def get_idea_q(gt_json):
    # pass whole file into a list
    q_ideal = []
    with open(gt_json, 'r') as f:
        content = f.read()
        float_strings = content.strip('[]').split(',')
        q_ideal = [float(x) for x in float_strings]

    # get q_ideal
    deformable_default = HexDeformable()
    deformable_default.Initialize(default_hex_bin_str, density, 'none', youngs_modulus, poissons_ratio)
    act = StdRealVector(0)
    deformable_default.PyGetShapeTargetSMatrixFromDeformation(q_ideal, act)
    act = np.array(act)
    print(int(act.shape[0] // 48) == default_hex_mesh.NumOfElements())
    return act, q_ideal

def do_shape_targeting(act, q_ideal):
    q_ideal = np.array(q_ideal)
    deformable_shapeTarget = HexDeformable()
    deformable_shapeTarget.Initialize(default_hex_bin_str, density, 'none', youngs_modulus, poissons_ratio)
    deformable_shapeTarget.SetShapeTargetStiffness( 2 * mu)
    print('deform2 dof:', deformable_shapeTarget.dofs())
    dof = deformable_shapeTarget.dofs() 

    png_file = render_folder / 'starfish_default.png'

    q_curr = default_hex_mesh.py_vertices()
    
    # get a default render
    deformable_shapeTarget.PySaveToMeshFile(q_curr, render_bin_str)
    render_quasi_starfish(render_bin_str, png_file) 
    
    import copy  
    
    # found a strong correlation between the stiffness and the convergence of the shape targeting
    deformable_shapeTarget.SetShapeTargetStiffness( 2000 * mu)
    q_next, v_next = StdRealVector(dof), StdRealVector(dof)
    deformable_shapeTarget.PyShapeTargetingForward(q_curr, act, options, q_next ) 
    q_next = np.array(q_next) 
    deformable_shapeTarget.PySaveToMeshFile(q_next, render_bin_str)
    png_file = render_folder / f'starfish_{obj_num}_shape_target.png'
    # every 5 iterations, save a render 
    render_quasi_starfish(render_bin_str, png_file)
    diff_q_ideal = q_next - q_ideal
    print("avg diff:", np.mean(diff_q_ideal))
    q_curr = copy.deepcopy(q_next)
    

act, q_ideal = get_idea_q(gt_json)
do_shape_targeting(act, q_ideal)

ShapeTarget Backward Test.

Use iter 90 as gt, use identity or iter 70 (less movement but same direction) as initialization. Update a few iterations.

In [None]:
'''global parameters'''
asset_folder = Path('/mnt/e/muscleCode/sample_muscle_data/starfish')
default_hex_bin_str = str(asset_folder / 'starfish_demo_voxel.bin')
gt_folder = Path('/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish/init_ground_truth')
render_folder = Path('/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish')
youngs_modulus = 5e5
poissons_ratio = 0.45
la = youngs_modulus * poissons_ratio / ((1 + poissons_ratio) * (1 - 2 * poissons_ratio))
mu = youngs_modulus / (2 * (1 + poissons_ratio))

density = 1e3
thread_ct = 20
dt = 1e-2    
options = {
        'max_pd_iter': 1000,
        'thread_ct': 20,
        'abs_tol': 1e-6,
        'rel_tol': 1e-6,
        'verbose': 0,
        'use_bfgs': 1,
        'bfgs_history_size': 10,
        'max_ls_iter': 10,
        
    }
default_hex_mesh = HexMesh3d()
default_hex_mesh.Initialize(default_hex_bin_str)  
deformable_shapeTarget = HexDeformable()
deformable_shapeTarget.Initialize(default_hex_bin_str, density, 'none', youngs_modulus, poissons_ratio)

'''Functions '''
def render_deformable(render_id, q_curr):
    # call save to render bin first before calling this function
    png_file = render_folder / f'starfish_{render_id}_shape_target.png'
    render_bin_path = render_folder / f'starfish_{render_id}_shape_target.bin'
    render_bin_str = str(render_bin_path)
    deformable_shapeTarget.PySaveToMeshFile(q_curr, render_bin_str)
    render_quasi_starfish(render_bin_str, png_file) 
    # remove the render bin file
    os.remove(render_bin_str)
    
def get_idea_q(gt_folder, obj_num):
    gt_json = str(gt_folder)+ '/'+ 'starfish_' + str(obj_num) + '_init_ground_truth.json'
    # pass whole file into a list
    q_ideal = []
    with open(gt_json, 'r') as f:
        content = f.read()
        float_strings = content.strip('[]').split(',')
        q_ideal = [float(x) for x in float_strings]

    # get q_ideal
    deformable_default = HexDeformable()
    deformable_default.Initialize(default_hex_bin_str, density, 'none', youngs_modulus, poissons_ratio)
    act = StdRealVector(0)
    deformable_default.PyGetShapeTargetSMatrixFromDeformation(q_ideal, act)
    act = np.array(act)
    print(int(act.shape[0] // 48) == default_hex_mesh.NumOfElements())
    return act, q_ideal

def forward_pass(act, q_curr): 
    dof = deformable_shapeTarget.dofs() 
    print('deform2 dof:', deformable_shapeTarget.dofs())    
    # found a strong correlation between the stiffness and the convergence of the shape targeting
    deformable_shapeTarget.SetShapeTargetStiffness( 2000 * mu)
    q_next = StdRealVector(dof)
    deformable_shapeTarget.PyShapeTargetingForward(q_curr, act, options, q_next ) 
    q_next_np = np.array(q_next) 
    return q_next, q_next_np

def loss(q_next, q_ideal):
    l2_diff = np.linalg.norm(q_next - q_ideal)
    print("l2_diff:", l2_diff)
    return l2_diff

init_obj_num = 30
target_obj_num = 30

# initialize local parameters
act_init, _ = get_idea_q(gt_folder, init_obj_num)
_, q_ideal = get_idea_q(gt_folder, target_obj_num)
q_curr = default_hex_mesh.py_vertices() 
render_deformable('default', q_curr)

# main loop
num_iter = 10
for i in range(num_iter):
    q_next, q_next_np = forward_pass(act_init, q_curr)
    render_deformable(i, q_next)
    l2_diff = loss(q_next, q_ideal)
    
    dl_dq_next = StdRealVector(0) # input
    
    dl_dq = StdRealVector(0) # output
    dl_dact = StdRealVector(0) 
    dl_dmat_w = StdRealVector(0) 
    dl_dmat_v = StdRealVector(0) 
    deformable_shapeTarget.PyShapeTargetingBackward(q_next_np, q_ideal, dl_dq_next, options, dl_dq, dl_dact, dl_dmat_w, dl_dmat_v)
    print("dl_dact:", dl_dact)
    
 