In [1]:
import torch
import sys
import importlib
import models.tsg_loss as tsg_loss
from models.tsg_centroid_module import get_model
importlib.reload(tsg_loss)


<module 'models.tsg_loss' from 'G:\\mesh\\segmentation\\models\\tsg_loss.py'>

In [2]:
model = get_model()
model.cuda()
model.load_state_dict(torch.load("pretrained_centroid_model.h5"))

<All keys matched successfully>

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import open3d as o3d
import os
import numpy as np
from models.tsg_centroid_module import get_model
from models.tsg_loss import centroid_loss
from torch.optim.lr_scheduler import ExponentialLR
from models.pointnet2_utils import square_distance

class CenterPointGenerator(Dataset):
    def __init__(self, data_dir="data/case1/sampled"):
        self.data_dir = data_dir

    def __len__(self):
        return 2

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        #low = o3d.io.read_point_cloud(os.path.join("data", "case1", "sampled", "align_low.ply"))

        #low_arr = np.asarray(low.points).astype("float32")
        #low_n = np.asarray(low.normals).astype("float32")
        #low_feat = np.concatenate((low_arr,low_n), axis=1)

        low_feat = np.load(os.path.join("data","case1","sampled","align_low_sampled.npy"))[:,:6].astype("float32")
        low_feat = torch.from_numpy(low_feat)
        low_feat = low_feat.permute(1,0)
        
        centroid = np.load(os.path.join("data","case1","sampled","align_low.txt.npy")).astype("float32")
        centroid = torch.from_numpy(centroid)
        centroid = centroid.permute(1,0)

        seg_label = np.load(os.path.join("data","case1","sampled","align_low_sampled.npy"))[:,6:].astype("int")
        seg_label = torch.from_numpy(seg_label)
        seg_label = seg_label.permute(1,0)
        return low_feat, centroid, seg_label

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [4]:
point_loader = DataLoader(CenterPointGenerator(), batch_size=2)

model.eval()
for item in point_loader:
    points = item[0].cuda()
    centroids = item[1].cuda()
    output = model(points)
    

In [5]:
from models.pointnet2_utils import square_distance
def distance_loss(pred_distance, sample_xyz, centroid):
    pred_distance = pred_distance.view(-1, sample_xyz.shape[2])
    sample_xyz = sample_xyz.permute(0,2,1)
    centroid = centroid.permute(0,2,1)

    dists = square_distance(sample_xyz, centroid)
    sorted_dists, idx = dists.sort(dim=-1)
    min_dists = sorted_dists[:, :, 0]
    loss = torch.nn.functional.smooth_l1_loss(pred_distance, min_dists)
    return loss, min_dists

def np_to_pcd(arr, color=[1,0,0]):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(arr)
    pcd.colors = o3d.utility.Vector3dVector([color]*len(pcd.points))
    return pcd


def centroid_dist_loss(pred_offset, sample_xyz, distance, centroid):
    distance = distance.view(-1, sample_xyz.shape[2])
    pred_offset = pred_offset.permute(0,2,1)
    sample_xyz = sample_xyz.permute(0,2,1)
    centroid = centroid.permute(0,2,1)

    pred_centroid = torch.add(pred_offset, sample_xyz)

    #source를 pred centroid로
    pred_ct_dists = square_distance(pred_centroid, centroid)
    sorted_pred_ct_dists, _ = pred_ct_dists.sort(dim=-1)
    min_pred_ct_dists = sorted_pred_ct_dists[:, :, 0]
    pred_ct_mask = distance.le(0.02)
    fin_pred_ct_dists = torch.masked_select(min_pred_ct_dists, pred_ct_mask)
    loss = torch.sum(fin_pred_ct_dists)

    #source를 centroid로
    ct_dists = square_distance(centroid, pred_centroid)
    sorted_ct_dists, _ = ct_dists.sort(dim=-1)
    min_ct_dists = sorted_ct_dists[:, :, 0]
    ct_mask = min_ct_dists.le(0.02)
    fin_ct_dists = torch.masked_select(min_ct_dists, ct_mask)
    loss += torch.sum(fin_ct_dists)
    return loss

def chamfer_distance_loss(pred_offset, sample_xyz, centroid):
    pred_offset = pred_offset.permute(0,2,1)
    sample_xyz = sample_xyz.permute(0,2,1)
    centroid = centroid.permute(0,2,1)

    pred_centroid = torch.add(pred_offset, sample_xyz)

    #source를 pred centroid로
    pred_ct_dists = square_distance(pred_centroid, centroid)
    sorted_pred_ct_dists, _ = pred_ct_dists.sort(dim=-1)
    min_pred_ct_dists = sorted_pred_ct_dists[:, :, :2]
    
    pred_ct_mask = min_pred_ct_dists[:,:,1].le(0.02)
    
    ratio = torch.div(min_pred_ct_dists[:,:,0], min_pred_ct_dists[:,:,1])
    ratio = torch.masked_select(ratio, pred_ct_mask)
    
    loss = torch.sum(ratio)
    return loss

In [6]:
output_gt_dist_cpu = distance_loss(output[5], output[3], centroids)[1].cpu().detach().numpy()[1,:].T

In [7]:
distance_loss(output[5], output[3], centroids)[0]
centroid_dist_loss(output[4], output[3], output[5], centroids)
chamfer_distance_loss(output[4], output[3], centroids)

tensor(13.2991, device='cuda:0', grad_fn=<SumBackward0>)

In [8]:
output_offset_cpu = output[4].cpu().detach().numpy()[0,:].T
output_xyz_cpu = output[3].cpu().detach().numpy()[0,:].T
cen_cpu = centroids.cpu().detach().numpy()[0,:].T
output_dist_cpu = output[5].cpu().detach().numpy()[0,:].T
output_input_xyz_cpu = output[2].cpu().detach().numpy()[0,:].T

In [18]:
import torch
import models.tsg_utils as utils
import importlib
importlib.reload(utils)

#o3d.visualization.draw_geometries([np_to_pcd(output_offset_cpu+output_xyz_cpu, color=[0,1,0]), np_to_pcd(cen_cpu)], mesh_show_back_face=False,mesh_show_wireframe=True)
#o3d.visualization.draw_geometries([np_to_pcd(output_xyz_cpu, color=[0,1,0]), np_to_pcd(cen_cpu)], mesh_show_back_face=False,mesh_show_wireframe=True)
#o3d.visualization.draw_geometries([np_to_pcd(output_xyz_cpu[output_gt_dist_cpu.reshape(256)<0.02], color=[0,1,0]), np_to_pcd(cen_cpu)], mesh_show_back_face=False,mesh_show_wireframe=True)
#o3d.visualization.draw_geometries([np_to_pcd(output_xyz_cpu[output_dist_cpu.reshape(256)<0.02], color=[0,1,0]), np_to_pcd(cen_cpu)], mesh_show_back_face=False,mesh_show_wireframe=True)
#sampled_db_scan = utils.dbscan_pc(output[3], output[4], output[5])
#o3d.visualization.draw_geometries([np_to_pcd(sampled_db_scan[0], color=[0,1,0]), np_to_pcd(cen_cpu)], mesh_show_back_face=False,mesh_show_wireframe=True)

test_cropped = ne_poi.cpu().detach().numpy()[8].T
print(test_cropped.shape)
o3d.visualization.draw_geometries([np_to_pcd(test_cropped, color=[0,1,0]), np_to_pcd(cen_cpu)], mesh_show_back_face=False,mesh_show_wireframe=True)


(4096, 3)


In [169]:
sampled_db_scan

array([list([array([0.20406094, 0.53706926, 0.4847145 ], dtype=float32), array([0.5025953 , 0.28543174, 0.47124866], dtype=float32), array([0.82142013, 0.73999816, 0.47141054], dtype=float32), array([0.69704455, 0.337916  , 0.47812864], dtype=float32), array([0.30652386, 0.3644412 , 0.48304662], dtype=float32), array([0.8159354 , 0.8780104 , 0.48148608], dtype=float32), array([0.81238014, 0.465721  , 0.45423186], dtype=float32), array([0.14626619, 0.6568919 , 0.4773864 ], dtype=float32), array([0.3769169 , 0.30961534, 0.48987198], dtype=float32), array([0.8180218 , 0.5975608 , 0.46019837], dtype=float32), array([0.637063  , 0.30266887, 0.48509288], dtype=float32), array([0.25585672, 0.43449455, 0.47864455], dtype=float32), array([0.7567793 , 0.38862902, 0.47854447], dtype=float32), array([0.44348887, 0.2856421 , 0.48213133], dtype=float32), array([0.57994336, 0.28647494, 0.48113176], dtype=float32)]),
       list([array([0.81653816, 0.8730276 , 0.481049  ], dtype=float32), array([0.257

In [38]:
import torch
import models.tsg_utils as utils
import importlib
importlib.reload(utils)
sampled_db_scan = utils.dbscan_pc(output[3], output[4], output[5])
nearest_n = utils.get_nearest_neighbor_idx(output[2], sampled_db_scan)
nei_features,_ = utils.get_nearest_neighbor_points_with_centroids(output[0],nearest_n, sampled_db_scan)
nei_points,cropped_sampled_clusters = utils.get_nearest_neighbor_points_with_centroids(output[2],nearest_n, sampled_db_scan)

concated = utils.concat_seg_input(nei_features, nei_points, cropped_sampled_clusters)
concated.shape

torch.Size([15, 20, 4096])

In [196]:
len(nearest_n[0])

15

In [241]:
tsg_loss.distance_loss(distance, xyz, centroid)

tensor(0.1347)

In [242]:
tsg_loss.centroid_loss(offset_xyz, xyz, distance, centroid)

tensor(40.4253)

In [243]:
tsg_loss.chamfer_distance_loss(offset_xyz, xyz, centroid)

tensor(38.7016)

In [114]:
xyz = torch.rand(2, 3, 32)
offset_xyz = torch.rand(2, 3, 32)
distance = torch.rand(2, 1, 32)
centroid = torch.rand(1,3,9)
centroid = centroid.repeat(2,1,1)

In [79]:
distance = distance.view(2,-1)

In [80]:
distance.shape

torch.Size([2, 32])

In [58]:
dists = square_distance(centroid, xyz)


In [59]:
# source(row) 로부터 target(col)까지의 거리, 가장 가까운 게 0부터 시작하는 것. ex [0,:1]
dists, idx = dists.sort(dim=-1)

In [71]:
dists[:, 0].shape

torch.Size([2, 32])

In [60]:
dists

tensor([[[0.0739, 0.0818, 0.1504, 0.2462, 0.3607, 0.4563, 0.4701, 0.4807,
          0.5001, 0.5645, 0.6194, 0.6347, 0.6497, 0.6565, 0.6925, 0.7109,
          0.7572, 0.7584, 0.7673, 0.8503, 0.8913, 0.8913, 0.9732, 1.0682,
          1.1045, 1.1850, 1.1942, 1.3112, 1.5498, 1.5609, 1.7778, 1.9340],
         [0.0268, 0.0345, 0.0552, 0.0694, 0.1129, 0.1444, 0.1448, 0.1511,
          0.1532, 0.1580, 0.1802, 0.1842, 0.2023, 0.2691, 0.3020, 0.3391,
          0.3510, 0.4399, 0.4724, 0.4911, 0.5349, 0.5467, 0.5713, 0.6040,
          0.6298, 0.6576, 0.6640, 0.7225, 0.7906, 0.8345, 0.9125, 0.9649],
         [0.0132, 0.0598, 0.0689, 0.0768, 0.1261, 0.1396, 0.1563, 0.1909,
          0.2247, 0.2319, 0.2581, 0.2925, 0.3036, 0.3299, 0.3416, 0.3561,
          0.5043, 0.5204, 0.6210, 0.6223, 0.6292, 0.6379, 0.6464, 0.6707,
          0.7058, 0.7392, 0.7809, 0.9788, 1.0331, 1.0527, 1.1398, 1.1427],
         [0.0234, 0.0400, 0.0617, 0.0677, 0.0894, 0.1213, 0.1213, 0.1845,
          0.2183, 0.2332, 0.2362, 0

In [119]:
centroid

tensor([[[0.3751, 0.3346, 0.5492, 0.9032, 0.6448, 0.1328, 0.1621, 0.4308,
          0.1521],
         [0.9229, 0.0413, 0.0669, 0.1982, 0.3968, 0.4664, 0.4098, 0.4463,
          0.9226],
         [0.7608, 0.8707, 0.2984, 0.8064, 0.3226, 0.1087, 0.6555, 0.3357,
          0.2989]],

        [[0.3751, 0.3346, 0.5492, 0.9032, 0.6448, 0.1328, 0.1621, 0.4308,
          0.1521],
         [0.9229, 0.0413, 0.0669, 0.1982, 0.3968, 0.4664, 0.4098, 0.4463,
          0.9226],
         [0.7608, 0.8707, 0.2984, 0.8064, 0.3226, 0.1087, 0.6555, 0.3357,
          0.2989]]])

In [53]:
import open3d as o3d
import numpy as np
from glob import glob
import os
stl_path_list = glob(os.path.join("data","case1","seg","*.stl"))
full_stl_path_list = glob(os.path.join("data","case1","full","*.stl"))
def get_number_from_name(path):
    return int(os.path.basename(path).split("_")[-1].split(".")[0])
def get_up_from_name(path):
    return os.path.basename(path).split("_")[-1].split(".")[0]=="up"

In [54]:
for global_item in full_stl_path_list:
    is_up = get_up_from_name(global_item) 
    global_mesh = o3d.io.read_triangle_mesh(global_item)
    global_mesh = global_mesh.sample_points_poisson_disk(16000)
    global_mesh_arr = np.asarray(global_mesh.points)
    global_mesh_n_arr = np.asarray(global_mesh.normals)
    
    global_min = np.min(global_mesh_arr)
    global_mesh_arr -= global_min
    global_max = np.max(global_mesh_arr)
    global_mesh_arr /= (global_max)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(global_mesh_arr)
    pcd.normals = o3d.utility.Vector3dVector(global_mesh_n_arr)
    #o3d.io.write_point_cloud(os.path.join("data","case1","sampled",os.path.basename(global_item).replace("stl","ply")), pcd)
    cp_arr = [] 
    for item in stl_path_list:
        if(get_number_from_name(item)>=30):
            if is_up:
                continue
        else:
            if not is_up:
                continue
        mesh = o3d.io.read_triangle_mesh(item)
        
        import time
        start = time.time()  # 시작 시간 저장
        # 작업 코드
        mesh = mesh.sample_points_poisson_disk(9000)
        print("time :", time.time() - start)  # 현재시각 - 시작시간 = 실행 시간
        
        mesh_arr = np.asarray(mesh.points)
        mesh_n_arr = np.asarray(mesh.normals)
        mesh_arr = (mesh_arr - global_min) / (global_max)
        
        cp_arr.append(np.mean(mesh_arr,axis=0))
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(mesh_arr)
        pcd.normals = o3d.utility.Vector3dVector(mesh_n_arr)
        #o3d.io.write_point_cloud(os.path.join("data","case1","sampled",os.path.basename(item).replace("stl","ply")), pcd)

    np.save(os.path.join("data","case1","sampled",os.path.basename(global_item).replace("stl","txt")), cp_arr)

time : 1.9586846828460693
time : 2.003563404083252
time : 1.986609697341919


KeyboardInterrupt: 

In [55]:
from sklearn.neighbors import KDTree

global_mesh = o3d.io.read_triangle_mesh(full_stl_path_list[0])
global_mesh = global_mesh.remove_duplicated_vertices()
is_up = get_up_from_name(full_stl_path_list[0])
global_mesh_arr = np.asarray(global_mesh.vertices)
global_mesh_arr = np.concatenate([global_mesh_arr, np.zeros((global_mesh_arr.shape[0], 1))], axis=1)

tree = KDTree(global_mesh_arr[:,:3], leaf_size=2)

for stl_path in stl_path_list:
    if(get_number_from_name(stl_path)>=30):
        if is_up:
            continue
    else:
        if not is_up:
            continue
    tooth_number = get_number_from_name(stl_path)%10
    mesh = o3d.io.read_triangle_mesh(stl_path)
    mesh = mesh.remove_duplicated_vertices()
    mesh_arr = np.asarray(mesh.vertices)
    dist, indexs = tree.query(mesh_arr, k=1)
    for point_num, corr_idx in enumerate(indexs):
        global_mesh_arr[corr_idx[0], 3] = tooth_number




In [56]:
sample_global_mesh = global_mesh.sample_points_poisson_disk(16000)

sample_global_mesh_arr = np.asarray(sample_global_mesh.points)
sample_global_mesh_n_arr = np.asarray(sample_global_mesh.normals)
sample_global_mesh_arr = np.concatenate([sample_global_mesh_arr, sample_global_mesh_n_arr], axis=1)
sample_global_mesh_arr = np.concatenate([sample_global_mesh_arr, np.zeros((16000,1))], axis=1)

indexs = tree.query_radius(sample_global_mesh_arr[:,:3], 0.2)
count_only_ls = tree.query_radius(sample_global_mesh_arr[:,:3], 0.1, count_only=True)

for point_num, corr_idx in enumerate(indexs):
    if(len(corr_idx)!=0):
        sample_global_mesh_arr[point_num,6] = global_mesh_arr[corr_idx[0],3]
        
    
sample_min = np.min(sample_global_mesh_arr[:, :3])
sample_global_mesh_arr[:, :3] -= sample_min
sample_max = np.max(sample_global_mesh_arr[:, :3])
sample_global_mesh_arr[:, :3] /= sample_max


np.save(os.path.join("data","case1","sampled","align_low_sampled"), sample_global_mesh_arr)

In [60]:
cp_arr = []
for stl_path in stl_path_list:
    if(get_number_from_name(stl_path)>=30):
        if is_up:
            continue
    else:
        if not is_up:
            continue
    mesh = o3d.io.read_triangle_mesh(stl_path)
    mesh = mesh.remove_duplicated_vertices()
    mesh_arr = np.asarray(mesh.vertices)
    mesh_arr -= sample_min
    mesh_arr /= sample_max
    cp_arr.append(np.mean(mesh_arr,axis=0))
np.save(os.path.join("data","case1","sampled","align_low.txt"), np.array(cp_arr))

In [59]:
np.array(cp_arr)

array([[0.57526349, 0.28411775, 0.47570913],
       [0.63925299, 0.30164196, 0.48554865],
       [0.69622766, 0.33861827, 0.4764211 ],
       [0.76302964, 0.39599004, 0.47923847],
       [0.812054  , 0.46345009, 0.45362384],
       [0.81951826, 0.5987912 , 0.46122558],
       [0.82055346, 0.73099579, 0.47666455],
       [0.81358786, 0.87691399, 0.48686669],
       [0.51238362, 0.28726439, 0.46832352],
       [0.44248962, 0.27834032, 0.4801039 ],
       [0.37989492, 0.30959741, 0.48769708],
       [0.30246302, 0.36574401, 0.48404154],
       [0.26126574, 0.43361285, 0.47823913],
       [0.21187077, 0.53071579, 0.48924161],
       [0.1513252 , 0.64290975, 0.48427326]])

In [64]:
label_colors = [[1,0,0],[0,1,0],[0,0,1],[1,1,0],[1,0,1],[0,1,1],[1/2,0,0],[0,1/2,0],[0,0,1/2]]

array([[-39.4018631 ,   2.27123451, -14.50884342,   0.        ],
       [-39.38402939,   2.10378671, -14.61677551,   0.        ],
       [-39.10272217,   2.12576747, -14.5669508 ,   0.        ],
       ...,
       [ 18.34498787, -10.02244377,  -9.96066666,   0.        ],
       [ 18.50905037, -10.22513008,  -9.99820709,   0.        ],
       [ 17.71750832,  -9.09092617,  -9.82278156,   0.        ]])

In [63]:
o3d.visualization.draw_geometries([np_to_by_label(sample_global_mesh_arr,6), np_to_pcd(cp_arr,[0,1,0])], mesh_show_back_face=False,mesh_show_wireframe=True)
    

In [62]:
def np_to_by_label(arr, axis=3):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(arr[:,:3])
    
    palte = [[1,0,0],[0,1,0],[0,0,1],[1,1,0],[1,0,1],[0,1,1],[1/2,0,0],[0,1/2,0],[0,0,1/2]]
    label_colors = np.zeros((arr.shape[0], 3))
    for idx, palte_color in enumerate(palte):
        label_colors[arr[:,axis]==idx] = palte[idx]
    pcd.colors = o3d.utility.Vector3dVector(label_colors)
    return pcd


In [133]:
import torch
import models.tsg_utils as utils
import importlib
importlib.reload(utils)


<module 'models.tsg_utils' from 'G:\\mesh\\segmentation\\models\\tsg_utils.py'>

In [None]:
torch.rand((2,3))

In [158]:
import torch

a = torch.tensor([2., 3.], requires_grad=False)
b = torch.tensor([6., 4.], requires_grad=False)
c = torch.tensor([6., 4.], requires_grad=True)
d = torch.tensor([6., 4.], requires_grad=True)
optimizer = torch.optim.Adam([a,b], lr=1e-3)

In [159]:
Q = 3*a**3 - b**2
R = 2*a**2 - b**4
T = 2*c**2 - d**4
W = 4*c**2 - d**3

loss = torch.sum(4-Q)
loss_r = torch.sum(3-R)
loss_t = torch.sum(Q-R)
loss_c1 = torch.sum(T-R)
loss_c2 = torch.sum(W-R)

In [160]:
loss_c1.backward()

In [161]:
loss_c2.backward()

In [141]:
(loss+loss_r+loss_t).backward()

In [130]:
optimizer.zero_grad()

In [131]:
loss.backward(retain_graph=True)

In [132]:
loss_r.backward(retain_graph=True)

In [133]:
loss_t.backward()

In [60]:
optimizer.step()

In [142]:
print(a.grad)
print(b.grad)

None
tensor([1728.,  512.])


In [75]:
a

tensor([2., 3.], requires_grad=True)

In [28]:
loss.backward()

In [31]:
optimizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.001,
   'betas': (0.9, 0.999),
   'eps': 1e-08,
   'weight_decay': 0,
   'amsgrad': False,
   'params': [0, 1]}]}

In [26]:

print(a.grad)
print(b.grad)

tensor([-36., -81.])
tensor([12.,  8.])


In [6]:
Q.backward()

RuntimeError: grad can be implicitly created only for scalar outputs