In [1]:
import sys, os
sys.path.append('/home/daniel/Documents/Uni/MT/poi-prediction')
from src.dataset.dataset import GruberDataset
from src.transforms.transforms import RandAffine

from utils.misc import np_to_bids_nii
from BIDS import POI, NII
from BIDS.vert_constants import conversion_poi, conversion_poi2text
from BIDS.POI_plotter import visualize_pois
import torch
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

gruber_master_df = pd.read_csv('/home/daniel/Data/Gruber/cutouts_scale-1-1-1/master_df.csv')
ds = GruberDataset(
    master_df = gruber_master_df,
    input_shape = (128,128,96),
    include_com = False,
    flip_prob = 0
)

In [2]:
#Iterate through the ds to create a mapping from subject-vertebra pair to corresponding index
sub_vert_to_idx = {}
for idx, dd in enumerate(ds):
    sub_vert_to_idx[(dd['subject'], dd['vertebra'])] = idx

In [3]:
sub_vert_to_idx['WS-45', 8]

103

In [4]:
# Define some useful utility functions
def get_dd_ctd(dd, poi_list = [90,91,92,93]):
    ctd = {}
    vertebra = dd['vertebra']

    for poi_coords, poi_idx in zip(dd['target'], dd['target_indices']):
        coords = (poi_coords[0].item(), poi_coords[1].item(), poi_coords[2].item())
        if poi_list is None or poi_idx in poi_list:
            ctd[vertebra, poi_idx.item()] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = (1,1,1), shape = (128,128,96))
    return ctd

def get_ctd(target, target_indices, vertebra, poi_list):
    ctd = {}
    for poi_coords, poi_idx in zip(target, target_indices):
        coords = (poi_coords[0].item(), poi_coords[1].item(), poi_coords[2].item())
        if poi_list is None or poi_idx in poi_list:
            ctd[vertebra, poi_idx.item()] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = (1,1,1), shape = (128,128,96))
    return ctd

def get_vert_msk_nii(dd):
    vertebra = dd['vertebra']
    msk = dd['input'].squeeze(0)
    return vertseg_to_vert_msk_nii(vertebra, msk)

def vertseg_to_vert_msk_nii(vertebra, msk):
    vert_msk = (msk != 0) * vertebra
    vert_msk_nii = np_to_bids_nii(vert_msk.numpy().astype(np.int32))
    vert_msk_nii.seg = True
    return vert_msk_nii

def get_vertseg_nii(dd):
    vertseg = dd['input'].squeeze(0)
    vertseg_nii = np_to_bids_nii(vertseg.numpy().astype(np.int32))
    vertseg_nii.seg = True
    return vertseg_nii

def get_vert_points(dd):
    msk = dd['input'].squeeze(0)
    vert_points = torch.where(msk)
    vert_points = torch.stack(vert_points, dim=1)
    return vert_points

def get_target_entry_points(dd):
    ctd = get_ctd(dd)
    vertebra = dd['vertebra']
    p_90 = torch.tensor(ctd[vertebra, 90])
    p_92 = torch.tensor(ctd[vertebra, 92])

    p_91 = torch.tensor(ctd[vertebra, 91])
    p_93 = torch.tensor(ctd[vertebra, 93])

    return p_90, p_92, p_91, p_93

def tensor_to_ctd(t, vertebra, origin, rotation, idx_list = None, shape = (128, 128, 96), zoom = (1,1,1), offset = (0,0,0)):
    ctd = {}
    for i, coords in enumerate(t):
        coords = coords.float() - torch.tensor(offset)
        coords = (coords[0].item(), coords[1].item(), coords[2].item())
        if idx_list is None:
            ctd[vertebra, i] = coords
        elif i < len(idx_list):
            ctd[vertebra, idx_list[i]] = coords

    ctd = POI(centroids = ctd, orientation = ('L', 'A', 'S'), zoom = zoom, shape = shape, origin = origin, rotation = rotation)
    return ctd

In [5]:
poi_types = {f"{key}: {value}": key for key, value in conversion_poi2text.items()}

def display_dd_pois(dd):
    display_pois(
        seg_vert = get_vert_msk_nii(dd),
        target = dd['target'],
        target_indices = dd['target_indices'],
        vertebra = dd['vertebra'],
        subject = dd['subject']
    )

def display_pois(seg_vert, target, target_indices, vertebra, subject):
    # Multi-Select for POI Types
    poi_type_select = widgets.SelectMultiple(
        options=poi_types,
        rows=23,
        description='POI Types',
        tooltip='Select the POI types to visualize',
        disabled=False
    )

    # Button for updating the visualization
    update_button = widgets.Button(
        description='Update',
        disabled=False,
        button_style='', # 'success', 'info', 'warning', 'danger' or ''
        tooltip='Update the visualization',
        icon='check' # (FontAwesome names without the `fa-` prefix)
    )

    def on_button_clicked(b):
        print(target, target_indices, vertebra, subject)
        ctd = get_ctd(target, target_indices, vertebra, poi_list=poi_type_select.value)
        print(f'Visualizing Subject {subject}, Vertebra {vertebra}, POIs {poi_type_select.value}')
        print(ctd)
        visualize_pois(
            ctd = ctd,
            seg_vert = seg_vert,
            vert_idx_list = [vertebra],
        )

    update_button.on_click(on_button_clicked)

    display(poi_type_select, update_button)

In [6]:
print(ds[466]['vertebra'])
display_dd_pois(ds[103])

17


SelectMultiple(description='POI Types', options={'81: SSL': 81, '109: ALL_CR_S': 109, '101: ALL_CR': 101, '117…

Button(description='Update', icon='check', style=ButtonStyle(), tooltip='Update the visualization')

tensor([[ 61.0810,  28.5310,  55.2150],
        [ 63.4540,  95.3610,  52.5420],
        [ 64.2100,  85.2550,  54.6300],
        [ 64.7460,  93.1310,  31.5540],
        [ 63.9710,  75.1100,  40.0050],
        [ 72.1630,  96.2220,  52.5420],
        [ 71.8860,  82.3260,  55.0350],
        [ 73.2640,  88.2600,  32.1480],
        [ 56.9180,  75.0340,  40.6080],
        [ 53.7120,  95.1600,  52.3260],
        [ 56.0090,  82.1350,  54.9720],
        [ 54.9080,  89.2650,  32.5710],
        [ 70.3740,  74.9670,  40.5450],
        [ 63.9620,  68.8890,  61.8210],
        [ 65.0330,  63.5200,  50.9940],
        [ 62.0570,  32.7520,  59.1390],
        [ 60.3440,  29.7760,  50.8860],
        [ 52.5150,  77.3020,  62.3070],
        [105.4300,  80.9190,  56.3760],
        [ 49.5960,  75.0910,  48.6720],
        [ 22.3210,  77.4740,  54.7560],
        [ 75.5890,  78.8230,  61.6320],
        [ 77.9920,  74.8610,  49.4910]]) tensor([ 81, 101, 102, 103, 104, 109, 110, 111, 112, 117, 118, 119, 120, 125,
 

100%|██████████| 1/1 [00:00<00:00, 27.34it/s]


In [5]:
class LandmarksRandAffine:
    def __init__(
            self,
            prob,
            rotate_range,
            shear_range,
            translate_range,
            scale_range,
            device = 'cpu'
    ):
        self.prob = prob
        self.rotate_range = rotate_range
        self.shear_range = shear_range
        self.translate_range = translate_range
        self.scale_range = scale_range

        self.image_transform = RandAffine(
            prob = prob,
            rotate_range = rotate_range,
            shear_range = shear_range,
            translate_range = translate_range,
            scale_range = scale_range,
            mode='nearest',
            padding_mode = 'zeros',
            device = device
        )

    def __call__(self, dd):
        volume = dd['input']
        landmarks = dd['target']

        #Apply MonAI's RandAffine to the volume
        transformed_volume, affine_matrix = self.image_transform(volume)

        #Convert landmarks to homogeneous coordinates
        ones = torch.ones(landmarks.shape[0], 1, dtype=landmarks.dtype, device=landmarks.device)
        homogeneous_landmarks = torch.cat([landmarks, ones], dim=1)

        #Apply the affine transformation to the landmarks
        transformed_landmarks = torch.mm(homogeneous_landmarks, torch.linalg.inv(torch.tensor(affine_matrix, dtype=torch.float).t()))[:, :3]

        dd['input'] = transformed_volume
        dd['target'] = transformed_landmarks

        return dd
    
class LandMarksRandHorizontalFlip:
    def __init__(self, prob, flip_pairs, device = 'cpu'):
        self.prob = prob
        self.flip_pairs = flip_pairs

    def __call__(self, dd):
        if torch.rand(1) < self.prob:
            target_indices = dd['target_indices']

            #Flip the volume horizontally, since the orientation is LAS (Left, Anterior, Superior), this means flipping along the L axis
            dd['input'] = torch.flip(dd['input'], dims=[1])

            #Flip the landmarks horizontally
            dd['target'][:, 0] = dd['input'].shape[1] - dd['target'][:, 0]

            #Reorder the landmarks according to the swap indices
            indices_map = {k.item(): v for v, k in enumerate(target_indices)} 
            new_positions = [indices_map[self.flip_pairs[k.item()]] for k in target_indices]

            dd['target'] = dd['target'][new_positions]
            
        return dd
    
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, dd):
        for transform in self.transforms:
            dd = transform(dd)
        return dd

In [6]:
affine_transform = LandmarksRandAffine(
    prob = 1,
    rotate_range = ([-10, 10], [-10, 10], [-10, 10]),
    shear_range = ([-0.1, -0.1], [0.1, 0.1], [0.1, 0.1]),
    translate_range = ([-10, 10], [-10, 10], [-10, 10]),
    scale_range = ([-0.3, -0.3], [0.1, 0.1], [0.5, 0.5]),
    device = 'cpu'
)

flip_transform = LandMarksRandHorizontalFlip(
    prob = 1,
    flip_pairs = {
                # These are the middle points, i.e. the ones that are not flipped
                81 : 81,
                101: 101,
                103: 103,
                102: 102,
                104: 104,
                125: 125,
                127: 127,
                134: 134,
                136: 136,
                
                # Flipped left to right
                109: 117,
                111: 119,
                110: 118,
                112: 120,
                149: 141,
                151: 143,
                142: 144,

                # Flipped right to left
                117: 109,
                119: 111,
                118: 110,
                120: 112,
                141: 149,
                143: 151,
                144: 142,

                #Center of mass, does not need to be flipped
                41: 41,
                42: 42,
                43: 43,
                44: 44,
                45: 45,
                46: 46,
                47: 47,
                48: 48,
                49: 49,
                50: 50,

                0:0
            },
)

transforms = Compose([
    affine_transform,
    flip_transform
])

# Get the volume and landmarks
dd = ds[7]

# Apply the affine transformation
dd = transforms(dd)

# Get the transformed landmarks
transformed_landmarks = dd['target']
transformed_volume = dd['input']

transformed_vert_msk_nii = vertseg_to_vert_msk_nii(dd['vertebra'], transformed_volume.squeeze(0))

# Display the transformed volume and landmarks
display_pois(
    seg_vert = transformed_vert_msk_nii,
    target = transformed_landmarks,
    target_indices = dd['target_indices'],
    vertebra = dd['vertebra'],
    subject = dd['subject']
)

SelectMultiple(description='POI Types', options={'81: SSL': 81, '109: ALL_CR_S': 109, '101: ALL_CR': 101, '117…

Button(description='Update', icon='check', style=ButtonStyle(), tooltip='Update the visualization')

Visualizing Subject WS-06, Vertebra 8, POIs (81, 109, 101, 117, 111, 103, 119, 110, 102, 118, 112, 104, 120, 149, 125, 141, 151, 127, 143, 134, 136, 142, 144)
POI(centroids={8: {81: (111.60933685302734, 63.394046783447266, 91.57066345214844), 101: (44.054168701171875, 69.55589294433594, 54.38749694824219), 102: (61.600303649902344, 65.45057678222656, 60.80192565917969), 103: (36.9173583984375, 78.61259460449219, 59.467750549316406), 104: (58.354705810546875, 76.40855407714844, 69.40667724609375), 109: (40.258697509765625, 63.855281829833984, 58.468299865722656), 110: (57.27146911621094, 59.513954162597656, 64.8820571899414), 111: (33.5614013671875, 72.78521728515625, 63.892601013183594), 112: (53.832977294921875, 71.10198974609375, 71.9076919555664), 117: (52.62055969238281, 74.59538269042969, 53.145729064941406), 118: (71.84156799316406, 69.90131378173828, 60.54643249511719), 119: (44.85203552246094, 83.4410400390625, 58.1368408203125), 120: (65.68084716796875, 80.63076782226562, 68.2

100%|██████████| 1/1 [00:00<00:00, 40.23it/s]


In [None]:
transformed_volume[0]

metatensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 