In [None]:
import vtk
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

In [1]:

## plot two atlases to highlight common structure + count common structures


class scrollable_image():
    def __init__(self, image_path, reader_type, z_slice, colour_window, colour_level):
        self.image_path = image_path
        self.reader_type = reader_type
        self.z_slice = z_slice
        self.colour_window = colour_window
        self.colour_level = colour_level

        # keep updated with obj at current step in pipeline
        # to allow pipeline steps to varry
        self.current_step = None

        self.colour_bar_num_colours = 0

    def read_image(self):
        if self.reader_type == "DICOM":
            dicom_reader = vtk.vtkDICOMImageReader()
            dicom_reader.SetDirectoryName(self.image_path)
            self.reader = dicom_reader
        else:
            nifti_reader = vtk.vtkNIFTIImageReader()
            nifti_reader.SetFileName(self.image_path)
            self.reader = nifti_reader

        self.reader.Update()
        self.current_step = self.reader

        # set num z slices in img
        self.z_slices = self.reader.GetOutput().GetDimensions()[2]
    
    def resize_image(self, render_size):
        self.resizer = vtk.vtkImageResize()
        self.resizer.SetInputConnection(self.current_step.GetOutputPort())
        # gets original dimensions
        original_dims = np.asarray(self.resizer.GetOutput().GetDimensions())
        new_dims = list( original_dims*(render_size/original_dims[0:1].max()).astype(int))

        # adjusts output dimensions of the image
        self.resizer.SetOutputDimensions(*new_dims)
        self.resizer.Update()

        # scale zslice as well - doesn't seem needed rn???
        #self.z_slice = self.z_slice*(render_size/original_dims[0:1].max()).astype(int)

        self.current_step = self.resizer

    def map_image(self):
        self.mapper = vtk.vtkImageMapper()
        self.mapper.SetInputConnection(self.current_step.GetOutputPort())

        self.mapper.SetColorWindow(self.colour_window)
        self.mapper.SetColorLevel(self.colour_level)
        self.mapper.SetZSlice(self.z_slice)
        self.mapper.Update()

        self.current_step = self.mapper

    def add_colour_mapper(self, num_values):
        num_unique_structures = 250
        num_comparable_structures = 22
        
        # create lookup table
        lut = vtk.vtkLookupTable()
        lut.SetNumberOfTableValues(num_values + 1) # add one to account for included background at 0
        lut.SetRange(0, num_values)
        lut.Build()

        self.colour_bar_lut = vtk.vtkLookupTable()
        self.colour_bar_lut.SetNumberOfTableValues(num_comparable_structures)
        self.colour_bar_lut.SetRange(num_unique_structures + 1, num_unique_structures + num_comparable_structures)
        self.colour_bar_lut.Build()

        color_series = vtk.vtkColorSeries()
        # choose the vtk colour scheme
        color_series.SetColorScheme(vtk.vtkColorSeries.BREWER_QUALITATIVE_SET3)

        # include 0 and go to num_unique_structures + 1 to account for a black bnackground when intensity is 0
        grayscale_lut = [i for i in np.linspace(0, 255, num_unique_structures + 1, dtype=int)]
        grey_offset = 0

        # map the colour series into the lut - can't do automatically
        num_colours_in_map = color_series.GetNumberOfColors()
        for i in range(0, num_values + 2):            
            # scale to proper colours - lighter for non compared
            if i <= num_unique_structures: # num of un aligned structures
                lut.SetTableValue(i, grayscale_lut[i], grayscale_lut[i], grayscale_lut[i], 1.0)

            # if structure in both atlases
            else:
                
                # skip over grey in colour series
                grey_idx_in_series = 8
                if self.colour_bar_num_colours % (num_colours_in_map - grey_offset) == grey_idx_in_series:
                    grey_offset += 1 
                
                colour = color_series.GetColor((self.colour_bar_num_colours + grey_offset)  % num_colours_in_map)

                # for first loop over colours
                if self.colour_bar_num_colours < (num_colours_in_map - 1):
                    scale = 255
                elif self.colour_bar_num_colours < 2*(num_colours_in_map - 1):
                    scale = 150
                else:
                    scale = 5

                lut.SetTableValue(i, colour.GetRed() / scale, colour.GetGreen() / scale, colour.GetBlue() / scale, 1.0)
                
                # save comparable structures
                self.colour_bar_lut.SetTableValue(self.colour_bar_num_colours, colour.GetRed() / scale, colour.GetGreen() / scale, colour.GetBlue() / scale, 1.0)
                self.colour_bar_num_colours += 1
                

        self.colour_mapper = vtk.vtkImageMapToColors()
        self.colour_mapper.SetInputConnection(self.current_step.GetOutputPort())
        self.colour_mapper.SetLookupTable(lut)
        self.colour_mapper.PassAlphaToOutputOn()
        self.colour_mapper.Update()

        self.current_step = self.colour_mapper


    def add_actor(self):
        self.actor = vtk.vtkActor2D()
        self.actor.SetMapper(self.current_step)

    def add_text_actor(self):
        self.text_actor = vtk.vtkTextActor()
        self.text_actor.GetTextProperty().SetFontSize(24)
        self.text_actor.GetTextProperty().SetColor(1, 1, 1)
        self.text_actor.SetInput("Slice: " + str(self.z_slice) + "/" + str(self.z_slices))
    
    def set_text_position(self, pos_x, pos_y):
        self.text_actor.SetPosition(pos_x, pos_y)

    def add_colour_bar(self, custom_labels):
        self.colour_bar = vtk.vtkScalarBarActor()
        self.colour_bar.SetLookupTable(self.colour_bar_lut)
        #self.colour_bar.SetUseCustomLabels(True)
        self.colour_bar.SetTitle("Comparable Structures")
        self.colour_bar.SetNumberOfLabels(self.colour_bar_num_colours)
        self.colour_bar.SetVerticalTitleSeparation(10)

        #self.colour_bar.SetAnnotation(250, "hiiii")

        # TODO: fix labeled colour bar
        # only can pass a double array
        # to change text need to use text actors with annotations??
        # create a vtk double array of label names to pass to function
        #custom_labels_vtk_arr = vtk.vtkStringArray()

        #for label in custom_labels:
        #    custom_labels_vtk_arr.InsertNextValue(label)
        #self.colour_bar.SetCustomLabels(custom_labels_vtk_arr)


    def create_mousewheel_callbacks(self):
    
        def scroll_forward(obj=None,event=None):
            self.z_slice = min(self.mapper.GetWholeZMax(), self.z_slice + 1)
            self.mapper.SetZSlice(self.z_slice)
            self.text_actor.SetInput("Slice: " + str(self.z_slice) + "/" + str(self.z_slices))

        def scroll_backward(obj=None,event=None):
            self.z_slice = max(self.mapper.GetWholeZMin(),self.mapper.GetZSlice() - 1)
            self.mapper.SetZSlice(self.z_slice)
            self.text_actor.SetInput("Slice: " + str(self.z_slice) + "/" + str(self.z_slices))
            
        return scroll_forward, scroll_backward

    def create_key_press_scroll(self):

        def key_press_scroll(obj, event):
            # Get the type of keypress
            key = obj.GetKeySym()
            
            if key == "Up":
                self.z_slice = min(self.mapper.GetWholeZMax(), self.z_slice + 1)
                self.mapper.SetZSlice(self.z_slice)
                self.text_actor.SetInput("Slice: " + str(self.z_slice) + "/" + str(self.z_slices))

            elif key == "Down":
                self.z_slice = max(self.mapper.GetWholeZMin(),self.mapper.GetZSlice() - 1)
                self.mapper.SetZSlice(self.z_slice)
                self.text_actor.SetInput("Slice: " + str(self.z_slice) + "/" + str(self.z_slices))
            self.window.Render()
        return key_press_scroll
    
    def set_render_window(self, window):
        self.window = window
    


In [2]:
# NOTE: the resampling of these labels should match the resampling of the 
# images to use for training!

human_labels_image_path = "./atlases/human_labels_common_proc.nii"
mouse_labels_image_path = "./atlases/mouse_labels_common_proc.nii"


reader_type = "NIFTY"
z_slice = 300
render_size_x = 500
render_size_y = 700
render_size = max(render_size_x, render_size_y)

human_labels_sitk = sitk.ReadImage(human_labels_image_path)
human_labels_np = sitk.GetArrayFromImage(human_labels_sitk)
human_labels_num_values = human_labels_np.max()

mouse_labels_sitk = sitk.ReadImage(mouse_labels_image_path)
mouse_labels_np = sitk.GetArrayFromImage(mouse_labels_sitk)
mouse_labels_num_values = mouse_labels_np.max()
num_values = max(human_labels_num_values, mouse_labels_num_values).astype(int)
min_intensity = 0
max_intensity = num_values
colour_window = max_intensity - min_intensity
colour_level = (min_intensity + max_intensity) / 2


mouse_labels = scrollable_image(mouse_labels_image_path, reader_type, z_slice, colour_window, colour_level)
mouse_labels.read_image()
mouse_labels.resize_image(render_size)
mouse_labels.add_colour_mapper(num_values)
mouse_labels.map_image()
mouse_labels.add_actor()
mouse_labels.add_text_actor()

human_labels = scrollable_image(human_labels_image_path, reader_type, z_slice, colour_window, colour_level)
human_labels.read_image()
human_labels.resize_image(render_size)
human_labels.add_colour_mapper(num_values)
human_labels.map_image()
human_labels.add_actor()
human_labels.add_text_actor()

custom_labels = ['amygdala', 'hippocampus', 'lateral ventricle', '4th ventricle', 'thalamus', 'third ventricle', 'basal forebrain', 'optic tract', 'cerebellum white matter', 'accumbens', 'pallidus']

custom_labels_LR = [side_label for label in custom_labels for side_label in [label + " R", label + " L"]]
human_labels.add_colour_bar(custom_labels_LR)

actors = {
    "mouse": mouse_labels.actor,
    "human": human_labels.actor,
    "mouse_text": mouse_labels.text_actor,
    "human_text": human_labels.text_actor,
    "colour_bar": human_labels.colour_bar,
}

mappers = {
    "mouse": mouse_labels.mapper,
    "human": human_labels.mapper,
}

# define interactor and rendering
interactor = vtk.vtkRenderWindowInteractor()
renderer = vtk.vtkRenderer()
renderer.SetBackground(0, 0, 0) # make sure background is black

window = vtk.vtkRenderWindow()
actors["human"].GetPositionCoordinate().SetValue(render_size_x, 0) # add offset
window.SetSize(2*render_size_x + 200, render_size_y)

# position the text TODO: make centered!!
mouse_labels.set_text_position(210, render_size_y - 100)
human_labels.set_text_position(685, render_size_y - 100)

window.AddRenderer(renderer)
interactor.SetRenderWindow(window)


fwd, bwd = mouse_labels.create_mousewheel_callbacks()
interactor.AddObserver('MouseWheelForwardEvent', fwd)
interactor.AddObserver('MouseWheelBackwardEvent', bwd)

kp = human_labels.create_key_press_scroll()
interactor.AddObserver("KeyPressEvent", kp)
# set render window to update on keypress
human_labels.set_render_window(window)

# Add actors to the renderer
for actor in actors.values():
    renderer.AddActor(actor)

# Display
window.Render()
interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera())  # Ensure proper interaction style
interactor.Initialize()
interactor.Start()

NameError: name 'sitk' is not defined

### TODO:
- add labels to colour bar