# Magic commands and import section

In [None]:
try:  # for ipykernel
    %matplotlib widget
    %load_ext autoreload
    %autoreload 2
except Exception:  # magic commands don't work with xeus-python (which supports debugging)
    pass

In [None]:
from context import uncertify

import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from torchvision.utils import make_grid
import nibabel as nib
import numpy as np
import cv2

from typing import List

from uncertify.visualization.grid import imshow_grid
from uncertify.visualization.plotting import imshow

# ISBR v2.0

In [None]:
def show_center_views_nib_img(nib_img) -> None:
    if len(nib_img.shape) == 4:
        np_image = img.get_fdata()[:, :, :, 0]
    else:
        np_image = img.get_fdata()
    print(f'Data shape: {np_image.shape}')
    i, j, k = np.array(np_image.shape) // 2
    fig, axes = plt.subplots(1, 3, figsize=(8, 4))
    imshow(np.flipud(np_image[i, :, :]), ax=axes[0], add_colorbar=False, axis='on')
    imshow(np.flipud(np_image[:, j, :].T), ax=axes[1], add_colorbar=False, axis='on')
    axes[1].set_aspect(1.5)
    imshow(np.flipud(np_image[:, :, k].T), ax=axes[2], add_colorbar=False, axis='on')
    plt.tight_layout()
    plt.show()

img = nib.load('/mnt/2TB_internal_HD/datasets/raw/CANDI/SchizBull_2008_SS_segimgreg_V1.2/SchizBull_2008/SS/SS_084/SS_084_procimg.nii.gz')
show_center_views_nib_img(img)

In [None]:
def to_torch_tensor(nib_image, remove_last_dim: bool = False, permutation: tuple = None, scale: float = None, 
                    to_binary: bool = False, mask_img= None, rotate90_ax: tuple = None) -> None:
    """Converts a nib scan to a pytorch tensor with the appropriate processing steps."""
    print(f'Original data shape {nib_image.shape}')
    np_image = nib_image.get_fdata()
    if remove_last_dim:
        np_image = np_image[:, :, :, 0]
    if permutation is not None:
        np_image = np.transpose(np_image, permutation)
    if to_binary:
        np_image = np.array(np_image != 0, dtype=float)
    if rotate90_ax:
        np_image = np.rot90(np_image, k=1, axes=rotate90_ax)
    n_axial_views, height, width = np_image.shape
    print(f'{n_axial_views} axial views')
    if scale is not None:
        width = int(width / scale)
    axial_views = np.empty((n_axial_views, 1, height, width))
    for axial_idx in range(n_axial_views):
        axial_views[axial_idx, 0, :, :] = cv2.resize(np_image[axial_idx], (width, height))
    return torch.tensor(axial_views)
    
def visualize_tensor_as_grid(tensor: torch.Tensor, **kwargs) -> None:
    grid = make_grid(tensor, nrow=16)
    imshow_grid(grid, one_channel=True, figsize=(10, 10), axis='off', **kwargs)
    plt.show()


def visualize_scan_mask_grid(scan_tensor: torch.tensor, mask_tensor: torch.tensor, show_orig_and_mask: bool = False, cmap: str = None) -> None:
    if show_orig_and_mask:
        visualize_tensor_as_grid(scan_tensor, cmap=cmap, add_colorbar=False)
        visualize_tensor_as_grid(mask_tensor.type(torch.FloatTensor), cmap=cmap, add_colorbar=False)
    visualize_tensor_as_grid(scan_tensor * mask_tensor, cmap=cmap, add_colorbar=False)

    
def visualize_isbr_scan_mask_grid(dir_path: Path, sample_nr: str, cmap: str = 'hot') -> None:
    data_dir_path = dir_path / f'IBSR_{sample_nr}'
    scan_file_path = data_dir_path / f'IBSR_{sample_nr}_ana.nii.gz'
    mask_file_path = data_dir_path / f'IBSR_{sample_nr}_ana_brainmask.nii.gz'
    scan_img = nib.load(scan_file_path)
    mask_img = nib.load(mask_file_path)
    scan_tensor = to_torch_tensor(scan_img, remove_last_dim=True, permutation=(1, 0, 2), scale=1.5, rotate90_ax=(1, 2))
    mask_tensor = to_torch_tensor(mask_img, remove_last_dim=True, permutation=(2, 0, 1), scale=1.5, rotate90_ax=(1, 2)).type(torch.BoolTensor)
    visualize_scan_mask_grid(scan_tensor, mask_tensor)

def visualize_candi_scan_mask_grid(dir_path: Path, series: str, sample_nr: str, cmap: str = 'hot') -> None:
    data_dir_path = dir_path / f'SS_{sample_nr}'
    scan_file_path = data_dir_path / f'{series}_{sample_nr}_procimg.nii.gz'
    mask_file_path = data_dir_path / f'{series}_{sample_nr}.seg.nii.gz'
    scan_img = nib.load(scan_file_path)
    mask_img = nib.load(mask_file_path)
    scan_tensor = to_torch_tensor(scan_img, remove_last_dim=False, permutation=(1, 0, 2), scale=1.5, rotate90_ax=(1, 2))
    mask_tensor = to_torch_tensor(mask_img, remove_last_dim=False, permutation=(1, 0, 2), scale=1.5, rotate90_ax=(1, 2)).type(torch.BoolTensor)
    visualize_scan_mask_grid(scan_tensor, mask_tensor, show_orig_and_mask=True, cmap=cmap)
    
#visualize_isbr_scan_mask_grid(Path('/mnt/2TB_internal_HD/datasets/raw/IBSR/IBSR_V2.0_nifti_stripped/IBSR_nifti_stripped'), sample_nr='02')
visualize_candi_scan_mask_grid(Path('/mnt/2TB_internal_HD/datasets/raw/CANDI/SchizBull_2008_SS_segimgreg_V1.2/SchizBull_2008/SS'), series='SS', sample_nr='084')

"""
# BraTS
img = nib.load('/mnt/2TB_internal_HD/datasets/raw/BraTS17/training/HGG/Brats17_CBICA_ABN_1/Brats17_CBICA_ABN_1_t1.nii.gz')
brats_tensor = to_torch_tensor(img, remove_last_dim=False, permutation=(2, 0, 1), rotate90_ax=(2, 1))
visualize_tensor_as_grid(brats_tensor, cmap='hot', add_colorbar=False, vmin=0)

# CamCAN
img = nib.load('/mnt/2TB_internal_HD/datasets/raw/CamCAN/T1w/sub-CC110037_T1w_unbiased.nii.gz')
camcan_tensor = to_torch_tensor(img, remove_last_dim=False, permutation=(2, 0, 1), rotate90_ax=(2, 1))
visualize_tensor_as_grid(camcan_tensor, cmap='hot', add_colorbar=False, vmin=0)
"""

In [None]:
import numpy as np

In [None]:
zero_array = np.zeros((4, 4))

In [None]:
np.matrix()

# Using the mrivis package

In [None]:
from mrivis import SlicePicker
from mrivis import Collage
from mrivis.utils import scale_0to1

img = nib.load('/mnt/2TB_internal_HD/datasets/raw/IBSR/IBSR_V2.0_nifti_stripped/IBSR_nifti_stripped/IBSR_03/IBSR_03_ana.nii.gz')
#img = nib.load('/mnt/2TB_internal_HD/datasets/raw/BraTS17/training/HGG/Brats17_CBICA_ABN_1/Brats17_CBICA_ABN_1_t1.nii.gz')
np_img = img.get_fdata()[:, :, :, 0]  # last dimension (probably time) can be discarded
#np_img = img.get_fdata()
sp = SlicePicker(np_img, view_set=(1, ), num_slices=20)

# Plotting individual slices
for sl_data in sp.get_slices():
    np_img = np.flipud(sl_data)
    height, width = np_img.shape
    np_img = cv2.resize(np_img, (int(width / 1.5), height))
    imshow(np_img, add_colorbar=False, figsize=(4, 4))
    plt.show()

In [None]:
# Plotting a collage
collage = Collage()
collage.attach(scale_0to1(np_img))
plt.show(collage.fig) 