# EDA: Exploratory Data Analysis
#### Initial General Notes: There is an imbalance of classes in the data as there are multiple organs in the image of the abdomen other than the pancreas. There are no missing data as the image to label ratio is clearly 1:1.

## Pixel Intensity

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_intensity_distribution(dataloader, num_batches=1, samples_per_volume=100000):
    volume_intensities = []
    label_intensities = []

    for i, (volume, label) in enumerate(dataloader):

        # Stopping criteria
        if i >= num_batches:
            break

        # Squeeze out channel and batch singleton dimensions
        volume = volume.squeeze().numpy()
        label = label.squeeze().numpy()

        # Iterate through each slice in batch
        for vol_slice, lbl_slice in zip(volume, label):

            # Filter out padding values
            non_padded_pixels = vol_slice[vol_slice > 0]  # Adjust this threshold as needed
            labeled_pixels = vol_slice[lbl_slice > 0]

            # Randomly sample pixels to reduce the data size
            if len(non_padded_pixels) > samples_per_volume:
                sampled_volume_pixels = np.random.choice(non_padded_pixels, samples_per_volume, replace=False)
            else:
                sampled_volume_pixels = non_padded_pixels
            
            if len(labeled_pixels) > samples_per_volume:
                sampled_label_pixels = np.random.choice(labeled_pixels, samples_per_volume, replace=False)
            else:
                sampled_label_pixels = labeled_pixels

            volume_intensities.extend(sampled_volume_pixels)
            label_intensities.extend(sampled_label_pixels)
    
    # Plot the histogram of the sampled intensities
    fig, axs = plt.subplots(1, 2, figsize=(15, 6))
    
    axs[0].hist(volume_intensities, bins=256, color='skyblue', alpha=0.75)  # Adjust bins if needed
    axs[0].set_title('Pixel Intensity Distribution of Volumes')
    axs[0].set_xlabel('Pixel Intensity')
    axs[0].set_ylabel('Frequency')
    axs[0].grid(True)

    axs[1].hist(label_intensities, bins=30, color='lightcoral', alpha=0.75)  # Adjust bins if needed
    axs[1].set_title('Pixel Intensity Distribution of Labeled Pancreas')
    axs[1].set_xlabel('Pixel Intensity')
    axs[1].set_ylabel('Frequency')
    axs[1].grid(True)

    plt.tight_layout()
    plt.show()

## Sample Splitting

In [None]:
import torch
from torch.utils.data import DataLoader, random_split

def create_data_loaders(dataset, batch_size=1):

    # Determine sizes of each split
    total_size = len(dataset)
    train_size = int(0.6 * total_size)
    val_size = int(0.2 * total_size)
    test_size = total_size - train_size - val_size

    # Randomly split the dataset into train, cv, and test sets
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(832) # Set seed
    )

    # Create data loaders for each split
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

# Create dataloaders
dataset = MedicalImageDataset(image_root_dir, label_root_dir, transforms=transforms)
train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=1)