# Helpers

In [None]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt
import scipy
import SimpleITK as sitk
from tqdm import tqdm

module_path = os.path.abspath(os.path.join('../..'))

if module_path not in sys.path:
    sys.path.append(module_path)


from multitask_method.plotting_utils import display_cross_section, display_normalised_cross_section

# Anomaly shape

In [None]:
from multitask_method.tasks.utils import nsa_sample_dimension
from multitask_method.tasks.task_shape import CombinedDeformedHypershapePatchMaker, EitherDeformedHypershapePatchMaker


shape_maker = EitherDeformedHypershapePatchMaker(nsa_sample_dimension)

target_shape = np.array([128, 128, 128])
lb = (0.06 * target_shape).round().astype(int)
ub = (0.80 * target_shape).round().astype(int)

mask = shape_maker(list(zip(lb, ub)), target_shape)

print(mask.shape)
display_cross_section(mask)

# Example Images

In [None]:
from multitask_method.utils import make_exp_config
from multitask_method.data.mood import MOODDatasetCoordinator

exp_config = make_exp_config('experiments/exp_HCP_low_res_1_train.py')

brain_coordinator = exp_config.curr_dset_coord
# brain_coordinator = MOODDatasetCoordinator(exp_config.mood_root, 'Abdomen', False, True)
brain_dset_container = brain_coordinator.make_container([0, 1])

img1, img1_mask, img1_id = brain_dset_container[0]
img2, img2_mask, _ = brain_dset_container[1]

display_normalised_cross_section(img1[0])
display_cross_section(img1_mask)
display_normalised_cross_section(img2[0])
display_cross_section(img2_mask)

In [None]:
print(np.unique(img1_mask))
_ = plt.hist(img1[0].flatten()[img1[0].flatten() != 0], bins=100)

In [None]:
from multitask_method.paths import base_data_input_dir
img1_new_all = np.load(base_data_input_dir / 'hcp' / 'lowres' / img1_id)[1]

_ = plt.hist(img1_new_all.flatten()[img1_new_all.flatten() != 0], bins=100)

In [None]:
display_cross_section(np.interp(img1[0], [img1[0].min(), img1[0].max()], [0, 1]))
display_cross_section(np.interp(img1_new_all, [img1_new_all.min(), img1_new_all.max()], [0, 1]))

# Tasks

## Patch blending tasks

In [None]:
from multitask_method.tasks.cutout_task import Cutout
from multitask_method.tasks.patch_blending_task import TestCutPastePatchBlender, TestPatchInterpolationBlender, \
    TestPoissonImageEditingMixedGradBlender, TestPoissonImageEditingSourceGradBlender

from multitask_method.tasks.labelling import FlippedGaussianLabeller

labeller = FlippedGaussianLabeller(0.2)


cutout_task = Cutout(labeller)
cutpaste_task = TestCutPastePatchBlender(labeller, img2, img2_mask)
patch_interp_task = TestPatchInterpolationBlender(labeller, img2, img2_mask)
poisson_image_editing_mixed_task = TestPoissonImageEditingMixedGradBlender(labeller, img2, img2_mask)
poisson_image_editing_source_task = TestPoissonImageEditingSourceGradBlender(labeller, img2, img2_mask)

all_test_tasks = [cutout_task, cutpaste_task, patch_interp_task, poisson_image_editing_source_task, poisson_image_editing_mixed_task]

In [None]:

fig, axes = plt.subplots(2 * len(all_test_tasks), 3, figsize=(18, 12 * len(all_test_tasks)))
ax_row_params = {'fontsize': 30, 'labelpad': 15}

for i, t in tqdm(enumerate(all_test_tasks)):
    aug_image, aug_image_label = t(img1, img1_mask)
    display_normalised_cross_section(aug_image[0], existing_fig_ax=(fig, axes[2 * i]))
    display_cross_section(aug_image_label, existing_fig_ax=(fig, axes[2 * i + 1]))
    axes[2 * i][0].set_ylabel(t.__class__.__name__, **ax_row_params)

## Deformation tasks

In [None]:
from multitask_method.tasks.deformation_task import SourceDeformationTask, SinkDeformationTask, FPISinkDeformationTask, IdentityDeformationTask

deformation_tasks = [SourceDeformationTask(None, None), SinkDeformationTask(None, None)]#, FPISinkDeformationTask(None, None)]

fig, axes = plt.subplots(2 * len(deformation_tasks), 3, figsize=(18, 12 * len(deformation_tasks)))

for i, t in tqdm(enumerate(deformation_tasks)):
    aug_image, aug_image_label = t(img1, img1_mask)
    display_normalised_cross_section(aug_image[0], existing_fig_ax=(fig, axes[2 * i]))
    display_cross_section(aug_image_label[0], existing_fig_ax=(fig, axes[2 * i + 1]))
    axes[2 * i][0].set_ylabel(t.__class__.__name__, **ax_row_params)

In [None]:
from scipy.ndimage import label

chess_grid = label(np.kron(np.tile(np.array([[0, 1], [1, 0]]), (5, 5)), np.ones((10, 10))))[0]

ax_title_params = {'fontsize': 30, 'pad': 15}

fig, axes = plt.subplots(2, len(deformation_tasks), figsize=(6 * len(deformation_tasks), 12))
for i, t in tqdm(enumerate(deformation_tasks)):
    aug_image, aug_image_label = t(chess_grid[None], np.ones_like(chess_grid))
    axes[0, i].imshow(aug_image[0])
    axes[1, i].imshow(aug_image_label[0])
    axes[0, i].set_title(t.__class__.__name__, **ax_title_params)

In [None]:
from scipy.ndimage import label

chess_grid = label(np.kron(np.tile(np.array([[0, 1], [1, 0]]), (5, 5)), np.ones((100, 100))))[0]

ax_title_params = {'fontsize': 30, 'pad': 15}

fig, axes = plt.subplots(2, len(deformation_tasks), figsize=(6 * len(deformation_tasks), 12))
for i, t in tqdm(enumerate(deformation_tasks)):
    aug_image, aug_image_label = t(chess_grid[None], np.ones_like(chess_grid))
    axes[0, i].imshow(aug_image[0])
    axes[1, i].imshow(aug_image_label[0])
    axes[0, i].set_title(t.__class__.__name__, **ax_title_params)

## Intensity tasks

In [None]:
from scipy.ndimage import distance_transform_edt

from multitask_method.tasks.task_shape import EitherDeformedHypershapePatchMaker

shape_maker = EitherDeformedHypershapePatchMaker()

mask = shape_maker.get_patch_mask([(20, 200), (20, 200)], np.array([256, 256]))
dist = distance_transform_edt(mask)

fig, ax = plt.subplots(ncols=2, figsize=(12, 6))

ax[0].imshow(mask)
ax[1].imshow(dist)
plt.show()

In [None]:

max_distances = []
mean_distances = []

for _ in tqdm(range(100)):
    d = distance_transform_edt(shape_maker.get_patch_mask([(15, 205), (15, 205), (15, 205)], np.array([256, 256, 256])))
    max_distances.append(np.max(d))
    mean_distances.append(np.mean(d[d > 0]))

plt.hist(max_distances, bins=50)
plt.show()
plt.hist(mean_distances, bins=50)
plt.show()

In [None]:
img_shape = np.array([256, 256])
mask = shape_maker.get_patch_mask([(20, 200), (20, 200)], img_shape)

dist = distance_transform_edt(mask)
min_shape_dim = np.min(img_shape)

smooth_dist = np.minimum(min_shape_dim * (0.02 + np.random.gamma(3, 0.01)), np.max(dist))
smooth_dist_map = dist / smooth_dist
smooth_dist_map[smooth_dist_map > 1] = 1

print('Smooth distance: ', smooth_dist)
print('Saturated area: ', np.sum(smooth_dist_map == 1) / np.sum(smooth_dist_map > 0))

fig, ax = plt.subplots(ncols=3, figsize=(18, 6))

ax[0].imshow(mask)
ax[1].imshow(dist)
ax[2].imshow(smooth_dist_map)
plt.show()

In [None]:
from multitask_method.tasks.intensity_tasks import SmoothIntensityChangeTask

intensity_task = SmoothIntensityChangeTask(None, 0.2)

intensity_aug_image, intensity_aug_image_label = intensity_task(img1, img1_mask)

diff = np.abs(intensity_aug_image - img1)
print(np.max(diff))

display_normalised_cross_section(img1[0])
display_normalised_cross_section(intensity_aug_image[0])
display_cross_section(intensity_aug_image_label[0])
display_cross_section(diff[0]) 

# Positional Encoding

In [None]:
from multitask_method.pos_encoding import PosEnc, ConvCoordEnc, GaussianRFFEnc

cc_enc = ConvCoordEnc(3)
fourier_enc = GaussianRFFEnc(3, 32, 12) 


def display_pos_enc_example(p_e: PosEnc):
    p_e_example = p_e((128, 128, 128))

    # scale to 0-1 so is visible

    p_e_min = p_e_example.min()
    p_e_max = p_e_example.max()
    print(p_e.__class__.__name__, f'range [{p_e_min}, {p_e_max}]')
    
    p_e_example = (p_e_example - p_e_min) / (p_e_max - p_e_min)
    
    for i, d in enumerate(p_e_example):
        if i >= 10:
            break
        display_cross_section(d)


In [None]:
display_pos_enc_example(cc_enc)

In [None]:
display_pos_enc_example(fourier_enc)

# Labelling

## Smoothness investigation

In [None]:
from multitask_method.tasks.patch_blending_task import TestPoissonImageEditingMixedGradBlender
from multitask_method.tasks.deformation_task import BendSourceDeformationTask
from multitask_method.tasks.labelling import FlippedGaussianLabeller

#main_task = TestPoissonImageEditingMixedGradBlender(None, img2, img2_mask)
main_task = BendSourceDeformationTask(None, min_push_dist=0, max_push_dist=5)
curr_labeller = FlippedGaussianLabeller(0.2)

In [None]:
from scipy import ndimage
from skimage.morphology import reconstruction
aug_image, aug_image_binary_label = main_task(img1, img1_mask)
direct_label = curr_labeller.label_fn(np.mean(aug_image_binary_label * np.abs(aug_image - img1), axis=0))

neighbour_footprint = ndimage.generate_binary_structure(3, 1)

extended_neighbour_footprint = ndimage.iterate_structure(neighbour_footprint, 2)

recon_seed_img = np.copy(direct_label)
recon_seed_img[3 * (slice(1, -1),)] = direct_label.max()

base_images = [
    ('Image', aug_image[0]),
    ('Binary label', aug_image_binary_label[0]),
    ('Direct label', direct_label),
    ('Direct Morph recon', reconstruction(recon_seed_img, direct_label, method='erosion', footprint=neighbour_footprint)),
    ('Closed Morph recon', reconstruction(recon_seed_img, ndimage.grey_closing(direct_label, footprint=neighbour_footprint), method='erosion', footprint=neighbour_footprint))
]

def make_sphere_mask(r):
    n_dims = 3
    diam_range = np.arange(-r, r+1)
    mg = np.meshgrid(*([diam_range] * n_dims))
    return np.sum([D ** 2 for D in mg], axis=0) <= r ** 2

closing_shapes = [
    ('Direct neigbours', neighbour_footprint),
    ('1.5 radius', ndimage.generate_binary_structure(3, 2)),
    ('3 kernel', np.ones((3, 3, 3), dtype=bool)),
    ('2 manhattan', extended_neighbour_footprint),
    ('3 manhattan', ndimage.binary_dilation(extended_neighbour_footprint)),
    ('2 radius', make_sphere_mask(2)),
    ('5 Kernel', np.ones((5, 5, 5), dtype=bool))
]

num_base_rows = len(base_images)
num_rows = num_base_rows + len(closing_shapes)

fig, axes = plt.subplots(num_rows, 3, figsize=(18, 6 * num_rows))
ax_row_params = {'fontsize': 30, 'labelpad': 15}


for a, (img_name, img_to_show) in zip(axes[:num_base_rows], base_images):
    
    if img_to_show.dtype == float:
        display_normalised_cross_section(img_to_show, existing_fig_ax=(fig, a))
    else:
        display_cross_section(img_to_show, existing_fig_ax=(fig, a))
        
    a[0].set_ylabel(img_name, **ax_row_params)

for a, (c_name, c) in zip(axes[num_base_rows:], closing_shapes):
    
    display_cross_section(ndimage.grey_closing(direct_label, footprint=c), existing_fig_ax=(fig, a))
    a[0].set_ylabel(c_name, **ax_row_params)

In [None]:
np.array_equal(base_images[2][1], base_images[3][1])

## Function investigation

In [None]:
from multitask_method.tasks.deformation_task import SourceDeformationTask, SinkDeformationTask
from multitask_method.tasks.labelling import SaturatingLabeller, FlippedGaussianLabeller

from multitask_method.tasks.patch_blending_task import TestPoissonImageEditingMixedGradBlender
from multitask_method.tasks.intensity_tasks import SmoothIntensityChangeTask


all_tasks = [TestPoissonImageEditingMixedGradBlender(None, img2, img2_mask),
             SmoothIntensityChangeTask(None, 0.2),
             SourceDeformationTask(None),
             SinkDeformationTask(None)]

# OLD, never calculated for VinDr
LABEL_CONFIGS = {
    'brain': 133.75381550463723,
}

curr_labeller = SaturatingLabeller(LABEL_CONFIGS['brain'], 0.1)
new_labeller = FlippedGaussianLabeller(0.037)

In [None]:
from scipy.ndimage import center_of_mass

TASK_INDEX = 3

aug_img, aug_img_mask  = all_tasks[TASK_INDEX](img1, img1_mask)

aug_img_mask = aug_img_mask[0]

curr_label = curr_labeller(aug_img, img1, aug_img_mask)
new_label = new_labeller(aug_img, img1, aug_img_mask)

z, y, x = np.array(center_of_mass(aug_img_mask)).astype(int)

In [None]:
from matplotlib.animation import FuncAnimation
from IPython import display


top_frames = [img1[0], aug_img[0], curr_label]
bot_frames = [img1[0], aug_img[0], new_label]

fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(12, 9))

ax_imgs = [[sub_a[0].imshow(fs[0][z,:,:], cmap='gray', vmin=fs[0].min(), vmax=fs[0].max()),
                       sub_a[1].imshow(fs[0][:,y,:], origin='lower', cmap='gray', vmin=fs[0].min(), vmax=fs[0].max()),
                       sub_a[2].imshow(fs[0][:,:,z], origin='lower', cmap='gray', vmin=fs[0].min(), vmax=fs[0].max())]
                     for sub_a, fs in zip(ax, [top_frames, bot_frames])]


fig.tight_layout()

for sub_a in ax:
    for a in sub_a:
        a.set_xticks([])
        a.set_yticks([])

def animate(frame_num):
    for sub_ax_img, fs in zip(ax_imgs, [top_frames, bot_frames]):
        curr_frame = fs[frame_num % len(top_frames)]
        sub_ax_img[0].set_data(curr_frame[z, :, :])
        sub_ax_img[1].set_data(curr_frame[:, y, :])
        sub_ax_img[2].set_data(curr_frame[:, :, x])
    return ax_imgs

anim = FuncAnimation(fig, animate, frames=50, interval=600)

video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()



# Random

In [None]:
from multitask_method.utils import make_exp_config
exp_config = make_exp_config('experiments/exp_HCP_debug.py')

hcp_coordinator = exp_config.curr_dset_coord
hcp_dset_container = hcp_coordinator.make_container([0, 1])

In [None]:
img1, img1_mask, _ = hcp_dset_container[0]

img1.shape