# Visualize preprocessing and data augmentation transformations

In [None]:
# manage imports
from torchvision import transforms
import torchio as tio
import sys
import os
import matplotlib.pyplot as plt
import numpy as np

# add ProstateCancer src directory to sys.path and import dataset
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, parent_dir)
from dataset.Dataset import ProstateDataset, TransformedDataset, OneSliceDataset, TumorOnlyDataset
from dataset.PICAITumor3DMultimodality import PICAITumor3DMultimodality

### Data Preprocessing Transformations
Data preprocessing transformations are applied to the training and test set

In [None]:
transform = tio.ZNormalization(masking_method=lambda x: x > 0)

# Choose Dataset and Slice

Run only one of the following cells (depending on which dataset you want to use)

In [None]:
# PRIVATE_PROSTATE
modalities = "t2w+adc+pet"
dataset = ProstateDataset(root_dir="../../data", modalities=modalities, modality_transform=transform, include_pca_segmentations=True)
modalities = modalities + "+mask"
img = dataset[30]["image"]
print("Dataset Dimensions:", img.shape)

In [None]:
# PICAI_3_SLICE
modalities = "t2w+adc+diff"
dataset = PICAITumor3DMultimodality(root_dir="../../data", version="NNUNet_Lesion/Picai_AI_Extracted_3_Slice_Numpy", modality_transform=transform)
img = dataset[29]["image"]
print("Dataset Dimensions:", img.shape)

In [None]:
# PICAI_ONLY_TUMOR
modalities = "t2w+adc+pet"
dataset = TumorPicaiDataset(root_dir="../../data", modality_transform=transform)
img = dataset[30]["image"]
print("Dataset Dimensions:", img.shape)

In [None]:
# PRIVATE_ONLY_TUMOR
modalities = "t2w+adc+pet+mask"
dataset = TumorOnlyDataset(root_dir="../../data", modality_transform=transform)
img = dataset[0]["image"]
print("Dataset Dimensions:", img.shape)

# Visualize slice

In [None]:
num_channels = img.shape[0]
slices = img.shape[1]

def plot_slice(slice):

    plt.figure(figsize=(15, 5))

    for channel in range(num_channels):
        slice_img = img[channel,slice,:,:]
        plt.subplot(1, num_channels, channel + 1)  # Rows, columns, index
        plt.imshow(slice_img, cmap=plt.cm.Greys_r)
        plt.title(f'{modalities.split("+")[channel].upper()} input')
        #plt.axis('off')

    # Set a title for all plots together
    #plt.suptitle(f'Slice {slice_number}', fontsize=16)

    plt.show()

In [None]:
# Interactive Slider
from ipywidgets import interact

interact(plot_slice, slice=(0, slices - 1))

In [None]:
# Plot intensity values

for v in range(img.shape[0]):
    print("Intensity distribution for slice", v)
    # normalize the intensity of each volume
    intensity_values = img[v].flatten()

    min_int = min(intensity_values)
    intensity_values_filtered = intensity_values[intensity_values != min_int]

    plt.hist(intensity_values_filtered, bins=200, color='black')
    plt.title('Intensity Distribution')
    plt.xlabel('Intensity')
    plt.ylabel('Frequency')
    plt.show()

# 3D Data Augmentation Transformations
Data augmentation transformations are only applied to the training set

In [None]:
# Select the trasformations that you'd like to visualize

train_transform = tio.Compose([
    tio.RandomFlip(axes=(0,)),  # equivalent to horizontal flip; axes can be adjusted for 3D
    #tio.RandomAffine(scales=(0.9, 1.1), degrees=15),  # for rotation and scaling
    tio.RandomElasticDeformation(max_displacement=0.2),
    tio.RandomGamma(log_gamma=(-0.3,0.3)), # contrast
    tio.RandomBlur(std=(0, 0.05)),  # for Gaussian blurring
    #tio.RandomNoise(mean=0, std=(0, 0.03)), # for Gaussian noise
    #tio.RandomBiasField(coefficients=(0, 0.2))
])

In [None]:
# show transformed images

def plot_transformed_slice(slice):

    fig, axes = plt.subplots(4, num_channels, figsize=(12, 12))
    axes = axes.ravel()

    # generate images and plot them
    i = 0
    for x in range(4):
        timg = train_transform(img)
        for channel in range(num_channels):
            tslice10 = timg[channel,slice,:,:]
            axes[i].imshow(tslice10, cmap=plt.cm.Greys_r)
            axes[i].axis('off')
            i+=1

    plt.tight_layout()
    plt.show()

interact(plot_transformed_slice, slice=(0, slices - 1))