In [17]:
import trimesh
import numpy as np
import copy
import os
from scipy.spatial import KDTree

import torch
import pyfqmr
import k3d
from pysdf import SDF
from ML.models.dilated_tooth_seg_network import LitDilatedToothSegmentationNetwork
from sklearn.decomposition import PCA
import random
from ML.teeth_numbering import color_mesh,_teeth_labels,_teeth_color
import igl
from lightning.pytorch import seed_everything
import fast_simplification 

# DEFs

In [18]:
def ML_extract_teeth(mesh,model):
    def process_mesh(mesh: trimesh, labels: torch.tensor = None):
        mesh_faces = torch.from_numpy(mesh.faces.copy()).float()
        mesh_triangles = torch.from_numpy(mesh.vertices[mesh.faces]).float()
        mesh_face_normals = torch.from_numpy(mesh.face_normals.copy()).float()
        mesh_vertices_normals = torch.from_numpy(mesh.vertex_normals[mesh.faces]).float()
        return mesh_faces, mesh_triangles, mesh_vertices_normals, mesh_face_normals, labels
    
    def preporces(data):
        mesh_faces, mesh_triangles, mesh_vertices_normals, mesh_face_normals, labels = data
        mesh = trimesh.Trimesh(**trimesh.triangles.to_kwargs(mesh_triangles.cpu().detach().numpy()))
    
        points = torch.from_numpy(mesh.vertices)
        v_normals = torch.from_numpy(mesh.vertex_normals)
    
        s, _ = mesh_faces.size()
        x = torch.zeros(s, 24).float()
        x[:, :3] = mesh_triangles[:, 0]
        x[:, 3:6] = mesh_triangles[:, 1]
        x[:, 6:9] = mesh_triangles[:, 2]
        x[:, 9:12] = mesh_triangles.mean(dim=1)
        x[:, 12:15] = mesh_vertices_normals[:, 0]
        x[:, 15:18] = mesh_vertices_normals[:, 1]
        x[:, 18:21] = mesh_vertices_normals[:, 2]
        x[:, 21:] = mesh_face_normals
    
        maxs = points.max(dim=0)[0]
        mins = points.min(dim=0)[0]
        means = points.mean(axis=0)
        stds = points.std(axis=0)
        nmeans = v_normals.mean(axis=0)
        nstds = v_normals.std(axis=0)
        nmeans_f = mesh_face_normals.mean(axis=0)
        nstds_f = mesh_face_normals.std(axis=0)
        for i in range(3):
            # normalize coordinate
            x[:, i] = (x[:, i] - means[i]) / stds[i]  # point 1
            x[:, i + 3] = (x[:, i + 3] - means[i]) / stds[i]  # point 2
            x[:, i + 6] = (x[:, i + 6] - means[i]) / stds[i]  # point 3
            x[:, i + 9] = (x[:, i + 9] - mins[i]) / (maxs[i] - mins[i])  # centre
            # normalize normal vector
            x[:, i + 12] = (x[:, i + 12] - nmeans[i]) / nstds[i]  # normal1
            x[:, i + 15] = (x[:, i + 15] - nmeans[i]) / nstds[i]  # normal2
            x[:, i + 18] = (x[:, i + 18] - nmeans[i]) / nstds[i]  # normal3
            x[:, i + 21] = (x[:, i + 21] - nmeans_f[i]) / nstds_f[i]  # face normal
    
        pos = x[:, 9:12]
    
        return pos, x, labels
    
    def PostProces(data_OG_def,x_def):
        _, mesh_triangles, _, mesh_face_normals, _ = data_OG_def
        mesh = trimesh.Trimesh(**trimesh.triangles.to_kwargs(mesh_triangles.cpu().detach().numpy()))
           
        maxs = mesh.vertices.max(axis=0)
        mins =  mesh.vertices.min(axis=0)
        means =  mesh.vertices.mean(axis=0)
        stds =  mesh.vertices.std(axis=0)
        nmeans = mesh.vertex_normals.mean(axis=0)
        nstds = mesh.vertex_normals.std(axis=0)
        nmeans_f = mesh_face_normals.mean(axis=0)
        nstds_f = mesh_face_normals.std(axis=0)
        for i in range(3):
            #  coordinate
            x_def[:, i] = (x_def[:, i] + means[i]) * stds[i]  # point 1
            x_def[:, i + 3] = (x_def[:, i + 3] + means[i]) * stds[i]  # point 2
            x_def[:, i + 6] = (x_def[:, i + 6] + means[i]) * stds[i]  # point 3
            x_def[:, i + 9] = (x_def[:, i + 9] + mins[i]) * (maxs[i] - mins[i])  # centre
            #  normal vector
            x_def[:, i + 12] = (x_def[:, i + 12] + nmeans[i]) * nstds[i]  # normal1
            x_def[:, i + 15] = (x_def[:, i + 15] + nmeans[i]) * nstds[i]  # normal2
            x_def[:, i + 18] = (x_def[:, i + 18] + nmeans[i]) * nstds[i]  # normal3
            x_def[:, i + 21] = (x_def[:, i + 21] + nmeans_f[i]) * nstds_f[i]  # face normal
        return x_def
    
    def get_pca_rotation(mesh):
        vertices = mesh.vertices
        pca = PCA(n_components=3)
        pca.fit(vertices)
        rotation_matrix = pca.components_
        if np.linalg.det(rotation_matrix) < 0:
            rotation_matrix[:, 2] = -rotation_matrix[:, 2]  # Adjusting the z-axis
        centroid = np.mean(vertices, axis=0)
        return rotation_matrix, centroid
    
    def align_meshes(mesh,src_rot_matrix,src_centroid,tgt_rot_matrix,tgt_centroid):
        # Compute the transformation matrix to align source to target
        transformation_matrix = src_rot_matrix.T @ tgt_rot_matrix
        # Apply  rotation
        rotated_vertices = mesh.vertices.dot(transformation_matrix)
        rotated_mesh = trimesh.Trimesh(vertices=rotated_vertices, faces=mesh.faces)
    
        # Translate source centroid to target centroid
        translation_vector = tgt_centroid - np.mean(rotated_mesh.vertices, axis=0)
        rotated_mesh.vertices += translation_vector
        return rotated_mesh
    
    def Downsample(mesh): 
        points_out, faces_out, collapses = fast_simplification.simplify(mesh.vertices, mesh.faces,(1-16000/mesh.faces.shape[0]) , return_collapses=True)
        points_out, faces_out, indice_mapping = fast_simplification.replay_simplification(mesh.vertices.astype('float32') , mesh.faces.astype('float32') , collapses.astype('int32'))
        mesh_simple = trimesh.Trimesh(vertices=points_out, faces=faces_out)
        
        vertices = mesh_simple.vertices
        faces = mesh_simple.faces
        if faces.shape[0] < 16000:
            fs_diff = 16000 - faces.shape[0]
            faces = np.append(faces, np.zeros((fs_diff, 3), dtype="int"), 0)
        elif faces.shape[0] > 16000:
            mesh_simple = trimesh.Trimesh(vertices=vertices, faces=faces)
            samples, face_index = trimesh.sample.sample_surface_even(mesh_simple, 16000)
            mesh_simple = trimesh.Trimesh(vertices=mesh_simple.vertices, faces=mesh_simple.faces[face_index])
            faces = mesh_simple.faces
            vertices = mesh_simple.vertices
        mesh_simple = trimesh.Trimesh(vertices=vertices, faces=faces)
        labels=np.zeros(faces.shape[0])
        return mesh_simple,labels,indice_mapping
    
    def Erosion(Extr_part_def,how_much_to_reduce=1):
         for lo_none in range(how_much_to_reduce): # reguliuoti kiek pasalinama
             Face_indices=np.arange(Extr_part_def.faces.shape[0])
             Boundary_edges_segment=Extr_part_def.edges[trimesh.grouping.group_rows(Extr_part_def.edges_sorted, require_count=1)]
             Boundary_edges_segment=np.unique(Boundary_edges_segment.flatten())
             mask=np.unique(Extr_part_def.vertex_faces[Boundary_edges_segment,:])[1:]     
             Extr_part_def.update_faces(Face_indices[~np.isin(Face_indices,mask)])
             Extr_part_def.remove_unreferenced_vertices()
           
             Extr_part_splited=Extr_part_def.split(only_watertight=False)
             Extr_part_def_number=np.argmax([i.area for i in Extr_part_splited])
             
             Extr_part_splited=Extr_part_splited[Extr_part_def_number]
            
         return Extr_part_splited   
      
    def Corect_plane(mesh_simple,src_rot_matrix, src_centroid): 
    # Ideja: rasti okliucija apibreziancias dantu virsuniu noramaliu vidurkiu.
    # Jei normales elemetai visi taigiami-kaip ir didzioji dalis dantu lanku duonbazeje tada nenaudojam flipinimo. 
         mesh_simple_first= align_meshes(mesh_simple,src_rot_matrix, src_centroid,tgt_rot_matrix,tgt_centroid)
         
         #_, _, max_pv_15, min_pv_15 = igl.principal_curvature(mesh_simple_first.vertices, mesh_simple_first.faces, radius=5)
         mean_curv=trimesh.curvature.discrete_mean_curvature_measure(mesh_simple, mesh_simple.vertices, 0.01)
         disc_means=np.tanh(mean_curv*np.arctanh((2**8 - 2) / (2**8 - 1)))
         return np.all(mesh_simple_first.vertex_normals[np.where(disc_means>0.7)].mean(axis=0)>0)
    
    #-----------Stable rotaion values for aligning to dataset -------------------
    tgt_rot_matrix =np.array([[ 0.99481732,  0.08303923, -0.05867689],
                              [-0.0920596 ,  0.98060088, -0.17305184],
                              [ 0.04316852,  0.17755674,  0.9831633 ]])
    
    tgt_centroid=np.array([[2.03561511,  -0.65064242, -90.05015842]])
   
    #----Import model
    mesh_OG=copy.copy(mesh)
    #----Downsample
    mesh_simple,labels,corespondance=Downsample(mesh)
    
    #----Correct aligment
    src_rot_matrix, src_centroid = get_pca_rotation(mesh_simple)
    Is_not_fliped=Corect_plane(mesh_simple,src_rot_matrix, src_centroid)
    
    if Is_not_fliped:
        mesh_simple= align_meshes(mesh_simple,src_rot_matrix, src_centroid,tgt_rot_matrix,tgt_centroid)
    else:
        flip_y_matrix = np.array([[-1, 0, 0],[0, 1, 0],[0, 0, -1]])
        tgt_rot_matrix = np.dot(tgt_rot_matrix, flip_y_matrix)
        mesh_simple= align_meshes(mesh_simple,src_rot_matrix, src_centroid,tgt_rot_matrix,tgt_centroid)
    
    #----Preporcess
    data = process_mesh(mesh_simple, torch.from_numpy(labels).long())
    data_OG=copy.copy(data)
    
    data =preporces(data)
    
    #----Use model
    # ground_truth = data[2]
    pre_labels = model.predict_labels(data).cpu().numpy()
    
    x=PostProces(data_OG,data[1]) # Postprocess
    
    triangles = x[:, :9].reshape(-1, 3, 3)
    mesh = trimesh.Trimesh(**trimesh.triangles.to_kwargs(triangles.cpu().detach().numpy()))
    mesh= align_meshes(mesh,tgt_rot_matrix,tgt_centroid,src_rot_matrix, src_centroid) # Back to original
    mesh_simple=align_meshes(mesh_simple,tgt_rot_matrix,tgt_centroid,src_rot_matrix, src_centroid) 
    
    mesh_pred = color_mesh(mesh, pre_labels)
    
    #----Extract teeth from original mesh
    # """
    test_mesh_predicted=copy.copy(mesh_pred)
    Store_segments=[]
    # Which_teeth_store=[]
    for i in range(len(np.unique(pre_labels))):
        first_unique_color = np.unique(test_mesh_predicted.visual.face_colors, axis=0)[i]
        matches_first_color = np.where(np.all(test_mesh_predicted.visual.face_colors == first_unique_color, axis=1))[0]
        Vert_decimated=np.unique(test_mesh_predicted.faces[matches_first_color])
        
        tree = KDTree(mesh_simple.vertices)
        Vert_decimated_true=[]
        for ids in Vert_decimated:
            _, indices = tree.query(test_mesh_predicted.vertices[ids], k=1)
            Vert_decimated_true.append(indices)
     
        Where_are_they=np.where(np.isin(corespondance,Vert_decimated_true))[0]
        
        mask = np.isin(mesh_OG.faces, Where_are_they)
        faces_with_vertices_mask = mask.any(axis=1)
        unique_faces = np.where(faces_with_vertices_mask)[0]
        unique_faces=np.unique(np.hstack(unique_faces))
         
        Extr_part=mesh_OG.submesh([unique_faces])[0] # create submesh
    
        #-------------REMOVE some triangles (erosion)
        Which_teeth=[i for i in np.unique(pre_labels) if np.all(first_unique_color[:3]==_teeth_color[i])][0]
            
        if Which_teeth!=0:
            Extr_part=Erosion(Extr_part)
        
        Extr_part.metadata['file_name']=_teeth_labels[Which_teeth] # changing names
        Extr_part.metadata['color']=np.array(_teeth_color[Which_teeth])/255
        
        Store_segments.append(Extr_part)
     
    return Store_segments

def plane_fit(X): # plane construct
    p = np.mean(X, axis=0) # point
    R = X - p
    V, D, _ = np.linalg.svd(R.T)
    n = V[:, -1]# plane normal
    V = V[:, :-1] # principal directio
    return n,p

def project_into_plane(Normal, Point, New_koord_vektor): # plane normal and point, New_koord_vektor point for projection
    v = New_koord_vektor - Point
    dist = np.dot(v, Normal[:, np.newaxis]).flatten()
    projected_points = New_koord_vektor - dist[:, np.newaxis] * Normal
    return projected_points

# Parameters to adjust

In [19]:
# Change contact range in your desire
Contac_parameter={
    "Contact_detected":0.04,  # contact detected
    "Good_contact":-0.015, 
    "Bad_contact":-0.05}    

# Change directory names
Dir_UPPER='UPPER_dir' # directory contains maxilla models. File names should follow a numbering format such as 1.stl, 2.stl, etc.
Dir_LOWER='LOWER_dir' # The directory contains mandibula models. File names should follow a numbering format such as 1.stl, 2.stl, etc.

#----ML----
SEED = 42
use_gpu=True
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
torch.set_float32_matmul_precision('medium')
random.seed(SEED)
seed_everything(SEED, workers=True)
# Model
model = LitDilatedToothSegmentationNetwork.load_from_checkpoint('ML/ML_param.ckpt')
if use_gpu==True:
   model = model.cuda()

Seed set to 42


# MAIN

In [20]:
Dir_UPPER_list=os.listdir(Dir_UPPER)
Dir_LOWER_list=os.listdir(Dir_LOWER)

for upper,lower in zip(Dir_UPPER_list,Dir_LOWER_list):
    Up=trimesh.load(os.path.join(Dir_UPPER,upper)) 
    Lo=trimesh.load(os.path.join(Dir_LOWER,lower))

    Segmented_parts=ML_extract_teeth(Up,model)
    
    Lo_OG=copy.copy(Lo)
    Up_OG=copy.copy(Up)
    
    #---- Initial segmentation
    SDsDF_uper, _, _ = igl.signed_distance(Up.vertices.astype(float), Lo.vertices.astype(float), Lo.faces.astype(int))
    SDsDF_uper = np.where(SDsDF_uper < Contac_parameter["Contact_detected"])

    #----Sample initial
    Up.update_faces(np.unique(Up.vertex_faces[SDsDF_uper[0]])[1:])
    Up.remove_unreferenced_vertices()
    
    initial_sample=abs(Contac_parameter["Contact_detected"])/2
    select_sample=(initial_sample)**2
    Sample_size=int(Up.area/(select_sample))
    
    Up_sample,face_index=trimesh.sample.sample_surface_even(Up,Sample_size)
    Up_sampes_normals=Up.face_normals[face_index]
    
    #----Refine found samples
    SDsDF_uper, _, _ = igl.signed_distance(Up_sample.astype(float), Lo.vertices.astype(float), Lo.faces.astype(int))
    Up_sample=Up_sample[np.where(SDsDF_uper <= Contac_parameter["Contact_detected"]),:][0]
    Up_sampes_normals=Up_sampes_normals[np.where(SDsDF_uper <= Contac_parameter["Contact_detected"]),:][0]
    
    SDsDF_uper=SDsDF_uper[np.where(SDsDF_uper <= Contac_parameter["Contact_detected"])]
    
    SDsDF_uper_OG=copy.deepcopy(SDsDF_uper)
    
    SDsDF_uper_colors = np.zeros(SDsDF_uper.shape[0], dtype=np.int8)
    SDsDF_uper_colors[(SDsDF_uper >= Contac_parameter["Bad_contact"]) & (SDsDF_uper <  Contac_parameter["Good_contact"])] = 1  
    SDsDF_uper_colors[SDsDF_uper >= Contac_parameter["Good_contact"]] = 2   
    
    # ---- Stats
    Area_all=((Up.area*Up_sample.shape[0])/Sample_size).round(0)
    Threshold_Good_contac=((Up.area*np.where(SDsDF_uper > Contac_parameter["Good_contact"])[0].shape[0])/Sample_size).round(0)
    Threshold_Bad_contac=((Up.area*np.where(SDsDF_uper < Contac_parameter["Bad_contact"])[0].shape[0])/Sample_size).round(0)
    
    # colors
    palette = {0: 0xFF0000, 1: 0xFFA500, 2: 0x00FF00}
    
    labels = SDsDF_uper_colors.astype(int)
    colors = np.array([palette[val] for val in labels], dtype=np.uint32)

    print('----'+upper[0].split('.')[0]+' file ----')
    print(f"All contact area: {Area_all} mm²")
    print(f"Good contact (green): {Threshold_Good_contac} mm²")
    print(f"Bad contact (red): {Threshold_Bad_contac} mm²")
    print(f"Intermediate contact (orange): {(Area_all-Threshold_Good_contac-Threshold_Bad_contac).round(0)} mm²")
    
    # --- Plot results---
    plot = k3d.plot(grid_visible=False, background_color=0xFFFFFF)
    k3d_mesh = k3d.mesh(Up_OG.vertices, Up_OG.faces, color=0xCCCCCC, wireframe=False, opacity=1.0, flat_shading=False,name='Main model')
    k3d_points = k3d.points(Up_sample, point_size=0.05, colors=colors)
    plot += k3d_mesh
    plot += k3d_points
    
    NO_model_contact_list=[]
    for ids,ire in enumerate(Segmented_parts):
            f=SDF(ire.vertices,ire.faces)
            if np.where(f(Up_sample) > 10e-09)[0].size == 0 and (ire.metadata['file_name'] != 'gum') :
              NO_model_contact_list.append(ire)
    
    NO_model_contact=trimesh.util.concatenate(NO_model_contact_list)
    k3d_mesh_combine = k3d.mesh(NO_model_contact.vertices, NO_model_contact.faces, color=0x717171, wireframe=False, opacity=1.0, flat_shading=False,name="Teeth with no occlusion")
    plot += k3d_mesh_combine
     
    plot.display()

  return np.all(mesh_simple_first.vertex_normals[np.where(disc_means>0.7)].mean(axis=0)>0)
  ret = um.true_divide(


----0 file ----
All contact area: 187.0 mm²
Good contact (green): 33.0 mm²
Bad contact (red): 123.0 mm²
Intermediate contact (orange): 31.0 mm²




Output()