In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact, Dropdown, IntSlider

In [17]:
# Put your file path here
data_path = "/Users/ksamgam/Downloads/MRNet-v1.0"
train_path = "/Users/ksamgam/Downloads/MRNet-v1.0/train"

In [18]:
# matplotlib notebook
plt.style.use('grayscale')

train_abnl = pd.read_csv(data_path + '/train-abnormal.csv', header=None,
                       names=['Case', 'Abnormal'], 
                       dtype={'Case': str, 'Abnormal': np.int64})

In [19]:
# data loading functions
def load_one_stack(case, data_path = train_path, plane= 'coronal'):
    fpath = data_path + "/" + plane + "/" + '{}.npy'.format(case)
    return np.load(fpath)

def load_stacks(case, data_path=train_path):
    x = {}
    planes = ['coronal', 'sagittal', 'axial']
    for i, plane in enumerate(planes):
        x[plane] = load_one_stack(case, plane=plane)
    return x

# interactive viewer
class KneePlot():
    def __init__(self, x, figsize=(10, 10)):
        self.x = x
        self.planes = list(x.keys())
        self.slice_nums = {plane: self.x[plane].shape[0] for plane in self.planes}
        self.figsize = figsize
    
    def _plot_slices(self, plane, im_slice): 
        fig, ax = plt.subplots(1, 1, figsize=self.figsize)
        ax.imshow(self.x[plane][im_slice, :, :])
        plt.show()
    
    def draw(self):
        planes_widget = Dropdown(options=self.planes)
        plane_init = self.planes[0]
        slice_init = self.slice_nums[plane_init] - 1
        slices_widget = IntSlider(min=0, max=slice_init, value=slice_init//2)
        def update_slices_widget(*args):
            slices_widget.max = self.slice_nums[planes_widget.value] - 1
            slices_widget.value = slices_widget.max // 2
        planes_widget.observe(update_slices_widget, 'value')
        interact(self._plot_slices, plane=planes_widget, im_slice=slices_widget)
    
    def resize(self, figsize): self.figsize = figsize

In [16]:
# example usage
case = train_abnl.Case[1]
x = load_stacks(case)
plot = KneePlot(x, figsize=(8, 8))
plot.draw()

aW50ZXJhY3RpdmUoY2hpbGRyZW49KERyb3Bkb3duKGRlc2NyaXB0aW9uPXUncGxhbmUnLCBvcHRpb25zPSgnYXhpYWwnLCAnY29yb25hbCcsICdzYWdpdHRhbCcpLCB2YWx1ZT0nYXhpYWwnKSzigKY=
