# Helpers

In [None]:
import os
from pathlib import Path
import sys


import numpy as np
import matplotlib.pyplot as plt
import scipy
import SimpleITK as sitk
from tqdm import tqdm

module_path = os.path.abspath(os.path.join('../..'))

if module_path not in sys.path:
    sys.path.append(module_path)

    
from multitask_method.plotting_utils import display_cross_section, display_normalised_cross_section

# MOOD

In [None]:
from multitask_method.utils import make_exp_config

exp_config = make_exp_config('experiments/exp_BRAIN_debug.py')

brain_coordinator = exp_config.curr_dset_coord
brain_dset_container = brain_coordinator.make_container([0, 1])

img1, img1_mask, _ = brain_dset_container[0]
img2, img2_mask, _ = brain_dset_container[1]
                                   
display_cross_section(img1[0])
display_cross_section(img1_mask)
display_cross_section(img2[0])
display_cross_section(img2_mask)

# HCP

## Raw

In [None]:
from multitask_method.preprocessing.brain_preproc import hcp_samples

In [None]:
from scipy.ndimage import find_objects
hcp_all_sizes = []
hcp_all_spacings = []

hcp_outer_box = [(260, 0), (311, 0), (260, 0)]
hcp_all_boxes = []

for hcp_sample_id, (t1_file, t2_file, wm_seg_file) in tqdm(hcp_samples):
    
    # t1_img = sitk.ReadImage(t1_file)
    t2_img = sitk.ReadImage(t2_file)
    
    # t1_img_sz = t1_img.GetSize()
    t2_img_sz = t2_img.GetSize()
    # assert t1_img_sz == t2_img_sz, f'Size missmatch {t1_img_sz} - {t2_img_sz}'
    
    # if t1_img_sz not in all_sizes:
    #     hcp_all_sizes.append(t1_img_sz)
    #     print(t1_img_sz)
        
    
    # t1_img_sp = t1_img.GetSpacing()
    t2_img_sp = t2_img.GetSpacing()
    # assert t1_img_sp == t2_img_sp, f'Spacing missmatch {t1_img_sp} - {t2_img_sp}'
    
    # if t1_img_sp not in all_spacings:
    #     hcp_all_spacings.append(t1_img_sp)
        
    # t1_arr = sitk.GetArrayFromImage(t1_img)
    t2_arr = sitk.GetArrayFromImage(t2_img[:, ::-1, :])
    
    # t1_box = [(s.start, s.stop) for s in find_objects((t1_arr > 0).astype(int))[0]]
    t2_box = [(s.start, s.stop) for s in find_objects((t2_arr > 0).astype(int))[0]]
    hcp_all_boxes.append(t2_box)
    
    # assert t1_box == t2_box
    hcp_outer_box = [(min(c_min, lb), max(c_max, ub)) for (c_min, c_max), (lb, ub) in zip(hcp_outer_box, t2_box)]
    
print(hcp_all_sizes)
print(hcp_all_spacings)
hcp_outer_box

In [None]:
from functools import reduce

reduce(lambda min_vals, new_box: [max(c_min, (b_ub - b_lb) * 0.7) for c_min, (b_lb, b_ub) in zip(min_vals, new_box)], hcp_all_boxes, [0, 0, 0] )

In [None]:
hcp_example_sample_id, (hcp_example_sample_t1_file, hcp_example_sample_t2_file, hcp_example_sample_wm_seg_file)  = sorted(hcp_samples)[0]
hcp_t1 = sitk.GetArrayFromImage(sitk.ReadImage(hcp_example_sample_t1_file))
hcp_t2 = sitk.GetArrayFromImage(sitk.ReadImage(hcp_example_sample_t2_file))
hcp_seg = sitk.GetArrayFromImage(sitk.ReadImage(hcp_example_sample_wm_seg_file))

display_cross_section(hcp_t1 / hcp_t1.max())
display_cross_section(hcp_t2 / hcp_t2.max())
display_cross_section(hcp_seg / hcp_seg.max())

In [None]:
from multitask_method.preprocessing.brain_preproc import z_norm

tmp_mask = hcp_t2 > 0

hcp_t2_norm = z_norm(hcp_t2, tmp_mask)

pp_mask = scipy.ndimage.binary_fill_holes(hcp_t2_norm != 0)

np.sum(pp_mask != (hcp_t2_norm != 0))

In [None]:
_ = plt.hist(hcp_t1.flatten()[hcp_t1.flatten() > 0], bins=100)

## Preprocessed

In [None]:
hcp_preprocessed_root = Path('/vol/biomedic3/mb4617/multitask_method_data/hcp/')
hcp_fullres_files = sorted(list((hcp_preprocessed_root / 'fullres').iterdir()))
hcp_lowres_files = sorted(list((hcp_preprocessed_root / 'lowres').iterdir()))

In [None]:
hcp_pp_example_file = hcp_fullres_files[0]
hcp_pp_example = np.load(hcp_pp_example_file)

print(hcp_pp_example.shape)

display_normalised_cross_section(hcp_pp_example[0])
display_normalised_cross_section(hcp_pp_example[1])
display_cross_section(hcp_pp_example[1] != 0)
display_cross_section(hcp_pp_example[2] / hcp_pp_example[2].max())

In [None]:
_ = plt.hist(hcp_pp_example[0].flatten()[hcp_pp_example[0].flatten() != 0], bins=100)

In [None]:
hcp_pp_example_file = hcp_lowres_files[0]
hcp_pp_example = np.load(hcp_pp_example_file)
print(hcp_pp_example.shape)

display_normalised_cross_section(hcp_pp_example[0])
display_normalised_cross_section(hcp_pp_example[1])
display_cross_section(hcp_pp_example[1] != 0)
display_cross_section(hcp_pp_example[2] / hcp_pp_example[2].max())

In [None]:
_ = plt.hist(hcp_pp_example[0].flatten()[hcp_pp_example[0].flatten() != 0], bins=100)

In [None]:
t = np.array([[0, 1, 3, -1], [3, 5, -2, 2]])
np.clip(t, t.min(axis=0, keepdims=True), t.max(axis=0, keepdims=True))

In [None]:
t.max(axis=tuple(range(1, len(t.shape))), keepdims=True)

# BRATS 2017

## Raw

In [None]:
from multitask_method.preprocessing.brain_preproc import brats_samples

In [None]:
from scipy.ndimage import find_objects

brats_all_sizes = []
brats_all_spacings = []
brats_all_boxes = []

brats_outer_box = [(np.inf, 0), (np.inf, 0), (np.inf, 0)]

for brats_sample_id, (t1_file, t2_file, seg_file) in tqdm(brats_samples):    
    
    t1_img = sitk.ReadImage(t1_file)
    t2_img = sitk.ReadImage(t2_file)
    
    t1_img_sz = t1_img.GetSize()
    t2_img_sz = t2_img.GetSize()
    assert t1_img_sz == t2_img_sz, f'Size missmatch {t1_img_sz} - {t2_img_sz}'
    
    if t2_img_sz not in brats_all_sizes:
        brats_all_sizes.append(t2_img_sz)
        print(t2_img_sz)
        
    
    t1_img_sp = t1_img.GetSpacing()
    t2_img_sp = t2_img.GetSpacing()
    assert t1_img_sp == t2_img_sp, f'Spacing missmatch {t1_img_sp} - {t2_img_sp}'
    
    if t2_img_sp not in brats_all_spacings:
        brats_all_spacings.append(t2_img_sp)
        
    t1_arr = sitk.GetArrayFromImage(t1_img)
    t2_arr = sitk.GetArrayFromImage(t2_img)
    
    t1_box = [(s.start, s.stop) for s in find_objects((t1_arr > 0).astype(int))[0]]
    t2_box = [(s.start, s.stop) for s in find_objects((t2_arr > 0).astype(int))[0]]
    
    brats_sample_outer_box = [(min(t1_min, t2_min), max(t1_max, t2_max)) for (t1_min, t1_max), (t2_min, t2_max) in zip(t1_box, t2_box)]
    brats_all_boxes.append(brats_sample_outer_box)
    
    brats_outer_box = [(min(c_min, lb), max(c_max, ub)) for (c_min, c_max), (lb, ub) in zip(brats_outer_box, brats_sample_outer_box)]
    
print(brats_all_sizes)
print(brats_all_spacings)
brats_outer_box

In [None]:
from functools import reduce

brats_max_brain_sz = reduce(lambda min_vals, new_box: [max(c_min, b_ub - b_lb) for c_min, (b_lb, b_ub) in zip(min_vals, new_box)], brats_all_boxes, [0, 0, 0] )

brats_max_brain_sz

In [None]:
brats_example_id, (brats_example_t1_file, brats_example_t2_file, brats_example_seg_file) = sorted(brats_samples)[0]
brats_t1 = sitk.GetArrayFromImage(sitk.ReadImage(brats_example_t1_file)[:, ::-1])
brats_t2 = sitk.GetArrayFromImage(sitk.ReadImage(brats_example_t2_file)[:, ::-1])
brats_seg = sitk.GetArrayFromImage(sitk.ReadImage(brats_example_seg_file)[:, ::-1])

print(brats_t1.min(), brats_t1.max())
print(np.unique(brats_seg))

display_cross_section(brats_t1 / brats_t1.max())
display_cross_section(brats_t2 / brats_t2.max())
display_cross_section(brats_seg)

In [None]:
from multitask_method.preprocessing.brain_preproc import load_and_crop

full_res, low_res = load_and_crop(*sorted(brats_samples)[0][1], 'brats17')


print(full_res[0][full_res[0] != 0].mean(), full_res[0][full_res[0] != 0].std())
print(low_res[0][low_res[0] != 0].mean(), low_res[0][low_res[0] != 0].std())

print(np.all(np.isin(low_res[2], full_res[2])))

display_normalised_cross_section(full_res[0])
display_normalised_cross_section(full_res[1])
display_cross_section(low_res[2] / low_res[2].max())

In [None]:
_ = plt.hist(brats_t1.flatten()[brats_t1.flatten() > 0], bins=100)

## Preprocessed

In [None]:
brats_preprocessed_root = Path('/vol/biomedic3/mb4617/multitask_method_data/brats17/')
brats_fullres_files = sorted(list((brats_preprocessed_root / 'fullres').iterdir()))
brats_lowres_files = sorted(list((brats_preprocessed_root / 'lowres').iterdir()))

In [None]:
brats_pp_example_file = brats_fullres_files[0]
brats_pp_example = np.load(brats_pp_example_file)

print(brats_pp_example.shape)
print('Labels: ', np.unique(brats_pp_example[2]))
print(brats_pp_example[0].max())

display_normalised_cross_section(brats_pp_example[0] / brats_pp_example[0].max())
display_normalised_cross_section(brats_pp_example[1] / brats_pp_example[1].max())
display_cross_section(brats_pp_example[2] / brats_pp_example[2].max())

In [None]:
brats_pp_example_file = brats_lowres_files[0]
brats_pp_example = np.load(brats_pp_example_file)

print(brats_pp_example.shape)
print('Labels: ', np.unique(brats_pp_example[2]))
print(brats_pp_example.min(), brats_pp_example.max())

display_normalised_cross_section(brats_pp_example[0] / brats_pp_example[0].max())
display_normalised_cross_section(brats_pp_example[1] / brats_pp_example[1].max())
display_cross_section(brats_pp_example[2] / brats_pp_example[2].max())

In [None]:
_ = plt.hist(brats_pp_example[0].flatten()[brats_pp_example[0].flatten() != 0], bins=100)

# ISLES 2015

## Raw

In [None]:
from multitask_method.preprocessing.brain_preproc import isles_samples

In [None]:
from scipy.ndimage import find_objects

isles_all_sizes = []
isles_all_spacings = []
isles_all_boxes = []

isles_outer_box = [(np.inf, 0), (np.inf, 0), (np.inf, 0)]
isles_size_counts = {}

for isles_sample_id, (t1_file, t2_file, seg_file) in tqdm(isles_samples):    
    
    t1_img = sitk.ReadImage(t1_file)
    t2_img = sitk.ReadImage(t2_file)
    
    t1_img_sz = t1_img.GetSize()
    t2_img_sz = t2_img.GetSize()
    assert t1_img_sz == t2_img_sz, f'Size missmatch {t1_img_sz} - {t2_img_sz}'
    
    if t2_img_sz not in isles_all_sizes:
        isles_all_sizes.append(t2_img_sz)
        isles_size_counts[t2_img_sz] = 1
        print(t2_img_sz)
    else:
        isles_size_counts[t2_img_sz] += 1
        
    
    t1_img_sp = t1_img.GetSpacing()
    t2_img_sp = t2_img.GetSpacing()
    assert t1_img_sp == t2_img_sp, f'Spacing missmatch {t1_img_sp} - {t2_img_sp}'
    
    if t2_img_sp not in isles_all_spacings:
        isles_all_spacings.append(t2_img_sp)
        
    t1_arr = sitk.GetArrayFromImage(t1_img)
    t2_arr = sitk.GetArrayFromImage(t2_img)
    
    t1_box = [(s.start, s.stop) for s in find_objects((t1_arr > 0).astype(int))[0]]
    t2_box = [(s.start, s.stop) for s in find_objects((t2_arr > 0).astype(int))[0]]
    
    isles_sample_outer_box = [(min(t1_min, t2_min), max(t1_max, t2_max)) for (t1_min, t1_max), (t2_min, t2_max) in zip(t1_box, t2_box)]
    isles_all_boxes.append(isles_sample_outer_box)
    
    isles_outer_box = [(min(c_min, lb), max(c_max, ub)) for (c_min, c_max), (lb, ub) in zip(isles_outer_box, isles_sample_outer_box)]
    
print(isles_all_sizes)
print(isles_size_counts)
print(isles_all_spacings)
isles_outer_box

In [None]:
print(1069 + 2671)
print(sum([count * (sz[2] - 20) for sz, count in isles_size_counts.items()]))

In [None]:
from functools import reduce

isles_max_brain_sz = reduce(lambda min_vals, new_box: [max(c_min, b_ub - b_lb) for c_min, (b_lb, b_ub) in zip(min_vals, new_box)], isles_all_boxes, [0, 0, 0] )

isles_max_brain_sz

In [None]:
isles_example_id, (isles_example_t1_file, isles_example_t2_file, isles_example_seg_file) = sorted(isles_samples)[0]
isles_t1 = sitk.GetArrayFromImage(sitk.ReadImage(isles_example_t1_file))
isles_t2 = sitk.GetArrayFromImage(sitk.ReadImage(isles_example_t2_file))
isles_seg = sitk.GetArrayFromImage(sitk.ReadImage(isles_example_seg_file))

display_cross_section(isles_t1 / isles_t1.max())
display_cross_section(isles_t2 / isles_t2.max())
display_cross_section(isles_seg)

In [None]:
_ = plt.hist(isles_t1.flatten()[isles_t1.flatten() > 0], bins=100)

## Preprocessed

In [None]:
isles_preprocessed_root = Path('/vol/biomedic3/mb4617/multitask_method_data/isles2015/')
isles_fullres_files = sorted(list((isles_preprocessed_root / 'fullres').iterdir()))
isles_lowres_files = sorted(list((isles_preprocessed_root / 'lowres').iterdir()))

In [None]:
isles_pp_example_file = sorted(isles_fullres_files)[0]
isles_pp_example = np.load(isles_pp_example_file)

print(isles_pp_example.shape)
print('Labels: ', np.unique(isles_pp_example[2]))

display_normalised_cross_section(isles_pp_example[0] / isles_pp_example[0].max())
display_normalised_cross_section(isles_pp_example[1] / isles_pp_example[1].max())
display_cross_section(isles_pp_example[2] / isles_pp_example[2].max())

In [None]:
isles_pp_example_file = isles_lowres_files[0]
isles_pp_example = np.load(isles_pp_example_file)

print(isles_pp_example.shape)
print('Labels: ', np.unique(isles_pp_example[2]))

display_normalised_cross_section(isles_pp_example[0] / isles_pp_example[0].max())
display_normalised_cross_section(isles_pp_example[1] / isles_pp_example[1].max())
display_cross_section(isles_pp_example[2] / isles_pp_example[2].max())

# VinDr-CXR

## Raw

In [None]:
from multitask_method.preprocessing.vindr_cxr_preproc import raw_root, gen_vindr_structure, TRAIN, TEST

raw_annotations_dict, raw_image_labels, raw_test_dir, raw_train_dir =  gen_vindr_structure(raw_root)

In [None]:
import pandas as pd
import pydicom

from multitask_method.preprocessing.vindr_cxr_preproc import vindr_preproc_func, generate_vindr_mask

test_anno_df = pd.read_csv(raw_annotations_dict[TEST])
test_label_df = pd.read_csv(raw_image_labels[1], index_col='image_id')

raw_test_samples = sorted(list(raw_test_dir.iterdir()))

def plot_raw_cxr(curr_img, curr_ax):
    curr_ax_im = curr_ax.imshow(curr_img, vmin=curr_img.min(), vmax=curr_img.max(), cmap='gray')
    plt.colorbar(curr_ax_im)

for i, dicom_path in enumerate(raw_test_samples):
    
    sample_id = dicom_path.stem
    raw_dicom = pydicom.dcmread(dicom_path)
    
    raw_arr = raw_dicom.pixel_array.astype(float)
    preproc_arr = vindr_preproc_func(raw_arr, raw_dicom)
    
    raw_annotation = generate_vindr_mask(test_anno_df[test_anno_df['image_id'] == sample_id], raw_arr)
    sample_class_row = test_label_df.loc[sample_id]
    sample_class = ', '.join(sample_class_row[sample_class_row == 1].index.tolist())
    
    fig, ax = plt.subplots(ncols=3, figsize=(20, 6))
    plot_raw_cxr(raw_arr, ax[0])
    plot_raw_cxr(preproc_arr, ax[1])
    ax[2].imshow(raw_annotation)
    fig.suptitle(sample_class)
        
    if i == 5:
        break

In [None]:
import pandas as pd
import pydicom

from multitask_method.preprocessing.vindr_cxr_preproc import vindr_preproc_func, generate_vindr_mask

train_anno_df = pd.read_csv(raw_annotations_dict[TRAIN])
train_label_df = pd.read_csv(raw_image_labels[0])

train_labels_sum = train_label_df.groupby('image_id')['No finding'].sum()
train_sample_ids = sorted(train_labels_sum[train_labels_sum == 3].index.tolist())
raw_train_samples = [raw_train_dir / f'{f}.dicom' for f in train_sample_ids]

def plot_raw_cxr(curr_img, curr_ax):
    curr_ax_im = curr_ax.imshow(curr_img, vmin=curr_img.min(), vmax=curr_img.max(), cmap='gray')
    plt.colorbar(curr_ax_im)

for i, dicom_path in enumerate(raw_train_samples):
    
    sample_id = dicom_path.stem
    raw_dicom = pydicom.dcmread(dicom_path)
    
    raw_arr = raw_dicom.pixel_array.astype(float)
    preproc_arr = vindr_preproc_func(raw_arr, raw_dicom)
    
    raw_annotation = generate_vindr_mask(train_anno_df[train_anno_df['image_id'] == sample_id], raw_arr)
    
    fig, ax = plt.subplots(ncols=3, figsize=(20, 6))
    plot_raw_cxr(raw_arr, ax[0])
    plot_raw_cxr(preproc_arr, ax[1])
    ax[2].imshow(raw_annotation)
    fig.suptitle(sample_id)
        
    if i == 5:
        break

## Preprocessed

In [None]:
from multitask_method.data.vindr_cxr import VinDrCXRDatasetCoordinator
from multitask_method.preprocessing.vindr_cxr_preproc import base_output_dir

test_samples = [5]

full_res_test_container = VinDrCXRDatasetCoordinator(base_output_dir, True, False).make_container([5])
low_res_test_container = VinDrCXRDatasetCoordinator(base_output_dir, False, False).make_container([5])


full_res_train_container = VinDrCXRDatasetCoordinator(base_output_dir, True, True).make_container([5])
low_res_train_container = VinDrCXRDatasetCoordinator(base_output_dir, False, True).make_container([5])

In [None]:
for i in range(len(full_res_test_container)):
    full_res_test_pp_img, full_res_test_pp_m, sample_id = full_res_test_container[i]
    
    fig, ax = plt.subplots(ncols=2, figsize=(12, 6))
    ax_im = ax[0].imshow(full_res_test_pp_img[0], vmin=full_res_test_pp_img.min(), vmax=full_res_test_pp_img.max(), cmap='gray')
    plt.colorbar(ax_im)
    ax[1].imshow(full_res_test_pp_m)
    fig.suptitle(sample_id)

In [None]:
for i in range(len(low_res_test_container)):
    low_res_test_pp_img, low_res_test_pp_m, sample_id = low_res_test_container[i]
    
    fig, ax = plt.subplots(ncols=2, figsize=(12, 6))
    ax_im = ax[0].imshow(low_res_test_pp_img[0], vmin=low_res_test_pp_img.min(), vmax=low_res_test_pp_img.max(), cmap='gray')
    plt.colorbar(ax_im)
    ax[1].imshow(low_res_test_pp_m)
    fig.suptitle(sample_id)

In [None]:
for i in range(len(low_res_train_container)):
    low_res_train_pp_img, _, sample_id = low_res_train_container[i]
    
    ax_im = plt.imshow(low_res_train_pp_img[0], vmin=low_res_train_pp_img.min(), vmax=low_res_train_pp_img.max(), cmap='gray')
    plt.colorbar(ax_im)
    plt.title(sample_id)
    plt.show()

In [None]:
for i in range(len(full_res_train_container)):
    full_res_train_pp_img, _, sample_id = full_res_train_container[i]
    
    ax_im = plt.imshow(full_res_train_pp_img[0], vmin=full_res_train_pp_img.min(), vmax=full_res_train_pp_img.max(), cmap='gray')
    plt.colorbar(ax_im)
    plt.title(sample_id)
    plt.show()