In [6]:
import sys
sys.path.append('/home/daniel/Documents/Uni/MT/poi-prediction')
from src.dataset.dataset import ImplantsDataset
import time
from utils.misc import np_to_bids_nii
from utils.dataloading_utils import compute_surface
from tqdm import tqdm

from BIDS import POI, NII
from BIDS.POI_plotter import visualize_pois
import torch
import pandas as pd

import numpy as np

implants_master_df = pd.read_csv('/home/daniel/Data/Implants/cutouts_scale-1-1-1/master_df.csv')
ds = ImplantsDataset(
    implants_master_df,
    flip_prob = 0,
    poi_file_ending='poi_old.json'
)

In [7]:
# Define some useful utility functions
def get_ctd(dd, poi_list = [90,91,92,93]):
    ctd = {}
    vertebra = dd['vertebra']

    for poi_coords, poi_idx in zip(dd['target'], dd['target_indices']):
        coords = (poi_coords[0].item(), poi_coords[1].item(), poi_coords[2].item())
        if poi_list is None or poi_idx in poi_list:
            ctd[vertebra, poi_idx.item()] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = (1,1,1), shape = (152,152,152))
    return ctd

def get_corpus_mks_nii(dd):
    vertebra = dd['vertebra']
    msk = dd['input'].squeeze(0)
    corpus_msk = (msk == 50) + (msk == 49)
    corpus_msk = corpus_msk * vertebra #For visualization purposes
    corpus_msk_nii = np_to_bids_nii(corpus_msk.numpy().astype(np.int32))
    corpus_msk_nii.seg = True
    return corpus_msk_nii

def get_vert_msk_nii(dd):
    vertebra = dd['vertebra']
    msk = dd['input'].squeeze(0)
    vert_msk = (msk != 0) * vertebra
    vert_msk_nii = np_to_bids_nii(vert_msk.numpy().astype(np.int32))
    vert_msk_nii.seg = True
    return vert_msk_nii

def get_vertseg_nii(dd):
    vertseg = dd['input'].squeeze(0)
    vertseg_nii = np_to_bids_nii(vertseg.numpy().astype(np.int32))
    vertseg_nii.seg = True
    return vertseg_nii

def get_vert_points(dd):
    msk = dd['input'].squeeze(0)
    vert_points = torch.where(msk)
    vert_points = torch.stack(vert_points, dim=1)
    return vert_points

def get_corpus_points(dd):
    msk = dd['input'].squeeze(0)
    corpus_points = torch.where((msk == 50) + (msk == 49))
    corpus_points = torch.stack(corpus_points, dim=1)
    return corpus_points

def get_target_entry_points(dd):
    ctd = get_ctd(dd)
    vertebra = dd['vertebra']
    p_90 = torch.tensor(ctd[vertebra, 90])
    p_92 = torch.tensor(ctd[vertebra, 92])

    p_91 = torch.tensor(ctd[vertebra, 91])
    p_93 = torch.tensor(ctd[vertebra, 93])

    return p_90, p_92, p_91, p_93

def tensor_to_ctd(t, vertebra, origin, rotation, idx_list = None, shape = (128, 128, 96), zoom = (1,1,1), offset = (0,0,0)):
    ctd = {}
    for i, coords in enumerate(t):
        offset = torch.tensor(offset) if isinstance(offset, tuple) else offset
        coords = coords.float() - offset.float()
        coords = (coords[0].item(), coords[1].item(), coords[2].item())
        if idx_list is None:
            ctd[vertebra, i] = coords
        elif i < len(idx_list):
            ctd[vertebra, idx_list[i]] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = zoom, shape = shape, origin = origin, rotation = rotation)
    return ctd

In [14]:
dd = ds[38]
target = dd['target']
vertebra = dd['vertebra']
offset = dd['offset']
origin = NII.load(dd['msk_path'], seg = True).origin
rotation = NII.load(dd['msk_path'], seg = True).rotation
shape = NII.load(dd['msk_path'], seg = True).shape
target_ctd = tensor_to_ctd(target, vertebra, origin = origin, idx_list = [90, 91, 92, 93], shape = shape, zoom = (1,1,1), offset = offset, rotation = rotation)
dd['subject'], vertebra, offset

('3', 20, tensor([26., 21., 22.]))

In [15]:
vertebra = dd['vertebra']
vert_msk_nii = NII.load(dd['msk_path'],seg=True)

visualize_pois(
    ctd = target_ctd,
    seg_vert = vert_msk_nii,
    vert_idx_list=[vertebra] 
)

#90: Left entry point
#91: Right entry point
#92: Left target point
#93: Right target point

[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 29.83it/s]


Widget(value="<iframe src='http://localhost:44411/index.html?ui=P_0x79adcc1153c0_4&reconnect=auto' style='widt…

Now we want to find out which of the points of the vertebra mask actually lie on the connecting line of the target and entry points, to define more robust versions.

In [16]:
def find_projection_scalars_and_errors(points, P, d):
    """
    Finds the scalar `a` for each point in `points` such that P + a*d is the closest point on the
    line defined by P and direction d to the original point, and calculates the reconstruction error.
    
    :param points: A tensor of shape (n, 3) representing n points in 3D space.
    :param P: A tensor of shape (3,) representing the point P.
    :param d: A tensor of shape (3,) representing the direction vector d.
    :return: A tuple containing:
        - A tensor of shape (n,) with the scalar `a` for each point.
        - A tensor of shape (n,) with the reconstruction error for each point.
    """
    # Ensure d is a unit vector
    d_normalized = d / torch.norm(d)
    
    # Calculate the difference between each point and P
    Q_minus_P = points - P
    
    # Calculate the dot product of d_normalized with each Q - P
    dot_d_Q_minus_P = torch.matmul(Q_minus_P, d_normalized)
    
    # Calculate a for each point
    a_values = dot_d_Q_minus_P / torch.dot(d_normalized, d_normalized)
    
    # Calculate the projection of each point onto the line
    projections = P + a_values.unsqueeze(-1) * d_normalized
    
    # Calculate the reconstruction error for each point
    errors = torch.norm(points - projections, dim=1)
    
    return a_values, errors

In [17]:
def calculate_screw_surface_points(dd):
    p_90, p_92, p_91, p_93 = get_target_entry_points(dd)
    
    d_90_92 = p_92 - p_90
    d_90_92 = d_90_92 / torch.norm(d_90_92)

    d_91_93 = p_93 - p_91
    d_91_93 = d_91_93 / torch.norm(d_91_93)

    points = get_vert_points(dd)

    a_values_90_92, errors_90_92 = find_projection_scalars_and_errors(points, p_90, d_90_92)
    a_values_91_93, errors_91_93 = find_projection_scalars_and_errors(points, p_91, d_91_93)

    screw_trajectory_surface_points_90_92 = torch.stack([p_90 + a_values_90_92[errors_90_92 < 0.5].min() * d_90_92, p_90 + a_values_90_92[errors_90_92 < 0.5].max() * d_90_92])
    screw_trajectory_surface_points_91_93 = torch.stack([p_91 + a_values_91_93[errors_91_93 < 0.5].min() * d_91_93, p_91 + a_values_91_93[errors_91_93 < 0.5].max() * d_91_93])

    screw_surface_points_all = torch.cat([screw_trajectory_surface_points_90_92, screw_trajectory_surface_points_91_93], dim=0)

    return screw_surface_points_all

def calculate_corpus_entry_points(dd):
    p_90, p_92, p_91, p_93 = get_target_entry_points(dd)
    
    d_90_92 = p_92 - p_90
    d_90_92 = d_90_92 / torch.norm(d_90_92)

    d_91_93 = p_93 - p_91
    d_91_93 = d_91_93 / torch.norm(d_91_93)

    points = get_corpus_points(dd)

    a_values_90_92, errors_90_92 = find_projection_scalars_and_errors(points, p_90, d_90_92)
    a_values_91_93, errors_91_93 = find_projection_scalars_and_errors(points, p_91, d_91_93)

    corpus_entry_points_90_92 = p_90 + a_values_90_92[errors_90_92 < 0.5].min() * d_90_92
    corpus_entry_points_91_93 = p_91 + a_values_91_93[errors_91_93 < 0.5].min() * d_91_93

    corpus_entry_points_all = torch.stack([corpus_entry_points_90_92, corpus_entry_points_91_93], dim=0)

    return corpus_entry_points_all

def create_new_ctd(dd):
    #Load metadata
    vertebra = dd['vertebra']
    offset = torch.tensor(dd['offset'])
    vert_msk_nii = NII.load(dd['msk_path'], seg = True)
    origin = vert_msk_nii.origin
    rotation = vert_msk_nii.rotation
    shape = vert_msk_nii.shape
    zoom = vert_msk_nii.zoom
    
    print(shape)
    screw_surface_points = calculate_screw_surface_points(dd)
    corpus_entry_points = calculate_corpus_entry_points(dd)
    combined_points = torch.cat([screw_surface_points, corpus_entry_points], dim=0)

    screw_trajectory_ctd = tensor_to_ctd(combined_points, vertebra, origin = origin, rotation = rotation, idx_list=[90, 92, 91, 93, 94, 95], shape = shape, zoom = zoom, offset = offset)
    return screw_trajectory_ctd

In [18]:
screw_surface_points = calculate_screw_surface_points(dd)
corpus_entry_points = calculate_corpus_entry_points(dd)

combined_points = torch.cat([screw_surface_points, corpus_entry_points], dim=0)

screw_trajectory_ctd = tensor_to_ctd(combined_points, dd['vertebra'], origin = None, rotation = None, idx_list=[90, 92, 91, 93, 94, 95], shape = (128,128,96), zoom = (1,1,1))
vert_msk_nii = get_vert_msk_nii(dd)

new_ctd = create_new_ctd(dd)

#Keep only pois 90 and 91 in the ctd to check correct alignemnt
new_ctd.centroids = {(k0, k1): v for k0, k1, v in new_ctd.centroids.items() if k1 in [90, 91, 94, 95]}

visualize_pois(
    ctd = new_ctd,
    seg_vert = NII.load(dd['msk_path'], seg = True),
    vert_idx_list=[dd['vertebra']] 
)

####################################
  offset = torch.tensor(dd['offset'])
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/as

(76, 85, 52)
[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 67.35it/s]


Widget(value="<iframe src='http://localhost:44411/index.html?ui=P_0x79adda44f640_5&reconnect=auto' style='widt…

In [71]:
import os
for dd in ds:
    new_ctd = create_new_ctd(dd)
    poi_path = dd['poi_path']
    #Rename the file at old poi path to *_old.json
    os.rename(poi_path, poi_path.replace('.json', '_old.json'))
    new_ctd.save(poi_path)

####################################
  offset = torch.tensor(dd['offset'])
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/home/daniel/anaconda3/envs/thesis/lib/python3.10/as

(55, 75, 43)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/002/19/poi.json in format POI[0m[0m
(64, 82, 50)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/002/20/poi.json in format POI[0m[0m
(75, 87, 52)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/002/21/poi.json in format POI[0m[0m
(87, 88, 57)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/002/22/poi.json in format POI[0m[0m
(80, 90, 50)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/002/23/poi.json in format POI[0m[0m
(86, 83, 42)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/002/24/poi.json in format POI[0m[0m
(86, 87, 63)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/008/20/poi.json in format POI[0m[0m
(91, 91, 65)
[96m[*] Centroids saved: /home/daniel/Data/Implants/cutouts_scale-1-1-1/008/21/poi.json in format POI[0m[0m
(104, 96

In [9]:
vert_nii = NII.load('/home/daniel/Data/Implants/cutouts_scale-1-1-1/002/22/vertseg.nii.gz', seg=True)
poi = POI.load('/home/daniel/Data/Implants/cutouts_scale-1-1-1/002/22/poi_surface.json')

vert_nii.rescale_()
poi.rescale_()

visualize_pois(
    ctd = poi,
    seg_vert = vert_nii,
    vert_idx_list=[22],
    radius = 5
)

[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 57.88it/s]




Widget(value="<iframe src='http://localhost:41823/index.html?ui=P_0x7a7d3826b520_2&reconnect=auto' style='widt…