In [1]:
import os, argparse
import pickle
import torch
import time
from tqdm import tqdm
import numpy as np
from statistics import mean
from easydict import EasyDict as edict
import yaml
from matplotlib.pyplot import cm as colormap

from easydict import EasyDict

from regtr.cvhelpers.misc import prepare_logger

from regtr.data_loaders import get_dataloader
from regtr.models import get_model
from regtr.trainer import Trainer
from regtr.utils.misc import load_config
from regtr.utils.se3_numpy import se3_transform
from regtr.utils.se3_torch import se3_transform as se3_transform_torch
from regtr.data_loaders.eardataset import EarDataset, EarDatasetTest

import regtr.cvhelpers.visualization as cvv
import regtr.cvhelpers.colors as colors
from regtr.cvhelpers.torch_helpers import to_numpy

from deformationpyramid.model.geometry import *
from deformationpyramid.model.loss import compute_truncated_chamfer_distance
from deformationpyramid.model.registration import Registration
from deformationpyramid.utils.benchmark_utils import setup_seed
from deformationpyramid.utils.tiktok import Timers

from sklearn.neighbors import KDTree
from scipy.spatial.distance import cdist

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def visualize_result_regtr(src_xyz: np.ndarray, tgt_xyz: np.ndarray,
                     src_kp: np.ndarray, src2tgt: np.ndarray,
                     src_overlap: np.ndarray,
                     pose: np.ndarray,
                     threshold: float = 0.5):
    """Visualizes the registration result:
       - Top-left: Source point cloud and keypoints
       - Top-right: Target point cloud and predicted corresponding kp positions
       - Bottom-left: Source and target point clouds before registration
       - Bottom-right: Source and target point clouds after registration
    Press 'q' to exit.
    Args:
        src_xyz: Source point cloud (M x 3)
        tgt_xyz: Target point cloud (N x 3)
        src_kp: Source keypoints (M' x 3)
        src2tgt: Corresponding positions of src_kp in target (M' x 3)
        src_overlap: Predicted probability the point lies in the overlapping region
        pose: Estimated rigid transform
        threshold: For clarity, we only show predicted overlapping points (prob > 0.5).
                   Set to 0 to show all keypoints, and a larger number to show
                   only points strictly within the overlap region.
    """

    small_pt_size = 4
    large_pt_size = 8
    color_mapper = colormap.ScalarMappable(norm=None, cmap=colormap.get_cmap('coolwarm'))
    overlap_colors = (color_mapper.to_rgba(src_overlap[:, 0])[:, :3] * 255).astype(np.uint8)
    m = src_overlap[:, 0] > threshold

    vis = cvv.Visualizer(
        win_size=(1600, 1000),
        num_renderers=4,
        bg_color=(255, 255, 255)
    )
    
    vis.add_object(
        cvv.create_point_cloud(src_xyz, colors=colors.BLUE, pt_size=small_pt_size, alpha=0.25),
        renderer_idx=0
    )
    vis.add_object(
        cvv.create_point_cloud(src_kp[m, :], colors=overlap_colors[m, :], pt_size=large_pt_size),
        renderer_idx=0
    )

    vis.add_object(
        cvv.create_point_cloud(tgt_xyz, colors=colors.ORANGE, pt_size=small_pt_size, alpha=0.25),
        renderer_idx=1
    )
    vis.add_object(
        cvv.create_point_cloud(src2tgt[m, :], colors=overlap_colors[m, :], pt_size=large_pt_size),
        renderer_idx=1
    )

    # Before registration
    vis.add_object(
        cvv.create_point_cloud(src_xyz, colors=colors.BLUE, pt_size=small_pt_size),
        renderer_idx=2
    )
    vis.add_object(
        cvv.create_point_cloud(tgt_xyz, colors=colors.ORANGE, pt_size=small_pt_size),
        renderer_idx=2
    )

    # After registration
    vis.add_object(
        cvv.create_point_cloud(se3_transform(pose, src_xyz), colors=colors.BLUE, pt_size=small_pt_size),
        renderer_idx=3
    )
    vis.add_object(
        cvv.create_point_cloud(tgt_xyz, colors=colors.ORANGE, pt_size=small_pt_size),
        renderer_idx=3
    )

    vis.set_titles(['Source point cloud (with keypoints)',
                    'Target point cloud (with predicted source keypoint positions)',
                    'Before Registration',
                    'After Registration'])

    vis.reset_camera()
    vis.start()

def visualize_result(src_xyz: np.ndarray, tgt_xyz: np.ndarray,
                     displacement: np.ndarray, inds: np.ndarray):
    """Visualizes the registration result:
       - Top-left: Source point cloud and keypoints
       - Top-right: Target point cloud and predicted corresponding kp positions
       - Bottom-left: Source and target point clouds before registration
       - Bottom-right: Source and target point clouds after registration
    Press 'q' to exit.
    Args:
        src_xyz: Source point cloud (M x 3)
        tgt_xyz: Target point cloud (N x 3)
        displacement: vector field to transform src to tgt non-rigidly
    """

    color_mapper = colormap.ScalarMappable(norm=None, cmap=colormap.get_cmap('coolwarm'))

    vis = cvv.Visualizer(
        win_size=(1600, 1000),
        num_renderers=2)

    # Before registration
    vis.add_object(
        cvv.create_point_cloud(src_xyz, colors=colors.BLUE),
        renderer_idx=0
    )
    vis.add_object(
        cvv.create_point_cloud(tgt_xyz, colors=colors.ORANGE),
        renderer_idx=0
    )
    vis.add_object(
        cvv.create_point_cloud(src_xyz[inds], colors=colors.BLUE, pt_size=4),
        renderer_idx=0
    )

    # After registration
    vis.add_object(
        cvv.create_point_cloud(src_xyz + displacement, colors=colors.BLUE),
        renderer_idx=1
    )
    vis.add_object(
        cvv.create_point_cloud(tgt_xyz, colors=colors.ORANGE),
        renderer_idx=1
    )

    vis.set_titles(['Before Registration', 'After Registration'])

    vis.reset_camera()
    vis.start()

with open('mesh_dataset/landmarks/landmarks.pkl', 'rb') as f:
    landmarks = pickle.load(f)
l_inds = [torch.tensor(v) for u, v in landmarks.items()]

def mean_displacement_error(dis_pred, dis_gt):
    return np.linalg.norm(dis_pred.cpu()-dis_gt.cpu(), axis=1).mean()

def denorm(arr, metadata):
    return arr * metadata['std'] + metadata['mean']

def landmark_loss(pred, intra):
    assert len(pred) == len(intra), 'len(pred) != len(intra)'
    l = []
    single_loss = {}
    segments = list(intra.keys())
    for seg in range(len(pred)):
        mat = cdist(pred[seg], intra[segments[seg]]).min(0)
        
        l.append(mat.mean())
        single_loss[segments[seg]] = mat.mean()
    return sum(l)/len(l), single_loss

def computeCDRegTr(src_xyz: torch.tensor, tgt_xyz: torch.tensor, pose: torch.tensor, metadata: dict, red='mean'):
    transformed = se3_transform_torch(pose, src_xyz)
    transformed = denorm(transformed, metadata)
    return compute_truncated_chamfer_distance(transformed, denorm(tgt_xyz, metadata), batch_reduction=red, trunc=1e6)

def computeCD(src_xyz: torch.tensor, tgt_xyz: torch.tensor, displ: torch.tensor, red='mean'):
    transformed = displ + src_xyz
    return compute_truncated_chamfer_distance(transformed, tgt_xyz, batch_reduction=red, trunc=1e6).cpu()

In [3]:
class conf:
    pass
conf.benchmark = '3DMatch'
conf.config = 'config/eardataset_regtr.yaml'
conf.logdir = 'logs'
conf.dev = False
conf.num_workers = 0
conf.name = None
#conf.resume = 'D:/logs/eardataset/230224_074601_regtr_eardataset_standard/ckpt/model-88000.pth'
conf.resume = 'D:/mesh2mesh/trainResults/eardataset/240107_080636_regtr_eardataset_standard_resume/ckpt/model-66000.pth'

In [4]:
opt = conf()
logger, opt.log_path = prepare_logger(opt)
# Override config if --resume is passed
if opt.config is None:
    if opt.resume is None or not os.path.exists(opt.resume):
        logger.error('--config needs to be supplied unless resuming from checkpoint')
        exit(-1)
    else:
        resume_folder = opt.resume if os.path.isdir(opt.resume) else os.path.dirname(opt.resume)
        opt.config = os.path.normpath(os.path.join(resume_folder, '../config.yaml'))
        if os.path.exists(opt.config):
            logger.info(f'Using config file from checkpoint directory: {opt.config}')
        else:
            logger.error('Config not found in resume directory')
            exit(-2)
else:
    # Save config to log
    config_out_fname = os.path.join(opt.log_path, 'config.yaml')
    with open(opt.config, 'r') as in_fid, open(config_out_fname, 'w') as out_fid:
        out_fid.write(f'# Original file name: {opt.config}\n')
        out_fid.write(in_fid.read())

cfg = EasyDict(load_config(opt.config))
with open(os.path.join(cfg.root,'metadata.pkl'), 'rb') as f:
    metadata = pickle.load(f)

with open(cfg.ndp_config_path,'r') as f:
    p_config = yaml.load(f, Loader=yaml.Loader)

p_config = edict(p_config)
p_config.device = torch.cuda.current_device()

cfg.dataset = 'eardataset_test'
test_loader = get_dataloader(cfg, phase='test')
Model = get_model(cfg.model)
model = Model(cfg)
trainer = Trainer(opt, niter=cfg.niter, grad_clip=cfg.grad_clip)

model_nonrigid = Registration(p_config)
timer = Timers()

[32m02/27 22:00:33[0m [1;30m[INFO][0m [34mroot[0m - Output and logs will be saved to logs\240227_220033
[32m02/27 22:00:33[0m [1;30m[INFO][0m [34mregtr.cvhelpers.misc[0m - Command: C:\Users\chenp\AppData\Roaming\Python\Python310\site-packages\ipykernel_launcher.py --f=c:\Users\chenp\AppData\Roaming\jupyter\runtime\kernel-v2-21752Hx2tmn2EHxlU.json
[32m02/27 22:00:33[0m [1;30m[INFO][0m [34mregtr.cvhelpers.misc[0m - Source is from Commit 95412c13 (2023-09-06): Fixed one small BUG in test_script.py!
[32m02/27 22:00:33[0m [1;30m[INFO][0m [34mregtr.cvhelpers.misc[0m - Arguments: 
[32m02/27 22:00:34[0m [1;30m[INFO][0m [34mRegTR[0m - Instantiating model RegTR
[32m02/27 22:00:35[0m [1;30m[INFO][0m [34mRegTR[0m - Loss weighting: {'overlap_5': 1.0, 'feature_5': 0.1, 'corr_5': 1.0, 'feature_un': 0.0}
[32m02/27 22:00:35[0m [1;30m[INFO][0m [34mRegTR[0m - Config: d_embed:64, nheads:8, pre_norm:True, use_pos_emb:True, sa_val_has_pos_emb:True, ca_val_has_pos_em

In [5]:
outputs = trainer.test(model, test_loader)

[32m02/27 22:00:39[0m [1;30m[INFO][0m [34mCheckPointManager[0m - Loaded models from D:/mesh2mesh/trainResults/eardataset/240107_080636_regtr_eardataset_standard_resume/ckpt/model-66000.pth
                                                                                

In [6]:
pred_src = []
pred_tgt = []
tgt_gt = []
pred_pose = []
pred_src_kp = []
pred_tgt_kp = []
pred_overlap_score = []
displ_gt = []
inds = []
src_paths = []
for sample in outputs:
    pred_src.extend(sample['src_xyz'])
    pred_tgt.extend(sample['tgt_xyz'])
    tgt_gt.extend(sample['full_tgt_xyz'])
    pred_pose.extend(sample['pose'][-1])
    pred_src_kp.extend(sample['src_kp'])
    pred_tgt_kp.extend(sample['src_kp_warped'])
    pred_overlap_score.extend(sample['src_overlap'])
    displ_gt.extend(sample['displ_gt'])
    src_paths.extend(sample['src_path'])

pred_src = torch.stack(pred_src)
tgt_gt = torch.stack(tgt_gt)
pred_pose = torch.stack(pred_pose)
pred_src_kp = torch.stack(pred_src_kp)
pred_tgt_kp = torch.stack(pred_tgt_kp)[:, -1]
pred_overlap_score = torch.sigmoid(torch.stack(pred_overlap_score)[:, -1])

In [11]:
nn_thresh = 0.5

transformations = []
nonregistered_cd_l, registered_cd_l, mean_displacement_error_l, landmark_loss_l = [], [], [], []
anulus, umbo, malleus_handle, long_process_of_incus, stape = [], [], [], [], []
regtr_cd_l = []
non_registered_landmark_loss_l = []
registered_partial_cd_l = []
wall_time = []
displ_l = []
side, status = [], []
cnt = 0

test_dataset_real = EarDatasetTest(
    cfg, 'test'
)

for pair_ind, (src_norm, tgt_norm, tgt_full_norm, src_kp_norm, overlap_score, pose, displ, src_path) in tqdm(enumerate(zip(pred_src, pred_tgt, tgt_gt, pred_src_kp, pred_overlap_score, pred_pose, displ_gt, src_paths))):
    t1 = time.time()

    # Use KDTree to compute correnspondences
    tree = KDTree(src_norm.cpu(), leaf_size=cfg.nn_leaf_size)
    distance, indices = tree.query((src_kp_norm[overlap_score.squeeze() > nn_thresh]).cpu(), k=10, return_distance=True)
    indices = np.concatenate(indices)
    distance = np.concatenate(distance)
    indices, occ = np.unique(indices[distance < 0.1], return_counts=True)
    inds.append(indices)

    # Transform and denorm pcds
    src_transformed_norm = torch.tensor(se3_transform(pose.cpu().numpy(), src_norm.cpu().numpy()))
    src_denorm = denorm(src_norm, metadata)
    src_transformed_denorm = denorm(src_transformed_norm, metadata)
    tgt_denorm = denorm(tgt_norm, metadata)
    tgt_full_denorm = denorm(tgt_full_norm, metadata)

    # NDP
    model_nonrigid.load_pcds(src_transformed_denorm, tgt_denorm, inds=indices, search_radius=0.0375)
    warped_pcd, hist, _, timer = model_nonrigid.register(visualize=False, timer = timer)
    pred_displ = warped_pcd - src_denorm
    displ_l.append(pred_displ)
    wall_time.append(time.time()-t1)
    #registered_partial_cd_l.append(computeCD(src_denorm[indices][None], tgt_full_denorm[None], pred_displ[indices][None]).item())
    
    # Compute Metrics
    if cfg.dataset == 'eardataset':
        nonregistered_cd_l.append(compute_truncated_chamfer_distance(src_denorm[None], tgt_full_denorm[None], trunc=1e6).item())

        registered_cd_l.append(computeCD(src_denorm[None], tgt_full_denorm[None], pred_displ[None]).item())

        mean_displacement_error_l.append(mean_displacement_error(pred_displ, displ).item())

        pred_landmarks = [warped_pcd[i].cpu() for i in l_inds]
        pre_landmarks = [tgt_full_denorm[i].cpu() for i in l_inds]
        lndmk = landmark_loss(pre_landmarks, pred_landmarks)
        landmark_loss_l.append(lndmk)
        regtr_cd_l.append(computeCD(src_transformed_denorm.cpu()[None], tgt_full_denorm.cpu()[None], torch.zeros(src_transformed_denorm.shape)[None]).item())

    else:
        sample_name = src_path.split('\\')[-1].split('.')[0]

        nonregistered_cd_l.append(compute_truncated_chamfer_distance(src_denorm[None], tgt_denorm[None], trunc=1e6).item())

        registered_cd_l.append(computeCD(src_denorm[None], tgt_denorm[None], pred_displ[None]).item())

        intra_data = test_dataset_real.__getitem__(pair_ind)
        intra_metadata = intra_data['metadata']
        side.append(intra_metadata['patient_info']['side'])
        status.append(intra_metadata['patient_info']['status'])
        
        landmarks_intra = intra_data['landmarks']
        
        if landmarks_intra != {}:
            pred_landmarks = [warped_pcd[landmarks[k]].cpu() for k, v in landmarks_intra.items()]
            intra_landmarks = {k:v for k, v in landmarks_intra.items()}
            lndmk, single_loss = landmark_loss(pred_landmarks, intra_landmarks)
            landmark_loss_l.append(lndmk)

            segments = list(single_loss.keys())
            if 'anulus' in segments:
                anulus.append(single_loss['anulus'])
            else:
                anulus.append(float('nan'))

            if 'Umbo' in segments:
                umbo.append(single_loss['Umbo'])
            else:
                umbo.append(float('nan'))
            
            if 'malleus handle' in segments:
                malleus_handle.append(single_loss['malleus handle'])
            else:
                malleus_handle.append(float('nan'))

            if 'long process of incus' in segments:
                long_process_of_incus.append(single_loss['long process of incus'])
            else:
                long_process_of_incus.append(float('nan'))

            if 'stape' in segments:
                stape.append(single_loss['stape'])
            else:
                stape.append(float('nan'))

            pred_landmarks = [src_denorm[landmarks[k]].cpu() for k, v in landmarks_intra.items()]
            lndmk, _ = landmark_loss(pred_landmarks, intra_landmarks)
            non_registered_landmark_loss_l.append(lndmk)

        else:
            landmark_loss_l.append(-1)
            anulus.append(float('nan'))
            umbo.append(float('nan'))
            malleus_handle.append(float('nan'))
            long_process_of_incus.append(float('nan'))
            stape.append(float('nan'))
            non_registered_landmark_loss_l.append(-1)
    oct_pcd = o3d.geometry.PointCloud()
    oct_pcd.points = o3d.utility.Vector3dVector(np.array(warped_pcd.cpu().detach()))
    o3d.io.write_point_cloud(f'test_output_folder/predictions/prediction_{pair_ind}.ply', oct_pcd)
    oct_pcd = o3d.geometry.PointCloud()
    oct_pcd.points = o3d.utility.Vector3dVector(np.array(tgt_denorm.cpu().detach()))
    o3d.io.write_point_cloud(f'test_output_folder/target shape/target_{pair_ind}.ply', oct_pcd)

0it [00:00, ?it/s]

43it [05:26,  7.60s/it]


In [12]:
print('Non-registered cd score:', mean(nonregistered_cd_l))
print('Registered cd score:', mean(registered_cd_l))
#print('Registered partial cd score:', mean(registered_partial_cd_l))
if cfg.dataset == 'eardataset':
    print('Mean displacement error: ', mean(mean_displacement_error_l))
print('Landmark loss:', mean([i for i in landmark_loss_l if i != -1]))
print('Wall time:', mean(wall_time), 's')

Non-registered cd score: 7.3261259766512135
Registered cd score: 0.8497772393531577
Landmark loss: 1.3019973623443006
Wall time: 7.563937364622604 s


In [13]:
visualize_result(src_denorm.cuda(),  tgt_denorm, pred_displ.cuda(), indices)

In [18]:
ind = 4
visualize_result_regtr(pred_src[ind].cpu().numpy(), pred_tgt[ind].cpu().numpy(), pred_src_kp[ind].cpu().numpy(), pred_tgt_kp[ind].cpu().numpy(), pred_overlap_score[ind].cpu().numpy(), pred_pose[ind].cpu().numpy(), threshold=0)

In [29]:
for i in range(43):
    data = dict(
        src = denorm(se3_transform(pred_pose[ind].cpu().numpy(), pred_src[ind].cpu().numpy()), metadata),
        inds = inds[i],
    )
    with open(f'mesh_dataset/DIOME_FanShapeCorr/sample_{i}/regtr_pred.pkl', 'wb') as f:
        pickle.dump(data, f)
    

In [13]:
tree = KDTree(src_norm.cpu(), leaf_size=cfg.nn_leaf_size)
indices = tree.query_radius((src_kp_norm[overlap_score.squeeze() > nn_thresh]).cpu(), r=0.1)
indices, occurences = np.unique(np.concatenate(indices), return_counts=True)

In [15]:
""" On normal dataset
Non-registered cd score: 1.4596144970549785
Registered cd score: 0.29597098740516437
Registered partial cd score: 2.3302804929076717
Mean displacement error:  1.2692070097504926
Landmark loss: 0.3599099684169502
Wall time: 8.540204111336594 s """

""" On high variety dataset
Non-registered cd score: 1.4065552084325432
Registered cd score: 1.0310056967748418
Registered partial cd score: 1.5106343214397315
Mean displacement error:  1.7764664716548748
Landmark loss: 0.8699200562834195
Wall time: 6.008897996879555 s 

On corrected DIOME
Non-registered cd score: 4.179325458615325
Registered cd score: 0.8000558288984521
Landmark loss: 1.2649516742809437
Wall time: 6.067085615424222 s"""

In [13]:
results_dict_invivo = dict(
    nonregistered_cd_l=nonregistered_cd_l,
    non_registered_landmark_loss_l=non_registered_landmark_loss_l,
    registered_cd_l=registered_cd_l,
    landmark_loss_l=[i if i != -1 else float('nan') for i in landmark_loss_l],
    wall_time_models=wall_time,
    displ=[i.cpu().numpy() for i in displ_l],
    #inds=inds,
    side=side,
    status=status,
    anulus=anulus,
    umbo=umbo,
    malleus_handle=malleus_handle,
    long_process_of_incus=long_process_of_incus,
    stape=stape
)

import pandas as pd
df = pd.DataFrame.from_dict(results_dict_invivo)
df.to_csv('finalResults/results_invivo_regtr2.csv')

In [14]:
results_dict_exvivo = dict(
    nonregistered_cd_l=nonregistered_cd_l,
    #non_registered_landmark_loss_l=non_registered_landmark_loss_l,
    mean_displacement_error_l=mean_displacement_error_l,
    registered_cd_l=registered_cd_l,
    regtr_cd_l=regtr_cd_l,
    landmark_loss_l=[i if i != -1 else float('nan') for i in landmark_loss_l],
    wall_time_models=wall_time,
    displ=[i.cpu().numpy() for i in displ_l],
    len_inds=[len(i) for i in inds],
    overlap_scores=number_of_points
)

import pandas as pd
df = pd.DataFrame.from_dict(results_dict_exvivo)
df.to_csv('finalResults/results_exvivo_regtr_variance.csv')

In [None]:
import trimesh as trm
from glob import glob

values = []

for path in glob('mesh_dataset/DIOME_FanShapeCorr/sample_*'):
    summm = 0
    parts = dict(
        tympanic_membrane = 0,
        malleus = 0,
        incus = 0,
        stapes = 0,
    )
    
    for mesh_path in glob(f'{path}/seg_*.stl'):
        if "tympanic_membrane" in mesh_path:
            index = 'tympanic_membrane'
        elif "malleus" in mesh_path:
            index = 'malleus'
        elif "incus" in mesh_path:
            index = 'incus'
        if "stapes" in mesh_path:
            index = 'stapes'
        elif "promontory" in mesh_path:
            index = 4
            continue
        s = trm.load(mesh_path).volume
        parts[index] = s
        summm += s
    parts = {k: v/summm for k, v in parts.items()}
    print(sum([v for k, v in parts.items()]))
    if sum([v for k, v in parts.items()]) > 0.9:
        values.append(parts)

with open('mesh_dataset/DIOME_FanShapeCorr/volume_distribution.pkl', 'wb') as f:
    pickle.dump(values, f)

values = []

for filename in glob('mesh_dataset/ear_dataset/*/data_cached.pkl'):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    segmentation = list(data['intra_segmentation'])
    sum = len(segmentation)
    values.append(dict(
        tympanic_membrane = segmentation.count(2)/sum,
        malleus = segmentation.count(1)/sum,
        incus = segmentation.count(0)/sum,
        stapes = segmentation.count(3)/sum,
    ))

with open('mesh_dataset/ear_dataset/volume_distribution.pkl', 'wb') as f:
    pickle.dump(values, f)