# Exploratory Data Analysis

Data description can be found in the [Shifts 2.0: Extending The Dataset of Real Distributional Shifts](https://arxiv.org/pdf/2206.15407.pdf) paper.

All input images have the following preprocessing applied:
- Denoising
- Skull stripping (brain mask calculated from the T1 images registered to the FLAIR space)
- Bias field correction
- Interpolation to the 1 mm isovoxel space (input images are linearly interpolated while all masks are interpolated using nearest neighbour)

The data is shared as a series of compressed *.nii* files.

In [13]:
import numpy as np
import os
import re
import pandas as pd
from glob import glob
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='white', font_scale=1.5)

## Collect paths to images from train, dev_in and eval_in sets

In [3]:
train_paths = ['../data/shifts_ms_pt1/msseg/train/flair', '../data/shifts_ms_pt2/best/train/flair']
train_lbl_paths = ['../data/shifts_ms_pt1/msseg/train/gt', '../data/shifts_ms_pt2/best/train/gt']
train_mask_paths = ['../data/shifts_ms_pt1/msseg/train/fg_mask', '../data/shifts_ms_pt2/best/train/fg_mask']

devin_paths = ['../data/shifts_ms_pt1/msseg/dev_in/flair', '../data/shifts_ms_pt2/best/dev_in/flair']
devin_lbl_paths = ['../data/shifts_ms_pt1/msseg/dev_in/gt', '../data/shifts_ms_pt2/best/dev_in/gt']
devin_mask_paths = ['../data/shifts_ms_pt1/msseg/dev_in/fg_mask', '../data/shifts_ms_pt2/best/dev_in/fg_mask']

evalin_paths = ['../data/shifts_ms_pt1/msseg/eval_in/flair', '../data/shifts_ms_pt2/best/eval_in/flair']
evalin_lbl_paths = ['../data/shifts_ms_pt1/msseg/eval_in/gt', '../data/shifts_ms_pt2/best/eval_in/gt']
evalin_mask_paths = ['../data/shifts_ms_pt1/msseg/eval_in/fg_mask', '../data/shifts_ms_pt2/best/eval_in/fg_mask']

devout_paths = ['../data/shifts_ms_pt2/ljubljana/dev_out/flair']
devout_lbl_paths = ['../data/shifts_ms_pt2/ljubljana/dev_out/gt']
devout_mask_paths = ['../data/shifts_ms_pt2/ljubljana/dev_out/fg_mask']

In [4]:
def get_paths(img_paths, lbl_paths, mask_paths):
    images_paths, labels_paths, masks_paths = [], [], []
    for flair_path, gts_path, mask_path in zip(img_paths, lbl_paths, mask_paths):
        images_paths += sorted(glob(os.path.join(flair_path, "*FLAIR_isovox.nii.gz")), key=lambda i: int(re.sub('\D', '', i)))
        labels_paths += sorted(glob(os.path.join(gts_path, "*gt_isovox.nii.gz")), key=lambda i: int(re.sub('\D', '', i)))
        masks_paths += sorted(glob(os.path.join(mask_path, "*isovox_fg_mask.nii.gz")), key=lambda i: int(re.sub('\D', '', i)))
        
    return images_paths, labels_paths, masks_paths

train_images, train_labels, train_masks = get_paths(train_paths, train_lbl_paths, train_mask_paths)
devin_images, devin_labels, devin_masks = get_paths(devin_paths, devin_lbl_paths, devin_mask_paths)
evalin_images, evalin_labels, evalin_masks = get_paths(evalin_paths, evalin_lbl_paths, evalin_mask_paths)
devout_images, devout_labels, devout_masks = get_paths(devout_paths, devout_lbl_paths, devout_mask_paths)

Check:

In [None]:
train_images[5], train_labels[5], train_masks[5], len(train_images), len(train_labels), len(train_masks)

In [None]:
devin_images[5], devin_labels[5], devin_masks[5], len(devin_images), len(devin_labels), len(devin_masks)

In [None]:
evalin_images[5], evalin_labels[5], evalin_masks[5], len(evalin_images), len(evalin_labels), len(evalin_masks)

In [None]:
devout_images[5], devout_labels[5], devout_masks[5], len(devout_images), len(devout_labels), len(devout_masks)

## Images

In [410]:
def plot_example_images(img_path, title, sag_idx, cor_idx, ax_idx, lbl_path=None):
    img = nib.load(img_path).get_fdata()
    if lbl_path:
        lbl = nib.load(lbl_path).get_fdata()

    plt.figure(figsize=(30,15), tight_layout=True)
    plt.suptitle(title)
    plt.subplot(231)
    plt.imshow(img[sag_idx,...], cmap='gray')
    plt.title('Sagittal')
    plt.subplot(232)
    plt.imshow(img[:,cor_idx,:], cmap='gray')
    plt.title('Coronal')
    plt.subplot(233)
    plt.imshow(img[...,ax_idx], cmap='gray')
    plt.title('Axial')
    
    if lbl_path:
        plt.subplot(234)
        plt.imshow(img[sag_idx,...], cmap='gray')
        plt.imshow(lbl[sag_idx,...], alpha=lbl[sag_idx,...])
        plt.subplot(235)
        plt.imshow(img[:,cor_idx,:], cmap='gray')
        plt.imshow(lbl[:,cor_idx,:], alpha=lbl[:,cor_idx,:])
        plt.subplot(236)
        plt.imshow(img[...,ax_idx], cmap='gray')
        plt.imshow(lbl[...,ax_idx], alpha=lbl[...,ax_idx]);

In [None]:
plot_example_images(train_images[0], 'train', 100, 120, 150, train_labels[0])

In [None]:
plot_example_images(devin_images[-2], 'dev_in', 100, 120, 150, devin_labels[-2])

In [None]:
plot_example_images(evalin_images[-1], 'eval_in', 100, 140, 70, evalin_labels[-1])

In [None]:
plot_example_images(devout_images[10], 'dev_out', 100, 120, 150, devout_labels[10])

### Shapes

In [9]:
def get_shapes(paths):
    shapes = []
    for img_path in paths:
        img = nib.load(img_path).get_fdata()
        shapes.append(img.shape)
    return np.asarray(shapes)

In [None]:
train_shapes = get_shapes(train_images)
train_shapes

Min and max shape in each axis:

In [170]:
for i in range(3):
    print(i, np.min(train_shapes[:,i]), np.max(train_shapes[:,i]))

0 158 212
1 212 260
2 151 265


In [11]:
devin_shapes = get_shapes(devin_images)
for i in range(3):
    print(i, np.min(devin_shapes[:,i]), np.max(devin_shapes[:,i]))

0 158 212
1 212 260
2 151 265


In [12]:
evalin_shapes = get_shapes(evalin_images) 
for i in range(3):
    print(i, np.min(evalin_shapes[:,i]), np.max(evalin_shapes[:,i]))

0 158 212
1 212 256
2 151 270


In [13]:
devout_shapes = get_shapes(devout_images) 
for i in range(3):
    print(i, np.min(devout_shapes[:,i]), np.max(devout_shapes[:,i]))

0 154 173
1 241 241
2 241 241


In [27]:
def shapes_scatterplot(x, y, title, label):
    sns.scatterplot(x=x, y=y, s=600, alpha=0.4, label=label, linewidth=0)
    plt.title(title)
    plt.xlabel('width')
    plt.ylabel('height')
    plt.xlim(145, 275)
    plt.ylim(145, 275)
    
def plot_shapes(shapes, label):
    plt.subplot(131)
    shapes_scatterplot(shapes[:,1], shapes[:,2], 'Sagittal', label)
    plt.subplot(132)
    shapes_scatterplot(shapes[:,0], shapes[:,2], 'Coronal', label)
    plt.subplot(133)
    shapes_scatterplot(shapes[:,0], shapes[:,1], 'Axial', label)

In [None]:
plt.figure(figsize=(30,10))
plot_shapes(train_shapes, 'train')

In [None]:
plt.figure(figsize=(30,10))
plot_shapes(devin_shapes, 'dev_in')

In [None]:
plt.figure(figsize=(30,10))
plot_shapes(devin_shapes, 'eval_in')

In [None]:
plt.figure(figsize=(30,10))
plot_shapes(devin_shapes, 'dev_out')

### Header

In [None]:
print(img_nii.header)

In [180]:
def check_pixdim(paths):
    pixdim = []
    for img_path in paths:
        img_header = nib.load(img_path).header
        pixdim.append(img_header['pixdim'])
    print(np.unique(pixdim, axis=0))

In [332]:
for paths in [train_images, devin_images, evalin_images, devout_images]:
    check_pixdim(paths)
    print()

[[-1.  1.  1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]]

[[-1.  1.  1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]]

[[-1.  1.  1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]]

[[1. 1. 1. 1. 1. 1. 1. 1.]]



In [183]:
def check_datatype(paths):
    datatype = []
    for img_path in paths:
        img_header = nib.load(img_path).header
        datatype.append(str(img_header.get_data_dtype()))
    print(set(datatype))

In [333]:
for paths in [train_images, devin_images, evalin_images, devout_images]:
    check_datatype(paths)
    print()

{'float32'}

{'float32'}

{'float32'}

{'float32'}



In [184]:
def check_bitpix(paths):
    bitpix = []
    for img_path in paths:
        img_header = nib.load(img_path).header
        bitpix.append(img_header['bitpix'].item())
    print(set(bitpix))

In [334]:
for paths in [train_images, devin_images, evalin_images, devout_images]:
    check_bitpix(paths)
    print()

{32}

{32}

{32}

{32}



In [151]:
def check_sform_code(paths):
    code = []
    for img_path in paths:
        img_header = nib.load(img_path).header
        code.append(img_header['sform_code'].item())
    print(set(code))

In [152]:
for paths in [train_images, devin_images, evalin_images, devout_images]:
    check_sform_code(paths)
    print()

{2}

{2}

{2}

{2}



### Intensity

Investigate intensity ranges and intensity histograms for each study.

In [28]:
def min_max_intensity(paths):
    for i, img_path in enumerate(paths):
        img = nib.load(img_path).get_fdata()
        print(img.min(), img.max())

In [51]:
# for paths in [train_images, devin_images, evalin_images, devout_images]:
#     min_max_intensity(paths)
#     print()

In [11]:
def get_max_intensities(paths):
    intensities = []
    for i, img_path in enumerate(paths):
        img = nib.load(img_path).get_fdata()
        intensities += [img.max()]
    return intensities

In [16]:
max_intensities = []
for paths in [train_images, devin_images, evalin_images, devout_images]:
    max_intensity = get_max_intensities(paths)
    max_intensities.append(max_intensity)

In [17]:
x = []
for i, set_ in enumerate(['train', 'dev_in', 'eval_in', 'dev_out']):
    x.append([i]*len(max_intensities[i]))
x = [item for sublist in x for item in sublist]
    
max_intensities = [item for sublist in max_intensities for item in sublist]

In [18]:
df = pd.DataFrame({'x': x, 'y': max_intensities})

In [None]:
plt.figure(figsize=(20,9), tight_layout=True)
plt.title('Distribution of max intensity value')
for i, set_ in enumerate(['train', 'dev_in', 'eval_in', 'dev_out']):
    sns.boxplot(data=df, x='x', y='y')
plt.xticks([0, 1, 2, 3],['train', 'dev_in', 'eval_in', 'dev_out'])
plt.ylabel('Max intensity value')
plt.xlabel('');

#### Intensity histograms for each study

While calculating the image statistics, zero values - black background voxels - were excluded. We are interested in foreground analysis since background carries no information here. Including zero values in calculating mean and standard deviation would significantly skew the metrics toward 0.

In [85]:
def plot_intensity_hist(paths, title):
    num_rows = len(paths) // 2 + len(paths) % 2
    plt.figure(figsize=(25,4*num_rows), tight_layout=True)
    plt.suptitle(title)
    
    for i, img_path in enumerate(paths):
        img = nib.load(img_path).get_fdata()
        
        mean = np.mean(img, where=np.where(img>0, True, False)) # exclude black background
        std = np.std(img, where=np.where(img>0, True, False))
        
        plt.subplot(num_rows, 2, i+1)
        sns.histplot(img.flatten())
        plt.yscale('log')
        plt.title('Image ' + str(i + 1))
        plt.legend([f'mean={int(mean)}, std={int(std)}'], title='Intensity')

In [None]:
plot_intensity_hist(train_images, 'train')

In [None]:
plot_intensity_hist(devin_images, 'dev_in')

In [None]:
plot_intensity_hist(evalin_images, 'eval_in')

In [None]:
plot_intensity_hist(devout_images, 'dev_out')

### Mean 3D image

Judging from pixdim and mean images, we know that images aren't always oriented the same way.

In [263]:
from monai.transforms import Compose, LoadImage, AddChannel, CropForeground, ResizeWithPadOrCrop, Orientation
import torch

In [420]:
def plot_mean_image(paths, subset, orientation=False):
    transforms_list = [LoadImage(image_only=True), AddChannel()]
    if orientation:
        transforms_list.append(Orientation(axcodes='RAS'))
    transforms_list.append(CropForeground())
    transforms = Compose(transforms_list)
    
    shapes = []
    for img_path in paths:
        output = transforms(img_path)
        shapes.append(output.shape)

    shapes = np.asarray(shapes)
    mean_img_shape = (np.max(shapes[:,1]), np.max(shapes[:,2]), np.max(shapes[:,3]))
    print('Mean image shape:', mean_img_shape)
    
    
    transforms_list.append(ResizeWithPadOrCrop(spatial_size=mean_img_shape))
    transforms = Compose(transforms_list)
    
    images = []
    for img_path in paths:
        output = transforms(img_path)
        images.append(output)
        
    images = torch.stack(images)
    mean_img = torch.mean(images, dim=0)
    
    for ax, (axis, height) in enumerate(zip(['Sagittal', 'Coronal', 'Axial'], [5, 4, 3])):
        mean_img_ = torch.movedim(mean_img[0], ax, 0)
        num_slices = mean_img_.shape[0]
        num_rows = num_slices//7 + num_slices%7
        
        plt.figure(figsize=(28, num_rows*height), layout="constrained")
        plt.suptitle('Mean image - ' + subset + ' - ' + axis)
        for i, sl in enumerate(mean_img_):
            plt.subplot(num_rows, 7, i+1)
            plt.imshow(sl, cmap='gray')
            plt.axis('off')

In [None]:
plot_mean_image(train_images, 'train')

In [None]:
plot_mean_image(devin_images, 'dev_in')

In [None]:
plot_mean_image(evalin_images, 'eval_in')

In [None]:
plot_mean_image(devout_images, 'dev_out')

Visualize mean images after changing images' orientation to be the same:

In [None]:
plot_mean_image(train_images, 'train', orientation=True)

In [None]:
plot_mean_image(devin_images, 'dev_in', orientation=True)

In [None]:
plot_mean_image(evalin_images, 'eval_in', orientation=True)

In [None]:
plot_mean_image(devout_images, 'dev_out', orientation=True)

## Labels

In [247]:
def check_unique_values(paths):
    lbl_values = []
    for i, lbl_path in enumerate(paths):
        lbl = nib.load(lbl_path).get_fdata()
        lbl_values.append(np.unique(lbl))
    print(np.unique(lbl_values))

In [336]:
for paths in [train_labels, devin_labels, evalin_labels, devout_labels]:
    check_unique_values(paths)
    print()

[0. 1.]

[0. 1.]

[0. 1.]

[0. 1.]



In [24]:
def check_redundant_values(paths, mask_paths):
    for i, lbl_path in enumerate(paths):
        lbl = nib.load(lbl_path).get_fdata()
        mask = nib.load(mask_paths[i]).get_fdata()
        
        redundant = ((lbl == 1) & (mask == 0)).sum()
        if redundant:
            print(lbl_path, redundant)

In [25]:
for paths, mask_paths in zip([train_labels, devin_labels, evalin_labels, devout_labels],
                             [train_masks, devin_masks, evalin_masks, devout_masks]):
    check_redundant_values(paths, mask_paths)

../data/shifts_ms_pt1/msseg/train/gt/20_gt_isovox.nii.gz 2
../data/shifts_ms_pt2/ljubljana/dev_out/gt/1_gt_isovox.nii.gz 34
../data/shifts_ms_pt2/ljubljana/dev_out/gt/10_gt_isovox.nii.gz 4
../data/shifts_ms_pt2/ljubljana/dev_out/gt/14_gt_isovox.nii.gz 4
../data/shifts_ms_pt2/ljubljana/dev_out/gt/15_gt_isovox.nii.gz 2
../data/shifts_ms_pt2/ljubljana/dev_out/gt/20_gt_isovox.nii.gz 65
../data/shifts_ms_pt2/ljubljana/dev_out/gt/24_gt_isovox.nii.gz 3


It turns out that some ground truth voxels are outside brain masks.

### Header

In [249]:
lbl_nii = nib.load(train_labels[0])
lbl = lbl_nii.get_fdata()

In [None]:
print(lbl_nii.header)

In [337]:
for paths in [train_labels, devin_labels, evalin_labels, devout_labels]:
    check_datatype(paths)
    print()

{'float32'}

{'float32'}

{'float32'}

{'float32'}



In [338]:
for paths in [train_labels, devin_labels, evalin_labels, devout_labels]:
    check_bitpix(paths)
    print()

{32}

{32}

{32}

{32}



In [339]:
for paths in [train_labels, devin_labels, evalin_labels, devout_labels]:
    check_pixdim(paths)
    print()

[[-1.  1.  1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]]

[[-1.  1.  1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]]

[[-1.  1.  1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]]

[[1. 1. 1. 1. 1. 1. 1. 1.]]



### Lesion volume

In [85]:
def plot_lesion_load(labels_paths, masks_paths, title):
    lesion_load, lesion_load2 = [], []
    for i, lbl_path in enumerate(labels_paths):
        lbl = nib.load(lbl_path).get_fdata()
        lesion_load.append(np.sum(lbl)/np.size(lbl)*100)

        fg_mask = nib.load(masks_paths[i]).get_fdata()
        lesion_load2.append(np.sum(lbl)/np.sum(fg_mask)*100)

    plt.figure(figsize=(25,9), tight_layout=True)
    plt.suptitle(title)
    plt.subplot(121)
    sns.barplot(x=list(range(1,len(lesion_load)+1)), y=lesion_load)
    plt.ylabel('Lesion load [%]')
    plt.xlabel('Image')
    plt.title('Lesion volume as percent of the whole image')
    plt.ylim(0, 1)
    
    plt.subplot(122)
    sns.barplot(x=list(range(1,len(lesion_load2)+1)), y=lesion_load2)
    plt.ylabel('Lesion load [%]')
    plt.title('Lesion volume as percent of the brain area')
    plt.xlabel('Image')
    plt.ylim(0, 5.5);

Lesion occupies very small volume in each study - up to around 5% of brain area and less than 1% of the whole nii image.

In [None]:
plot_lesion_load(train_labels, train_masks, 'train')

In [None]:
plot_lesion_load(devin_labels, devin_masks, 'dev_in')

In [None]:
plot_lesion_load(evalin_labels, evalin_masks, 'eval_in')

In [None]:
plot_lesion_load(devout_labels, devout_masks, 'dev_out')

In [423]:
def plot_lesions_volume(labels_paths, subset, chart_height):
    plt.figure(figsize=(25,chart_height))
    plt.title('Lesion volumes on each image - ' + subset)
    
    for i, lbl_path in enumerate(labels_paths):
        lbl = nib.load(lbl_path).get_fdata()
        
        labeled_seg, num_labels = ndimage.label(lbl)
        label_list = np.unique(labeled_seg)
        num_elements_by_lesion = ndimage.labeled_comprehension(lbl, labeled_seg, label_list, np.sum, float, 0)
        sns.scatterplot(y=[i+1]*len(num_elements_by_lesion), x=num_elements_by_lesion, s=100, alpha=0.7)
    
    plt.xlabel('Lesion volume [voxels]')
    plt.ylabel('Image')
    plt.legend('', frameon=False)
    plt.ylim(i+1.5, 0.5)
    plt.yticks(list(range(1,i+2)), list(range(1,i+2)))
    plt.xscale('log')

In [153]:
def plot_lesions_volume(labels_paths, subset):
    plt.figure(figsize=(25,9))
    plt.title('Distribution of lesion volume on each image - ' + subset)
    
    volumes = []
    for i, lbl_path in enumerate(labels_paths):
        lbl = nib.load(lbl_path).get_fdata()
        
        labeled_seg, num_labels = ndimage.label(lbl)
        label_list = np.unique(labeled_seg)
        num_elements_by_lesion = ndimage.labeled_comprehension(lbl, labeled_seg, label_list, np.sum, float, 0)
        volumes.append(num_elements_by_lesion)

    sns.boxplot(data=volumes)
    plt.ylabel('Lesion volume [voxels]')
    plt.xlabel('Image')
    plt.yscale('log')
    plt.ylim(1, 10e4)
    plt.xticks(range(len(volumes)), list(range(1,len(volumes)+1)))

In [None]:
plot_lesions_volume(train_labels, 'train')

In [None]:
plot_lesions_volume(devin_labels, 'dev_in')

In [None]:
plot_lesions_volume(evalin_labels, 'eval_in')

In [None]:
plot_lesions_volume(devout_labels, 'dev_out')

In [8]:
def lesion_volume(labels_paths):
    lesion_volume = []
    for i, lbl_path in enumerate(labels_paths):
        lbl = nib.load(lbl_path).get_fdata()
        lesion_volume.append(np.sum(lbl))
        
    return lesion_volume

In [9]:
train_lesion_vol = lesion_volume(train_labels)
devin_lesion_vol = lesion_volume(devin_labels)
evalin_lesion_vol = lesion_volume(evalin_labels)
devout_lesion_vol = lesion_volume(devout_labels)

In [None]:
plt.figure(figsize=(20,9))
plt.title('Distribution of lesion volume per image')
sns.boxplot(data=[train_lesion_vol, devin_lesion_vol, evalin_lesion_vol, devout_lesion_vol])
plt.xticks(range(4), ['train', 'dev_in', 'eval_in', 'dev_out'])
plt.ylabel('Lesion volume [voxels]');

In [None]:
plt.figure(figsize=(15,7))
sns.histplot(data=[train_lesion_vol, devin_lesion_vol, evalin_lesion_vol, devout_lesion_vol], bins=150)
plt.xlabel('Lesion volume')
plt.ylabel('Lesion count')
plt.legend(['train', 'dev_in', 'eval_in', 'dev_out']);

### Number of lesions

In [8]:
from scipy import ndimage

In [9]:
def count_number_of_lesions(labels_paths):
    lesion_count = []
    for i, lbl_path in enumerate(labels_paths):
        lbl = nib.load(lbl_path).get_fdata()
        
        labeled_seg, num_labels = ndimage.label(lbl)
        label_list = np.unique(labeled_seg)
        lesion_count.append(len(label_list))
    
    return lesion_count

def plot_number_of_lesions(lesion_counts, subset):
    plt.figure(figsize=(25,7))
    plt.title('Number of lesions per image - ' + subset)
    sns.barplot(x=list(range(1, len(lesion_counts)+1)), y=lesion_counts)
    plt.xlabel('Image')
    plt.ylabel('Number of lesions')

In [None]:
train_lesion_count = count_number_of_lesions(train_labels)
plot_number_of_lesions(train_lesion_count, 'train')

In [None]:
devin_lesion_count = count_number_of_lesions(devin_labels)
plot_number_of_lesions(devin_lesion_count, 'dev_in')

In [None]:
evalin_lesion_count = count_number_of_lesions(evalin_labels)
plot_number_of_lesions(evalin_lesion_count, 'eval_in')

In [None]:
devout_lesion_count = count_number_of_lesions(devout_labels)
plot_number_of_lesions(devout_lesion_count, 'dev_out')

In [None]:
plt.figure(figsize=(20,9))
plt.title('Distribution of lesion count per image')
sns.boxplot(data=[train_lesion_count, devin_lesion_count, evalin_lesion_count, devout_lesion_count])
plt.xticks(range(4), ['train', 'dev_in', 'eval_in', 'dev_out']);

### Heatmaps

In [229]:
from monai.transforms import Compose, LoadImaged, AddChanneld, CropForegroundd, ResizeWithPadOrCropd, Orientationd

In [259]:
def plot_label_heatmap(img_paths, lbl_paths, subset):
    transforms_list = [LoadImaged(keys=["image", "label"]), AddChanneld(keys=["image", "label"]), CropForegroundd(keys=["image", "label"], source_key='image')]
    transforms = Compose(transforms_list)
    
    shapes = []
    for img_path, lbl_path in zip(img_paths, lbl_paths):
        output = transforms({'image': img_path, 'label': lbl_path})
        shapes.append(output['image'].shape)

    shapes = np.asarray(shapes)
    mean_img_shape = (np.max(shapes[:,1]), np.max(shapes[:,2]), np.max(shapes[:,3]))
    
    transforms_list.append(ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=mean_img_shape))
    transforms_list.append(Orientationd(keys=["image", "label"], axcodes='RAS'))
    transforms = Compose(transforms_list)
    
    labels = []
    for img_path, lbl_path in zip(img_paths, lbl_paths):
        output = transforms({'image': img_path, 'label': lbl_path})
        labels.append(output['label'])
                
    labels = torch.stack(labels)
    sum_labels = torch.sum(labels, dim=(0,1))

    plt.figure(figsize=(30,12), layout="constrained")
    plt.suptitle('Labels heatmap - ' + subset)
    for ax, (axis, height) in enumerate(zip(['Sagittal', 'Coronal', 'Axial'], [5, 4, 3])):

        labels_ = torch.movedim(sum_labels, ax, 0)

        plt.subplot(1,3,ax+1)
        plt.imshow(torch.sum(labels_, dim=0), cmap='turbo')
        plt.axis('off')
        plt.title(axis)

In [None]:
plot_label_heatmap(train_images, train_labels, 'train')

In [None]:
plot_label_heatmap(devin_images, devin_labels, 'dev_in')

In [None]:
plot_label_heatmap(evalin_images, evalin_labels, 'eval_in')

In [None]:
plot_label_heatmap(devout_images, devout_labels, 'dev_out')