In [1]:
import sys
sys.path.append('../')
import os

from pathlib import Path
import time
import numpy as np
import scipy.optimize
import pickle
import matplotlib.pyplot as plt

from py_diff_pd.common.common import ndarray, create_folder
from py_diff_pd.common.common import print_info, print_ok, print_error, print_warning
from py_diff_pd.common.grad_check import check_gradients
from py_diff_pd.common.display import export_gif
from py_diff_pd.core.py_diff_pd_core import StdRealVector
from py_diff_pd.env.soft_starfish_env_3d import SoftStarfishEnv3d
from py_diff_pd.common.project_path import root_path
from py_diff_pd.core.py_diff_pd_core import HexMesh3d, HexDeformable, StdRealVector
import py_diff_pd.common.hex_mesh as hex

In [3]:
# test env
asset_folder = Path('/mnt/e/muscleCode/sample_muscle_data/starfish')
# list all the files in the folder
files = os.listdir(asset_folder)
print(files)
input_obj = asset_folder / 'starfish_demo.obj'

['starfish.mb', 'starfish_demo.bin', 'starfish_demo.obj', 'starfish_n_1.obj', 'starfish_n_10.obj', 'starfish_n_100.obj', 'starfish_n_101.obj', 'starfish_n_102.obj', 'starfish_n_103.obj', 'starfish_n_104.obj', 'starfish_n_105.obj', 'starfish_n_106.obj', 'starfish_n_107.obj', 'starfish_n_108.obj', 'starfish_n_109.obj', 'starfish_n_11.obj', 'starfish_n_110.obj', 'starfish_n_111.obj', 'starfish_n_112.obj', 'starfish_n_113.obj', 'starfish_n_114.obj', 'starfish_n_115.obj', 'starfish_n_116.obj', 'starfish_n_117.obj', 'starfish_n_118.obj', 'starfish_n_119.obj', 'starfish_n_12.obj', 'starfish_n_120.obj', 'starfish_n_13.obj', 'starfish_n_14.obj', 'starfish_n_15.obj', 'starfish_n_16.obj', 'starfish_n_17.obj', 'starfish_n_18.obj', 'starfish_n_19.obj', 'starfish_n_2.obj', 'starfish_n_20.obj', 'starfish_n_21.obj', 'starfish_n_22.obj', 'starfish_n_23.obj', 'starfish_n_24.obj', 'starfish_n_25.obj', 'starfish_n_26.obj', 'starfish_n_27.obj', 'starfish_n_28.obj', 'starfish_n_29.obj', 'starfish_n_3.obj', 

In [3]:
time_check = time.time()

asset_folder = Path('/mnt/e/muscleCode/sample_muscle_data/starfish')
# list all the files in the folder
input_obj = asset_folder / 'starfish_demo_voxel.obj'
mesh_bin = asset_folder / 'starfish_demo_voxel.bin'
voxel_output = asset_folder / 'starfish_demo_voxel_output.obj'
json_file_path = asset_folder / 'starfish_demo_48x9x46.json'



time_check = time.time()
mesh = HexMesh3d()
mesh.Initialize(str(mesh_bin))
print('initialize hex mesh time:', time.time()-time_check)

time_check = time.time()
hex.hex2obj(mesh, voxel_output, 'tri')
print('hex2obj time:', time.time()-time_check)


initialize hex mesh time: 0.012445688247680664
hex2obj time: 0.03848004341125488


In [14]:
import os
import copy
file_count = 120
input_dir =  "/mnt/e/muscleCode/sample_muscle_data/starfish/"
output_dir = "E:/muscleCode/sample_muscle_data/starfish/"
# read star fish obj file trimesh
# open path / starfish_frame_1.obj, parse lines starting with 'v ' and store xyz in a list

def load_tri_starfish_obj(input_dir, file_name):
    vertex_lines,first_lines , rest_lines = [],[],[]
    count = 0
    with open(os.path.join(input_dir, file_name), 'r') as file:
        for line in file:
            if count <3:
                first_lines.append(line)
            elif count >= 1085: # hardcoded for starfish datasets
                rest_lines.append(line)
            else:
                parts = line.strip().split()
                xyz = [float(parts[1]), float(parts[2]), float(parts[3])]
                vertex_lines.append(xyz)
            count += 1
    return vertex_lines, first_lines, rest_lines


def load_hex_starfish_obj(input_dir, file_name):
    vertex_lines, rest_lines = [],[]
    with open(os.path.join(input_dir, file_name), 'r') as file:
        for line in file:
            if not line.startswith('v '):
                rest_lines.append(line)
            else:
                parts = line.strip().split()
                xyz = [float(parts[1]), float(parts[2]), float(parts[3])]
                vertex_lines.append(xyz) 
    return vertex_lines, rest_lines

# overwrite the starfish obj file with new vertex positions
def write_tri_starfish_obj(output_dir, output_name, first_lines, rest_lines, new_verts):
    # new verts [[x,y,z]]
    with open(os.path.join(output_dir, output_name), 'w') as file:
        for line in first_lines:
            file.write(line)
        for v in new_verts:
            file.write(f"v {v[0]} {v[1]} {v[2]}\n")
        for line in rest_lines:
            file.write(line)
 

In [18]:
obj2_verts = mesh.py_vertices()

16314

In [22]:

# load obj1 and hex mesh, construct mapping and save to json
obj1_verts, first_lines, rest_lines = load_tri_starfish_obj(input_dir, "starfish_1.obj")
obj2_verts = mesh.py_vertices()
one_to_hex_mapping = {}
hex_to_one_mapping = {}
all_one_to_hex_dist = 0
for i, v in enumerate(obj1_verts):
    min_dist = float('inf')
    min_idx = -1
    for j in range(0, len(obj2_verts), 3):
        hex_v = obj2_verts[j:j+3]
        dist = sum([(v[k] - hex_v[k]) ** 2 for k in range(3)])
        if dist < min_dist:
            min_dist = dist
            min_idx = j
    one_to_hex_mapping[i] = min_idx
    all_one_to_hex_dist += min_dist
    if min_idx in hex_to_one_mapping:
        hex_to_one_mapping[min_idx].append(i)
    else:
        hex_to_one_mapping[min_idx] = [i]

In [25]:
print(all_one_to_hex_dist / len(obj1_verts))
# count number of 1,2,3,4,5+ in hex_to_one_mapping
one_count, two_count, three_count, four_count, five_plus_count = 0, 0, 0, 0, 0
for k, v in hex_to_one_mapping.items():
    if len(v) == 1:
        one_count += 1
    elif len(v) == 2:
        two_count += 1
    elif len(v) == 3:
        three_count += 1
    elif len(v) == 4:
        four_count += 1
    else:
        five_plus_count += 1
print(one_count, two_count, three_count, four_count, five_plus_count, len(obj2_verts) / 3)
print(len(obj1_verts))


0.0006167308275463548
814 84 16 5 5 5438.0
1082


In [None]:
# write to json
import json
hex_to_one_json = asset_folder / 'hex_to_one.json'
one_to_hex_json = asset_folder / 'one_to_hex.json'
with open(hex_to_one_json, 'w') as f:
    json.dump(hex_to_one_mapping, f)
with open(one_to_hex_json, 'w') as f:
    json.dump(one_to_hex_mapping, f)