In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
from torchvision.utils import make_grid
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
WORK_DIR = Path('/scratch_net/samuylov/maheer/datasets/hist_matching')

In [None]:
def get_name(modality: str, is_ref: bool, is_unbiased: bool, idx: int = None) -> str:
    return f'{modality}{"_ref" if is_ref else ""}{"_unbiased" if is_ref else ""}{f"_{idx}" if idx is not None else ""}.nii.gz'

def transform_image(slices: np.ndarray) -> np.ndarray:
    """Do all manipulations to the raw numpy array like transposing, rotation etc."""
    slices = np.transpose(slices, axes=[2, 0, 1])  # makes dimensions to be [slice, width, height]?
    slices = np.rot90(slices, k=1, axes=(2, 1))  # rotates once in the (2, 1) plane, i.e. width-height-plane
    slices = slices[:, 27:227, 20:220]  # arbitrary numbers crop
    return slices

def create_masks(slices: np.ndarray) -> np.ndarray:
    """Get the masks for (already manipulated) sample slices."""
    mask = (slices != 0).astype('int')
    return mask

In [None]:
t1_img = nib.load(WORK_DIR / get_name('t1', is_ref=False, is_unbiased=False)).get_fdata()
t1_img_2 = nib.load(WORK_DIR / get_name('t1', is_ref=False, is_unbiased=False, idx=2)).get_fdata()
t1_ref = nib.load(WORK_DIR / get_name('t1', is_ref=True, is_unbiased=False)).get_fdata()

t1_mask = create_masks(t1_img)
t1_mask_2 = create_masks(t1_img_2)
t1_ref_mask = create_masks(t1_ref)

t1_img_unbiased = nib.load(WORK_DIR / get_name('t1', is_ref=False, is_unbiased=True)).get_fdata()
t1_img_unbiased_2 = nib.load(WORK_DIR / get_name('t1', is_ref=False, is_unbiased=True, idx=2)).get_fdata()
t1_ref_unbiased = nib.load(WORK_DIR / get_name('t1', is_ref=True, is_unbiased=True)).get_fdata()

In [None]:
plt.imshow(transform_image(t1_img_unbiased)[50], vmax=500)
plt.colorbar()
plt.title('T1 reference unbiased')
plt.show()

plt.imshow(transform_image(t1_mask)[50])
plt.show()

plt.imshow(transform_image(t1_ref_unbiased)[60], vmax=500)
plt.colorbar()
plt.title('T1 unbiased')
plt.show()

plt.imshow(transform_image(t1_ref_mask)[60])
plt.show()

In [None]:
from uncertify.data.preprocessing.histogram_matching.histogram_matching import MatchHistogramsTwoImages

def plot_nii_samples(slices_list, step, titles):
    assert len(slices_list) == len(titles)
    for slice_idx in range(20, len(slices_list[0]), 20):
        fig, axes = plt.subplots(ncols=len(slices_list), figsize=(18, 6))
        print(f'slice: {slice_idx}')
        for idx, (sample, title) in enumerate(zip(slices_list, titles)):
            ax = axes[idx]
            im = ax.imshow(sample[slice_idx], cmap='hot', vmax=500)
            ax.set_title(title)
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", size="5%", pad=0.05)

            plt.colorbar(im, cax=cax)
        fig.tight_layout()
        plt.show()


def run_histogram_matching(orig_img: np.ndarray, ref_img: np.ndarray, ref_mask: np.ndarray, orig_mask: np.ndarray) -> np.ndarray:
    matched_img = MatchHistogramsTwoImages(ref_img, orig_img, L=200, nbins=246, begval=0.05, finval=0.98,
                                           train_mask=ref_mask,
                                           test_mask=orig_mask)
    return matched_img

In [None]:
matched = run_histogram_matching(t1_img_unbiased, t1_ref_unbiased, t1_ref_mask, t1_mask)

In [None]:
matched2 = run_histogram_matching(t1_img_unbiased_2, t1_ref_unbiased, t1_ref_mask, t1_mask_2)

In [None]:
plot_nii_samples([transform_image(t1_ref), transform_image(t1_img), transform_image(matched), transform_image(t1_img_2), transform_image(matched2)], step=20, 
                 titles=['reference T1', 'original T1 1', 'matched T1 1',  'original T1 2', 'matched T1 2', ])

