In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import os
import glob
%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from matplotlib.animation import FuncAnimation

In [3]:
import nibabel as nib
from celluloid import Camera
from IPython.display import HTML
from tqdm.notebook import tqdm
import imgaug
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import imgaug.augmenters as iaa

In [4]:
# Creating Dataset Class
class Dataset(torch.utils.data.Dataset):
    def __init__(self, root, augment_params):
        self.all_files = self.extract_files(root)
        self.augment_params = augment_params
    
    # Extract the paths to all slices given the root path (ends with train or val)
    @staticmethod
    def extract_files(root):
        files = []
        for subject in root.glob("*"):   # Iterate over the subjects
            slice_path = subject/"data"  # Get the slices for current subject
            for slice in slice_path.glob("*.npy"):
                files.append(slice)
        return files
    
    # Replace data with mask to get the masks    
    @staticmethod
    def change_img_to_label_path(path):

        parts = list(path.parts)
        parts[parts.index("data")] = "masks"
        return Path(*parts)

    # Augments slice and segmentation mask in the exact same way
    def augment(self, slice, mask):

        random_seed = torch.randint(0, 1000000, (1,)).item()
        imgaug.seed(random_seed)

        mask = SegmentationMapsOnImage(mask, mask.shape)
        slice_aug, mask_aug = self.augment_params(image=slice, segmentation_maps=mask)
        mask_aug = mask_aug.get_arr()
        return slice_aug, mask_aug
    
    # Return the length of the dataset (length of all files)
    def __len__(self):
        return len(self.all_files)
    
    # Given an index return the (augmented) slice and corresponding mask
    # Add another dimension for pytorch
    def __getitem__(self, idx):

        file_path = self.all_files[idx]
        mask_path = self.change_img_to_label_path(file_path)
        slice = np.load(file_path).astype(np.float32)  # Convert to float for torch
        mask = np.load(mask_path)
        
        if self.augment_params:
            slice, mask = self.augment(slice, mask)

        return np.expand_dims(slice, 0), np.expand_dims(mask, 0)

In [5]:
# Transformation
seq = iaa.Sequential([
    iaa.Affine(scale=(0.85, 1.15), # Zoom in or out
               rotate=(-45, 45)),  # Rotate up to 45 degrees
    iaa.ElasticTransformation()  # Random Elastic Deformations
                ])

In [6]:
# Create the dataset objects
LGG_train = Path("Preprocessed/LGG/test")
LGG_test = Path("Preprocessed/LGG/test")

train_dataset = Dataset(train_path, seq)
test_dataset = Dataset(test_path, None)