In [1]:
import os 
import sys
import torch 
import torch.nn as nn
import open3d as o3d
import numpy as np
import json
from sklearn.metrics.pairwise import cosine_similarity
sys.path.append(os.path.abspath('/home/shirshak/00_teeth_similarity_matching/models'))
sys.path.append(os.path.abspath('/home/shirshak/00_teeth_similarity_matching/src/data_preprocess'))
from dgcnn import DGCNN
from load_obj_save_pcd import preprocess_and_save_obj
import plotly.io as pio
import plotly.graph_objects as go
import io
import matplotlib.pyplot as plt
from PIL import Image
import re
from collections import defaultdict
from tqdm import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def mesh_to_pcd(mesh,num_points =2048):

    target_faces = 8000

    while len(mesh.triangles) < target_faces:
        mesh = mesh.subdivide_midpoint(number_of_iterations=1)  # Increase iterations as needed

    # Decimation (reduce number of faces)
    mesh = mesh.simplify_quadric_decimation(target_faces)

    #mesh to PCD
    pcd = mesh.sample_points_uniformly(number_of_points=num_points)

    
    #normalize PCD
    points = np.asarray(pcd.points)
    centroid = np.mean(points, axis=0)
    points -= centroid  # Centering
    max_distance = np.max(np.linalg.norm(points, axis=1))
    points /= max_distance  # Scaling
    pcd.points = o3d.utility.Vector3dVector(points)

    np_points = np.asarray(pcd.points)

    return np_points


def load_point_cloud(file_path, num_points=2048):
    """Load and preprocess a single obj file into point cloud file."""
    mesh = o3d.io.read_triangle_mesh(file_path)

    data = mesh_to_pcd(mesh, num_points=num_points)
    
    # Randomly sample `num_points` points from the data if needed
    if data.shape[0] > num_points:
        indices = np.random.choice(data.shape[0], num_points, replace=False)
    else:
        indices = np.random.choice(data.shape[0], num_points, replace=True)
    data = data[indices, :]
    
    return torch.tensor(data, dtype=torch.float32)

In [3]:
def mesh_to_voxel_grid(mesh, voxel_size=2):
    """Convert a mesh to a voxel grid using Open3D."""
    return o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=voxel_size)


def load_voxel_grid(file_path, voxel_size=2):
    """Load and preprocess a single obj file into a voxel grid."""
    mesh = o3d.io.read_triangle_mesh(file_path)
    return mesh_to_voxel_grid(mesh, voxel_size)

In [4]:
def load_data_from_json(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    
    # Convert feature vectors back to numpy arrays
    for entry in data:
        entry['feature_vector'] = np.array(entry['feature_vector'])
    
    return data

In [5]:
def compute_voxel_dice_score(voxel_grid1, voxel_grid2):
    """
    Compute the Dice score between two voxel grids.
    """
    # Get the set of voxel coordinates for each voxel grid
    voxels1 = set([tuple(voxel.grid_index) for voxel in voxel_grid1.get_voxels()])
    voxels2 = set([tuple(voxel.grid_index) for voxel in voxel_grid2.get_voxels()])
    
    # Compute intersection of the voxel sets
    intersection = voxels1.intersection(voxels2)
    
    # Dice Score = 2 * |A ∩ B| / (|A| + |B|)
    dice_score = (2 * len(intersection)) / (len(voxels1) + len(voxels2)) if (len(voxels1) + len(voxels2)) > 0 else 0
    
    return dice_score

In [6]:
class SimilaritySearch:
    def __init__(self):
        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        self.load_model_and_checkpoint()
        self.model.eval()

    def load_model_and_checkpoint(self):
        self.model = DGCNN(output_channels=32).to(self.device)
        self.model = self.model.to(self.device)

        self.checkpoint = torch.load("/home/shirshak/00_teeth_similarity_matching/model_ckpt/best_model.pth")
        self.model.load_state_dict(self.checkpoint["model_state_dict"])
        self.model.dp2 = nn.Identity()
        self.model.linear3 = nn.Identity()

    def find_top_n_similar_feature_vectors(self, query_vector, data, top_n=-1):
        all_feature_vectors = np.array([entry['feature_vector'] for entry in data])
        similarities = cosine_similarity(query_vector, all_feature_vectors).flatten()
        if top_n != -1:
            top_n_indices = np.argsort(-similarities)[:top_n]
        else:
            top_n_indices = np.argsort(-similarities)[:top_n]

        similarities = [similarities[index] for index in top_n_indices]
        return top_n_indices, similarities
    
    def get_pid_fid_from_indices(self, indices, data):
        pids = []
        fids = []
        for idx in indices:
            thumbnail_path = data[idx]['thumbnail_location']
            pids.append(thumbnail_path.split('/')[-1].split('_')[0])
            fids.append(thumbnail_path.split('/')[-1].split('.')[0].split('_')[-1])
        return pids, fids
    
    def pack_json(self, pids, fids, similarity_score, dice_score):
        ranked_data = [
            {
                'pid': pid,
                'fid': fid,
                'similarity_score': score,
                'dice_score': dice
            }
            for pid, fid, score, dice in zip(pids, fids, similarity_score, dice_score)
        ]

        # Return the list of dictionaries, FastAPI will handle JSON serialization
        return ranked_data
    
    def get_similarity(self, obj_path):

        # orig_mesh = load_voxel_grid(obj_path)
        data_point_cloud_orig = load_point_cloud(obj_path)
        data_point_cloud = data_point_cloud_orig.to(self.device).unsqueeze(0).permute(0,2,1)

        with torch.no_grad():
            original_feature_256 = self.model(data_point_cloud)

        original_feature_256 = original_feature_256.cpu().numpy()

        feature_data = load_data_from_json('/home/shirshak/00_teeth_similarity_matching/feature_info.json')

        top_10_indices, simil = self.find_top_n_similar_feature_vectors(original_feature_256, feature_data, top_n=-1)

        pids, fids = self.get_pid_fid_from_indices(top_10_indices, feature_data)

        # dices = []
        # category = str()
        # for pid, fid in zip(pids, fids):

        #     if int(fid.split("fid")[-1]) < 30:
        #         category = 'upper'
        #     if int(fid.split("fid")[-1]) > 30:
        #         category = 'lower'

        #     sim_obj_path = f"/home/shirshak/00_teeth_similarity_matching/individual_teeth/{pid}_{category}_{fid}.obj"
        #     similar_mesh = load_voxel_grid(sim_obj_path)
        #     dices.append(compute_voxel_dice_score(orig_mesh, similar_mesh))

        # # print(dices)
        # # return self.pack_json(pids, fids, simil, dices)

        return pids, fids, simil#, dices

    def get_similarity_multiple_teeth(self, teeth_paths, tooth_labels):
        """
        Find the most similar teeth from multiple provided tooth
        teeth_paths ===> [left_adjacent, right_adjacent, opposite]
        tooth_labels ===> [left_tooth_label, right_tooth_label, opposite_tooth_label]
        """
        if len(teeth_paths) != len(tooth_labels):
            raise ValueError("Number of tooth paths must match number of tooth labels")

        print(self.model)
        
        # Similarity for each tooth
        all_results = {}
        for i, tooth_path in tqdm(enumerate(teeth_paths)):

            target_tooth_number = int(tooth_path.split("/")[-1].split(".")[0].split("fid")[1])
            pids, fids, similarity_scores = self.get_similarity(tooth_path)

            for pid, fid, similarity_score in zip(pids, fids, similarity_scores):
                # print(pid)
                # print(fid)
                fid_number = int(fid.split("fid")[-1])

                if fid_number == target_tooth_number: # because fid ====> fid32, so to only extract 32
                    if pid in all_results:
                        all_results[pid].append([fid, similarity_score])
                    else:
                        all_results[pid] = [[fid, similarity_score]]
        
        # print(all_results)

        avg_similarity_score = {}
        for result_key in all_results.keys():
            # print(all_results[result_key])
            # print(all_results[result_key][0][2])
            # print(all_results[result_key][1][2])
            # print(all_results[result_key][2][2])
            # print(len(all_results[result_key]))

            avg_similarity_score[result_key] = sum(item[1] for item in all_results[result_key]) / len(all_results[result_key])

        return avg_similarity_score, all_results

In [7]:
similarity_search = SimilaritySearch()

In [8]:
# left_tooth_path = "/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/0KPHM46Q_upper_fid14.obj"
# right_tooth_path = "/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/0KPHM46Q_upper_fid16.obj"
# opposite_tooth_path = "/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/0KPHM46Q_lower_fid45.obj"


# teeth_paths = [left_tooth_path, right_tooth_path, opposite_tooth_path]

# tooth_labels = [14, 16, 45]


left_tooth_path = "/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/MXWIBTGF_upper_fid12.obj"
right_tooth_path = "/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/MXWIBTGF_upper_fid13.obj"
opposite_tooth_path = "/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/MXWIBTGF_upper_fid14.obj"


teeth_paths = [left_tooth_path, right_tooth_path, opposite_tooth_path]

tooth_labels = [12, 13, 14]

In [9]:
avg_similarity_score, all_results = similarity_search.get_similarity_multiple_teeth(teeth_paths, tooth_labels)

DGCNN(
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn5): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv1): Sequential(
    (0): Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (conv2): Sequential(
    (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
  )
  (conv3): Sequential(
    (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=Fal

3it [00:10,  3.60s/it]


In [10]:
avg_similarity_score = sorted(avg_similarity_score.items(), key=lambda item: item[1], reverse=True)

In [15]:
avg_similarity_score

[('BEONHKS1', np.float64(0.8141875917540076)),
 ('MXWIBTGF', np.float64(0.7849632861977902)),
 ('BNQ3Q4G3', np.float64(0.7700917536536369)),
 ('JY7F4AAF', np.float64(0.7682196819575581)),
 ('01FG729R', np.float64(0.7676360100901748)),
 ('017UWE8F', np.float64(0.763409780412227)),
 ('01FC2D4A', np.float64(0.7625863278424854)),
 ('UNKC1VVC', np.float64(0.7584407397587304)),
 ('baliwish', np.float64(0.7565012525139935)),
 ('WQOGLZY4', np.float64(0.7557677289203136)),
 ('DC8VMT30', np.float64(0.7545690548984597)),
 ('KAHYFGOY', np.float64(0.7544989212911773)),
 ('HE565KIU', np.float64(0.7519870152680306)),
 ('RSUL1J8U', np.float64(0.7489312465730089)),
 ('45UD35GC', np.float64(0.7477625369821452)),
 ('KS0FGXGB', np.float64(0.7467236319387948)),
 ('016PTY8K', np.float64(0.7457588141249346)),
 ('0165W7J4', np.float64(0.7435790906993557)),
 ('01A6HE9H', np.float64(0.7434108197049102)),
 ('0148UTKX', np.float64(0.7430720657484403)),
 ('I9TWNSD1', np.float64(0.7426803779067784)),
 ('0140YFGV', 

In [11]:
print(avg_similarity_score[0][0])
print('-'*80)
print(all_results[avg_similarity_score[0][0]])

BEONHKS1
--------------------------------------------------------------------------------
[['fid12', np.float64(0.8293766156113642)], ['fid13', np.float64(0.8566252201705776)], ['fid14', np.float64(0.7565609394800807)]]


In [17]:
def get_dice(all_results, teeth_paths, pid):
    dice_scores = []
    for row, teeth_path in zip(all_results[pid], teeth_paths):
        if int(row[0].split("fid")[-1]) < 30:
            category = 'upper'
        if int(row[0].split("fid")[-1]) > 30:
            category = 'lower'

        sim_obj_path = f"/home/shirshak/00_teeth_similarity_matching/individual_teeth/{pid}_{category}_{row[0]}.obj"

        orig_mesh = load_voxel_grid(teeth_path)
        similar_mesh = load_voxel_grid(sim_obj_path)

        dice_score = compute_voxel_dice_score(orig_mesh, similar_mesh)

        dice_scores.append(dice_score)

    return dice_scores, float(np.array(dice_score).mean())


In [18]:
dice_scores

NameError: name 'dice_scores' is not defined

In [None]:
pid = avg_similarity_score[0][0]
print(pid)
print(get_dice(all_results, teeth_paths, pid))

BEONHKS1
([0.75, 0.7413793103448276, 0.6825396825396826], 0.6825396825396826)


In [14]:
pid = avg_similarity_score[1][0]
print(pid)
print(get_dice(all_results, teeth_paths, pid))

MXWIBTGF
([1.0, 1.0, 1.0], 1.0)
