In [None]:
import trimesh
import numpy as np
import torch
import open3d as o3d
import json
import torch.nn as nn 
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
# mesh = trimesh.load("/home/shirshak/STL_client_data/11/2024-07-05_00003-InoueYui-11Cr-11-crown_cad.stl")
# mesh.show()
import glob 
from tqdm import tqdm 
import time
import os 
from pathlib import Path
from sklearn.decomposition import PCA
import re

# Utility Functions

In [None]:
def load_point_cloud(file_path, num_points=2048):
    """Load and preprocess a single obj file into point cloud file."""
    mesh = o3d.io.read_triangle_mesh(file_path)
    mesh.compute_vertex_normals()
    o3d.utility.random.seed(12345)

    pcd = mesh.sample_points_uniformly(number_of_points=num_points)

    points = np.asarray(pcd.points, dtype=np.float32)

    return torch.from_numpy(points)


def mesh_to_voxel_grid(mesh, voxel_size=2):
    """Convert a mesh to a voxel grid using Open3D."""
    return o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=voxel_size)

def load_voxel_grid(file_path, voxel_size=2):
    """Load and preprocess a single obj file into a voxel grid."""
    mesh = o3d.io.read_triangle_mesh(file_path)
    return mesh_to_voxel_grid(mesh, voxel_size)

def load_data_from_json(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    for entry in data:
        entry['feature_vector'] = np.array(entry['feature_vector'])
    return data

def compute_voxel_dice_score(voxel_grid1, voxel_grid2):
    """
    Compute the Dice score between two voxel grids.
    """
    # Get the set of voxel coordinates for each voxel grid
    voxels1 = set([tuple(voxel.grid_index) for voxel in voxel_grid1.get_voxels()])
    voxels2 = set([tuple(voxel.grid_index) for voxel in voxel_grid2.get_voxels()])
    # Compute intersection of the voxel sets
    intersection = voxels1.intersection(voxels2)
    # Dice Score = 2 * |A ∩ B| / (|A| + |B|)
    dice_score = (2 * len(intersection)) / (len(voxels1) + len(voxels2)) if (len(voxels1) + len(voxels2)) > 0 else 0
    return dice_score

def get_similar_teeth_paths(pid, fid_and_similarity_score, base_path=""):
    
    left_tooth_category = "upper" if int(fid_and_similarity_score[0][0].split("fid")[-1]) < 30 else "lower"
    right_tooth_category = "upper" if int(fid_and_similarity_score[1][0].split("fid")[-1]) < 30 else "lower"
    opposite_tooth_category = "upper" if int(fid_and_similarity_score[2][0].split("fid")[-1]) < 30 else "lower"

    left_similar_teeth = f"{base_path}/{pid}_{left_tooth_category}_{fid_and_similarity_score[0][0]}.obj"
    right_similar_teeth = f"{base_path}/{pid}_{right_tooth_category}_{fid_and_similarity_score[1][0]}.obj"
    opposite_similar_teeth = f"{base_path}/{pid}_{opposite_tooth_category}_{fid_and_similarity_score[2][0]}.obj"

    similar_teeth_paths = [left_similar_teeth, right_similar_teeth, opposite_similar_teeth]

    return similar_teeth_paths

# DGCNN model 

In [None]:
def intermediate(x, xx):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    torch.cuda.empty_cache()
    return -xx - inner

def knn(x, k):
    x = x.to(torch.float16)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = intermediate(x, xx) - xx.transpose(2, 1)
    torch.cuda.empty_cache()
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx

def get_graph_feature(x, device, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 is False:
            idx = knn(x, k=k)
        else:
            idx = knn(x[:, 6:], k=k)

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base
    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature

class DGCNN(nn.Module):
    def __init__(self, device, output_channels=16,input_dims=3, k =20, emb_dims = 1024, dropout= 0.5):
        super(DGCNN, self).__init__()
        self.device = device
        self.input_dims = input_dims
        self.k = k
        self.emb_dims = emb_dims
        self.dropout = dropout
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(self.emb_dims)

        self.conv1 = nn.Sequential(nn.Conv2d(self.input_dims*2, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, self.emb_dims, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(self.emb_dims*2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=self.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=self.dropout)
        self.linear3 = nn.Linear(256, output_channels)

    def forward(self, x):
        batch_size = x.size(0)
        x = get_graph_feature(x, k=self.k, device=self.device)      # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k, device=self.device)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k, device=self.device)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 128, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        x = get_graph_feature(x3, k=self.k, device=self.device)     # (batch_size, 128, num_points) -> (batch_size, 128*2, num_points, k)
        x = self.conv4(x)                       # (batch_size, 128*2, num_points, k) -> (batch_size, 256, num_points, k)
        x4 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 256, num_points, k) -> (batch_size, 256, num_points)

        x = torch.cat((x1, x2, x3, x4), dim=1)  # (batch_size, 64+64+128+256, num_points)

        x = self.conv5(x)                       # (batch_size, 64+64+128+256, num_points) -> (batch_size, emb_dims, num_points)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x = torch.cat((x1, x2), 1)              # (batch_size, emb_dims*2)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
        x = self.dp2(x)
        # x = self.linear3(x)                                             # (batch_size, 256) -> (batch_size, output_channels)
        
        return x

# Similarity Search Class

In [None]:
class SimilaritySearch:
    def __init__(self):
        torch.manual_seed(0)
        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        self.load_model_and_checkpoint()
        self.model.eval()

    def load_model_and_checkpoint(self):
        self.model = DGCNN(device=self.device, output_channels=32)
        self.model = self.model.to(self.device)

        self.checkpoint = torch.load("best_model.pth")
        self.model.load_state_dict(self.checkpoint["model_state_dict"])
        
        self.model.dp1 = nn.Identity()
        self.model.dp2 = nn.Identity()
        self.model.linear3 = nn.Identity()

    def find_top_n_similar_feature_vectors(self, query_vector, data, top_n=-1):
        all_feature_vectors = np.array([entry['feature_vector'] for entry in data])
        similarities = cosine_similarity(query_vector, all_feature_vectors).flatten()
        if top_n != -1:
            top_n_indices = np.argsort(-similarities)[:top_n]
        else:
            top_n_indices = np.argsort(-similarities)[:]
        
        similarities = [similarities[index] for index in top_n_indices]
        return top_n_indices, similarities
    
    def get_pid_fid_from_indices(self, indices, data):
        pids = []
        fids = []
        for idx in indices:
            thumbnail_path = data[idx]['thumbnail_location']
            pids.append(thumbnail_path.split('/')[-1].split('_')[0])
            fids.append(thumbnail_path.split('/')[-1].split('.')[0].split('_')[-1])
        return pids, fids
    
    def pack_json(self, pids, fids, similarity_score, dice_score):
        ranked_data = [
            {
                'pid': pid,
                'fid': fid,
                'similarity_score': score,
                'dice_score': dice
            }
            for pid, fid, score, dice in zip(pids, fids, similarity_score, dice_score)
        ]

        # Return the list of dictionaries, FastAPI will handle JSON serialization
        return ranked_data
    
    def get_similarity(self, obj_path):
        data_point_cloud_orig = load_point_cloud(obj_path)
        data_point_cloud = data_point_cloud_orig.to(self.device).unsqueeze(0).permute(0,2,1)

        with torch.no_grad():
            original_feature_256 = self.model(data_point_cloud)

        original_feature_256 = original_feature_256.cpu().numpy()

        feature_data = load_data_from_json('feature_info.json')

        top_10_indices, simil = self.find_top_n_similar_feature_vectors(original_feature_256, feature_data, top_n=-1)

        pids, fids = self.get_pid_fid_from_indices(top_10_indices, feature_data)

        return pids, fids, simil

    # There might be some missing tooth as well, so we need to take that into account
    def filter_ids_with_min_teeth(self, all_results, min_teeth=3):
        filtered_results = {}
        for id_code, teeth_data in all_results.items():
            num_teeth = len(teeth_data) if hasattr(teeth_data, '__len__') else 0
            # Only keep IDs with at least min_teeth
            if num_teeth >= min_teeth:
                filtered_results[id_code] = teeth_data
        return filtered_results

    def get_dice(self, all_results, teeth_paths, pid):
        dice_scores = []
        dice_scores_dict = {}

        for row, teeth_path in zip(all_results[pid], teeth_paths):
            if int(row[0].split("fid")[-1]) < 30:
                category = 'upper'
            if int(row[0].split("fid")[-1]) > 30:
                category = 'lower'

            sim_obj_path = f"/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/{pid}_{category}_{row[0]}.obj"

            orig_mesh = load_voxel_grid(teeth_path)
            similar_mesh = load_voxel_grid(sim_obj_path)

            dice_score = compute_voxel_dice_score(orig_mesh, similar_mesh)

            dice_scores.append(dice_score)

            dice_scores_dict[sim_obj_path] = dice_score

        return dice_scores_dict, dice_scores, float(np.array(dice_scores).mean())

    def get_tooth_path(self, indiv_tooth_fid_with_similarity, pid):
        if int(indiv_tooth_fid_with_similarity[0].split("fid")[-1]) < 30:
            category = 'upper'
        else:
            category = 'lower'
        tooth_path = f"/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/{pid}_{category}_{indiv_tooth_fid_with_similarity[0]}.obj"
        return tooth_path
    

    def get_similarity_multiple_teeth(self, teeth_paths, tooth_labels):
        if len(teeth_paths) != len(tooth_labels):
            raise ValueError("Number of tooth paths must match number of tooth labels")
        
        # Similarity for each tooth
        all_results = {}
        for tooth_path, tooth_label in tqdm(zip(teeth_paths, tooth_labels)):
            pids, fids, similarity_scores = self.get_similarity(tooth_path)

            for pid, fid, similarity_score in zip(pids, fids, similarity_scores):
                # print(pid)
                # print(fid)
                fid_number = int(fid.split("fid")[-1])

                if fid_number == tooth_label: # because fid ====> fid32, so to only extract 32
                    if pid in all_results:
                        all_results[pid].append([fid, similarity_score])
                    else:
                        all_results[pid] = [[fid, similarity_score]]
        
        # print(all_results)

        avg_similarity_score = {}
        for result_key in all_results.keys():
            avg_similarity_score[result_key] = sum(item[1] for item in all_results[result_key]) / len(all_results[result_key])

        # Sort the avg_similarity score
        avg_similarity_score = sorted(avg_similarity_score.items(), key=lambda item: item[1], reverse=True)

        avg_dice_score = {}
        top_10_similar_embeddings = 10
        for i in range(top_10_similar_embeddings):
            pid = avg_similarity_score[i][0]
            # print(self.get_dice(all_results, teeth_paths, pid))
            _, _, total_dice = self.get_dice(all_results, teeth_paths, pid)
            avg_dice_score[pid] = total_dice

        # Sort the avg dice score
        avg_dice_score = sorted(avg_dice_score.items(), key=lambda item: item[1], reverse=True)
        print(avg_dice_score)

        pid = avg_dice_score[2][0] # Gives PID For 2nd position, as most similar on 2nd position tooth
        avg_dice_score = avg_dice_score[2][1] # Gives PID For 2nd position, as most similar on 1st position tooth

        # print("-----------------------------------")
        # print(self.get_dice(all_results, teeth_paths, pid))
        dice_scores_dict, dice_scores, _ = self.get_dice(all_results, teeth_paths, pid)

        # print(dice_scores_dict.items())
        # print(dice_scores)

        return pid, all_results[pid], {"dice_scores_dict":dice_scores_dict, "avg_dice_score":avg_dice_score}, 


    def get_similar_crowns(self, tooth_obj_path, label_damaged):
        data_point_cloud_orig = load_point_cloud(tooth_obj_path, num_points=2048)
        data_point_cloud = data_point_cloud_orig.transpose(0, 1).unsqueeze(0).to(self.device)
        # print(data_point_cloud.shape)
        orig_mesh = load_voxel_grid(tooth_obj_path)
        # print(data_point_cloud.shape)

        with torch.no_grad():
            original_feature_256 = self.model(data_point_cloud)

        original_feature_256 = original_feature_256.cpu().numpy()
        feature_data = load_data_from_json('feature_info_client_data.json')

        new_feature_data = self.get_feature_data_of_selected_label_only(feature_data, label_damaged)

        top_10_indices, simil = self.find_top_n_similar_feature_vectors(original_feature_256, new_feature_data, top_n=-1)

        # print(len(top_10_indices))
        # print(feature_data)
        final_scores = {}
        for index, similarity in zip(top_10_indices, simil):
            sim_obj_path = new_feature_data[index]['mesh_location']
            # print(sim_obj_path)

            similar_mesh = load_voxel_grid(sim_obj_path)
            dice_score = compute_voxel_dice_score(orig_mesh, similar_mesh)
            final_scores[sim_obj_path] =[f"Similarity Score : {similarity}", f"Dice Score : {dice_score}"]
        # Sort the final scores according to the dice_scores 
        sorted_final_scores = dict(sorted(final_scores.items(), key=lambda item: item[1][1], reverse=True))
        # print(sorted_final_scores)
        return sorted_final_scores

    def get_feature_data_of_selected_label_only(self, feature_data, label_damaged):
        data_json = []
        # print(len(feature_data))
        for i in range(len(feature_data)):
            # print(feature_data[i]['label'])
            if int(feature_data[i]['label']) == int(label_damaged):
                data_json.append({
                "mesh_location": feature_data[i]['mesh_location'],
                "label": feature_data[i]['label'], 
                "feature_vector": feature_data[i]['feature_vector']
                })
        # print(data_json)
        return data_json

# Extract Individual Teeth

In [None]:
fids_lower = [31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48]
fids_upper = [11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28]


def extract_teeth(indir, outdir):
    # print("Out Dir")
    # print(outdir)

    json_file_paths = sorted(glob.glob(indir + "*.json"))
    mesh_paths = sorted(glob.glob(indir + "*.obj"))
    

    for (json_file_path, mesh_path) in zip(json_file_paths, mesh_paths):
        print(json_file_path)
        print(mesh_path)
        
        base_mesh_name = mesh_path.split("/")[-1]
        base_mesh_name = base_mesh_name.replace(".obj", "")
        
        with open(json_file_path) as json_file:
            json_data = json.load(json_file)

        face_labels = [json_data["cells"][i]["fdi"] for i in range(len(json_data["cells"]))]
        face_labels = np.array(face_labels)

        with open(mesh_path) as obj_file:
            obj_lines = obj_file.readlines()

        v_start = 0
        while v_start < len(obj_lines) and not obj_lines[v_start].startswith("v "):
            v_start += 1
        
        f_start = v_start
        while f_start < len(obj_lines) and not obj_lines[f_start].startswith("f "):
            f_start += 1

        if not os.path.exists(outdir):
            os.makedirs(outdir)

        if "lower"  in base_mesh_name.lower() or "preparation" in base_mesh_name.lower():
            fids_selection = fids_lower
        elif "upper"  in base_mesh_name.lower() or "antagonist" in base_mesh_name.lower():
            fids_selection = fids_upper

        for teeth_no in fids_selection:
            start_time = time.time()
            
            faces_to_extract = np.where(face_labels == teeth_no)[0]
            
            if len(faces_to_extract) == 0:
                continue
            
            vertices_to_extract = set()
            valid_faces = []
            
            for face_idx in faces_to_extract:
                if f_start + face_idx >= len(obj_lines):
                    continue
                    
                face_line = obj_lines[f_start + face_idx].strip()
                if not face_line.startswith("f "):
                    continue
                    
                parts = face_line.split()
                if len(parts) < 4:
                    continue
                    
                try:
                    v1 = int(parts[1].split('/')[0]) - 1
                    v2 = int(parts[2].split('/')[0]) - 1  
                    v3 = int(parts[3].split('/')[0]) - 1
                    
                    vertices_to_extract.add(v1)
                    vertices_to_extract.add(v2)
                    vertices_to_extract.add(v3)
                    valid_faces.append((v1, v2, v3))
                    
                except (ValueError, IndexError):
                    continue
            
            if len(vertices_to_extract) == 0:
                print(f"No valid vertices found for tooth {teeth_no}")
                continue
                
            vertices_to_extract = sorted(list(vertices_to_extract))
            vertex_mapping = {old_idx: new_idx + 1 for new_idx, old_idx in enumerate(vertices_to_extract)}
            
            output_path = os.path.join(outdir, f"{base_mesh_name}_fid{teeth_no-1}.obj") # TODO REMOVE THE -1 TERM FROM TEETH_NO. NOW WE'RE JUST USING IT BECAUSE TEETH 3DS DOESN'T HAVE LABEL 48 TOOTH SO......
            
            with open(output_path, "w") as new_obj_file:
                for vertex_idx in vertices_to_extract:
                    if v_start + vertex_idx < len(obj_lines):
                        vertex_line = obj_lines[v_start + vertex_idx]
                        if vertex_line.startswith("v "):
                            new_obj_file.write(vertex_line)
                
                for v1, v2, v3 in valid_faces:
                    if v1 in vertex_mapping and v2 in vertex_mapping and v3 in vertex_mapping:
                        new_v1 = vertex_mapping[v1]
                        new_v2 = vertex_mapping[v2]
                        new_v3 = vertex_mapping[v3]
                        new_obj_file.write(f"f {new_v1} {new_v2} {new_v3}\n")
                    
            end_time = time.time()
            print(f"Time taken for extracting tooth {teeth_no} from {indir}: {end_time - start_time:.2f} seconds")

            print(output_path)
        
        
def read_ply_file(ply_path):
    """
    Read PLY file and preserve vertex ordering
    """
    mesh = o3d.io.read_triangle_mesh(ply_path)
    vertices = np.asarray(mesh.vertices)
    faces = np.asarray(mesh.triangles) + 1
    
    if mesh.has_vertex_colors():
        rgb = np.asarray(mesh.vertex_colors)
    else:
        rgb = np.full((len(vertices), 3), 0.501)
    
    print(f"Read PLY file: {len(vertices)} vertices, {len(faces)} faces")
    return vertices, rgb, faces

def convert_ply_to_obj(ply_path):
    """
    Convert PLY to OBJ format while preserving vertex order
    """
    vertices, rgb, faces = read_ply_file(ply_path)
    vertex_count = len(vertices)

    obj_path = Path(ply_path).with_suffix(".obj")

    with open(obj_path, 'w') as f:
        for v, c in zip(vertices, rgb):
            f.write(f"v {v[0]} {v[1]} {v[2]} {c[0]} {c[1]} {c[2]}\n")
        for face in faces:
            f.write(f"f {face[0]} {face[1]} {face[2]}\n")
    
    return obj_path, vertex_count

# Taking a jaw scan of patient in ply and conversion of it in obj : lower_jaw_1.ply => lower_jaw_1.obj  & upper_jaw_1.ply => upper_jaw_1.obj

In [None]:
# _, _ = convert_ply_to_obj("/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA/pdf_2c_20220629_084758_ITERO(100008941)_0/pdf_2c_20220629_084758_ITERO(100008941)_0-AntagonistScan.ply")
# _, _ = convert_ply_to_obj("/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA/pdf_2c_20220629_084758_ITERO(100008941)_0/pdf_2c_20220629_084758_ITERO(100008941)_0-PreparationScan.ply")

_, _ = convert_ply_to_obj("/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA_2/JAW_1234565/lower_jaw_1.ply")
_, _ = convert_ply_to_obj("/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA_2/JAW_1234565/upper_jaw_1.ply")

# Extraction of each individual teeth of the patient.

In [None]:
mesh_folder_paths = glob.glob("/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA_2/JAW_1234565/")      #input path of the original raw data
outdir = "/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA_2/JAW_1234565/individual_teeth/"       #output folder of the indivial teeth


for mesh_folder_path in mesh_folder_paths:
    # print(mesh_folder_path)
    patient_id = mesh_folder_path.split("/")[-2]
    # print(patient_id)
    start_time = time.time()
    extract_teeth(mesh_folder_path, os.path.join(outdir))
    # extract_teeth(mesh_folder_path, os.path.join(outdir + patient_id))
    end_time = time.time()
    print(f"\n Time taken for extracting all tooth from IOS: {patient_id}: {end_time - start_time:.2f} seconds")

In [None]:
left_tooth = 47
damaged_tooth = 46
right_tooth = 45
opposite_tooth = 16

# Visualize Client Data

In [None]:
mesh_name_2 = "/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA_2/JAW_1234565/lower_jaw_1.ply"
mesh_name_1 = "/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA_2/JAW_1234565/upper_jaw_1.ply"

mesh1 = o3d.io.read_triangle_mesh(mesh_name_1)
mesh2 = o3d.io.read_triangle_mesh(mesh_name_2)

# center_m1 = mesh1.get_center()
# mesh1.translate(-center_m1)

# center_m2 = mesh2.get_center()
# mesh2.translate(-center_m2)

from IPython.display import display, HTML
display(HTML(f'<h1>Client Data : Damaged Jaw Visualization</h1>'))

# o3d.visualization.draw_plotly([mesh1])
# o3d.visualization.draw_plotly([mesh2])

# Combine both meshes in single plot
# mesh2 = mesh2.translate([20, 0, -10])

print(f"Jaw with one smaller tooth (Abutment Tooth) ==> Lower Jaw : {os.path.basename(mesh_name_1)}")
print(f"Teeth without abutment tooth ==> Upper Jaw : {os.path.basename(mesh_name_2)}")
print(f"Abutment tooth FDI Number : {damaged_tooth}")


# o3d.visualization.draw_plotly([mesh1])
# o3d.visualization.draw_plotly([mesh2])

o3d.visualization.draw_plotly([mesh1, mesh2])

#### For damaged tooth 46, select its left (47) and right (45) adjacent tooth, and opposite tooth (16)

In [None]:
left_tooth = 47
damaged_tooth = 46
right_tooth = 45
opposite_tooth = 16


left_tooth_path, right_tooth_path, opposite_tooth_path = "", "", ""


for file_name in sorted(glob.glob("/home/shirshak/00_teeth_similarity_matching/CLIENT_DATA_2/JAW_1234565/individual_teeth/*")):

    label = int(file_name.split("/")[-1].split(".obj")[0].split("fid")[-1])

    if label == left_tooth:
        left_tooth_path = file_name
    
    if  label == right_tooth:
        right_tooth_path = file_name

    if label == opposite_tooth:
        opposite_tooth_path = file_name

assert left_tooth_path != "" and right_tooth_path != "" and opposite_tooth_path != ""

In [None]:
teeth_paths = [left_tooth_path, right_tooth_path, opposite_tooth_path]
tooth_labels = [left_tooth, right_tooth, opposite_tooth]

teeth_paths, tooth_labels

# Similarity search algorithm to extract most similar jaw of teeth3ds database of 900 jaws from left, right, opposite adjacent teeth of abutment teeth

In [None]:
similarity_search = SimilaritySearch()
pid, fid_and_similarity_score, dice_dict =  similarity_search.get_similarity_multiple_teeth(teeth_paths, tooth_labels)

In [None]:
dice_individual_dict = dice_dict["dice_scores_dict"]
dice_avg_score = dice_dict["avg_dice_score"]

In [None]:
tooth_labels[0]

In [None]:
final_dice_scores = {}

for path, dice_score in dice_individual_dict.items():
    jaw_name = os.path.basename(path).split(".obj")[0]
    fid_num = os.path.basename(path).split(".obj")[0].split("fid")[1]

    if int(fid_num) == int(tooth_labels[0]):
        final_dice_scores["left_dice"] = dice_score
    
    elif int(fid_num) == int(tooth_labels[1]):
        final_dice_scores["right_dice"] = dice_score

    elif int(fid_num) == int(tooth_labels[2]):
        final_dice_scores["opposite_dice"] = dice_score

    final_dice_scores["jaw_name"] = jaw_name

In [None]:
final_dice_scores

##### From above we see, the same jaw AD8EQEUR is the most similar tooth. Now extraction of each of similar tooth parts.

In [None]:
similar_teeth_paths = get_similar_teeth_paths(pid, fid_and_similarity_score, base_path="/home/shirshak/Teeth3DS_individual_teeth/individual_teeth")
similar_teeth_paths

In [None]:
# mesh1 = o3d.io.read_triangle_mesh(teeth_paths[0])
# mesh2 = o3d.io.read_triangle_mesh(similar_teeth_paths[0])

# mesh2 = mesh2.translate([13, 0, 0])

# o3d.visualization.draw_plotly([mesh1, mesh2], window_name="Given Tooth VS Similar Tooth", mesh_show_wireframe=True)

# Visualize most similar right, left and opposite tooth from Teeth3DS Database of 900 jaw scans

In [None]:
def align_orientation(mesh_to_align, reference_mesh):
    """
    Align the orientation of mesh_to_align to match reference_mesh using PCA.
    
    Args:
        mesh_to_align: The mesh whose orientation needs to be changed
        reference_mesh: The reference mesh to align to
    
    Returns:
        Aligned mesh (copy of mesh_to_align with corrected orientation)
    """
    
    # Create copies to avoid modifying original meshes
    mesh_copy = o3d.geometry.TriangleMesh(mesh_to_align)
    
    # Get vertices as numpy arrays
    vertices_to_align = np.asarray(mesh_copy.vertices)
    vertices_reference = np.asarray(reference_mesh.vertices)
    
    # Perform PCA on both meshes to find principal components
    pca_to_align = PCA(n_components=3)
    pca_reference = PCA(n_components=3)
    
    pca_to_align.fit(vertices_to_align)
    pca_reference.fit(vertices_reference)
    
    # Get the principal component matrices (rotation matrices)
    components_to_align = pca_to_align.components_
    components_reference = pca_reference.components_
    
    # Ensure consistent orientation of principal components
    # Flip components if they point in opposite directions
    for i in range(3):
        if np.dot(components_to_align[i], components_reference[i]) < 0:
            components_to_align[i] *= -1
    
    # Calculate rotation matrix to align principal components
    # R = R_ref * R_to_align^T
    rotation_matrix = components_reference.T @ components_to_align
    
    # Apply rotation to the mesh
    mesh_copy.rotate(rotation_matrix, center=(0, 0, 0))
    
    return mesh_copy

In [None]:
def transform_meshes(teeth_path, similar_teeth_path, which_tooth = "", dice_score = 0, rotation_degree = None):
    mesh1 = o3d.io.read_triangle_mesh(teeth_path)
    mesh2 = o3d.io.read_triangle_mesh(similar_teeth_path)

    center_m1 = mesh1.get_center()
    mesh1.translate(-center_m1)

    center_m2 = mesh2.get_center()
    mesh2.translate(-center_m2)

    mesh2 = align_orientation(mesh2, mesh1)

    from IPython.display import display, HTML
    display(HTML(f'<h1>{which_tooth} Most Similar</h1>'))
    # o3d.visualization.draw_plotly([mesh1])
    # o3d.visualization.draw_plotly([mesh2])

    # Combine both meshes in single plot
    mesh2 = mesh2.translate([15, 0, 0])

    print(f"Left ==> Original Tooth : {os.path.basename(teeth_path)}")
    print(f"Right ==> Most Similar Tooth : {os.path.basename(similar_teeth_path)}")
    print(f"Dice Score = {dice_score}")

    return mesh1, mesh2

    # o3d.visualization.draw_plotly([mesh1, mesh2])

    # o3d.visualization.draw_plotly([mesh1])
    # o3d.visualization.draw_plotly([mesh2])

In [None]:
which_tooth = "Left Tooth"
dice_score = final_dice_scores["left_dice"]
mesh1, mesh2 = transform_meshes(teeth_paths[0], similar_teeth_paths[0], which_tooth, dice_score)

rotation_degree = 180
rotation_matrix = mesh2.get_rotation_matrix_from_axis_angle([0, 0, np.radians(rotation_degree)])
mesh2.rotate(rotation_matrix, center=(0, 0, 0))

rotation_matrix = mesh1.get_rotation_matrix_from_axis_angle([np.radians(47), 0, 0])
mesh1.rotate(rotation_matrix, center=(0, 0, 0))


o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:
which_tooth = "Right Tooth"
dice_score = final_dice_scores["right_dice"]
mesh1, mesh2 = transform_meshes(teeth_paths[1], similar_teeth_paths[1], which_tooth, dice_score)


rotation_matrix = mesh2.get_rotation_matrix_from_axis_angle([np.radians(15), np.radians(0), np.radians(5)])
mesh2.rotate(rotation_matrix, center=(0, 0, 0))

rotation_matrix = mesh1.get_rotation_matrix_from_axis_angle([np.radians(90), np.radians(55), np.radians(-15)])
mesh1.rotate(rotation_matrix, center=(0, 0, 0))



o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:
which_tooth = "Opposite Tooth"
dice_score = final_dice_scores["opposite_dice"]
mesh1, mesh2 = transform_meshes(teeth_paths[2], similar_teeth_paths[2], which_tooth, dice_score)

rotation_matrix = mesh2.get_rotation_matrix_from_axis_angle([np.radians(0), np.radians(0), np.radians(180)])
mesh2.rotate(rotation_matrix, center=(0, 0, 0))

rotation_matrix = mesh1.get_rotation_matrix_from_axis_angle([np.radians(0), np.radians(0), np.radians(180)])
mesh1.rotate(rotation_matrix, center=(0, 0, 0))


o3d.visualization.draw_plotly([mesh1, mesh2])

# Visualize Teeth3DS most similar jaw with Client jaw

In [None]:
mesh1 = o3d.io.read_triangle_mesh("AD8EQEUR_lower.obj")
mesh2 = o3d.io.read_triangle_mesh("AD8EQEUR_upper.obj")

rotation_matrix = mesh2.get_rotation_matrix_from_axis_angle([np.radians(180), np.radians(0), np.radians(0)])
mesh2.rotate(rotation_matrix, center=(0, 0, 0))


rotation_matrix = mesh1.get_rotation_matrix_from_axis_angle([np.radians(0), np.radians(0), np.radians(180)])
mesh1.rotate(rotation_matrix, center=(0, 0, 0))

mesh2.translate([0, 0, -100])  # move UPPER jaw down
mesh1.translate([0, 0, 57])  # move LOWER jaw up



from IPython.display import display, HTML
display(HTML(f'''<h1> Most Similar Jaw : {os.path.basename("AD8EQEUR_lower.obj").split("_")[0]} </h1>'''))

o3d.visualization.draw_plotly([mesh1, mesh2])

# Visualize Teeth3DS most similar jaw's FID tooth (same FID as abutment tooth)

In [None]:
# mesh = similar_teeth_paths[0].split(".obj")[0].split("fid")[1]
new_path = re.sub(r'fid\d+', f'fid{damaged_tooth}', similar_teeth_paths[0])

mesh = o3d.io.read_triangle_mesh(new_path)

center_m = mesh.get_center()
mesh.translate(-center_m)


from IPython.display import display, HTML
display(HTML(f'''
             <h1> Most Similar Jaw : {os.path.basename(new_path).split("_")[0]}</h1>
             <h1> Tooth FID : {os.path.basename(new_path).split(".obj")[0].split("fid")[1]} </h1>
             '''))
print(f"Teeth 3DS most similar jaw's FID Tooth : {os.path.basename(new_path)}")
print(f"It is tooth whose Most Similar Crown to be extracted")

o3d.visualization.draw_plotly([mesh])

# Find the most similar crown and its visualization

In [None]:
# print(output_file)
if damaged_tooth < 30:
    category = 'upper'
else:
    category = 'lower'

In [None]:
output_file = f"/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/{pid}_{category}_fid{damaged_tooth}.obj"
final_crown_templates = similarity_search.get_similar_crowns(output_file, label_damaged=damaged_tooth)

In [None]:
json_holder = {}
json_holder["Most Similar Jaw's Abutment Tooth Filename :"] = output_file

for i in range(len(final_crown_templates)): 
    file_i = list(final_crown_templates.items())[i]
    json_holder[f"{i} Similar"] = file_i

In [None]:
json_holder

In [None]:
def transform_meshes(teeth_path, similar_teeth_path, which_tooth = "", dice_score = 0, rotation_degree = None):
    mesh1 = o3d.io.read_triangle_mesh(teeth_path)
    mesh2 = o3d.io.read_triangle_mesh(similar_teeth_path)

    center_m1 = mesh1.get_center()
    mesh1.translate(-center_m1)

    center_m2 = mesh2.get_center()
    mesh2.translate(-center_m2)

    mesh2 = align_orientation(mesh2, mesh1)

    from IPython.display import display, HTML
    display(HTML(f'<h1>{which_tooth} Most Similar</h1>'))
    # o3d.visualization.draw_plotly([mesh1])
    # o3d.visualization.draw_plotly([mesh2])

    # Combine both meshes in single plot
    mesh2 = mesh2.translate([15, 0, 0])

    print(f"Left ==> Original Tooth : {os.path.basename(teeth_path)}")
    print(f"Right ==> Most Similar Tooth : {os.path.basename(similar_teeth_path)}")
    print(f"Dice Score = {dice_score}")

    return mesh1, mesh2

    # o3d.visualization.draw_plotly([mesh1, mesh2])

    # o3d.visualization.draw_plotly([mesh1])
    # o3d.visualization.draw_plotly([mesh2])

In [None]:
def compare_most_similar_crown(abutment_teeth_most_similar_jaw_filename, most_similar_crown):
    mesh1 = o3d.io.read_triangle_mesh(abutment_teeth_most_similar_jaw_filename)
    mesh2 = o3d.io.read_triangle_mesh(most_similar_crown)

    center_m1 = mesh1.get_center()
    mesh1.translate(-center_m1)

    center_m2 = mesh2.get_center()
    mesh2.translate(-center_m2)

    mesh2 = mesh2.translate([15, 0, 0])

    return mesh1, mesh2 
    # o3d.visualization.draw_plotly([mesh1])
    # o3d.visualization.draw_plotly([mesh2])
    # o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:
abutment_teeth_most_similar_jaw_filename = json_holder["Most Similar Jaw's Abutment Tooth Filename :"]
most_similar_crown = json_holder["0 Similar"][0]

similarity_score = json_holder["0 Similar"][1][0]
dice_score = json_holder["0 Similar"][1][1]

mesh1, mesh2 = compare_most_similar_crown(abutment_teeth_most_similar_jaw_filename, most_similar_crown)
mesh1.translate([0, 0, 1.5])
from IPython.display import display, HTML
display(HTML(f'<h1>Most Similar Crown</h1>'))

print(f"Left ==> Most Similar Crown : {os.path.basename(most_similar_crown)}")
print(f"Right ==> Teeth3DS Similar Jaw's Tooth of same FDI Label as abutment : {os.path.basename(abutment_teeth_most_similar_jaw_filename)}")
print(dice_score)
print(similarity_score)

o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:
abutment_teeth_most_similar_jaw_filename = json_holder["Most Similar Jaw's Abutment Tooth Filename :"]
most_similar_crown = json_holder["1 Similar"][0]

similarity_score = json_holder["1 Similar"][1][0]
dice_score = json_holder["1 Similar"][1][1]

mesh1, mesh2 = compare_most_similar_crown(abutment_teeth_most_similar_jaw_filename, most_similar_crown)

# mesh1.translate([0, 0, 1.5])
from IPython.display import display, HTML
display(HTML(f'<h1>2nd Most Similar Crown</h1>'))

print(f"Left ==> 2nd Most Similar Crown : {os.path.basename(most_similar_crown)}")
print(f"Right ==> Teeth3DS Similar Jaw's Tooth of same FDI Label as abutment: {os.path.basename(abutment_teeth_most_similar_jaw_filename)}")

print(dice_score)
print(similarity_score)


o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:
abutment_teeth_most_similar_jaw_filename = json_holder["Most Similar Jaw's Abutment Tooth Filename :"]
most_similar_crown = json_holder["1 Similar"][0]

similarity_score = json_holder["1 Similar"][1][0]
dice_score = json_holder["1 Similar"][1][1]

mesh1, mesh2 = compare_most_similar_crown(abutment_teeth_most_similar_jaw_filename, most_similar_crown)

# mesh1.translate([0, 0, 1.5])
from IPython.display import display, HTML
display(HTML(f'<h1>2nd Most Similar Crown</h1>'))

print(f"Left ==> 2nd Most Similar Crown : {os.path.basename(most_similar_crown)}")
print(f"Right ==> Teeth3DS Similar Jaw's Tooth of same FDI Label as abutment: {os.path.basename(abutment_teeth_most_similar_jaw_filename)}")

print(dice_score)
print(similarity_score)


o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:
abutment_teeth_most_similar_jaw_filename = json_holder["Most Similar Jaw's Abutment Tooth Filename :"]
most_similar_crown = json_holder["2 Similar"][0]

similarity_score = json_holder["2 Similar"][1][0]
dice_score = json_holder["2 Similar"][1][1]

mesh1, mesh2 = compare_most_similar_crown(abutment_teeth_most_similar_jaw_filename, most_similar_crown)

# mesh1.translate([0, 0, 1.5])
from IPython.display import display, HTML
display(HTML(f'<h1>3rd Most Similar Crown</h1>'))

print(f"Left ==> 3rd Most Similar Crown : {os.path.basename(most_similar_crown)}")
print(f"Right ==> Teeth3DS Similar Jaw's Tooth of same FDI Label as abutment: {os.path.basename(abutment_teeth_most_similar_jaw_filename)}")

print(dice_score)
print(similarity_score)

o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:
abutment_teeth_most_similar_jaw_filename = json_holder["Most Similar Jaw's Abutment Tooth Filename :"]
most_similar_crown = json_holder["3 Similar"][0]

similarity_score = json_holder["3 Similar"][1][0]
dice_score = json_holder["3 Similar"][1][1]

mesh1, mesh2 = compare_most_similar_crown(abutment_teeth_most_similar_jaw_filename, most_similar_crown)

# mesh1.translate([0, 0, 1.5])
from IPython.display import display, HTML
display(HTML(f'<h1>4th Most Similar Crown</h1>'))

print(f"Left ==> 4th Most Similar Crown : {os.path.basename(most_similar_crown)}")
print(f"Right ==> Teeth3DS Similar Jaw's Tooth of same FDI Label as abutment: {os.path.basename(abutment_teeth_most_similar_jaw_filename)}")

print(dice_score)
print(similarity_score)

o3d.visualization.draw_plotly([mesh1, mesh2])