This is a utility script for shape targeting simulation. It includes code for 
1. manipulating obj files: read, write verts pos
2. generate hexahedral simulation mesh and its trilinear mapping
3. perform the "zero rest-length spring" initialization mentioned in the paper
Most of the cells can be run cell by cell order. But the path and file system are hardcoded. Adjust as needed.

voxelizer used link
https://drububu.com/miscellaneous/voxelizer/?out=obj

### Sample Import

In [1]:
import sys
sys.path.append('../')
import os

from pathlib import Path
import time
import numpy as np
import scipy.optimize
import pickle
import matplotlib.pyplot as plt

from py_diff_pd.common.common import ndarray, create_folder
from py_diff_pd.common.common import print_info, print_ok, print_error, print_warning
from py_diff_pd.common.grad_check import check_gradients
from py_diff_pd.common.display import export_gif
from py_diff_pd.core.py_diff_pd_core import StdRealVector
from py_diff_pd.env.soft_starfish_env_3d import SoftStarfishEnv3d
from py_diff_pd.common.project_path import root_path
from py_diff_pd.core.py_diff_pd_core import HexMesh3d, HexDeformable, StdRealVector
import py_diff_pd.common.hex_mesh as hex

### Load and write objs

In [2]:
import os
import copy
file_count = 120
input_dir =  "/mnt/e/muscleCode/sample_muscle_data/starfish/"
output_dir = "E:/muscleCode/sample_muscle_data/starfish/"
# read star fish obj file trimesh
# open path / starfish_frame_1.obj, parse lines starting with 'v ' and store xyz in a list

def load_tri_starfish_obj(input_dir, file_name):
    vertex_lines,first_lines , rest_lines = [],[],[]
    count = 0
    with open(os.path.join(input_dir, file_name), 'r') as file:
        for line in file:
            if count <3:
                first_lines.append(line)
            elif count >= 1085: # hardcoded for starfish datasets
                rest_lines.append(line)
            else:
                parts = line.strip().split()
                xyz = [float(parts[1]), float(parts[2]), float(parts[3])]
                vertex_lines.append(xyz)
            count += 1
    return vertex_lines, first_lines, rest_lines


def load_hex_starfish_obj(input_dir, file_name):
    vertex_lines, rest_lines = [],[]
    with open(os.path.join(input_dir, file_name), 'r') as file:
        for line in file:
            if not line.startswith('v '):
                rest_lines.append(line)
            else:
                parts = line.strip().split()
                xyz = [float(parts[1]), float(parts[2]), float(parts[3])]
                vertex_lines.append(xyz) 
    return vertex_lines, rest_lines

# overwrite the starfish obj file with new vertex positions
def write_tri_starfish_obj(output_dir, output_name, first_lines, rest_lines, new_verts):
    # new verts [[x,y,z]]
    with open(os.path.join(output_dir, output_name), 'w') as file:
        for line in first_lines:
            file.write(line)
        for v in new_verts:
            file.write(f"v {v[0]} {v[1]} {v[2]}\n")
        for line in rest_lines:
            file.write(line)
 

Generate Mapping between hex and default trimesh and write to json with trilinear interpolation
For each surface verts:
1. find the cloest element center
2. generate a weight map key=vid, val = weight. 8 total
* the order of element stored vid is 000, 001, 010, 011, 100, 101, 110, 111. Default range is [0.05, 0.05, 0.05]

In [3]:
default_hex_bin_str = str('/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish/ground_truth/starfish_demo_voxel.bin')
mesh = HexMesh3d()
mesh.Initialize(default_hex_bin_str)
py_element_count = mesh.NumOfElements()
obj1_verts, first_lines, rest_lines = load_tri_starfish_obj(input_dir, "starfish_1.obj") 

In [4]:
len(obj1_verts), len(first_lines), len(rest_lines)

(1082, 3, 3486)

In [71]:
# functions related to trilinear interpolation and intergration with hex mesh
def trilinear_weights(point_pos_np, element_id):    
    # range is 0.05
    v0 = mesh.py_element(element_id)[0]
    v1 = mesh.py_element(element_id)[1]
    range_ =  max(np.array(mesh.py_vertex(v1)) - np.array(mesh.py_vertex(v0)))
    assert abs(range_ - 0.05) < 1e-6
    diff = point_pos_np - np.array(mesh.py_vertex(v0))
    x, y, z = diff / range_
    return [
        (1 - x) * (1 - y) * (1 - z),  # w_000
        (1 - x) * (1 - y) * z,        # w_001
        (1 - x) * y * (1 - z),        # w_010
        (1 - x) * y * z,              # w_011
        x * (1 - y) * (1 - z),        # w_100
        x * (1 - y) * z,              # w_101
        x * y * (1 - z),              # w_110
        x * y * z                     # w_111
    ]

'''
Both functions here only appliable to hex mesh aligned with the axis
'''
def get_element_center(element_id): 
    v0 = mesh.py_element(element_id)[0]
    v7 = mesh.py_element(element_id)[7] # 0,0,0 and 1,1,1
    range_ = np.array(mesh.py_vertex(v7)) - np.array(mesh.py_vertex(v0))
    return np.array(mesh.py_vertex(v0)) + range_ / 2
    
# reconstruct the surface vert from mapping
# verts is [3x8], weights is [8]
def reconstruct_trilinear(verts, weights):
    assert len(verts) == 8 and len(weights) == 8
    new_vert = np.zeros(3)
    for i in range(8):
        new_vert += verts[i] * weights[i]
    return new_vert

# get all verts pos given element id
def get_element_verts_pos(element_id):
    verts = []
    v_ids = mesh.py_element(element_id)
    for i in range(8):
        verts.append(np.array(mesh.py_vertex(v_ids[i])))
    return verts

In [89]:
# load obj1 and hex mesh, construct mapping and save to json
one_to_hex_mapping = {}
hex_to_one_mapping = {}
trilinear_weights_mapping = []
all_one_to_hex_dist = 0
for i, v in enumerate(obj1_verts):
    min_dist = float('inf')
    min_idx = -1
    for e_id in range(py_element_count):
        dist_nparr = get_element_center(e_id) - np.array(v)
        dist = np.linalg.norm(dist_nparr)
        if dist < min_dist:
            min_dist = dist
            min_idx = e_id
    weights = trilinear_weights(v, min_idx)
    trilinear_weights_mapping.append({min_idx: weights})
dict_to_write = {
    'trilinear_weights_mapping': trilinear_weights_mapping
}

In [None]:
print(f'total length of mapping: {len(trilinear_weights_mapping)})')
all_Key_map = {}
one_count, two_count, three_count, four_count, five_plus_count = 0, 0, 0, 0, 0
for dic in dict_to_write['trilinear_weights_mapping']:
    key = list(dic.keys())[0]
    if key in all_Key_map:
        all_Key_map[key] += 1
    else:
        all_Key_map[key] = 1
for key in all_Key_map:
    if all_Key_map[key] == 1:
        one_count += 1
    elif all_Key_map[key] == 2:
        two_count += 1
    elif all_Key_map[key] == 3:
        three_count += 1
    elif all_Key_map[key] == 4:
        four_count += 1
    else:
        five_plus_count += 1
print(one_count, two_count, three_count, four_count, five_plus_count)
print(trilinear_weights_mapping[0])

In [None]:
# write to json
tri_map_path = '/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish/ground_truth/trilinear_weights.json'
import json
with open(tri_map_path, 'w') as f:
    json.dump(dict_to_write, f)
    
# rest is deprecated. Once used for construct one-to-one mapping  
# hex_to_one_json = asset_folder / 'hex_to_one.json'
# one_to_hex_json = asset_folder / 'one_to_hex.json'
# with open(hex_to_one_json, 'w') as f:
#     json.dump(hex_to_one_mapping, f)
# with open(one_to_hex_json, 'w') as f:
#     json.dump(one_to_hex_mapping, f)

In [None]:
# # write default surface mesh verts for gt
# for s in range(1,121):
#     obj_v_map = {}
#     obj_verts, _, __ = load_tri_starfish_obj(input_dir, "starfish_"+str(s)+".obj")
#     obj_v_map['starfish'] = obj_verts
#     path_to_store = Path(f'quasi_starfish/ground_truth/default_pos/starfish_obj_{s}_verts.json')
#     with open(path_to_store, 'w') as f:
#         json.dump(obj_v_map, f)
#     print(f'starfish_{s} done')


In [75]:
# quick verification. Can run independently after initial import
dict_to_write = json.load(open(tri_map_path, 'r'))
trilinear_weights_mapping = dict_to_write['trilinear_weights_mapping']
gt = obj1_verts
for i, mapping in enumerate(trilinear_weights_mapping):
    element_id = list(mapping.keys())[0]
    weights = mapping[element_id]
    verts_pos = get_element_verts_pos(int(element_id))
    new_vert = np.array(reconstruct_trilinear(verts_pos, weights))
    gt_vert = np.array(gt[i])
    assert np.linalg.norm(new_vert - gt_vert) < 1e-6

In [88]:
x = np.array([[2,2],[2,2]])
a = np.zeros((x.shape))
a

array([[0., 0.],
       [0., 0.]])

## PD for autoencoder groundtruth

### import and Initialize

In [None]:

import sys
import os
from pathlib import Path
import time
import numpy as np
from py_diff_pd.core.py_diff_pd_core import HexMesh3d, HexDeformable, StdRealVector
import py_diff_pd.common.hex_mesh as hex
from py_diff_pd.core.py_diff_pd_core import StdRealVector, StdIntVector

In [None]:
# Implement one iteration of the PD with zero rest length
# Then use the Deformation gradiant's diagonal as A
target_frame = 90
# Env _init_ 
# folder
asset_folder = Path('/mnt/e/muscleCode/sample_muscle_data/starfish')
mesh_bin = asset_folder / 'starfish_demo_voxel.bin'
mesh_bin_str = str(mesh_bin)
voxel_output = asset_folder / 'starfish_demo_voxel_output.obj'
json_file_path = asset_folder / 'starfish_demo_48x9x46.json'
render_folder = Path('/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish')
render_bin_path = render_folder / str('starfish_voxel_'+str(target_frame)+'.bin')
render_bin_str = str(render_bin_path)

# Deformable param
youngs_modulus = 5e5
poissons_ratio = 0.45
la = youngs_modulus * poissons_ratio / ((1 + poissons_ratio) * (1 - 2 * poissons_ratio))
mu = youngs_modulus / (2 * (1 + poissons_ratio))
density = 1e3
thread_ct = 20
dt = 1e-2

pd_opt = { 'max_pd_iter': 500, 'max_ls_iter': 10, 'abs_tol': 1e-9, 'rel_tol': 1e-4, 'verbose': 0, 'thread_ct': 20,
        'use_bfgs': 1, 'bfgs_history_size': 10 }
foward_method = 'pd_eigen'

In [None]:
# Initialize active objects and param related to deformable
default_hex_mesh = HexMesh3d()
default_hex_mesh.Initialize(mesh_bin_str) 
deformable = HexDeformable()
deformable.Initialize(mesh_bin_str, density, 'none', youngs_modulus, poissons_ratio)
deformable.AddPdEnergy('corotated', [2 * mu,], [])
deformable.AddPdEnergy('volume', [la,], [])

dof = deformable.dofs()
act_maps = np.zeros(deformable.act_dofs())

q_curr = default_hex_mesh.py_vertices()
v_curr = np.zeros(deformable.dofs())
q_next, v_next, contact_index = StdRealVector(dof), StdRealVector(dof), StdIntVector(0)

# Compute initial f_ext
obj_target_verts, _, _ = load_tri_starfish_obj(input_dir, "starfish_"+str(target_frame)+".obj")
obj_1_verts, _, _ = load_tri_starfish_obj(input_dir, "starfish_1.obj")
# load mapping
hex_to_one_json = input_dir + 'hex_to_one.json'
one_to_hex_json = input_dir + 'one_to_hex.json'
hex_to_one_mapping = {}
one_to_hex_mapping = {}
import json
with open(hex_to_one_json, 'r') as f:
    hex_to_one_mapping = json.load(f)
with open(one_to_hex_json, 'r') as f:
    one_to_hex_mapping = json.load(f)
# accumulate forces on hex vertices

def accumulate_forces_on_hex(obj_target_verts, obj_1_verts, close_flag): # need to update obj_1_verts in global scope
    f_ext = np.zeros(deformable.dofs())
    k = 1e2
    if close_flag:
        k = 1
    forces_on_verts = (np.array(obj_target_verts) - np.array(obj_1_verts) ) * k
    
    for k, verts in hex_to_one_mapping.items():
        for v in verts:
            k = int(k) // 3
            v = int(v)
            # write force to f_ext
            current_force = forces_on_verts[v]
            f_ext[k*3] += current_force[0]
            f_ext[k*3+1] += current_force[1]
            f_ext[k*3+2] += current_force[2]
    return f_ext
    


### Render

In [None]:
# visualize the hex mesh
from py_diff_pd.common.renderer import PbrtRenderer
png_file = render_folder / 'starfish_default.png'
def render_quasi_starfish(mesh_file, png_file):
    options = {
        'file_name': png_file,
        'light_map': 'uffizi-large.exr',
        'sample': 4,
        'max_depth': 2,
        'camera_pos': (2, 3, 5),
        'camera_lookat': (1, -1, 0), # roughly the center of starfish obj
        
    }
    renderer = PbrtRenderer(options)
    
    mesh = HexMesh3d()
    mesh.Initialize(mesh_file)
    renderer.add_hex_mesh(mesh, render_voxel_edge=True, color=(.3, .7, .5), transforms=[
        ('r', [90, 1, 0, 0]),  # Rotate 90 degrees around the x-axis
        ('t', [0, 0, 0]),
        ])
    renderer.add_tri_mesh(Path(root_path) / 'asset/mesh/flat_ground.obj',
            texture_img='chkbd_24_0.7', transforms=[
                ('s', 4),
                ('t', [0, 0, -1]),
                ])
    
    
    renderer.render()

render_quasi_starfish(mesh_bin_str, png_file)



### Acumulate forces - Forward pass

In [None]:

speed_decay = 1

In [None]:
# forward pass parameters  
# compute f_ext from zero rest length springs 
# fxi = -k(length - rest_length) * (xi - xj) / length  
# when rest_length = 0,  fxi = -k * (xi-xj) 
# Initial k guess is 1e3, same as density, so F, m is about the same level
# Then define

# rough pipline 
# initialize related parameters
# for i in range(10):
#     q_next, v_next, contact_index from pyforward
#     visualize q_next, check if v is close to 0
#     compute new diff , update forces accumulation, set as new f_ext
num_iters = 20
close_flag = False
for i in range(num_iters):    
    speed_decay *= 0.982
    f_ext = accumulate_forces_on_hex(obj_target_verts, obj_1_verts, close_flag)
    deformable.PyForward(foward_method, q_curr, v_curr, act_maps, f_ext, dt, pd_opt, q_next, v_next, contact_index)
    # render
    png_file = render_folder / f'starfish_{target_frame}_init_{i}.png'
    deformable.PySaveToMeshFile(q_next, render_bin_str)
    render_quasi_starfish(render_bin_str, png_file)
    # iter obj_1 verts and apply q_diff to them    
    q_diff = np.array(q_next) - np.array(q_curr)
    for v in range(len(obj_1_verts)):
        hex_idx = int(one_to_hex_mapping[str(v)])
        obj_1_verts[v][0] += q_diff[hex_idx]
        obj_1_verts[v][1] += q_diff[hex_idx+1]
        obj_1_verts[v][2] += q_diff[hex_idx+2]
    verts_diff = np.array(obj_1_verts) - np.array(obj_target_verts)
    l2_diff = np.linalg.norm(verts_diff)
    print(f'iter {i} l2_diff {l2_diff}')
    # if l2_diff < 1.2:
    #     close_flag = True
    #     speed_decay *= 0.3
    v_next = np.array(v_next) * speed_decay
    v_next = StdRealVector(v_next)
    avg_speed = np.mean(np.abs(v_next))
    print(f'iter {i} avg_speed {avg_speed}')
    
    q_curr = q_next
    v_curr = v_next
    q_next, v_next, contact_index = StdRealVector(dof), StdRealVector(dof), StdIntVector(0)


## Repeat for all key frames

In [None]:
import sys
import os
from pathlib import Path
import time
import numpy as np
from py_diff_pd.core.py_diff_pd_core import HexMesh3d, HexDeformable, StdRealVector
import py_diff_pd.common.hex_mesh as hex
from py_diff_pd.core.py_diff_pd_core import StdRealVector, StdIntVector

In [None]:
# global parameters across iterations

asset_folder = Path('/mnt/e/muscleCode/sample_muscle_data/starfish')
mesh_bin = asset_folder / 'starfish_demo_voxel.bin'
mesh_bin_str = str(mesh_bin)
output_folder = Path('/mnt/e/wsl_projects/diff_pd_public/python/example/quasi_starfish/stretched_gt_corotate_volume/')

# Deformable param
youngs_modulus = 5e5
poissons_ratio = 0.45
la = youngs_modulus * poissons_ratio / ((1 + poissons_ratio) * (1 - 2 * poissons_ratio))
mu = youngs_modulus / (2 * (1 + poissons_ratio))
density = 1e3
thread_ct = 20
dt = 1e-2

pd_opt = { 'max_pd_iter': 500, 'max_ls_iter': 10, 'abs_tol': 1e-9, 'rel_tol': 1e-4, 'verbose': 0, 'thread_ct': 20,
        'use_bfgs': 1, 'bfgs_history_size': 10 }
foward_method = 'pd_eigen'

# load mapping
hex_to_one_json = input_dir + 'hex_to_one.json'
one_to_hex_json = input_dir + 'one_to_hex.json'
hex_to_one_mapping = {}
one_to_hex_mapping = {}
import json
with open(hex_to_one_json, 'r') as f:
    hex_to_one_mapping = json.load(f)
with open(one_to_hex_json, 'r') as f:
    one_to_hex_mapping = json.load(f)
# accumulate forces on hex vertices
def accumulate_forces_on_hex(obj_target_verts, obj_1_verts):
    # need to update obj_1_verts in global scope
    f_ext = np.zeros(deformable.dofs())
    k = 1e2 
    forces_on_verts = (np.array(obj_target_verts) - np.array(obj_1_verts) ) * k
    
    for k, verts in hex_to_one_mapping.items():
        for v in verts:
            k = int(k) // 3
            v = int(v)
            # write force to f_ext
            current_force = forces_on_verts[v]
            f_ext[k*3] += current_force[0]
            f_ext[k*3+1] += current_force[1]
            f_ext[k*3+2] += current_force[2]
    return f_ext

In [None]:
frame_count_start = 1
frame_count = 120
total_diff = 0
for frame in range(frame_count_start, frame_count_start + frame_count):
    output_file = output_folder / f'starfish_{frame}_init_ground_truth.json'
    obj_target_verts, _, _ = load_tri_starfish_obj(input_dir, "starfish_"+str(frame)+".obj")
       
    # iteration parameters
    decay_rates = [0.99, 0.987, 0.985, 0.982, 0.98]
    min_diff = 10000
    min_q = []
    num_iters = 20
    for decay_rate in decay_rates:
        default_hex_mesh = HexMesh3d()
        default_hex_mesh.Initialize(mesh_bin_str) 
        deformable = HexDeformable()
        deformable.Initialize(mesh_bin_str, density, 'none', youngs_modulus, poissons_ratio)
        deformable.AddPdEnergy('corotated', [2 * mu,], [])

        dof = deformable.dofs()
        act_maps = np.zeros(deformable.act_dofs())
        q_curr = default_hex_mesh.py_vertices()
        v_curr = np.zeros(deformable.dofs())
        q_next, v_next, contact_index = StdRealVector(dof), StdRealVector(dof), StdIntVector(0)
        
        obj_1_verts, _, _ = load_tri_starfish_obj(input_dir, "starfish_1.obj")
        speed_decay = 1
        
        for iter in range(num_iters):
            speed_decay *= decay_rate
            f_ext = accumulate_forces_on_hex(obj_target_verts, obj_1_verts)
            deformable.PyForward(foward_method, q_curr, v_curr, act_maps, f_ext, dt, pd_opt, q_next, v_next, contact_index)
            q_diff = np.array(q_next) - np.array(q_curr)
            for v in range(len(obj_1_verts)):
                hex_idx = int(one_to_hex_mapping[str(v)])
                obj_1_verts[v][0] += q_diff[hex_idx]
                obj_1_verts[v][1] += q_diff[hex_idx+1]
                obj_1_verts[v][2] += q_diff[hex_idx+2]
            verts_diff = np.array(obj_1_verts) - np.array(obj_target_verts)
            l2_diff = np.linalg.norm(verts_diff)
            if l2_diff < min_diff:
                min_diff = l2_diff
                min_q = q_next
            v_next = np.array(v_next) * speed_decay
            v_next = StdRealVector(v_next)
            q_curr = q_next
            v_curr = v_next
            q_next, v_next, contact_index = StdRealVector(dof), StdRealVector(dof), StdIntVector(0)
    # save min_q to json
    with open(output_file, 'w') as f:
        min_q = np.array(min_q).tolist()
        json.dump(min_q, f)
    print(f'frame {frame} , min_diff {min_diff}, saved to {output_file}')
    total_diff += min_diff

print(f'average diff {total_diff / frame_count}')