In [2]:
import numpy as np
from sklearn.cluster import KMeans
from stl import mesh  # You may need to install the 'numpy-stl' library for STL file manipulation
from tqdm import tqdm
import math
import matplotlib.pyplot as plt
from itertools import cycle
import plotly.express as px
import pandas as pd

import matplotlib
import numpy as np
import open3d as o3d
from PIL import ImageColor
import matplotlib.pyplot as plt

from skimage.filters import gaussian,unsharp_mask,threshold_multiotsu
from skimage.io import imshow
from skimage import color
from skimage.transform import hough_circle, hough_circle_peaks
from skimage.feature import canny
from skimage.draw import circle_perimeter,disk
from skimage.util import img_as_ubyte
from skimage.exposure import rescale_intensity

from scipy.spatial.distance import pdist,squareform
from sklearn.neighbors import KNeighborsClassifier

from random import randint
matplotlib.rcParams['font.size'] = 9

# Define a set of distinct colors with strong contrast
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
          '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
          '#ff9896', '#aec7e8', '#ffbb78', '#98df8a', '#ff9896',
          '#c5b0d5', '#c49c94', '#f7b6d2', '#c7c7c7', '#dbdb8d',
          '#9edae5', '#ad494a', '#8c6d31', '#aec7e8', '#ffbb78',
          '#393b79', '#5254a3', '#6b6ecf', '#9c9ede', '#637939',
          '#8ca252', '#b5cf6b', '#cedb9c', '#8c6d31', '#bd9e39',
          '#e7ba52', '#e7cb94', '#843c39', '#ad494a', '#d6616b']
label_color_map = dict(zip(list(range(40)), cycle(colors)))


def compute_angle_with_y_axis(points):
    """
    Compute the angle between the array of points and the y-axis.
    
    Parameters:
        points (numpy.ndarray): Array of shape (n, 3) representing n points in 3D space.
        
    Returns:
        float: Angle (in radians) between the array of points and the y-axis.
    """
    # Compute the vector from the origin to the centroid of the points
    centroid = np.mean(points, axis=0)
    centroid_normalized = centroid / np.linalg.norm(centroid)
    
    # Compute the angle between the centroid vector and the y-axis
    angle = np.arccos(np.dot(centroid_normalized, [1, 0, 0]))
    return angle

class Teeth:
    def __init__(self, data, num_cluster):
        self.class_number = num_cluster
        self.data=data
    def rotate_teeth(self, theta,axis):
            # Convert angle to radians
        angle_radians = math.radians(theta)

    # Define rotation matrices for each axis
        if axis == 0:
            rotation_matrix = np.array([[1, 0, 0],
                                        [0, np.cos(theta), -np.sin(theta)],
                                        [0, np.sin(theta), np.cos(theta)]])
        elif axis == 1:
            rotation_matrix = np.array([[np.cos(theta), 0, np.sin(theta)],
                                        [0, 1, 0],
                                        [-np.sin(theta), 0, np.cos(theta)]])
        elif axis == 2:
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                        [np.sin(theta), np.cos(theta), 0],
                                        [0, 0, 1]])

            # Apply rotation to each point in the teeth data
        self.data[["x","y","z"]] = np.dot(self.data.values[:3], rotation_matrix.T)

    def translate_teeth(self, translation):
        col=["x","y","z"]
        for i,transl in enumerate(translation):
            self.data[[col[i]]] +=transl

    def plot(self):
        array_color=[ImageColor.getrgb(color) for color in self.data['color']]
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(np.array(self.data[['x','y','z']].values))
        pcd.colors = o3d.utility.Vector3dVector(np.array(array_color)/255)
        pcd.estimate_normals()
        o3d.visualization.draw_geometries([pcd])

In [3]:
def segmentation(df,num_teeth): 
    img=np.zeros((int(np.max(df["x"])-np.min(df["x"])+1)*10,int(np.max(df["y"])-np.min(df["y"])+1)*10))
    df_temp=df.copy()
    df_temp["x"]=df_temp["x"]-np.min(df["x"])
    df_temp["y"]=df_temp["y"]-np.min(df["y"])
    for row in df_temp.index:
        row_=df_temp.loc[row]
        img[int(row_["x"]*10),int(row_["y"]*10)]+=1
    image=rescale_intensity(1-img.T,out_range=(0,1))

    result_3 = unsharp_mask(image, radius=20, amount=10)
    good_image=rescale_intensity(gaussian(result_3,sigma=2),out_range=(0,1))
    image2=rescale_intensity(gaussian(result_3,sigma=2),out_range=(0,1))
    image2[image2>=0.9999]=0
    # Applying multi-Otsu threshold for the default value, generating
    # three classes.
    thresholds = threshold_multiotsu(image2)

    # Using the threshold values, we generate the three regions.
    regions = np.digitize(image2, bins=thresholds)
    # Load picture and detect edges
    image = (regions==1).astype(np.uint8)
    #edges = canny(image, sigma=1, low_threshold=10, high_threshold=50)


    # Detect two radii
    hough_radii = np.arange(40, 80, 2)
    hough_res = hough_circle(image, hough_radii)

    # Select the most prominent 3 circles
    accums, cx, cy, radii = hough_circle_peaks(hough_res, hough_radii,min_ydistance=50,min_xdistance=50,total_num_peaks=9)

    x_list=np.sort(np.array([cy,cx]).T,axis=0)[:,1]
    y_list=np.sort(np.array([cy,cx]).T,axis=0)[:,0]
    window_x=x_list[np.argmax(np.diff(x_list))],x_list[np.argmax(np.diff(x_list))+1]
    window_y=y_list[np.argmax(np.diff(x_list))],y_list[np.argmax(np.diff(x_list))+1]


    # Load picture and detect edges
    image = (regions[:int(1.5*max(window_y)),window_x[0]:window_x[1]]==1).astype(np.uint8)
    #edges = canny(image, sigma=1, low_threshold=10, high_threshold=50)
    # Detect two radii
    hough_radii = np.arange(25, 40, 2)
    hough_res = hough_circle(image, hough_radii)

    # Select the most prominent 3 circles
    accums, cx_2, cy_2, radii_2 = hough_circle_peaks(hough_res, hough_radii,min_ydistance=30,min_xdistance=10,total_num_peaks=7)



    dict_color=np.zeros(20,dtype=object)
    for q in range(20):
        dict_color[q]=[randint(0, 255),randint(0, 255),randint(0, 255)]

    center_y=np.concatenate([cy,cy_2])
    center_x=np.concatenate([cx,cx_2+window_x[0]])
    radii_=np.concatenate([radii,radii_2])

    disk_list=np.array([center_x,center_y,radii_]).T

    dst=squareform(pdist(disk_list[:,:2]))
    dst[np.where(dst==0)]=np.inf
    to_delete=np.unique(np.max(np.array(np.where(dst<50)).T,axis=1))
    disk_list=np.delete(disk_list,to_delete,axis=0)

    hex_=['#%02x%02x%02x' % tuple(x) for x in dict_color]

    col=np.zeros(len(df_temp))-1
    for i,disk_ in enumerate(disk_list):
        within_i=(((df_temp[["x"]].values*10-disk_[0])**2+(df_temp[["y"]].values*10-disk_[1])**2)<=1.25*disk_[2]**2).astype(int).flatten()
        col[within_i==1]=i
    colors=[hex_[int(i)] for i in col]

    df_temp["label"]=col
    df_temp["color"]=colors
    X=df_temp[df_temp["label"]!=-1][["x","y","z"]].values
    y=np.ravel(df_temp[df_temp["label"]!=-1][['label']].values)

    neigh = KNeighborsClassifier(n_neighbors=10,algorithm="kd_tree")
    neigh.fit(X, y)
    label__=neigh.predict(df_temp[df_temp["label"]==-1][['x','y','z']].values)
    df_copy=df_temp[df_temp["label"]==-1].copy()
    df_copy['label']=list(label__)
    colors=[hex_[int(i)] for i in label__]
    df_copy['color']=colors=[hex_[int(i)] for i in label__]
    df_final=pd.concat([df_temp[df_temp["label"]!=-1],df_copy])
    return df_final,image,regions,disk_list,good_image,dict_color


In [6]:
class Jaw:

    def __init__(self, stl_files, num_teeth,jaw):
        self.stl_files = stl_files
        self.num_teeth=num_teeth
        self.jaw=jaw
        self.dict_teeth_name = {
        'Central Incisor': 0,
        'Lateral Incisor': 1,
        'Canine (Cuspid)': 2,
        'First Premolar (Bicuspid)': 3,
        'Second Premolar (Bicuspid)': 4,
        'First Molar': 5,
        'Second Molar': 6,
        'Third Molar (Wisdom Tooth)': 7,
        'Upper Right Central Incisor': 8,
        'Upper Right Lateral Incisor': 9,
        'Upper Right Canine (Cuspid)': 10,
        'Upper Right First Premolar (Bicuspid)': 11,
        'Upper Right Second Premolar (Bicuspid)': 12,
        'Upper Right First Molar': 13,
        'Upper Right Second Molar': 14,
        'Upper Right Third Molar (Wisdom Tooth)': 15,
        'Upper Left Central Incisor': 16,
        'Upper Left Lateral Incisor': 17,
        'Upper Left Canine (Cuspid)': 18,
        'Upper Left First Premolar (Bicuspid)': 19,
        'Upper Left Second Premolar (Bicuspid)': 20,
        'Upper Left First Molar': 21,
        'Upper Left Second Molar': 22,
        'Upper Left Third Molar (Wisdom Tooth)': 23,
        'Lower Right Central Incisor': 24,
        'Lower Right Lateral Incisor': 25,
        'Lower Right Canine (Cuspid)': 26,
        'Lower Right First Premolar (Bicuspid)': 27,
        'Lower Right Second Premolar (Bicuspid)': 28,
        'Lower Right First Molar': 29,
        'Lower Right Second Molar': 30,
        'Lower Right Third Molar (Wisdom Tooth)': 31,
        'Lower Left Central Incisor': 32,
        'Lower Left Lateral Incisor': 33,
        'Lower Left Canine (Cuspid)': 34,
        'Lower Left First Premolar (Bicuspid)': 35,
        'Lower Left Second Premolar (Bicuspid)': 36,
        'Lower Left First Molar': 37,
        'Lower Left Second Molar': 38,
        'Lower Left Third Molar (Wisdom Tooth)': 39
    }

        self.dict_teeth={"l":{0:32,1:31,2:30,3:29,4:28,5:27,6:26,7:25,8:24,9:23,10:22,11:21,12:20,13:19,14:18,15:17,},
                         "u":{0:1,1:2,2:3,3:4,4:5,5:6,6:7,7:8,8:9,9:10,10:11,11:12,12:13,13:14,14:15,15:16}}
        self.teeth_list = self.load_teeth()
        
    def cluster_teeth(self,df,num_teeth):
        return segmentation(df,num_teeth)
    
    def get_correct_labels(self,cluster_centers):
        return {i:np.argsort(np.argsort(cluster_centers,axis=0)[:,0])[i] for i in range(len(cluster_centers))}

    def getTeethByIdx(self,idx):
        return self.teeth_list[idx]
    
    def getTeethByName(self,label):
        dentist_nb=self.dict_teeth_name[label]
        idx=np.where(np.array(list(self.dict_teeth[self.jaw].values()))==dentist_nb)[0][0]
        return self.getTeethByIdx(idx)

    def load_teeth(self):
        teeth_list = []
        for i,stl_file in  enumerate(self.stl_files):
            mesh_data = mesh.Mesh.from_file(stl_file).vectors.reshape(-1, 3)
            df=pd.DataFrame(mesh_data,columns=['x','y','z'])
            df_final,image,regions,disk_list,good_image,dict_color=self.cluster_teeth(df,self.num_teeth[i])
            self.data_jaw=df_final
            label_dict=self.get_correct_labels(disk_list[:,:2])
            # Print or use the clustered data
            self.num_teeth[i]=len(np.unique(df_final[['label']]))
            for label in np.unique(df_final[['label']]):
                df_temp=df_final[df_final['label']==label]
                teeth_instance = Teeth(df_temp, self.dict_teeth[self.jaw][label_dict[label]])
                teeth_list.append(teeth_instance)
        return teeth_list
    
          
    def compute_distances_and_depths(self,distances=[]):
        # Assuming specific indices for the canine, second premolar, and first molar
        # Replace these indices with the actual indices of your teeth clusters
        res=[]
        for elements in distances:
            if type(elements)==list:
                teeth_1=self.getTeethByName(elements[0])
                teeth_2=self.getTeethByName(elements[1])
                res.append(np.linalg.norm(teeth_1.data - teeth_2.data))
            elif type(elements)==str:
                if elements.__contains__('canine') or elements.__contains__('premolar') or elements.__contains__('first_molar'):
                    teeth_1=self.getTeethByName(elements[0])
                    res.append(np.max(teeth_1.data[:,2]) - np.min(teeth_1.data[:,2]))
        return res
    
    def print_summary(self,detailed=False):
        print("Jaw Summary:")
        print("-" * 40)
        print(f"Number of Teeth: {len(self.teeth_list)}")
        print()
        if detailed:
            for i, teeth_instance in enumerate(self.teeth_list):
                print(f"Teeth {i + 1} Summary:")
                print(f"  Class Number: {teeth_instance.class_number}")
                print(f"  Data Shape: {teeth_instance.data.shape}")
                print()
        else:
            avg=0
            for i, teeth_instance in enumerate(self.teeth_list):
                avg+=teeth_instance.data.shape[0]
            print(f"  Avg size of teeth: {avg/np.mean(self.num_teeth)}")
            print()
            
    def align_z(self):
        avg_z=[]
        for i in range(len(self)):
            avg_z.append(np.mean(self.getTeethByIdx(i).data.values[:,2]))
        avg_z=np.mean(avg_z)
        for i in range(len(self)):
            self.getTeethByIdx(i).data.loc[:, ('z',)]+=avg_z-np.mean(self.getTeethByIdx(i).data["z"])
    def correct_rotation(self):
        # Compute the current angle between the points and the y-axis
        for i in range(len(self)):
            teeth=self.getTeethByIdx(i)
            angle = compute_angle_with_y_axis(teeth.data)
            teeth.rotate_teeth(angle,1)

    def plot(self):
        array_color=[ImageColor.getrgb(color) for color in self.data_jaw['color']]
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(np.array(self.data_jaw[['x','y','z']].values))
        pcd.colors = o3d.utility.Vector3dVector(np.array(array_color)/255)
        pcd.estimate_normals()
        o3d.visualization.draw_geometries([pcd])

    def __len__(self):
        return self.num_teeth[0]
# Example Usage:
# Provide a list of paths to your STL files and the desired number of clusters
#stl_files = [r"C:\Users\Emman\Desktop\JE\Ortho\STL-Segmentation\OrthoCAD_Export_43495989\43495989_shell_occlusion_l.stl"]
stl_files=['/Users/ecordina/Desktop/je/STL-Segmentation/OrthoCAD_Export_43495989/43495989_shell_occlusion_l.stl']
#r"C:\Users\Emman\Desktop\JE\Ortho\STL-Segmentation\OrthoCAD_Export_43495989\43495989_shell_occlusion_u.stl"]
num_clusters = [16,16]

# Create Jaw object
jaw_instance = Jaw(stl_files, num_clusters,'l')
jaw_instance.print_summary()

Jaw Summary:
----------------------------------------
Number of Teeth: 13

  Avg size of teeth: 39285.10344827586



In [8]:
jaw_instance.plot()

In [16]:
for i in range(11):
    print(np.percentile(jaw_instance.getTeethByIdx(i).data['z'],10))

28.723033905029297
29.41584701538086
28.84005355834961
28.12866973876953
28.146963119506836
28.647314071655273
29.762094497680664
28.274660110473633
29.109350204467773
28.112308502197266
27.836198806762695


In [98]:
# fig = px.scatter_3d(
#     df_final, x="x", y="y", z="z", labels="label",color='color',color_discrete_map="identity"
# )
# fig.update_layout(scene_aspectmode='manual',
#         scene_aspectratio=dict(x=1, y=1, z=1))
# fig.show()

In [99]:
# img=np.zeros((int(np.max(df["x"])-np.min(df["x"])+1)*10,int(np.max(df["y"])-np.min(df["y"])+1)*10,int(np.max(df["z"])-np.min(df["z"])+1)*10))
# df_temp=df.copy()
# df_temp["x"]=df_temp["x"]-np.min(df["x"])
# df_temp["y"]=df_temp["y"]-np.min(df["y"])
# df_temp["z"]=df_temp["z"]-np.min(df["z"])
# for row in df_temp.index:
#     row_=df_temp.loc[row]
#     img[int(row_["x"]*10),int(row_["y"]*10),int(row_["z"]*10)]+=1