In [1]:
# from build_dataset import MRILesionDatasetBuilder

In [2]:
import os
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

class MRILesionDatasetBuilder:
    def __init__(self, data_folder="/home/benet/data", input_folder="VH", output_folder="lesion2D", folders=["train", "test"], flair_image="flair.nii.gz",
                 mask_image="lesionMask.nii.gz", slices_per_example=13, slices_step=1, start_slice=85, train_split=0.7, seed=17844, skip_empty_masks=True,
                 fill_lesion=False):
        self.data_folder = data_folder
        self.input_folder = input_folder
        self.output_folder = os.path.join(data_folder, output_folder)
        self.folders = folders
        self.flair_image = flair_image
        self.mask_image = mask_image
        self.slices_per_example = slices_per_example
        self.slices_step = slices_step
        self.start_slice = start_slice
        self.train_split = train_split
        self.seed = seed
        np.random.seed(seed)
        self.skip_empty_masks = skip_empty_masks
        self.fill_lesion = fill_lesion
        self._create_output_dirs()
    
    def _create_output_dirs(self):
        """Creates necessary output directories."""
        sub_dirs = ["train/flair", "train/mask", "test/flair", "test/mask"]
        for sub_dir in sub_dirs:
            os.makedirs(os.path.join(self.output_folder, sub_dir), exist_ok=True)
    
    def build_dataset(self):
        """Processes all specified folders (train/test)."""
        empty_masks, train_count, test_count = 0, 0, 0
        if self.input_folder == "VH":
            train_examples, test_examples = 0, 0
            total_examples = sum(len(os.listdir(os.path.join(self.data_folder, self.input_folder, folder))) for folder in self.folders)
            for folder in self.folders:
                folder_path = os.path.join(self.data_folder, self.input_folder, folder)
                examples = sorted(os.listdir(folder_path))
                
                for example in examples:
                    if train_examples >= total_examples * self.train_split:
                        train_test = "test"
                        example_empty_masks = self._process_example(folder_path, example, folder, train_test)
                    elif test_examples >= total_examples * (1 - self.train_split):
                        train_test = "train"
                        example_empty_masks = self._process_example(folder_path, example, folder, train_test)
                    else:
                        train_test = folder
                        example_empty_masks = self._process_example(folder_path, example, folder, train_test)
                    
                    empty_masks += example_empty_masks
                    increment = self.slices_per_example - example_empty_masks
                    train_count += (train_test == "train") * increment
                    test_count += (train_test == "test") * increment
                    train_examples += (train_test == "train")
                    test_examples += (train_test == "test")

        elif self.input_folder == "SHIFTS_preprocessedMNI":
            for folder in self.folders:
                folder_path = os.path.join(self.data_folder, self.input_folder, folder)
                examples = os.listdir(folder_path)
                np.random.shuffle(examples)
                examples_train = examples[:int(len(examples) * self.train_split)]
                examples_test = examples[int(len(examples) * self.train_split):]
                for example in examples_train:
                    example_empty_masks = self._process_example(folder_path, example, folder, "train")
                    empty_masks += example_empty_masks
                    train_count += self.slices_per_example - example_empty_masks
                for example in examples_test:
                    example_empty_masks = self._process_example(folder_path, example, folder, "test")
                    empty_masks += example_empty_masks
                    test_count += self.slices_per_example - example_empty_masks

        elif self.input_folder == "WMH2017_preprocessedMNI":
            folder_path = os.path.join(self.data_folder, self.input_folder)
            examples = os.listdir(folder_path)
            np.random.shuffle(examples)
            examples_train = examples[:int(len(examples) * self.train_split)]
            examples_test = examples[int(len(examples) * self.train_split):]
            for example in examples_train:
                example_empty_masks = self._process_example(folder_path, example, None, "train")
                empty_masks += example_empty_masks
                train_count += self.slices_per_example - example_empty_masks
            for example in examples_test:
                example_empty_masks = self._process_example(folder_path, example, None, "test")
                empty_masks += example_empty_masks
                test_count += self.slices_per_example - example_empty_masks
        
        else:
            print(f"Unknown input folder: {self.input_folder}, only 'VH' and 'SHIFTS_preprocessedMNI' and 'WMH2017_preprocessedMNI' are supported.")
            return
                
        print(f"Total empty masks skipped: {empty_masks}")

        print(f"In the preprocessed folder: Total examples: {train_count + test_count}, train examples: {train_count} ({train_count/(train_count + test_count) * 100:.2f}%), test examples: {test_count} ({test_count/(train_count + test_count) * 100:.2f}%)")

        train_count_total = len(os.listdir(os.path.join(self.output_folder, "train/flair")))
        test_count_total = len(os.listdir(os.path.join(self.output_folder, "test/flair")))
        print(f"In the hole dataset: Total examples: {train_count_total + test_count_total}, train examples: {train_count_total} ({train_count_total/(train_count_total + test_count_total) * 100:.2f}%), test examples: {test_count_total} ({test_count_total/(train_count_total + test_count_total) * 100:.2f}%)")

    def _process_example(self, folder_path, example, folder, train_test=None):
        """Processes a single example folder."""
        example_path = os.path.join(folder_path, example)
        if not os.path.isdir(example_path):
            print(f"Skipping {example_path}")
            return
        
        flair_path = os.path.join(example_path, self.flair_image)
        mask_path = os.path.join(example_path, self.mask_image)
        flair_data, mask_data = self._load_nifti_images(flair_path, mask_path)
        
        return self._save_slices(example, flair_data, mask_data, folder, train_test)   
      
    def _load_nifti_images(self, flair_path, mask_path):
        """Loads NIfTI images and returns their data arrays."""
        flair = nib.load(flair_path).get_fdata()
        mask = nib.load(mask_path).get_fdata()
        assert flair.shape == mask.shape, "Flair and Mask shapes do not match!"
        return flair, mask
    
    def _save_slices(self, example, flair_data, mask_data, folder, train_test=None):
        """Extracts and saves slices as PNG files."""
        end_slice = self.start_slice + self.slices_per_example * self.slices_step
        empty_masks = 0
        for j, i in enumerate(range(self.start_slice, end_slice, self.slices_step)):
            flair_slice = np.rot90(flair_data[:, :, i])
            mask_slice = np.rot90(mask_data[:, :, i])
            
            # If mask_slice is empty, skip saving
            if self.skip_empty_masks and np.sum(mask_slice) == 0:
                print(f"Skipping empty mask for {folder} {example} at slice {i}")
                empty_masks += 1
                continue
        
            # If fill_lesion is True, fill the lesion in the flair image with a gray value (0-255)
            if self.fill_lesion:
                # mid = (np.max(flair_slice) - np.min(flair_slice)) / 2
                # mean of flair_slice excuding where pixels are 0
                # zero_pixels = np.where(flair_slice == 0)
                # print(f"Number of zero pixels in flair slice: {len(zero_pixels[0])}")
                mean = np.mean(flair_slice[flair_slice > 0])
                flair_slice[mask_slice > 0] = mean
                # print(f"Filling lesion for {folder} {example} at slice {i}")
                # print the maximum value in the flair slice and the minimum
                # print(f"Max value in flair slice: {np.max(flair_slice)}, Min value in flair slice: {np.min(flair_slice)}")

            self._save_image(flair_slice, "flair", example, j, folder, train_test)
            self._save_image(mask_slice, "mask", example, j, folder, train_test)

        return empty_masks
    
    def _save_image(self, slice_data, image_type, example, index, folder, train_test=None):
        """Saves a single image slice as PNG."""
        if self.input_folder == "VH": # VH dataset folder is already train/test
            path = os.path.join(self.output_folder, train_test, image_type, f"{self.input_folder}_{example}_{index}.png")
        elif self.input_folder == "SHIFTS_preprocessedMNI":
            path = os.path.join(self.output_folder, train_test, image_type, f"{folder}_{example}_{index}.png")
        elif self.input_folder == "WMH2017_preprocessedMNI":
            path = os.path.join(self.output_folder, train_test, image_type, f"WMH2017_{example}_{index}.png")
        else:
            raise ValueError(f"Unknown input folder: {self.input_folder}")
               
        plt.imsave(path, slice_data, cmap="gray")


# Build the dataset for the brain generation model
- Inpaint lesions with simple method in the slices with lesions
- Include SHIFTS and WMH2017 datasets

### VH

In [3]:
data_folder="/home/benet/data"
input_folder="VH"
output_folder="generation2D_VH-SHIFTS-WMH2017"
folders=["train", "test"]
flair_image="flair.nii.gz"
mask_image="lesionMask.nii.gz"
slices_per_example=13
slices_step=1
start_slice=85
train_split=1
seed = 17844
skip_empty_masks=False
fill_lesion=True

# dataset_builder = MRILesionDatasetBuilder(data_folder, input_folder, output_folder, folders, flair_image, mask_image, slices_per_example, slices_step, start_slice, train_split, seed, skip_empty_masks, fill_lesion)
# dataset_builder.build_dataset()


# output_folder="lesion2D_VH"
# fill_lesion=False
# dataset_builder = MRILesionDatasetBuilder(data_folder, input_folder, output_folder, folders, flair_image, mask_image, slices_per_example, slices_step, start_slice, train_split, seed, skip_empty_masks, fill_lesion)
# dataset_builder.build_dataset()

output_folder="lesion2D_VH_split"
fill_lesion=False
train_split=0.7
dataset_builder = MRILesionDatasetBuilder(data_folder, input_folder, output_folder, folders, flair_image, mask_image, slices_per_example, slices_step, start_slice, train_split, seed, skip_empty_masks, fill_lesion)
dataset_builder.build_dataset()

Total empty masks skipped: 0
In the preprocessed folder: Total examples: 741, train examples: 507 (68.42%), test examples: 234 (31.58%)
In the hole dataset: Total examples: 741, train examples: 507 (68.42%), test examples: 234 (31.58%)


### SHIFTS

In [6]:
data_folder="/home/benet/data"
input_folder="SHIFTS_preprocessedMNI" ###
output_folder="generation2D_VH-SHIFTS-WMH2017"
folders=["dev_in", "dev_out", "eval_in", "train"] ###
flair_image="flair.nii.gz"
mask_image="lesionMask.nii.gz"
slices_per_example=13
slices_step=1
start_slice=85
train_split=1
seed = 17844
skip_empty_masks=False
fill_lesion=True

dataset_builder = MRILesionDatasetBuilder(data_folder, input_folder, output_folder, folders, flair_image, mask_image, slices_per_example, slices_step, start_slice, train_split, seed, skip_empty_masks, fill_lesion)
dataset_builder.build_dataset()


# output_folder="lesion2D_SHIFTS"
# fill_lesion=False
# dataset_builder = MRILesionDatasetBuilder(data_folder, input_folder, output_folder, folders, flair_image, mask_image, slices_per_example, slices_step, start_slice, train_split, seed, skip_empty_masks, fill_lesion)
# dataset_builder.build_dataset()

Total empty masks skipped: 0
In the preprocessed folder: Total examples: 1274, train examples: 1274 (100.00%), test examples: 0 (0.00%)
In the hole dataset: Total examples: 2015, train examples: 2015 (100.00%), test examples: 0 (0.00%)


### WMH2017

In [7]:
data_folder="/home/benet/data"
input_folder="WMH2017_preprocessedMNI" ###
output_folder="generation2D_VH-SHIFTS-WMH2017"
folders=None ###
flair_image="flair.nii.gz"
mask_image="lesionMask.nii.gz"
slices_per_example=13
slices_step=1
start_slice=85
train_split=1
seed = 17844
skip_empty_masks=False
fill_lesion=True

dataset_builder = MRILesionDatasetBuilder(data_folder, input_folder, output_folder, folders, flair_image, mask_image, slices_per_example, slices_step, start_slice, train_split, seed, skip_empty_masks, fill_lesion)
dataset_builder.build_dataset()


# output_folder="lesion2D_WMH2017"
# fill_lesion=False
# dataset_builder = MRILesionDatasetBuilder(data_folder, input_folder, output_folder, folders, flair_image, mask_image, slices_per_example, slices_step, start_slice, train_split, seed, skip_empty_masks, fill_lesion)
# dataset_builder.build_dataset()

Total empty masks skipped: 0
In the preprocessed folder: Total examples: 780, train examples: 780 (100.00%), test examples: 0 (0.00%)
In the hole dataset: Total examples: 2795, train examples: 2795 (100.00%), test examples: 0 (0.00%)
