# Visualize preprocessing and data augmentation transformations

In [None]:
# manage imports
from torchvision import transforms
import torchio as tio
import sys
import os
import matplotlib.pyplot as plt
import numpy as np

# add ProstateCancer src directory to sys.path and import dataset
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)
from dataset.Dataset import OneSliceDataset, TranformedMaskedDataset
from dataset.PICAITumor3DMultimodality import PICAI3DMultimodality

### Data Preprocessing Transformations
Data preprocessing transformations are applied to the training and test set

In [None]:
transform = tio.RescaleIntensity(out_min_max=(0, 1), percentiles=(0, 99.5))

# Visualize One Slice Dataset

Run only one of the following cells (depending on which dataset you want to use)

In [None]:
# PRIVATE_1_SLICE
modalities = "t2w+adc+pet+mask"
dataset = OneSliceDataset(root_dir="../../data", modality_transform=transform)
img_id = 3
img = dataset[img_id]["image"]
print("Dataset Dimensions:", img.shape)

In [None]:
# PICAI_1_SLICE
modalities = "t2w+adc+dwi+mask"
dataset = PICAI3DMultimodality(root_dir="../../data", version="NNUNet_Lesion/Picai_Extracted_1_Slice_Numpy", include_mask=True, modality_transform=transform)
img = dataset[20]["image"]
img_id = 3
print("Dataset Dimensions:", img.shape)

In [None]:
num_channels = img.shape[0]

def plot_slice():

    plt.figure(figsize=(15, 5))

    for channel in range(num_channels):
        slice_img = img[channel,0,:,:]
        plt.subplot(1, num_channels, channel + 1)  # Rows, columns, index
        plt.imshow(slice_img, cmap=plt.cm.Greys_r)
        plt.title(f'{modalities.split("+")[channel].upper()} input')
        #plt.axis('off')

    # Set a title for all plots together
    #plt.suptitle(f'Slice {slice_number}', fontsize=16)

    plt.show()

plot_slice()

In [None]:
# Plot intensity values (without 0s)

for v in range(img.shape[0]):
    print("Intensity distribution for slice", v)
    # normalize the intensity of each volume
    intensity_values = img[v].flatten()

    min_int = min(intensity_values)
    intensity_values = intensity_values[intensity_values != min_int] # comment this line out if you want to include zeros

    plt.hist(intensity_values, bins=200, color='black')
    plt.title('Intensity Distribution')
    plt.xlabel('Intensity')
    plt.ylabel('Frequency')
    plt.show()

# Data Augmentation Transformations
Data augmentation transformations are only applied to the training set

In [None]:
# for 2d datsets

# for all channels
displacement_transform = tio.Compose([
    tio.RandomFlip(axes=(0,)),  # equivalent to horizontal flip; axes can be adjusted for 3D
    tio.RandomAffine(scales=(0.9, 1.1), degrees=(-7, 7, 0, 0, 0, 0)),  # for rotation and scaling
])

# for non-mask channels
non_masked_transform = tio.Compose([
    tio.RandomGamma(log_gamma=(-0.3,0.3)), # contrast
    tio.RandomBlur(std=(0, 0.05)),  # for Gaussian blurring
    tio.RandomNoise(mean=0, std=(0, 0.05)), # for Gaussian noise
    #tio.RandomBiasField(coefficients=(0, 0.1))
])

transformed_dataset = TranformedMaskedDataset(dataset, displacement_transform, non_masked_transform)

In [None]:
# show transformed images

def plot_transformed_slice():

    fig, axes = plt.subplots(4, num_channels, figsize=(12, 12))
    axes = axes.ravel()

    # generate images and plot them
    i = 0
    for x in range(4):
        timg = transformed_dataset[img_id]["image"]
        for channel in range(num_channels):
            tslice10 = timg[channel,0,:,:]
            axes[i].imshow(tslice10, cmap=plt.cm.Greys_r)
            axes[i].axis('off')
            i+=1

    plt.tight_layout()
    plt.show()

plot_transformed_slice()