# Creting the splits for the dataset with synthetic data
### The validation only contains real data and the training data cannot contain the same cases as in the validation (even with the synthetic tumour)
Each split was created to follow the same split from when only real data was used for training. 
In case you want to use this notebook, you need to first create the 5 folds with only real data.
Check the example/splits_final* file content.

In [None]:
import json
import os
from os import listdir
from os.path import join

def create_json(Dataset_name_real, Dataset_name_fake):
    # Specify the path to the JSON file
    Dataset_splits_json = f"./nnUNet_preprocessed/{Dataset_name_real}/splits_final.json"

    # Open the file and load its contents
    with open(Dataset_splits_json, 'r') as file:
        Dataset_splits = json.load(file)
        
    all_folds = []
    for fold in range(5):    
        # Get the validation splits for this fold
        Dataset_splits_val = Dataset_splits[fold]['val']

        # Create a list with all the cases of the synthetic dataset
        synthetic_dataset_L = []
        for case in listdir(f"./nnUNet_raw/{Dataset_name_fake}/labelsTr"):
            case_name = case.split(".nii.gz")[0]
            synthetic_dataset_L.append(case_name)
        print(f"len(synthetic_dataset_L): {len(synthetic_dataset_L)}")

        # Remove the validation cases from the entire list of cases
        for case in Dataset_splits_val:
            synthetic_dataset_L = [s for s in synthetic_dataset_L if case not in s]
        fake_train_L = synthetic_dataset_L
        print(f"len(fake_train_L): {len(fake_train_L)}")

        # create the split dict for this fold
        this_fold = {
            "train": fake_train_L,
            "val": Dataset_splits_val
        }

        all_folds.append(this_fold)

    # Save the new data split 
    with open(Dataset_splits_json, 'w') as file:
        json.dump(all_folds, file, indent=4)

    print(f"In something prints after this, something is wrong with the dataset")
    # double check if everything if correct
    # Open the file and load its contents
    with open(Dataset_splits_json, 'r') as file:
        Dataset_splits = json.load(file)
        
    for split in Dataset_splits:
        for val_case in split["val"]:
            for train_case in split["train"]:
                if val_case in train_case:
                    print(val_case)

Dataset_name_real = "Dataset231_BraTS_2023"
Dataset_name_fake = "Dataset232_BraTS_2023_rGANs"
create_json(Dataset_name_real=Dataset_name_real, Dataset_name_fake=Dataset_name_fake)

## For the MedNeXt

In [None]:
import pickle
import json
import os
from os import listdir
from os.path import join
import numpy as np
from collections import OrderedDict


def create_json(Dataset_name_real, Dataset_name_fake):
    # Path to your .pkl file
    Dataset_splits_pkl = f'./nnUNet_preprocessed/{Dataset_name_real}/splits_final.pkl'

    # Open the file in binary read mode
    with open(Dataset_splits_pkl, 'rb') as file:
        Dataset_splits = pickle.load(file)

    all_folds = []
        
    for fold in range(5):  
        Dataset_splits_val = Dataset_splits[fold]['val']

        # Create a list with all the cases of the synthetic dataset
        synthetic_dataset_L = []
        for case in listdir("./nnUNet_preprocessed/{Dataset_name_fake}/gt_segmentations"):
            case_name = case.split(".nii.gz")[0]
            synthetic_dataset_L.append(case_name)
        print(f"len(synthetic_dataset_L): {len(synthetic_dataset_L)}")

        # Remove the validation cases from the entire list of cases
        for case in Dataset_splits_val:
            synthetic_dataset_L = [s for s in synthetic_dataset_L if case not in s]
        fake_train_L = synthetic_dataset_L
        print(f"len(fake_train_L): {len(fake_train_L)}")

        # create the split dict for this fold
        train_array = np.array(fake_train_L, dtype='<U19')
        val_array = np.array(Dataset_splits_val, dtype='<U19')
        
        this_fold = OrderedDict([("train", train_array), ("val", val_array)])

        all_folds.append(this_fold)


Dataset_name_real = "Task231_BraTS_2023"
Dataset_name_fake = "Task232_BraTS_2023_rGANs"
create_json(Dataset_name_real=Dataset_name_real, Dataset_name_fake=Dataset_name_fake)

