In [16]:
import sys
sys.path.append('/home/daniel/Documents/Uni/MT/poi-prediction')
from src.dataset.dataset import GruberDataset, ImplantsDataset

from utils.misc import np_to_bids_nii
from BIDS import POI
from BIDS.vert_constants import conversion_poi2text
from BIDS.POI_plotter import visualize_pois
import torch
import pandas as pd

import numpy as np
import ipywidgets as widgets

dataset = "Implants"
# dataset = "Gruber"

gruber_master_df = pd.read_csv('/home/daniel/Data/Gruber/cutouts_scale-1-1-1/master_df.csv')
implants_master_df = pd.read_csv('/home/daniel/Data/Implants/cutouts_scale-1-1-1/master_df.csv')

if dataset == "Gruber":
    ds = GruberDataset(
        master_df = gruber_master_df,
        input_shape = (128,128,96),
        include_com = False,
        flip_prob = 0
    )
    poi2text = conversion_poi2text

elif dataset == "Implants":
    ds = ImplantsDataset(
        master_df = implants_master_df,
        input_shape = (128,128,96),
        include_com = False,
        flip_prob = 0,
        poi_file_ending = 'poi_surface.json'
    )
    poi2text = {
        90: 'Left Entry',
        91: 'Right Entry',
        92: 'Left Target',
        93: 'Right Target',
    }

#Iterate through the ds to create a mapping from subject-vertebra pair to corresponding index
sub_vert_to_idx = {}
for idx, dd in enumerate(ds):
    sub_vert_to_idx[(dd['subject'], dd['vertebra'])] = idx


In [17]:
sub_vert_to_idx['2', 20]

1

In [18]:
# Define some useful utility functions
def get_dd_ctd(dd, poi_list = [90,91,92,93]):
    ctd = {}
    vertebra = dd['vertebra']

    for poi_coords, poi_idx in zip(dd['target'], dd['target_indices']):
        coords = (poi_coords[0].item(), poi_coords[1].item(), poi_coords[2].item())
        if poi_list is None or poi_idx in poi_list:
            ctd[vertebra, poi_idx.item()] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = (1,1,1), shape = (128,128,96))
    return ctd

def get_ctd(target, target_indices, vertebra, poi_list):
    ctd = {}
    for poi_coords, poi_idx in zip(target, target_indices):
        coords = (poi_coords[0].item(), poi_coords[1].item(), poi_coords[2].item())
        if poi_list is None or poi_idx in poi_list:
            ctd[vertebra, poi_idx.item()] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = (1,1,1), shape = (128,128,96))
    return ctd

def get_vert_msk_nii(dd):
    vertebra = dd['vertebra']
    msk = dd['input'].squeeze(0)
    return vertseg_to_vert_msk_nii(vertebra, msk)

def vertseg_to_vert_msk_nii(vertebra, msk):
    vert_msk = (msk != 0) * vertebra
    vert_msk_nii = np_to_bids_nii(vert_msk.numpy().astype(np.int32))
    vert_msk_nii.seg = True
    return vert_msk_nii

def get_vertseg_nii(dd):
    vertseg = dd['input'].squeeze(0)
    vertseg_nii = np_to_bids_nii(vertseg.numpy().astype(np.int32))
    vertseg_nii.seg = True
    return vertseg_nii

def get_vert_points(dd):
    msk = dd['input'].squeeze(0)
    vert_points = torch.where(msk)
    vert_points = torch.stack(vert_points, dim=1)
    return vert_points

def get_target_entry_points(dd):
    ctd = get_ctd(dd)
    vertebra = dd['vertebra']
    p_90 = torch.tensor(ctd[vertebra, 90])
    p_92 = torch.tensor(ctd[vertebra, 92])

    p_91 = torch.tensor(ctd[vertebra, 91])
    p_93 = torch.tensor(ctd[vertebra, 93])

    return p_90, p_92, p_91, p_93

def tensor_to_ctd(t, vertebra, origin, rotation, idx_list = None, shape = (128, 128, 96), zoom = (1,1,1), offset = (0,0,0)):
    ctd = {}
    for i, coords in enumerate(t):
        coords = coords.float() - torch.tensor(offset)
        coords = (coords[0].item(), coords[1].item(), coords[2].item())
        if idx_list is None:
            ctd[vertebra, i] = coords
        elif i < len(idx_list):
            ctd[vertebra, idx_list[i]] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = zoom, shape = shape, origin = origin, rotation = rotation)
    return ctd

poi_types = {f"{key}: {value}": key for key, value in poi2text.items()}

subjects = sorted(list(set(key[0] for key in sub_vert_to_idx.keys())))
vertebrae = sorted(list(set(key[1] for key in sub_vert_to_idx.keys())))

def display_pois():
    # Multi-Select for POI Types
    subject_select = widgets.Dropdown(
        options=subjects,
        description='Subject:',
        disabled=False
    )

    vertebra_select = widgets.Dropdown(
        options=vertebrae,
        description='Vertebra:',
        disabled=False
    )

    poi_type_select = widgets.SelectMultiple(
        options=poi_types,
        rows=23 if dataset == "Gruber" else 4,
        description='POI Types',
        tooltip='Select the POI types to visualize',
        disabled=False
    )

    def update_vert_select(*args):
        vertebra_select.options = sorted(list(set(gruber_master_df[gruber_master_df['subject'] == subject_select.value]['vertebra'])))

    subject_select.observe(update_vert_select, 'value')

    # Button for updating the visualization
    update_button = widgets.Button(
        description='Update',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Update the visualization',
        icon='check' # (FontAwesome names without the `fa-` prefix)
    )

    def on_button_clicked(b):
        dd = ds[sub_vert_to_idx[(subject_select.value, vertebra_select.value)]]
        
        seg_vert = get_vert_msk_nii(dd)
        target = dd['target']
        target_indices = dd['target_indices']
        vertebra = dd['vertebra']
        subject = dd['subject']
        ctd = get_ctd(target, target_indices, vertebra, poi_list=poi_type_select.value)
        print(f'Visualizing Subject {subject}, Vertebra {vertebra}, POIs {poi_type_select.value}')
        visualize_pois(
            ctd = ctd,
            seg_vert = seg_vert,
            vert_idx_list = [vertebra],
        )

    update_button.on_click(on_button_clicked)

    display(subject_select, vertebra_select, poi_type_select, update_button)

In [19]:
display_pois()

Dropdown(description='Subject:', options=('10', '2', '3', '5', '6', '7', '8', '9'), value='10')

Dropdown(description='Vertebra:', options=(2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21…

SelectMultiple(description='POI Types', options={'90: Left Entry': 90, '91: Right Entry': 91, '92: Left Target…

Button(description='Update', icon='check', style=ButtonStyle(), tooltip='Update the visualization')

Visualizing Subject 10, Vertebra 21, POIs (91,)
[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 25.02it/s]


Visualizing Subject 10, Vertebra 21, POIs (92,)
[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 28.79it/s]


Visualizing Subject 10, Vertebra 21, POIs (93,)
[0m[ ] Image reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
[0m[*] Centroids reoriented from ('L', 'A', 'S') to ('P', 'I', 'R')[0m[0m
('P', 'I', 'R') ('P', 'I', 'R')


100%|██████████| 1/1 [00:00<00:00, 26.53it/s]
