In [1]:
import glob
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import ipywidgets as ipw

files = glob.glob('data/NIfTI/NIIori/Nor/*.nii.gz') + glob.glob('data/NIfTI/NIIori/PD/*.nii.gz')

In [2]:
class VolumeViewer:
    def __init__(self, volume):
        self.volume = volume
        self.axis = 0
        print(f'Volume shape: {self.volume.shape}')
        axis_selector = ipw.ToggleButtons(
            options=[i for i in range(3)],
            description='Axis to view:'
        )
        ipw.interact(self.axis_setter, axis=axis_selector)

    def show_slices(self, idx):
        print(f'Args: {idx}')
        slc = np.take(self.volume, idx, axis=self.axis)
        plt.imshow(slc, cmap='gray')
        plt.axis('off')
    
    def axis_setter(self, axis):
        self.axis = axis
        slice_selector = ipw.IntSlider(
            min=0, max=self.volume.shape[axis]-1,
            continuous_update=False,
            description='Slice index:'
        )
        ipw.interact(self.show_slices, idx=slice_selector)

In [None]:
def view_nifti_file(fname):
    volume = nib.load(fname).get_fdata()
    VolumeViewer(volume)

ipw.interact(view_nifti_file, fname=files)

In [None]:
## Display slices in one view (gridview)

def display_volume_slices(volume, axis=0, cols=10, fig_width=10):
    n_slices = volume.shape[axis]
    rows = np.ceil(n_slices / cols).astype(int)
    fig = plt.figure(figsize=(rows * fig_width, fig_width))

    for idx in range(n_slices):
        slc = np.take(volume, idx, axis=axis)
        ax = fig.add_subplot(rows, cols, idx+1)
        ax.imshow(slc, cmap='gray')
        ax.axis('off')
        ax.set_aspect('equal')
    
    fig.subplots_adjust(wspace=0, hspace=0)

volume = nib.load('data/NIfTI_std/NIIori_std/Nor/Nor108_std.nii.gz').get_fdata()
display_volume_slices(volume, axis=2)

In [None]:
## NIfTI file viewer widget
%matplotlib widget

import ipywidgets as ipw
import matplotlib.pyplot as plt
import nibabel as nib

class NiftiViewer(ipw.VBox):
    def __init__(self, fname, figsize=(4, 4)):
        self.nii = nib.load(fname)
        self.vol = self.nii.get_fdata()
        
        axcodes= nib.aff2axcodes(self.nii.affine)
        self.ax_selector = ipw.ToggleButtons(
            options=[(axcode, i) for i, axcode in enumerate(axcodes)],
            description='Orientation: '
        )
        self.slc_selector = ipw.IntSlider(
            min=0, max=self.vol.shape[self.ax_selector.value]-1,
            description='Slice index:'
        )
        self.ax_selector.observe(self.set_axis, 'value')
        self.slc_selector.observe(self.set_slice, 'value')

        output = ipw.Output()
        with output:
            print(f'{fname} | {self.vol.shape}')
            self.fig, self.ax = plt.subplots(figsize=figsize)
        self.fig.canvas.header_visible = False
        self.fig.canvas.footer_visible = False
        self.fig.canvas.toolbar_visible = False

        super().__init__([self.ax_selector, self.slc_selector, output])
        self.set_axis({'new': 0})

    def set_axis(self, change):
        self.slc_selector.max = self.vol.shape[change['new']] - 1
        self.slc_selector.value = self.vol.shape[change['new']] // 2
        self.set_slice({'new': self.vol.shape[change['new']] // 2})
        
    def set_slice(self, change):
        self.ax.axis('off')
        self.ax.imshow(
            self.vol.take(indices=change['new'], axis=self.ax_selector.value),
            cmap='gray'
        )

NiftiViewer('data/NIfTI_std/NIIori_std/Nor/Nor108_std.nii.gz')

In [None]:
## Check orientation and dimension of nifti files
files = glob.glob('data/NIfTI_std/NIIori_std/Nor/*')
for f in files:
    img = nib.load(f)
    print(f'{f.split("/")[-1]}: {img.shape} {nib.aff2axcodes(img.affine)}')

In [None]:
## Using animation to show slices
from matplotlib import animation, rc
rc('animation', html='jshtml')

def vol_viewer(fname, figsize=(6, 6)):
    fig = plt.figure(figsize=figsize)
    nii = nib.load(fname)
    vol = nii.get_fdata()
    plt_img = plt.imshow(vol[0], cmap='gray')
    
    def slice_viewer(idx):
        plt_img.set_array(vol[idx])
        return [plt_img]
    
    return animation.FuncAnimation(fig, slice_viewer, frames=len(vol), interval=1000//24)

vol_viewer('data/NIfTI_std/NIIori_std/Nor/Nor108_std.nii.gz')

In [11]:
nii1 = nib.load('data/NIfTI/NIIori/Nor/Nor108.nii.gz')
nii2 = nib.load('data/NIfTI_std/NIIori_std/Nor/Nor108_std.nii.gz')
print(f'nii1: {nib.aff2axcodes(nii1.affine)} | nii2: {nib.aff2axcodes(nii2.affine)}')
VolumeViewer(nii1.get_fdata())
print('===================================================')
VolumeViewer(nii2.get_fdata())

## Check if nii2 same as nii1 after reorientation
end_ornt = nib.orientations.axcodes2ornt(('P', 'S', 'R'))
start_ornt = nib.orientations.axcodes2ornt(nib.aff2axcodes(nii2.affine))
transform_ornt = nib.orientations.ornt_transform(start_ornt, end_ornt)
print(nii1.get_fdata().shape, nib.apply_orientation(nii2.get_fdata(), transform_ornt).shape)
np.array_equal(
    nii1.get_fdata(),
    nib.apply_orientation(nii2.get_fdata(), transform_ornt)
)

nii1: ('P', 'S', 'R') | nii2: ('R', 'A', 'S')
Volume shape: (256, 256, 178)


interactive(children=(ToggleButtons(description='Axis to view:', options=(0, 1, 2), value=0), Output()), _dom_…

Volume shape: (178, 256, 256)


interactive(children=(ToggleButtons(description='Axis to view:', options=(0, 1, 2), value=0), Output()), _dom_…

(256, 256, 178) (256, 256, 178)


True

In [15]:
from tqdm import tqdm

fnames = glob.glob('data/NIfTI/*/*/*.nii.gz')
end_ornt = nib.orientations.axcodes2ornt(('P', 'S', 'R'))

for fname in tqdm(fnames):
    tkns = fname[:-7].split('/')
    for i in [1, 2, -1]:
        tkns[i] += '_std'
    fname_std = f'{"/".join(tkns)}.nii.gz'
    
    nii = nib.load(fname)
    nii_std = nib.load(fname_std)
    tornt = nib.orientations.ornt_transform(
        nib.io_orientation(nii.affine), # or use this nib.orientations.axcodes2ornt(nib.aff2axcodes(nii.affine)),
        end_ornt
    )
    tornt_std = nib.orientations.ornt_transform(
        nib.io_orientation(nii_std.affine), # or use this nib.orientations.axcodes2ornt(nib.aff2axcodes(nii_std.affine)),
        end_ornt
    )
    nii_img = nib.apply_orientation(nii.get_fdata(), tornt)
    nii_img_std = nib.apply_orientation(nii_std.get_fdata(), tornt_std)
    
    if not np.allclose(nii_img, nii_img_std): # array_equal is too strict!
        print(
            f'{fname} ({nii_img.shape}) | {fname_std} ({nii_img_std.shape}) : result in different orientation.'
        )
        print('------------------------------------')

 27%|██▋       | 477/1740 [02:41<46:33,  2.21s/it]

data/NIfTI/NIIorisk/Nor/Nor105sk.nii.gz ((512, 512, 256)) | data/NIfTI_std/NIIorisk_std/Nor/Nor105sk_std.nii.gz ((512, 512, 256)) : result in different orientation.
------------------------------------


 76%|███████▌  | 1320/1740 [21:42<12:20,  1.76s/it]

data/NIfTI/NIIoriss/Nor/Nor105ss.nii.gz ((512, 512, 256)) | data/NIfTI_std/NIIoriss_std/Nor/Nor105ss_std.nii.gz ((512, 512, 256)) : result in different orientation.
------------------------------------


100%|██████████| 1740/1740 [31:25<00:00,  1.08s/it]


## 3D Grad-CAM

In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.transform import resize
from tqdm import tqdm
import os
import cv2
import io

from utils import Params
from util.gradcam import gradcam

def plot_3d_gradcam_grid(fname, model, layer_name, labels=['Nor', 'PD']):
    ### Load and resize nii image
    volume_size = model.input_shape[1:-1]
    img = nib.load(fname)
    input_img = resize(img.get_fdata(), volume_size, order=1, mode='constant', anti_aliasing=True)

    ### Predict and receive gradcam heatmap
    pred = model(np.expand_dims(input_img, axis=(0, -1)))[0]
    cls_idx = np.argmax(pred)
    heatmap = gradcam(np.expand_dims(input_img, axis=(0, -1)), model, layer_name)
    heatmap = resize(heatmap, volume_size, order=1, mode='constant', anti_aliasing=True)

    ### Normalize both input_img and heatmap, then convert into uint8
    input_img = (input_img - np.min(input_img)) / (np.max(input_img) - np.min(input_img))
    input_img = np.uint8(255 * input_img)
    heatmap = np.uint8(255 * heatmap)

    ### Reorientate input_img and heatmap
    transform = nib.orientations.ornt_transform(
        nib.io_orientation(img.affine),
        nib.orientations.axcodes2ornt('PSR')
    )
    input_img = nib.apply_orientation(input_img, transform)
    heatmap = nib.apply_orientation(heatmap, transform)
    
    ### Visualize in a grid
    ncols = 10
    nrows = np.ceil(input_img.shape[1] / ncols).astype(int)
    fig, axarr = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows))
    for i, ax in enumerate(axarr.ravel()):
        if i < input_img.shape[1]:
            slc = input_img.take(i, axis=1)
            slc = np.repeat(np.expand_dims(slc, -1), 3, axis=-1)

            buf = io.BytesIO()
            plt.imsave(buf, heatmap.take(i, axis=1), format='jpg', cmap='jet')
            buf.seek(0)
            hmap = plt.imread(buf, format='jpg')

            ax.imshow(cv2.addWeighted(slc, 0.6, hmap, 0.4, 0))
            ax.set_aspect('auto')
        ax.axis('off')
    fig.suptitle(f'[{layer_name}] {fname.split("/")[-1]} | {labels[cls_idx]} ({pred[cls_idx] * 100:.2f} %)', fontsize=80)
    
    return fig

exp_types = ['ori', 'oriss', 'orisk', 'reg']
for exp_type in exp_types:
    exp_dir = f'thesis_exp/3dcnn/{exp_type}'
    gradcam_dir = f'{exp_dir}/gradcam_result'
    os.makedirs(gradcam_dir)
    print(f'Creating 3D Grad-CAM for {exp_dir}...')
    
    params = Params(f'{exp_dir}/train_params.json')
    df = pd.read_csv(params.data)
    X, y = df.Fpath.values, df.Label.values
    X_tr, X_ts, y_tr, y_ts = train_test_split(X, y, stratify=y, train_size=0.85, random_state=params.seed)
    model = tf.keras.models.load_model(f'{exp_dir}/model_best.h5')
    for fname in tqdm(X_ts):
        fig = plot_3d_gradcam_grid(fname, model, 'conv3d_3')
        fig.savefig(f'{gradcam_dir}/{fname.split("/")[-1].replace(".nii.gz", "")}.jpg')
        plt.close()

# 3D Attention Map

In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from skimage.transform import resize
from tqdm import tqdm
import os
import cv2
import io

from models import find_model
from utils import Params
from util.vis import attention_map

def plot_attn_map_grid(fname, model, labels=['Nor', 'PD']):
    ### Load and resize nii image
    volume_size = model.input_shape[1:-1]
    img = nib.load(fname)
    input_img = resize(img.get_fdata(), volume_size, order=1, mode='constant', anti_aliasing=True)
    
    ### Predict and receive attention map mask
    pred = model(np.expand_dims(input_img, axis=(0, -1)))[0]
    cls_idx = np.argmax(pred)
    mask = attention_map(np.expand_dims(input_img, axis=(0, -1)), model)
    mask = resize(mask, volume_size, order=1, mode='constant', anti_aliasing=True)

    ### Normalize both input_img and mask, then convert into uint8
    input_img = (input_img - np.min(input_img)) / (np.max(input_img) - np.min(input_img))
    input_img = np.uint8(255 * input_img)
    mask = np.uint8(255 * mask)

    ### Reorientate input_img and mask
    transform = nib.orientations.ornt_transform(
        nib.io_orientation(img.affine),
        nib.orientations.axcodes2ornt('PSR')
    )
    input_img = nib.apply_orientation(input_img, transform)
    mask = nib.apply_orientation(mask, transform)
    
    ### Visualize in a grid
    ncols = 10
    nrows = np.ceil(input_img.shape[1] / ncols).astype(int)
    fig, axarr = plt.subplots(nrows, ncols, figsize=(5*ncols, 5*nrows))
    for i, ax in enumerate(axarr.ravel()):
        if i < input_img.shape[1]:
            slc = input_img.take(i, axis=1)
            slc = np.repeat(np.expand_dims(slc, -1), 3, axis=-1)

            buf = io.BytesIO()
            plt.imsave(buf, mask.take(i, axis=1), format='jpg', cmap='jet')
            buf.seek(0)
            hmap = plt.imread(buf, format='jpg')

            ax.imshow(cv2.addWeighted(slc, 0.6, hmap, 0.4, 0))
            ax.set_aspect('auto')
        ax.axis('off')
    fig.suptitle(f'{fname.split("/")[-1]} | {labels[cls_idx]} ({pred[cls_idx] * 100:.2f} %)', fontsize=80)

    return fig

exp_types = ['ori', 'oriss', 'orisk', 'reg']
for exp_type in exp_types:
    exp_dir = f'thesis_exp/3dvit/{exp_type}'
    attn_map_dir = f'{exp_dir}/attn_maps'
    os.makedirs(attn_map_dir)
    print(f'Creating attention map for {exp_dir}...')

    params = Params(f'{exp_dir}/train_params.json')
    df = pd.read_csv(params.data)
    X, y = df.Fpath.values, df.Label.values
    X_tr, X_ts, y_tr, y_ts = train_test_split(X, y, stratify=y, train_size=0.85, random_state=params.seed)
    model_params = {'classes': 2}
    if hasattr(params, 'model_params'):
        model_params.update(params.model_params)
    model = find_model(params.model)(input_shape=(*params.volume_size, 1), **model_params)
    model.load_weights(f'{exp_dir}/model_best.h5')
    for fname in tqdm(X_ts):
        fig = plot_attn_map_grid(fname, model)
        fig.savefig(f'{attn_map_dir}/{fname.split("/")[-1].replace(".nii.gz", "")}.jpg')
        plt.close()