In [38]:
import pyvista as pv
import meshio
import SimpleITK as sitk
import vtk
import numpy as np
from tqdm import tqdm
from pathlib import Path
from meshpy.tet import MeshInfo, build
from scipy.ndimage import label, map_coordinates
from sklearn.decomposition import PCA
import math
import os
import meshlib.mrmeshpy as mr
import meshlib.mrmeshnumpy as mrn
import random

## Functions

In [61]:
def hex_to_5tets(hexes):
    a,b,c,d,e,f,g,h = hexes[:,0],hexes[:,1],hexes[:,2],hexes[:,3], \
                      hexes[:,4],hexes[:,5],hexes[:,6],hexes[:,7]
    return np.vstack([
        np.stack([a,b,d,e],axis=1),
        np.stack([b,c,d,g],axis=1),
        np.stack([b,e,f,g],axis=1),
        np.stack([b,d,e,g],axis=1),
        np.stack([d,e,g,h],axis=1),
    ])

def numpy_to_vtk_coords(coords):
    vtkarr = vtk.vtkDoubleArray()
    vtkarr.SetNumberOfValues(len(coords))
    for i, v in enumerate(coords):
        vtkarr.SetValue(i, float(v))
    return vtkarr

def replace_header_and_elements(inp_file, new_header_file, output_file):
    # Read new header
    with open(new_header_file, "r") as f:
        new_header = f.read().rstrip("\n") + "\n"

    lines = []
    header_inserted = False

    with open(inp_file, "r") as f:
        for line in f:
            stripped = line.strip()

            # --- 1. Replace header before *NODE
            if not header_inserted and stripped.upper().startswith("*NODE"):
                lines.append(new_header)
                header_inserted = True
                # skip writing the *NODE line itself
                continue

            # --- 2. Modify *ELEMENT, TYPE= lines
            if stripped.upper().startswith("*ELEMENT, TYPE="):
                parts = line.split("=", 1)
                if len(parts) == 2:
                    elem_type = parts[1].strip().upper()
                    line = f"*Element, type={elem_type}\n"

            # Add the processed line
            lines.append(line)

    # Write result
    with open(output_file, "w") as f:
        f.writelines(lines)

def append_footer(inp_file, footer_file, output_file):
    # Read footer
    with open(footer_file, "r") as f:
        footer = f.read()

    with open(inp_file, "r") as f:
        content = f.read()

    with open(output_file, "w") as f:
        f.write(content.rstrip("\n") + "\n")  # ensure last line ends
        f.write(footer.rstrip("\n") + "\n")

#Clean mask from small voxels
def remove_small_islands(binary_matrix, area_threshold):
    """Remove small connected components from a binary mask."""
    labeled_array, num_features = label(binary_matrix)
    for i in range(1, num_features + 1):
        component = (labeled_array == i)
        if component.sum() < area_threshold:
            binary_matrix[component] = 0
    return binary_matrix

def mask_region_of_interest(mask,ranges):
    zero_mask=np.zeros(np.shape(mask))
    zero_mask[ranges[0]:ranges[1],ranges[2]:ranges[3],ranges[4]:ranges[5]]=mask[ranges[0]:ranges[1],ranges[2]:ranges[3],ranges[4]:ranges[5]]
    return zero_mask

def create_blocks(mask, thickness, axis):
    if (axis=="X"):
        xmin=np.min(np.where(mask==1)[0])
        xmax=np.max(np.where(mask==1)[0])+1
        ymin=np.min([np.min(np.where(mask[xmin:xmin+thickness,:,:]==1)[1]),np.min(np.where(mask[xmax-thickness:xmax,:,:]==1)[1])])
        ymax=np.max([np.max(np.where(mask[xmin:xmin+thickness,:,:]==1)[1]),np.max(np.where(mask[xmax-thickness:xmax,:,:]==1)[1])])+1
        zmin=np.min([np.min(np.where(mask[xmin:xmin+thickness,:,:]==1)[2]),np.min(np.where(mask[xmax-thickness:xmax,:,:]==1)[2])])
        zmax=np.max([np.max(np.where(mask[xmin:xmin+thickness,:,:]==1)[2]),np.max(np.where(mask[xmax-thickness:xmax,:,:]==1)[2])])+1
        blocks=np.zeros(np.shape(mask))
        blocks[xmin:xmin+thickness,ymin:ymax,zmin:zmax]=2
        blocks[xmax-thickness:xmax,ymin:ymax,zmin:zmax]=2
        final_mask=blocks+mask
        return final_mask

def _skew(v):
    return np.array([[0.0, -v[2], v[1]],
                     [v[2],  0.0, -v[0]],
                     [-v[1], v[0],  0.0]], dtype=float)

def align_by_centroids_roi(image, iso_value, roi=[0,0,0,0,0,0], fraction=0.1, axis="Z", debug=True):
    """
    Align the volume so the line connecting top/bottom centroids (computed inside ROI)
    maps to the chosen world axis ("X","Y","Z").

    Returns: (numpy_array (z,y,x), rotated_sitk_image)
    roi = [x0,x1,y0,y1,z0,z1] using python slices (x0 inclusive, x1 exclusive).
    """
    arr = sitk.GetArrayFromImage(image)  # (z,y,x)

    print(roi)

    # ---------- 1) ROI crop and centroids in ROI index space (z,y,x) ----------
    x0, x1, y0, y1, z0, z1 = roi
    if np.prod(roi)!=0:
        roi_arr = arr[z0:z1, y0:y1, x0:x1]
    else:
        roi_arr=arr
    mask = roi_arr > iso_value

    nz, ny, nx = mask.shape
    if nz <= 0 or ny <= 0 or nx <= 0:
        raise ValueError("ROI is empty or invalid")

    n_slices = max(1, int(nz * fraction))

    # bottom region (first n_slices in ROI z)
    bottom_mask = mask[:n_slices]
    if bottom_mask.sum() == 0:
        raise ValueError("No voxels in the bottom ROI slices. Decrease fraction or change ROI.")
    bottom_coords = np.argwhere(bottom_mask)  # (z,y,x) within ROI
    bottom_centroid_roi = bottom_coords.mean(axis=0)  # (z,y,x) in ROI coords

    # top region (last n_slices in ROI z)
    top_mask = mask[-n_slices:]
    if top_mask.sum() == 0:
        raise ValueError("No voxels in the top ROI slices. Decrease fraction or change ROI.")
    top_coords = np.argwhere(top_mask)
    top_coords[:, 0] += nz - n_slices   # shift z indices back inside ROI
    top_centroid_roi = top_coords.mean(axis=0)

    # Map ROI centroids back to full-image indices (still z,y,x)
    bottom_idx_zyx = bottom_centroid_roi + np.array([z0, y0, x0])
    top_idx_zyx    = top_centroid_roi    + np.array([z0, y0, x0])

    # ---------- 2) Convert centroids to PHYSICAL coords (x,y,z) ----------
    spacing = np.array(image.GetSpacing(), dtype=float)    # (sx,sy,sz)
    origin  = np.array(image.GetOrigin(), dtype=float)     # (ox,oy,oz)
    D = np.array(image.GetDirection(), dtype=float).reshape(3,3)  # direction matrix
    M = D @ np.diag(spacing)   # maps index_xyz -> physical offset (without origin)

    # convert z,y,x -> x,y,z index order
    bottom_idx_xyz = bottom_idx_zyx[::-1]
    top_idx_xyz    = top_idx_zyx[::-1]

    bottom_phys = origin + M @ bottom_idx_xyz
    top_phys    = origin + M @ top_idx_xyz

    # ---------- 3) Build direction vector in PHYSICAL space ----------
    v = top_phys - bottom_phys
    norm_v = np.linalg.norm(v)
    if norm_v == 0:
        raise ValueError("Top and bottom centroids coincide (zero-length). Choose different ROI/fraction.")
    v_unit = v / norm_v

    axis = axis.upper()
    if axis == "X":
        t = np.array([1.0, 0.0, 0.0], dtype=float)
    elif axis == "Y":
        t = np.array([0.0, 1.0, 0.0], dtype=float)
    else:
        t = np.array([0.0, 0.0, 1.0], dtype=float)

    # ---------- 4) Compute rotation R that maps v_unit -> t ----------
    k = np.cross(v_unit, t)          # not normalized; ||k|| = sin(theta)
    s = np.linalg.norm(k)
    c = np.dot(v_unit, t)

    if s < 1e-12 and c > 0.999999:
        R = np.eye(3)
    elif s < 1e-12 and c < -0.999999:
        # 180-degree rotation: pick an arbitrary orthogonal axis
        arbitrary = np.array([1.0, 0.0, 0.0], dtype=float)
        if abs(np.dot(arbitrary, v_unit)) > 0.9:
            arbitrary = np.array([0.0, 1.0, 0.0], dtype=float)
        axis_orth = np.cross(v_unit, arbitrary)
        axis_orth /= np.linalg.norm(axis_orth)
        K = _skew(axis_orth)
        R = np.eye(3) + 2.0 * (K @ K)   # since sin(pi)=0, 1-cos(pi)=2
    else:
        K = _skew(k)   # use k (not normalized)
        # Rodrigues variant for non-normalized k:
        # R = I + K + K^2 * ((1 - c) / (s^2))
        R = np.eye(3) + K + (K @ K) * ((1.0 - c) / (s * s))

    # ---------- diagnostics ----------
    if debug:
        ang = math.degrees(math.acos(np.clip(c, -1.0, 1.0)))
        print("bottom_idx_zyx (full image):", bottom_idx_zyx)
        print("top_idx_zyx    (full image):", top_idx_zyx)
        print("bottom_phys (mm):", bottom_phys)
        print("top_phys    (mm):", top_phys)
        print(f"Angle between centroid vector and target axis: {ang:.6f} deg")
        print("Rotation matrix R:\n", R)

    # ---------- 5) Build output reference grid (expand bounding box so rotated volume fits) ----------
    size = np.array(image.GetSize(), dtype=float)   # (sx,sy,sz)
    # generate corners in index_xyz order
    corners_idx_xyz = np.array([[0,0,0],
                                [size[0]-1,0,0],
                                [0,size[1]-1,0],
                                [0,0,size[2]-1],
                                [size[0]-1,size[1]-1,0],
                                [size[0]-1,0,size[2]-1],
                                [0,size[1]-1,size[2]-1],
                                [size[0]-1,size[1]-1,size[2]-1]], dtype=float)
    # corners -> physical
    corners_phys = (corners_idx_xyz @ M.T) + origin  # (8,3)

    # center of rotation: use image geometric center (physical)
    center_idx_xyz = (size - 1.0) / 2.0
    center_phys = origin + M @ center_idx_xyz

    # rotate corners about center: p' = R @ (p - c) + c
    rotated_corners = ((corners_phys - center_phys) @ R.T) + center_phys

    out_spacing = spacing.copy()
    mins = rotated_corners.min(axis=0)
    maxs = rotated_corners.max(axis=0)
    extent = maxs - mins
    out_size = np.ceil(extent / out_spacing).astype(int)
    out_size = np.maximum(out_size, 1)
    out_origin = mins

    # build reference image with identity direction
    ref = sitk.Image(int(out_size[0]), int(out_size[1]), int(out_size[2]), image.GetPixelID())
    ref.SetSpacing(tuple(float(x) for x in out_spacing))
    ref.SetOrigin(tuple(float(x) for x in out_origin))
    ref.SetDirection(tuple(np.eye(3).ravel().tolist()))

    # ---------- 6) Build ITK transform mapping output->input (p_in = R^{-1}(p_out - c) + c) ----------
    R_inv = R.T
    T = sitk.AffineTransform(3)
    T.SetMatrix(list(R_inv.ravel()))
    T.SetCenter(tuple(center_phys.tolist()))

    # ---------- 7) Resample with nearest neighbor to preserve densities ----------
    out = sitk.Resample(
        image,
        ref,
        T,
        sitk.sitkNearestNeighbor,
        0.0,
        image.GetPixelID()
    )

    out_arr = sitk.GetArrayFromImage(out)  # (z,y,x)
    if debug:
        print("Output image shape (z,y,x):", out_arr.shape,
              " intensity range:", out_arr.min(), out_arr.max())
    return out_arr, out


## Import the .nddr CT file

In [62]:
# ------------------------
# PARAMETERS
# ------------------------
input_nrrd = "microCT_volume_preview045423.nrrd"
iso_value = 8000         # Threshold for segmentation
downsample_factor = 4     # Reduce voxel count for speed (optional)
roi = [0,70,0,70,0,140]

# ------------------------
# 1. LOAD IMAGE
# ------------------------
if not os.path.exists(input_nrrd):
    raise FileNotFoundError(f"File not found: {input_nrrd}")

image = sitk.ReadImage(input_nrrd)

array, rot_im = align_by_centroids_roi(image, iso_value, roi,fraction=0.1, axis="Z", debug=True)
image=rot_im

orig_size = np.array(image.GetSize())
orig_spacing = np.array(image.GetSpacing())
origin = np.array(image.GetOrigin())

# ------------------------
# 2. DOWNSAMPLE (OPTIONAL)
# ------------------------
new_size = (orig_size / downsample_factor).astype(int)
new_spacing = orig_spacing * (orig_size / new_size)

resampler = sitk.ResampleImageFilter()
resampler.SetSize([int(s) for s in new_size])
resampler.SetOutputSpacing([float(s) for s in new_spacing])
resampler.SetOutputOrigin(origin)
resampler.SetOutputDirection(image.GetDirection())
resampler.SetInterpolator(sitk.sitkLinear)
image = resampler.Execute(image)

# Convert to numpy array
array = sitk.GetArrayFromImage(image)  # z, y, x
spacing = np.array(image.GetSpacing())
origin = np.array(image.GetOrigin())
print(f"Loaded image: shape={array.shape}, intensity range=({array.min()}, {array.max()})")

# ------------------------
# 3. APPLY THRESHOLD
# ------------------------
mask = array >= iso_value
n_voxels = np.sum(mask)
print(f"Threshold applied: {n_voxels} voxels above threshold")
if n_voxels == 0:
    raise ValueError("No voxels above threshold. Reduce iso_value.")

mask = remove_small_islands(mask, 30)
mask = mask_region_of_interest(mask,roi)
mask = create_blocks(mask,5,"X")

# Generate STL surface
simpleVolume = mrn.simpleVolumeFrom3Darray(np.float32(mask>0))
floatGrid = mr.simpleVolumeToDenseGrid(simpleVolume)
mesh_stl = mr.gridToMesh(floatGrid, mr.Vector3f(1.0, 1.0, 1.0), 0.5)
stl_path = Path(input_nrrd).stem + "_TETmesh.stl"
mr.saveMesh(mesh_stl, stl_path)

mesh_nuclei = pv.read(stl_path)
if mesh_nuclei.volume > 0.0:
    mesh_nuclei.decimate(target_reduction=0.8, inplace=True)

# ------------------------
# 4. TETRAHEDRALIZATION
# ------------------------
surface_mesh = pv.read(stl_path)
points = np.array(surface_mesh.points)
faces = surface_mesh.faces.reshape(-1, 4)[:, 1:]

mesh_info = MeshInfo()
mesh_info.set_points(points)
mesh_info.set_facets(faces.tolist())

tet_mesh = build(mesh_info, max_volume=1.0)

tet_points = np.array(tet_mesh.points)
tet_elements = []
for tet in tet_mesh.elements:
    tet_elements.extend([4, *tet])  # 4 nodes per tetra
tet_elements = np.array(tet_elements)
celltypes = np.full(len(tet_mesh.elements), pv.CellType.TETRA, dtype=np.uint8)

grid = pv.UnstructuredGrid(tet_elements, celltypes, tet_points)

# ------------------------
# DEBUGGING COORDINATES
# ------------------------
print("\n--- Coordinate Sanity Check ---")
print("Voxel index limits (z,y,x):", array.shape)
print("Centroid range:", tet_points.min(axis=0), tet_points.max(axis=0))

# ------------------------
# 6. DENSITY & YOUNG MODULUS SAMPLING
# ------------------------
cell_densities = []
cell_ymodulus = []
for tet in tet_mesh.elements:
    centroid = tet_points[tet].mean(axis=0)
    v_coord = centroid[[0, 1, 2]]
    if mask[round(v_coord[0]), round(v_coord[1]), round(v_coord[2])]<2:
        density_val = map_coordinates(array, v_coord.reshape(3, 1), order=1, mode='nearest')
        #Conversion
        density_val=(density_val[0]*3.85)/6.6e4
        ymodulus_val=(density_val**2.0)*1.4e4
        cell_densities.append(density_val)
        cell_ymodulus.append(ymodulus_val)
    else:
        cell_densities.append(5.0)
        cell_ymodulus.append(100000)

grid.cell_data["Density (g/cm3)"] = np.array(cell_densities)
grid.cell_data["YM (MPa)"] = np.array(cell_ymodulus)

# ------------------------
# 6. SAVE VTK
# ------------------------
vtk_path = Path(input_nrrd).stem + "_TETmesh_rot.vtk"
grid.save(vtk_path)
print(f"Tetrahedral mesh with densities saved to {vtk_path}")

[0, 70, 0, 70, 0, 140]
bottom_idx_zyx (full image): [ 33.33308626 204.11938329 144.64609493]
top_idx_zyx    (full image): [568.39482783 108.90119412 154.05157567]
bottom_phys (mm): [2.89273675 4.08212639 0.66661906]
top_phys    (mm): [ 3.08083433  2.17788449 11.36716901]
Angle between centroid vector and target axis: 10.138665 deg
Rotation matrix R:
 [[ 0.99984911  0.00152756 -0.01730382]
 [ 0.00152756  0.9845355   0.17517851]
 [ 0.01730382 -0.17517851  0.98438461]]
Output image shape (z,y,x): (643, 378, 325)  intensity range: 0 60098
Loaded image: shape=(160, 94, 81), intensity range=(0, 29846)
Threshold applied: 47979 voxels above threshold

--- Coordinate Sanity Check ---
Voxel index limits (z,y,x): (160, 94, 81)
Centroid range: [ 1.5        40.5        14.83333302] [69.5        66.5        59.29999924]
Tetrahedral mesh with densities saved to microCT_volume_preview045423_TETmesh_rot.vtk


In [47]:
from collections import defaultdict

tolerance = 0.01  # 1% difference allowed
sorted_elements = sorted(
    [(i + 1, E) for i, E in enumerate(cell_ymodulus)],
    key=lambda x: x[1]
)

E_groups = []
group_values = []

for elem_id, E in sorted_elements:
    placed = False
    for idx, g_val in enumerate(group_values):
        if abs(E - g_val) / g_val <= tolerance:
            E_groups[idx].append(elem_id)
            # update group value as average to keep cluster centered
            group_values[idx] = (group_values[idx] * (len(E_groups[idx]) - 1) + E) / len(E_groups[idx])
            placed = True
            break
    if not placed:
        group_values.append(E)
        E_groups.append([elem_id])

In [48]:
# Write INP
inp_path = Path(input_nrrd).stem + "_TETmesh_binned_E.inp"
with open(inp_path, "w") as f:
    f.write("*Heading\n")
    f.write("** Generated by Python script\n")
    
    # Nodes
    f.write("*Node\n")
    for i, p in enumerate(tet_points, start=1):
        f.write(f"{i}, {p[0]:.6f}, {p[1]:.6f}, {p[2]:.6f}\n")
    
    # Elements
    f.write("*Element, type=C3D4\n")
    for i, e in enumerate(tet_mesh.elements, start=1):
        f.write(f"{i}, {e[0]+1}, {e[1]+1}, {e[2]+1}, {e[3]+1}\n")

    # Materials and sections per group
    for mat_idx, (E_val, elems) in enumerate(zip(group_values, E_groups), start=1):
        f.write(f"*Elset, elset=ESET{mat_idx}\n")
        for e_id in elems:
            f.write(f"{e_id},\n")
        
        f.write(f"*Material, name=MAT{mat_idx}\n")
        f.write("*Elastic\n")
        f.write(f"{E_val:.6f}, 0.3\n")  # Poisson's ratio constant
        
        f.write(f"*Solid Section, elset=ESET{mat_idx}, material=MAT{mat_idx}\n")

In [4]:
import vtk

# --- Read rectilinear grid ---
reader = vtk.vtkRectilinearGridReader()
reader.SetFileName(Path(input_file).stem + ".vtk")
reader.Update()

rgrid = reader.GetOutput()

# --- Get max indices along each axis ---
dims = rgrid.GetDimensions()  # (nx, ny, nz) = number of points in each direction
xmax = dims[0] - 1
ymax = dims[1] - 1
zmax = dims[2] - 1

print(f"Max index X: {xmax}, Y: {ymax}, Z: {zmax}")

# --- Extract subset (VOI = Volume of Interest) ---
extract = vtk.vtkExtractRectilinearGrid()
extract.SetInputData(rgrid)

# Example: keep voxels from (x: 10–50, y: 20–80, z: 5–40)
extract.SetVOI(0, xmax,   # x-min, x-max (indices)
               0, 150,   # y-min, y-max
               0, 40)    # z-min, z-max

extract.Update()
subgrid = extract.GetOutput()

# --- Write subset to file ---
writer = vtk.vtkRectilinearGridWriter()
writer.SetFileName(Path(input_file).stem + ".vtk")
writer.SetInputData(subgrid)
writer.Write()

print(f"Subset saved as {Path(input_file).stem}.vtk")

Max index X: 238, Y: 256, Z: 298
Subset saved as microCT_volume_preview.vtk
