In [None]:
import open3d as o3d
from extract_context_teeth import extract_teeth_and_gums, get_response, \
    extract_and_trim_tooth_mesh, fill_top_with_highest_point, merge_scaled_tooth_with_remaining, \
    mesh_to_pcd
import numpy as np

from pathlib import Path 
import glob 
from scipy.sparse import csr_matrix 
from scipy.sparse.csgraph import connected_components
import os
from sklearn.neighbors import NearestNeighbors

from scipy.spatial.distance import cdist

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def extract_teeth(mesh, segments, tooth_labels = [31]):
    """
    Extract each segmented tooth from the mesh and return as dictionary of mesh objects
    Only the largest connected component for each tooth is returned
    
    Returns:
        dict: Dictionary where keys are segment_ids and values are Open3D mesh objects
    """
    # Convert mesh vertices and triangles to numpy arrays
    vertices = np.asarray(mesh.vertices)
    triangles = np.asarray(mesh.triangles)
    
    # Get unique segment IDs (each ID represents a tooth)
    # unique_segments = np.unique(segments)


    # THIS CODE DOESN'T EXTRACT ALL SEGMENT BUT ONLY SELECTED ONES
    unique_segments = np.array(tooth_labels)
    
    # Dictionary to store extracted teeth
    extracted_teeth = {}
    
    # Process each segment (tooth)
    for segment_id in unique_segments:
        # Skip segment 0 if it represents background/non-tooth area
        if segment_id == 0:
            continue
            
        # Get indices of faces belonging to this tooth
        tooth_face_indices = np.where(segments == segment_id)[0]
        
        if len(tooth_face_indices) == 0:
            print(f"No faces found for segment {segment_id}, skipping")
            continue
            
        # Get the triangles for this tooth
        tooth_triangles = triangles[tooth_face_indices]
        
        # Find connected components
        # First, build an adjacency matrix between faces
        face_count = len(tooth_triangles)
        
        # Create a dictionary to track shared vertices between faces
        edge_dict = {}
        
        # For each face
        for i, tri in enumerate(tooth_triangles):
            # For each edge in the face
            for j in range(3):
                v1, v2 = sorted([tri[j], tri[(j+1)%3]])  # Get vertices of the edge
                edge = (v1, v2)
                
                if edge in edge_dict:
                    # This edge is shared with another face
                    neighbor = edge_dict[edge]
                    # Build adjacency - faces i and neighbor are connected
                    if neighbor != i:  # Avoid self-loops
                        # Create adjacency
                        edge_dict[(v1, v2)] = i  # Update edge owner
                else:
                    # First time we see this edge
                    edge_dict[edge] = i
        
        # Build adjacency matrix for faces
        adjacency = np.zeros((face_count, face_count), dtype=bool)
        
        # For each face, find triangles that share an edge
        for i, tri in enumerate(tooth_triangles):
            for j in range(3):
                v1, v2 = sorted([tri[j], tri[(j+1)%3]])
                edge = (v1, v2)
                neighbor = edge_dict.get(edge)
                if neighbor is not None and neighbor != i:
                    adjacency[i, neighbor] = True
                    adjacency[neighbor, i] = True
        
        # Convert to sparse matrix and find connected components
        graph = csr_matrix(adjacency)
        n_components, labels = connected_components(graph, directed=False)
        
        print(f"Found {n_components} connected components for tooth segment {segment_id}")
        
        if n_components > 1:
            # Find the largest component
            component_sizes = np.bincount(labels)
            largest_component = np.argmax(component_sizes)
            
            # Filter triangles to keep only the largest component
            component_mask = (labels == largest_component)
            tooth_triangles = tooth_triangles[component_mask]
            tooth_face_indices = tooth_face_indices[component_mask]
            
            print(f"Keeping largest component with {component_sizes[largest_component]} faces")
        
        # Get all vertex indices used by this tooth
        vertex_indices = np.unique(tooth_triangles.flatten())
        
        # Create a mapping from original vertex indices to new indices
        vertex_map = {old_idx: new_idx for new_idx, old_idx in enumerate(vertex_indices)}
        
        # Get the vertices for this tooth
        tooth_vertices = vertices[vertex_indices]
        
        # Remap triangle indices to new vertex indices
        remapped_triangles = np.array([
            [vertex_map[idx] for idx in triangle]
            for triangle in tooth_triangles
        ])
        
        # Create a new mesh object for this tooth
        tooth_mesh = o3d.geometry.TriangleMesh()
        tooth_mesh.vertices = o3d.utility.Vector3dVector(tooth_vertices)
        tooth_mesh.triangles = o3d.utility.Vector3iVector(remapped_triangles)
        
        # Store the tooth mesh in the dictionary
        extracted_teeth[segment_id] = tooth_mesh
        
        print(f"Extracted tooth segment {segment_id} with {len(remapped_triangles)} faces")
        
    print(f"Extracted {len(extracted_teeth)} teeth")
    return extracted_teeth

In [3]:
def extract_individual_tooth_and_gum(mesh, segments, target_tooth_label, gum_label=0, min_triangles=100, gum_extension_factor=0, connectivity_threshold=99990):
    """
    Extract a single tooth and its corresponding gums from the mesh based on a specific tooth label.
    Only includes the target tooth label and gum label (0), excluding all other tooth labels.
    
    Parameters:
    -----------
    mesh : o3d.geometry.TriangleMesh
        The original mesh containing all teeth and gums
    segments : np.ndarray or list
        Array indicating which segment each triangle belongs to
    target_tooth_label : int
        The specific tooth label to extract (e.g., 32)
    gum_label : int, optional (default=0)
        The label for gum segments
    min_triangles : int, optional (default=100)
        Minimum number of triangles for a component to be kept
    gum_extension_factor : float, optional (default=0.5)
        Factor to determine how much of the gum to include below the teeth
    connectivity_threshold : float, optional (default=0.01)
        Distance threshold to determine gum connectivity to the tooth
        
    Returns:
    --------
    extracted_result : o3d.geometry.TriangleMesh
        Mesh containing the extracted tooth and corresponding gums
    remaining_mesh : o3d.geometry.TriangleMesh
        Mesh with the extracted parts removed
    result_segments : np.ndarray
        Segments information for the triangles in the result mesh
    """
    
    # Ensure segments is a numpy array
    segments = np.asarray(segments)
    
    # Get triangles and vertices of the original mesh
    triangles = np.asarray(mesh.triangles)
    vertices = np.asarray(mesh.vertices)
    
    # Create a mask for triangles in the target tooth segment only
    tooth_mask = (segments == target_tooth_label)
    
    # Check if we found any triangles in the target segment
    if not np.any(tooth_mask):
        raise ValueError(f"No triangles found for tooth label {target_tooth_label}")
    
    print(f"Found {np.sum(tooth_mask)} triangles for tooth label {target_tooth_label}")
    
    # STEP 1: CLEAN THE TOOTH SELECTION BY REMOVING SMALL COMPONENTS
    # Get triangles for the tooth selection
    tooth_triangles = triangles[tooth_mask]
    
    # Create a temporary mesh for the tooth
    temp_tooth_mesh = o3d.geometry.TriangleMesh()
    
    # Find all unique vertices used in the tooth triangles
    unique_vertices = np.unique(tooth_triangles.flatten())
    new_vertices = vertices[unique_vertices]
    
    # Create mapping from old indices to new indices
    vertex_map = {old_idx: new_idx for new_idx, old_idx in enumerate(unique_vertices)}
    
    # Remap triangle indices
    new_triangles = np.array([[vertex_map[v] for v in triangle] for triangle in tooth_triangles])
    
    # Set the tooth mesh geometry
    temp_tooth_mesh.vertices = o3d.utility.Vector3dVector(new_vertices)
    temp_tooth_mesh.triangles = o3d.utility.Vector3iVector(new_triangles)
    
    # Cluster the mesh to identify connected components
    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Error) as cm:
        cluster_indices = temp_tooth_mesh.cluster_connected_triangles()
    
    # Get component sizes and indices
    component_indices = np.asarray(cluster_indices[0])
    component_sizes = np.asarray(cluster_indices[1])
    
    # Create a mask for triangles in components large enough to keep
    keep_triangle_mask = np.array([component_sizes[c_idx] >= min_triangles for c_idx in component_indices])
    
    if not np.any(keep_triangle_mask):
        print("Warning: All tooth components are smaller than the minimum size. Keeping the largest component.")
        largest_component_idx = np.argmax(component_sizes)
        keep_triangle_mask = (component_indices == largest_component_idx)
    
    # Get the indices of the original triangles to keep
    original_tooth_triangle_indices = np.where(tooth_mask)[0]
    kept_original_indices = original_tooth_triangle_indices[keep_triangle_mask]
    
    # Update tooth mask to reflect only the kept triangles
    clean_tooth_mask = np.zeros_like(segments, dtype=bool)
    clean_tooth_mask[kept_original_indices] = True
    
    # print(f"After cleaning: {np.sum(clean_tooth_mask)} triangles for tooth label {target_tooth_label}")
    
    # STEP 2: FIND CONNECTED GUMS USING SPATIAL PROXIMITY AND CONNECTIVITY
    # Get vertices of the cleaned tooth
    clean_tooth_vertex_indices = np.unique(triangles[clean_tooth_mask].flatten())
    clean_tooth_vertices = vertices[clean_tooth_vertex_indices]
    
    # Determine the bounding box of the cleaned tooth
    min_bound = np.min(clean_tooth_vertices, axis=0)
    max_bound = np.max(clean_tooth_vertices, axis=0)
    
    # For dental models, typically y-axis is the vertical dimension
    vertical_axis = 1  # y-axis (0 for x-axis, 2 for z-axis)
    
    # Original height of the tooth
    tooth_height = max_bound[vertical_axis] - min_bound[vertical_axis]
    
    # Extend the bounding box downward to include gums
    extended_min_bound = min_bound.copy()
    extended_min_bound[vertical_axis] -= gum_extension_factor * tooth_height
    
    # STEP 3: IDENTIFY GUM TRIANGLES ONLY (EXCLUDE OTHER TOOTH LABELS)
    # Create a mask for gum triangles only
    gum_mask = (segments == gum_label)
    
    # Create a mask for gum triangles within the extended bounding box
    candidate_gum_mask = np.zeros_like(segments, dtype=bool)
    
    for i, triangle in enumerate(triangles):
        # Skip if not a gum triangle
        if not gum_mask[i]:
            continue
            
        triangle_vertices = vertices[triangle]
        
        # Check if any vertex of this triangle is within the extended bounding box
        for vertex in triangle_vertices:
            in_x_range = min_bound[0] <= vertex[0] <= max_bound[0]
            in_z_range = min_bound[2] <= vertex[2] <= max_bound[2]
            in_extended_y_range = extended_min_bound[1] <= vertex[1] <= max_bound[1]
            
            if in_x_range and in_extended_y_range and in_z_range:
                candidate_gum_mask[i] = True
                break
    
    # print(f"Found {np.sum(candidate_gum_mask)} candidate gum triangles in bounding box")
    
    # STEP 4: REFINE GUM SELECTION BASED ON PROXIMITY TO TOOTH
    # Calculate distances between gum triangle centroids and tooth vertices
    if np.sum(candidate_gum_mask) > 0:
        # Get centroids of candidate gum triangles
        candidate_gum_indices = np.where(candidate_gum_mask)[0]
        gum_triangle_centroids = []
        
        for idx in candidate_gum_indices:
            triangle = triangles[idx]
            centroid = np.mean(vertices[triangle], axis=0)
            gum_triangle_centroids.append(centroid)
        
        gum_triangle_centroids = np.array(gum_triangle_centroids)
        
        # Calculate distances from gum centroids to all tooth vertices
        distances = cdist(gum_triangle_centroids, clean_tooth_vertices)
        min_distances = np.min(distances, axis=1)
        
        # Determine distance threshold based on tooth size
        tooth_size = np.max(max_bound - min_bound)
        distance_threshold = tooth_size * connectivity_threshold
        
        # Keep only gum triangles that are close enough to the tooth
        close_gum_mask = min_distances <= distance_threshold
        final_gum_indices = candidate_gum_indices[close_gum_mask]
        
        # Create final mask for selected gums
        selected_gum_mask = np.zeros_like(segments, dtype=bool)
        selected_gum_mask[final_gum_indices] = True
        
        # print(f"Selected {np.sum(selected_gum_mask)} gum triangles based on proximity (threshold: {distance_threshold:.4f})")
    else:
        selected_gum_mask = np.zeros_like(segments, dtype=bool)
        print("No candidate gum triangles found")
    
    # STEP 5: COMBINE TOOTH AND SELECTED GUMS
    # Combine tooth and gum masks
    final_mask = clean_tooth_mask | selected_gum_mask
    
    # print(f"Final extraction: {np.sum(final_mask)} triangles total")
    # print(f"- Tooth triangles: {np.sum(clean_tooth_mask)}")
    # print(f"- Gum triangles: {np.sum(selected_gum_mask)}")
    
    # Get triangles for the final selection
    final_triangles = triangles[final_mask]
    
    # Create the result mesh
    result_mesh = o3d.geometry.TriangleMesh()
    
    # Find all unique vertices used in the final triangles
    final_unique_vertices = np.unique(final_triangles.flatten())
    final_vertices = vertices[final_unique_vertices]
    
    # Create mapping from old indices to new indices
    final_vertex_map = {old_idx: new_idx for new_idx, old_idx in enumerate(final_unique_vertices)}
    
    # Remap triangle indices
    final_new_triangles = np.array([[final_vertex_map[v] for v in triangle] for triangle in final_triangles])
    
    # Set the result mesh geometry
    result_mesh.vertices = o3d.utility.Vector3dVector(final_vertices)
    result_mesh.triangles = o3d.utility.Vector3iVector(final_new_triangles)
    result_mesh.compute_vertex_normals()
    
    # Store the segment information for the result triangles
    result_segments = segments[final_mask]
    
    # STEP 6: CREATE REMAINING MESH
    remaining_mesh = o3d.geometry.TriangleMesh()
    remaining_triangles = triangles[~final_mask]
    
    # Check if we have any remaining triangles
    if remaining_triangles.size == 0:
        print("Warning: No remaining triangles after extraction")
        remaining_mesh.vertices = o3d.utility.Vector3dVector(vertices)
        remaining_mesh.triangles = o3d.utility.Vector3iVector([])
    else:
        # Set the vertices and triangles for the remaining mesh
        remaining_mesh.vertices = o3d.utility.Vector3dVector(vertices)
        remaining_mesh.triangles = o3d.utility.Vector3iVector(remaining_triangles)
        
        # Remove unused vertices
        remaining_mesh.remove_unreferenced_vertices()
        remaining_mesh.compute_vertex_normals()
    
    # Verify that result only contains target tooth label and gum label
    unique_labels = np.unique(result_segments)
    print(f"Labels in result: {unique_labels}")
    
    # Check for unwanted labels
    unwanted_labels = unique_labels[(unique_labels != target_tooth_label) & (unique_labels != gum_label)]
    if len(unwanted_labels) > 0:
        print(f"Warning: Found unwanted labels in result: {unwanted_labels}")
    
    return result_mesh, remaining_mesh, result_segments

In [4]:
def extract_and_trim_tooth_mesh(mesh, keep_ratio=0.5):
    # Extract vertices and triangles
    vertices = np.asarray(mesh.vertices)
    triangles = np.asarray(mesh.triangles)

    # Compute Z range (height)
    z_min, z_max = vertices[:, 2].min(), vertices[:, 2].max()
    z_threshold = z_min + keep_ratio * (z_max - z_min)

    # Mask for keeping only the lower portion of the tooth
    keep_mask = vertices[:, 2] < z_threshold

    # Get indices of kept vertices
    kept_indices = np.where(keep_mask)[0]
    kept_index_set = set(kept_indices)

    # Remap old vertex indices to new indices
    index_map = {old_idx: new_idx for new_idx, old_idx in enumerate(kept_indices)}

    # Filter triangles: only keep triangles where all 3 vertices are kept
    new_triangles = []
    for tri in triangles:
        if all(v in kept_index_set for v in tri):
            new_triangles.append([index_map[v] for v in tri])

    # Create the new trimmed mesh
    trimmed_tooth = o3d.geometry.TriangleMesh()
    trimmed_tooth.vertices = o3d.utility.Vector3dVector(vertices[kept_indices])
    trimmed_tooth.triangles = o3d.utility.Vector3iVector(new_triangles)
    trimmed_tooth.remove_unreferenced_vertices()
    trimmed_tooth.compute_vertex_normals()

    return trimmed_tooth

In [5]:
three_teeth_segment_ids = [31, 32, 33]
tooth_id_to_trim = 32

out_dir = "outputs/GT_individual_tooth_for_DMC"

file_name = "teeth_files/01A91JH6_lower.obj"

mesh = o3d.io.read_triangle_mesh(file_name)

response = get_response(mesh)
segments = response['labels']

bounding_boxes = response['bounding_boxes']
color_map = response['color_map']

[128 128 128]
triangles.shape =  (15999, 3)
(47997,)


In [None]:
trim_segment_id = 32

for segment_id in three_teeth_segment_ids: 
    # final_result_mesh, _, _ = extract_individual_tooth_and_gum(mesh, segments, target_tooth_label = [segment_id])
    
    extracted_tooth = extract_teeth(mesh, segments, tooth_labels=[segment_id])
    final_result_mesh = extracted_tooth[segment_id]


    final_result_mesh.remove_degenerate_triangles()
    final_result_mesh.remove_duplicated_triangles()
    final_result_mesh.remove_duplicated_vertices()
    final_result_mesh.remove_non_manifold_edges()

    final_result_npy = mesh_to_pcd(final_result_mesh, num_points=1024) # 8192 points for Ground Truth

    base_file_name = os.path.basename(file_name)
    patient_id = base_file_name.split("_")[0]
    jaw_part = base_file_name.split("_")[1].split(".")[0]

    patient_dir = os.path.join(out_dir, patient_id)
    os.makedirs(patient_dir, exist_ok=True)  # Ensure directory exists

    OUTPUT_PATH = os.path.join(patient_dir, f"{patient_id}_fid_{segment_id}_{jaw_part}.pcd")
    
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(final_result_npy)
    o3d.io.write_point_cloud(OUTPUT_PATH, pcd)
    
    # np.save(OUTPUT_PATH, final_result_pcd)

    if trim_segment_id == segment_id: 

        o3d.io.write_triangle_mesh(os.path.join(patient_dir, f"{patient_id}_fid_{segment_id}_{jaw_part}.ply"), final_result_mesh)


        trimmed_tooth = extract_and_trim_tooth_mesh(final_result_mesh, keep_ratio=0.6)
        trimmed_filled_tooth = fill_top_with_highest_point(trimmed_tooth)

        trimmed_final_result_npy = mesh_to_pcd(trimmed_filled_tooth, num_points=1024) # 8192 points for Ground Truth

        TRIMMED_OUTPUT_PATH = os.path.join(patient_dir, f"{patient_id}_fid_{segment_id}_partial_{jaw_part}.pcd")

        

        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(trimmed_final_result_npy)
        o3d.io.write_point_cloud(TRIMMED_OUTPUT_PATH, pcd)

        

Found 8 connected components for tooth segment 31
Keeping largest component with 7197 faces
Extracted tooth segment 31 with 7197 faces
Extracted 1 teeth
Found 1 connected components for tooth segment 32
Extracted tooth segment 32 with 8069 faces
Extracted 1 teeth
Filling the top of the cropped tooth using the highest point...
Found 3 connected components for tooth segment 33
Keeping largest component with 7786 faces
Extracted tooth segment 33 with 7786 faces
Extracted 1 teeth
