In [1]:
import sys
sys.path.append("/media/janis/Storage1/development/faimed3d/")

In [2]:
from faimed3d.all import *

In [3]:
#import inspect
#inspect.getmodule(TensorDicom3D)

In [4]:
from ipywidgets import interactive, widgets
from IPython.display import display
import matplotlib.pyplot as plt
from itertools import chain, islice

# Example Data

In [5]:
## generate example data ##

# two single images
image = TensorDicom3D.create('/media/scaleout/vahldiek/MRI/SIJ/TRAINING_FINAL/01001/T1')
image2 = TensorDicom3D.create('/media/scaleout/vahldiek/MRI/SIJ/TRAINING_FINAL/01004/T1')
image3 = TensorDicom3D.create('/media/janis/Storage1/DeepSpA/T1_NIFTI/01004_n4b_corrected.nii.gz')

# a list of several images
image_list = [image, image2, image3, image]

# a dataloader from CSV
d = pd.read_csv('/media/scaleout/vahldiek/MRI/SIJ/mri_sij_labels_TRAINING_FINAL.csv')
d = d[d['has_T1'] == 1]
d = d[d['has_TIRM'] == 1]
d = d[['path_T1', 'active_changes', 'is_valid', 'patient']].dropna()
d['path_T1'] = d['path_T1'].str.replace('../', '/media/scaleout/vahldiek/MRI/SIJ/', regex=False)
for index, row in d.iterrows():
    patient = row['patient']
    d.at[index,'path_T1'] = '/media/janis/Storage1/DeepSpA/T1_NIFTI/'+patient + '_n4b_corrected.nii.gz'
for index, row in d.iterrows(): 
    path = row['path_T1']
    patient = row['patient']
    if not os.path.isfile(path):
        d = d.drop(index)
dls = ImageDataLoaders3D.from_df(d, '/media/..', 
                                item_tfms = ResizeCrop3D(crop_by = ((0.01, 0.01), (0.01, 0.01), 0.01), resize_to = (16, 112, 112), perc_crop = True),
                                #rescale_method=PiecewiseHistScaling(percs, standard_scale),
                                batch_tfms = [
                                          #RandomPerspective3D(150, p=0.2), 
                                          #    RandomRotate3DBy(degrees=(20), axis=[-3], p=0.2),
                                          #    RandomContrast3D(p=0.2),
                                          #    RandomCrop3D((2, 25, 25), (1, 10, 10), p = 0.5), 
                                          #    Resize3D((14, 112, 112), p=1),
                                              ],
                                valid_col = 'is_valid',
                                bs = 5, val_bs = 5)

Cleaning tmpdir.
removing 459 files from /tmp/faimed3d_metadata/
You can disable automatic cleanup of the tmpdir (e.g. when doing multiple sessions in parallel) with setting clean_tmpdir=False
No rescale method was used. This is not advisable due to high risk of exploding gradients. Falling back to mean scaling.


# Basic View Class

In [6]:
class DatasetExplorerBasicView():
    
    def __init__(self, x:TensorDicom3D, y:str=None, prediction:str=None, description: str=None, figsize=(5, 5), cmap:str='bone', show_hist=False, from_list=False):
        
        self.x = x
        
        self.slice_range = (1, len(x))
        self.cmap = cmap
        self.y = y
        self.prediction = prediction
        self.description = description
        self.figsize = figsize
        self.from_list = from_list
        self.show_hist = show_hist
        
        self._generate_views(figsize, show_hist)
        
        plt.style.use('default')
    
    def _plot_slice(self, im_slice):
        fig, ax = plt.subplots(1, 1, figsize=self.figsize)
        ax.imshow(self.x[im_slice-1, :, :], cmap=self.cmap)
        #plt.axis('off')
        ax.set_xticks([])
        ax.set_yticks([])
        plt.show()
        
    def _plot_hist(self, upper_bound):
        fig, ax = plt.subplots(figsize=self.figsize)
        plt.hist(self.x.numpy().flatten(), 50)
        plt.show()
        
    def _generate_views(self, figsize, show_hist):
        
        self.figsize = figsize
        
        items_in_box = []
        
        description_text = None
        if self.description:
            description_text = self.description
        elif self.from_list and self.show_hist:
            description_text = '3D Image'
            
        if description_text:
            description_label = widgets.Label(
                description_text, 
                layout=widgets.Layout(
                    width='100%', 
                    display='flex', 
                    justify_content="center"
                )
            )
            items_in_box.append(description_label)
            
            
        label_text = ''
        if self.y:
            label_text = 'Class: '+self.y
            if self.prediction:
                label_text = label_text + '  |  Prediction: ' + self.prediction    
        elif self.from_list and self.show_hist:
            label_text = ' '
            
        if label_text:
            y_label =  widgets.Label(
                label_text, 
                layout=widgets.Layout(
                    width='100%', 
                    display='flex', 
                    justify_content="center"
                )
            )
            items_in_box.append(y_label)
        
        slice_slider = widgets.IntSlider(
            min=min(self.slice_range), 
            max=max(self.slice_range), 
            step=1, 
            value=max(self.slice_range)//2, 
            description='', 
            continuous_update=True,
            readout = False,
            layout=widgets.Layout(width='99%', min_width='200px'),
            style={'description_width': 'initial'}
        )
        
        image_output = widgets.interactive_output(
                                self._plot_slice,
                                {'im_slice': slice_slider}
        )
        
        items_in_box.append(image_output)
        items_in_box.append(slice_slider)
        
        border_style_vbox = 'none' if show_hist else 'solid 1px lightgrey'
            
        self.vbox=widgets.VBox(
            items_in_box,
            layout=widgets.Layout(
                border = border_style_vbox, 
                margin = '10px 5px 0px 0px', 
                padding =  '5px'
            )
        )
        
        self.box = widgets.HBox(children=[self.vbox])
        
        if show_hist:
            hist_slider = widgets.IntSlider(
                min=-1000, 
                max=+1000, 
                step=1, 
                value=0, 
                description='Upper Bound', 
                continuous_update=True,
                readout = True,
                layout=widgets.Layout(width='100%', min_width='200px'),
                style={'description_width': 'initial'}
            )
            hist_slider2 = widgets.IntSlider(
                min=-1000, 
                max=+1000, 
                step=1, 
                value=0, 
                description='Lower Bound', 
                continuous_update=True,
                readout = True,
                layout=widgets.Layout(width='100%', min_width='200px'),
                style={'description_width': 'initial'}
            )
            hist_button = widgets.Button(
                description = 'Apply to Image',
            )
            hist_output = widgets.interactive_output(
                                self._plot_hist,
                                {'upper_bound': hist_slider}
            )
            title_label = widgets.Label(
                '3D Histogram', 
                layout=widgets.Layout(
                    width='100%', 
                    display='flex', 
                    justify_content="center"
                )
            )
            values_label = widgets.Label(
                'Min: ' + '{:5.1f}'.format(torch.min(self.x).item()) + ' | Max: ' + '{:5.1f}'.format(torch.max(self.x).item()) + ' | Mean: ' + '{:5.1f}'.format(torch.mean(self.x).item()), 
                layout=widgets.Layout(
                    width='100%', 
                    display='flex', 
                    justify_content="center"
                )
            )
            self.hist_box = widgets.VBox(
                [title_label, values_label, hist_output, hist_slider, hist_slider2, hist_button],
                layout=widgets.Layout(
                    margin = '10px 5px 5px 5px', 
                    padding =  '5px',
                    display='flex',
                    flex_flow='column',
                    align_items='center',
                )
            )
            self.box = widgets.HBox(
                [self.vbox, self.hist_box],
                layout=widgets.Layout(
                    border = 'solid 1px lightgrey', 
                    margin = '10px 5px 0px 0px', 
                    padding =  '5px', 
                    width = '70%'
                )
            )

    def show(self):
        display(self.box)

In [7]:
plot = DatasetExplorerBasicView(image, y='1', prediction='0', description='/media/images/t2.nii.gz')
plot.show()

HBox(children=(VBox(children=(Label(value='/media/images/t2.nii.gz', layout=Layout(display='flex', justify_con…

### With Histogram

In [8]:
plot = DatasetExplorerBasicView(image, y='1', prediction='0', description='/media/images/t2.nii.gz', show_hist=True)
plot.show()

HBox(children=(VBox(children=(Label(value='/media/images/t2.nii.gz', layout=Layout(display='flex', justify_con…

# List Viewer Class

In [9]:
class DatasetExplorerListViewer():
    
    def __init__(self, images:(list, DataLoaders), num:int=3, figsize=(5, 5), cmap:str='bone', show_hist=False):
        
        # images = list of TensorDicom3D objects or DataLoader of TensorDicom3D objects
        
        self.images = []
        
        if isinstance(images, DataLoaders):
            self.dls = images
            self.activate_batch()
            
        elif isinstance(images, list):
            self.images = images
        else:
            return
        
        self.cmap = cmap
        self.num = num
        self.show_hist = show_hist
        self.figsize = figsize
        
    def chunk(self, seq, chunksize, process=iter):
        """ Yields items from an iterator in iterable chunks."""
        it = iter(seq)
        while True:
            try:
                yield process(chain([next(it)], islice(it, chunksize - 1)))
            except StopIteration:
                return
            
    def activate_batch(self):
        self.images = []
        
        xb, yb = next(iter(self.dls)).one_batch()
        
        for t in xb:
            ima = t[0]
            if t.is_cuda:
                ima = ima.cpu()
            self.images.append(ima)

        if yb.is_cuda:
            yb = yb.cpu()
        self.yb = yb.numpy()
            
    def next_batch_clicked(self, button):
        self.activate_batch()
        button.disabled=True
        self.show()
        
    def show(self):
        all_boxes = []
        
        for index, image in enumerate(self.images):
            y = None
            if hasattr(self, 'dls'):
                y = str(self.yb[index])
            basic_view = DatasetExplorerBasicView(image, y=y, description=None, figsize=self.figsize, cmap='bone', show_hist=self.show_hist, from_list=True)
                        
            if self.show_hist:
                basic_view.vbox.layout = widgets.Layout(
                    border = 'none', 
                    margin = '5px', 
                    padding =  '5px',
                    width = '50%'
                )
            
                basic_view.hist_box.layout = widgets.Layout(
                    margin = '10px 5px 5px 5px', 
                    padding =  '5px',
                    display='flex',
                    flex_flow='column',
                    align_items='center',
                    width='50%'
                )
                all_boxes.append(basic_view.box)
            else:
                basic_view.vbox.layout = widgets.Layout(
                    border = 'solid 1px lightgray', 
                    margin = '5px', 
                    padding =  '5px',
                    width = '30%'
                )
                all_boxes.append(basic_view.vbox)
        
        if hasattr(self, 'dls'):
            next_batch_button = widgets.Button(
                description='Show Next Batch >',
                layout=widgets.Layout(
                    width='100%',
                    margin='40px 0px'
                )
            )
            next_batch_button.on_click(self.next_batch_clicked)
            all_boxes.append(next_batch_button)
        
        viewbox = widgets.HBox(
            children=all_boxes,
            layout=widgets.Layout(
                display = 'flex',
                flex_flow ='row wrap'
            )
        )
        
        display(viewbox)
    
    # deprecated
    def _show(self):
    
        all_boxes = []
        for image in self.images:
            basic_view = DatasetExplorerBasicView(image, figsize=(5, 5), cmap='bone')
            all_boxes.append(basic_view.vbox)
            
        vboxes = []
        
        for chunked_boxes in self.chunk(all_boxes, self.num, list):
            vboxes.append(widgets.HBox(children=chunked_boxes))
            
        display(widgets.VBox(children=vboxes))

In [10]:
DatasetExplorerListViewer(image_list).show()

HBox(children=(VBox(children=(Output(), IntSlider(value=12, layout=Layout(min_width='200px', width='99%'), max…

### With Histogram

In [11]:
DatasetExplorerListViewer(image_list, show_hist=True).show()

HBox(children=(HBox(children=(VBox(children=(Label(value='3D Image', layout=Layout(display='flex', justify_con…

### From DataLoaders

In [12]:
DatasetExplorerListViewer(dls).show()

HBox(children=(VBox(children=(Label(value='Class: 0', layout=Layout(display='flex', justify_content='center', …

HBox(children=(VBox(children=(Label(value='Class: 1', layout=Layout(display='flex', justify_content='center', …

### From DataLoaders With Histogram

In [13]:
DatasetExplorerListViewer(dls, show_hist=True).show()

HBox(children=(HBox(children=(VBox(children=(Label(value='3D Image', layout=Layout(display='flex', justify_con…

HBox(children=(HBox(children=(VBox(children=(Label(value='3D Image', layout=Layout(display='flex', justify_con…