In [None]:
import time
from pathlib import Path
from datetime import datetime

import torch
from torch.utils.data import random_split, DataLoader
import monai
import gdown
import pandas as pd
import torchio as tio
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()
plt.rcParams['figure.figsize'] = 12, 8
monai.utils.set_determinism()

print('Last run on', time.ctime())

%load_ext tensorboard

First We define data module
paths to files together with metadata are already in pandas dataframe df

In [None]:
class PiCaiDataModule(pl.LightningDataModule):
    def __init__(self, batch_size, df):
        super().__init__()
        self.task = task
        self.batch_size = batch_size
        self.df = df

        self.subjects = None
        self.train_subjects = None
        self.val_subjects = None
        self.test_subjects = None
        
        self.preprocess = None
        self.transform = None
        self.train_set = None
        self.val_set = None
        self.test_set = None
    
    def get_max_shape(self, subjects):
        dataset = tio.SubjectsDataset(subjects)
        shapes = np.array([s.spatial_shape for s in dataset])
        return shapes.max(axis=0)
    
    def getSubjectDataFromDataFrame(row):
        """
        given row from data frame prepares Subject object from it
        """
        subject = tio.Subject(
        #MRI images
        adc=tio.ScalarImage(row['adc']),
        cor=tio.ScalarImage(row['cor']),
        hbv=tio.ScalarImage(row['hbv']),
        sag=tio.ScalarImage(row['sag']),
        t2w=tio.ScalarImage(row['t2w']),
        anythingInMask=tio.ScalarImage(row['anythingInMask']),

        #metadata from CSV
        patient_id=tio.ScalarImage(row['patient_id']),
        study_id=tio.ScalarImage(row['study_id']),
        patient_age=tio.ScalarImage(row['patient_age']),
        psa=tio.ScalarImage(row['psa']),
        psad=tio.ScalarImage(row['psad']),
        prostate_volume=tio.ScalarImage(row['prostate_volume']),
        histopath_type=tio.ScalarImage(row['histopath_type']),
        lesion_GS=tio.ScalarImage(row['lesion_GS']),
        #resampled labels to t2w
        label=tio.LabelMap(row['reSampledPath']),
        diagnosis='negative')

        return subject



    def prepare_data(self):

        dictList = self.df.to_dict().values()
        self.subjects = map(lambda row: self.getSubjectDataFromDataFrame(row)   , dictList)


        num_subjects = len(self.subjects )
        # Random split into test train and validation
        train_set, valid_set,test_set = torch.utils.data.random_split(self.subjects, [0.7, 0.15,0.15])
    
        self.train_subjects = train_set
        self.val_subjects = valid_set
        self.test_subjects = test_set
        

        # After
        print('='*30)
        print('Train data set:', len(train_set))
        print('Test data set:', len(test_set))
        print('Valid data set:', len(valid_set))


    
    def get_preprocessing_transform(self):
        preprocess = tio.Compose([
            tio.RescaleIntensity((-1, 1)),
            tio.CropOrPad(self.get_max_shape(self.subjects)),
            tio.EnsureShapeMultiple(8),  # for the U-Net
            tio.OneHot(),
        ])
        return preprocess
    
    def get_augmentation_transform(self):
        augment = tio.Compose([
            tio.RandomAffine(),
            # tio.RandomGamma(p=0.5),
            # tio.RandomNoise(p=0.5),
            tio.RandomMotion(p=0.1),
            tio.RandomBiasField(p=0.25),
        ])
        return augment

    def setup(self, stage=None):
        self.preprocess = self.get_preprocessing_transform()
        augment = self.get_augmentation_transform()
        self.transform = tio.Compose([self.preprocess, augment])
        self.train_set = tio.SubjectsDataset(self.train_subjects, transform=self.transform)
        self.val_set = tio.SubjectsDataset(self.val_subjects, transform=self.preprocess)
        self.test_set = tio.SubjectsDataset(self.test_subjects, transform=self.preprocess)

    def train_dataloader(self):
        return DataLoader(self.train_set, self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.batch_size)

In [None]:
data = PiCaiDataModule(
    df= df,
    batch_size=16,
)

data.prepare_data()
data.setup()
# print('Training:  ', len(data.train_set))
# print('Validation: ', len(data.val_set))
# print('Test:      ', len(data.test_set))