In [11]:
import sys
sys.path.append('/home/daniel/Documents/Uni/MT/poi-prediction')
from src.dataset.dataset import ImplantsDataset
import time
from utils.misc import np_to_bids_nii
from utils.dataloading_utils import compute_surface
from tqdm import tqdm

from BIDS import POI, NII
from BIDS.POI_plotter import visualize_pois
import torch
import pandas as pd

import numpy as np

implants_master_df = pd.read_csv('/home/daniel/Data/Implants/cutouts_scale-1-1-1/master_df.csv')
ds = ImplantsDataset(
    implants_master_df,
    flip_prob = 0,
    poi_file_ending = 'poi.json'
)

In [12]:
# Define some useful utility functions
def get_ctd(dd, poi_list = [90,91,92,93]):
    ctd = {}
    vertebra = dd['vertebra']

    for poi_coords, poi_idx in zip(dd['target'], dd['target_indices']):
        coords = (poi_coords[0].item(), poi_coords[1].item(), poi_coords[2].item())
        if poi_list is None or poi_idx in poi_list:
            ctd[vertebra, poi_idx.item()] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = (1,1,1), shape = (152,152,152))
    return ctd

def get_vert_msk_nii(dd):
    vertebra = dd['vertebra']
    msk = dd['input'].squeeze(0)
    vert_msk = (msk != 0) * vertebra
    vert_msk_nii = np_to_bids_nii(vert_msk.numpy().astype(np.int32))
    vert_msk_nii.seg = True
    return vert_msk_nii

def get_vertseg_nii(dd):
    vertseg = dd['input'].squeeze(0)
    vertseg_nii = np_to_bids_nii(vertseg.numpy().astype(np.int32))
    vertseg_nii.seg = True
    return vertseg_nii

def get_vert_points(dd):
    msk = dd['input'].squeeze(0)
    vert_points = torch.where(msk)
    vert_points = torch.stack(vert_points, dim=1)
    return vert_points

def get_target_entry_points(dd):
    ctd = get_ctd(dd)
    vertebra = dd['vertebra']
    p_90 = torch.tensor(ctd[vertebra, 90])
    p_92 = torch.tensor(ctd[vertebra, 92])

    p_91 = torch.tensor(ctd[vertebra, 91])
    p_93 = torch.tensor(ctd[vertebra, 93])

    return p_90, p_92, p_91, p_93

def tensor_to_ctd(t, vertebra, origin, rotation, idx_list = None, shape = (128, 128, 96), zoom = (1,1,1), offset = (0,0,0)):
    ctd = {}
    for i, coords in enumerate(t):
        coords = coords.float() - torch.tensor(offset)
        coords = (coords[0].item(), coords[1].item(), coords[2].item())
        if idx_list is None:
            ctd[vertebra, i] = coords
        elif i < len(idx_list):
            ctd[vertebra, idx_list[i]] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = zoom, shape = shape, origin = origin, rotation = rotation)
    return ctd

def fit_plane_to_points(P1, P2, P3, P4):
    X = torch.vstack([P1, P2, P3, P4])
    X = torch.cat([X, torch.ones(4, 1)], dim=1)  # Augment with ones for D coefficient
    U, S, V = torch.linalg.svd(X, full_matrices=False)
    plane_coefficients = V[-1, :]
    A, B, C, D = plane_coefficients

    # Calculate projection error as the sum of squared distances of points to the plane
    n = torch.tensor([A, B, C])
    distances = (torch.matmul(X[:, :-1], n) + D) / torch.linalg.norm(n)
    projection_error = torch.sum(distances ** 2)

    # Step 3: Project points onto the plane
    n_normalized = n / torch.linalg.norm(n)
    projection_vectors = distances.view(-1, 1) * n_normalized
    points_projected = X[:, :-1] - projection_vectors  # Use original points, not augmented

    return plane_coefficients, distances, points_projected

In [25]:
dd = ds[40]
target = dd['target']
vertebra = dd['vertebra']
offset = dd['offset']
origin = NII.load(dd['msk_path'], seg = True).origin
rotation = NII.load(dd['msk_path'], seg = True).rotation
shape = NII.load(dd['msk_path'], seg = True).shape
target_ctd = tensor_to_ctd(target, vertebra, origin = origin, idx_list = [90,91, 94, 95], shape = shape, zoom = (1,1,1), offset = offset, rotation = rotation)
dd['subject'], vertebra, offset

####################################
  coords = coords.float() - torch.tensor(offset)
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/home/daniel/anaconda3/envs/thesis/lib/py

('3', 22, tensor([18., 16., 20.]))

In [26]:
vertebra = dd['vertebra']
vert_msk_nii = NII.load(dd['msk_path'],seg=True)

visualize_pois(
    ctd = target_ctd,
    seg_vert = vert_msk_nii,
    vert_idx_list=[vertebra] 
)

#90: Right entry point
#91: Left entry point
#92: Right target point
#93: Left target point

[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 35.32it/s]


Widget(value="<iframe src='http://localhost:35633/index.html?ui=P_0x7ea41c1a1bd0_9&reconnect=auto' style='widt…

In [27]:
def get_gt_trajectory_points(dd, n_points = 100):
    p_90, p_92, p_91, p_93 = get_target_entry_points(dd)
    
    d_90_92 = p_92 - p_90

    d_91_93 = p_93 - p_91

    screw_trajectory_points_90_92 = torch.stack([p_90 + i/n_points * d_90_92 for i in range(n_points)])
    screw_trajectory_points_91_93 = torch.stack([p_91 + i/n_points * d_91_93 for i in range(n_points)])

    screw_trajectory_points_all = torch.cat([screw_trajectory_points_90_92, screw_trajectory_points_91_93], dim=0)

    return screw_trajectory_points_all

In [28]:
dd = ds[2]
screw_trajectory_points = get_gt_trajectory_points(dd, n_points = 100)
screw_trajectory_ctd = tensor_to_ctd(screw_trajectory_points, dd['vertebra'], origin = None, rotation = None, shape = (128,128,96), zoom = (1,1,1))

visualize_pois(
    ctd = screw_trajectory_ctd,
    seg_vert = get_vert_msk_nii(dd),
    vert_idx_list=[dd['vertebra']],
    radius = 2
)

[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 28.90it/s]


Widget(value="<iframe src='http://localhost:35633/index.html?ui=P_0x7ea41c126a40_10&reconnect=auto' style='wid…

In [29]:
#For all points in the screw trajectory, compute the distance to the nearest vertebra point
def compute_distance_to_nearest_vertebra_point(screw_trajectory_points, vertebra_points):
    distances = torch.cdist(screw_trajectory_points, vertebra_points)
    return torch.min(distances, dim=1).values

def trajectory_outside_vertebra(dd, thre = 1):
    screw_trajectory_points = get_gt_trajectory_points(dd, n_points = 100)
    vertebra_points = get_vert_points(dd).float()

    distances = compute_distance_to_nearest_vertebra_point(screw_trajectory_points, vertebra_points)
    return torch.any(distances > thre)

data_dict = {
    'subject': [],
    'vertebra': [],
    'outside': []
}
for idx, dd in tqdm(enumerate(ds)):
    data_dict['subject'].append(dd['subject'])
    data_dict['vertebra'].append(dd['vertebra'])
    data_dict['outside'].append(trajectory_outside_vertebra(dd).item())

59it [00:02, 21.31it/s]


In [30]:
df = pd.DataFrame(data_dict)
df

Unnamed: 0,subject,vertebra,outside
0,2,19,False
1,2,20,False
2,2,21,False
3,2,22,False
4,2,23,False
5,2,24,True
6,8,20,False
7,8,21,False
8,8,22,False
9,8,23,False


In [31]:
df['subject'] = df['subject'].astype(int)

In [32]:
#Add a column use_sample to the implants master df that is true if the subject_vertebra pair is not in the outside trajectory df
df = pd.merge(implants_master_df, df, on = ['subject', 'vertebra'])

df['use_sample'] = ~df['outside']
df = df.drop(columns = ['outside'])

#Add a column bad_poi_list to the implants master df that contains empty lists for all samples (for compatibility with the POI dataset)
df['bad_poi_list'] = df['subject'].apply(lambda x: [])

df.to_csv('/home/daniel/Data/Implants/cutouts_scale-1-1-1/master_df_cleaned.csv', index = False)