In [1]:
import trimesh
import numpy as np
import torch
import open3d as o3d
import json
import torch.nn as nn 
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
# mesh = trimesh.load("/home/shirshak/STL_client_data/11/2024-07-05_00003-InoueYui-11Cr-11-crown_cad.stl")
# mesh.show()
import glob 
from tqdm import tqdm 
import time
import os 

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def intermediate(x, xx):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    torch.cuda.empty_cache()
    return -xx - inner

def knn(x, k):
    x = x.to(torch.float16)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = intermediate(x, xx) - xx.transpose(2, 1)
    torch.cuda.empty_cache()
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx

def get_graph_feature(x, device, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 is False:
            idx = knn(x, k=k)
        else:
            idx = knn(x[:, 6:], k=k)

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base
    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature

class DGCNN(nn.Module):
    def __init__(self, device, output_channels=16,input_dims=3, k =20, emb_dims = 1024, dropout= 0.5):
        super(DGCNN, self).__init__()
        self.device = device
        self.input_dims = input_dims
        self.k = k
        self.emb_dims = emb_dims
        self.dropout = dropout
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm1d(self.emb_dims)

        self.conv1 = nn.Sequential(nn.Conv2d(self.input_dims*2, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(512, self.emb_dims, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.linear1 = nn.Linear(self.emb_dims*2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=self.dropout)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=self.dropout)
        self.linear3 = nn.Linear(256, output_channels)

    def forward(self, x):
        batch_size = x.size(0)
        x = get_graph_feature(x, k=self.k, device=self.device)      # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k, device=self.device)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k, device=self.device)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 128, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        x = get_graph_feature(x3, k=self.k, device=self.device)     # (batch_size, 128, num_points) -> (batch_size, 128*2, num_points, k)
        x = self.conv4(x)                       # (batch_size, 128*2, num_points, k) -> (batch_size, 256, num_points, k)
        x4 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 256, num_points, k) -> (batch_size, 256, num_points)

        x = torch.cat((x1, x2, x3, x4), dim=1)  # (batch_size, 64+64+128+256, num_points)

        x = self.conv5(x)                       # (batch_size, 64+64+128+256, num_points) -> (batch_size, emb_dims, num_points)
        x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x = torch.cat((x1, x2), 1)              # (batch_size, emb_dims*2)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
        x = self.dp2(x)
        # x = self.linear3(x)                                             # (batch_size, 256) -> (batch_size, output_channels)
        
        return x

In [3]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

model = DGCNN(device=device, output_channels=32)
model = model.to(device)

checkpoint = torch.load("best_model.pth")
model.load_state_dict(checkpoint["model_state_dict"])

model.bn7 = nn.Identity()
# model.dp1 = nn.Identity()
# model.dp2 = nn.Identity()
model.linear3 = nn.Identity()


In [4]:
def load_point_cloud(file_path, num_points=2048):
    """Load and preprocess a single obj file into point cloud file."""
    mesh = o3d.io.read_triangle_mesh(file_path)
    mesh.compute_vertex_normals()
    o3d.utility.random.seed(12345)

    pcd = mesh.sample_points_uniformly(number_of_points=num_points)

    points = np.asarray(pcd.points, dtype=np.float32)

    return torch.from_numpy(points)

In [7]:
left_tooth_path = "/home/shirshak/Teeth3DS_individual_teeth/individual_teeth/Z1OCAFH9_upper_fid14.obj"
data_point_cloud_orig = load_point_cloud(left_tooth_path, num_points=2048)
data_point_cloud = data_point_cloud_orig.to(device).unsqueeze(0).permute(0,2,1)
data_point_cloud

tensor([[[ 24.5341,  18.5766,  22.0546,  ...,  23.0728,  22.1399,  24.6711],
         [ -4.8275,  -2.5837,  -7.5689,  ...,  -4.8816,  -2.8338,  -4.6487],
         [-87.4216, -86.0271, -89.2503,  ..., -85.9925, -87.4795, -88.4821]]],
       device='cuda:0')

In [18]:
model.eval()
original_feature_256 = model(data_point_cloud)
original_feature_256

tensor([[ 2.6096e+02, -1.6889e+01, -1.5018e+01, -5.0474e+01,  1.1233e+02,
         -2.1119e+00,  1.3017e+02, -5.6133e+01,  9.9382e+01, -4.7793e+01,
         -4.0620e+01, -6.1533e+01,  7.5770e+01, -2.7090e+01, -3.9253e+01,
         -2.1242e+01,  6.5983e+01,  1.3417e+02, -1.3519e+01,  1.6413e+02,
         -7.2891e+00,  1.8226e+02,  2.2112e+02, -5.3093e+01,  5.7111e+01,
          3.2182e+02, -1.1316e+00,  7.6645e+00, -6.5486e+01,  4.8812e+01,
         -5.4742e+01, -5.3664e+01, -2.5112e+01,  2.7728e+02, -1.7659e+01,
          2.0422e+02,  2.5514e+02, -2.8865e+00, -4.9963e+01,  7.1159e+01,
          1.1167e+02, -5.8754e+01, -1.0713e+01,  5.4551e+01, -6.2837e+01,
         -2.1939e+01, -4.8000e+01, -1.1533e+01,  1.2806e+02,  4.4449e+01,
          6.2052e+01, -3.5632e+01, -1.2182e+00,  2.5460e+02, -2.8137e+01,
         -3.5505e+01,  9.1578e+01, -2.5945e+01,  2.7675e+01,  4.7414e+01,
         -5.1192e+01,  4.8994e+01, -3.0818e+01, -4.5021e+01, -1.5069e+01,
          1.0353e+02,  4.5879e+01,  7.