## Prepare dataset

In [None]:
DATASET_NAME = "CAS"
DATA_PATH = "/data/falcetta/brain_data/CAS/volumes"
OUT_PATH = "/data/falcetta/brain_data/CASJ/preprocessed"

#these functions are to define for each new dataset, unless fixing a rule
#they allow to define a clear association between the data path of a brain image and the corresponding brain/vessel mask
#the brain mask is not mandatory, if not present just return None
import utils
def img_path_to_brain_path(img_path):
    out_path =img_path.replace("CAS/volumes","CAS/brain_masks_A2V").replace(".nii.gz","_pred.nii.gz")
    return out_path
def img_path_to_vessel_path(img_path):
    out_path =img_path.replace("CAS/volumes/","CAS/vessels/vessels_")
    return out_path
def img_path_to_weight_path(img_path):
    out_path =img_path.replace(".nii.gz","_weight.nii.gz")
    return out_path
    
utils.img_path_to_brain_path = img_path_to_brain_path
utils.img_path_to_vessel_path = img_path_to_vessel_path
utils.img_path_to_weight_path = img_path_to_weight_path

#this regex works as a filter in case you have unnecessary .nii files in your dataset folder
PATH_RULE = "" #"^\d{3}\.nii\.gz$"


#if you have some outliers you want to discard, you can specify it in the list below
OUTLIERs = [
    '068.nii.gz' # No brain mask
]

#if you need to preprocess also images with no vessel masks set the following flag to False
VESSELs_REQUIRED = True

#this parameter is to unify the orientation of our images
#set it True if you notice in the examples shown below that the nose is oriented downward and not upward
utils.do_flip = False

### Extract paths

In [None]:
import os

from utils import search_nii, extract_paths, load_info_from_checkpoint

In [None]:
if not os.path.isfile(f"info_{DATASET_NAME}.pkl"):
    info = {
        "train": extract_paths(path_rule=PATH_RULE),
        "val": None,
        "test": None,
    }
    search_nii(DATA_PATH, info["train"])
else:
    info = load_info_from_checkpoint(f"info_{DATASET_NAME}.pkl")
    

### Display examples

In [None]:
import random

from utils import print_bold, load_and_display_middle_slice

In [None]:
print_bold("[Example 01]")
random.seed(0)
load_and_display_middle_slice(random.choice(info["train"].paths), display_header=True)

print_bold("[Example 02]")
random.seed(1)
load_and_display_middle_slice(random.choice(info["train"].paths), axis=[0,1,2], display_header=True)

### Split train/val/test

In [None]:
from utils import is_medical_volume

In total, there are

In [None]:
IMGs = sum([
    info["train"].paths,
    info["val"].paths if info["val"] is not None else [],
    info["test"].paths if info["test"] is not None else []
], [])

len(IMGs)

volumes. First, we discard the outliers:

In [None]:
IMGs = [img_path for img_path in IMGs if os.path.basename(img_path) not in OUTLIERs]

#plot first 10 images PATHS from IMGs
print(f"Found {len(IMGs)} images")
print(f"First 10 images:")
for img_path in IMGs[:10]:
    print(img_path)

We have vessel annotations for

In [None]:
HAVE_VESSELs = [
    img_path for img_path in IMGs if is_medical_volume(img_path_to_vessel_path(img_path))
]

print(f"Found {len(HAVE_VESSELs)} images with vessel masks")
print(f"First 10 images:")
for img_path in HAVE_VESSELs[:10]:
    print(f"Image: {img_path} - Vessel: {img_path_to_vessel_path(img_path)}")

of them. With a chosen ratio of 70-15-15, we select

In [None]:
COUNT_TEST = len(HAVE_VESSELs if VESSELs_REQUIRED else IMGs) * 16 // 100
COUNT_TEST

images to be part of the validation/testing set.

In case brain masks are available, we collect them too.

In [None]:
HAVE_BRAINs = [
    img_path for img_path in IMGs if is_medical_volume(img_path_to_brain_path(img_path))
]

print(f"Found {len(HAVE_BRAINs)} images with brain masks")
print(f"First 10 images:")
for img_path in HAVE_BRAINs[:10]:
    print(f"Image: {img_path} - Brain: {img_path_to_brain_path(img_path)}")

Finally, we randomly split our dataset. 

In [None]:
random.seed(0)

print("NO VAL SET ==> JUST TRAINING AND TEST SETS")
VAL_IMGs = [] #random.sample(sorted(HAVE_VESSELs), COUNT_TEST)

print("val: ", [os.path.basename(img) for img in VAL_IMGs])

random.seed(1)

TEST_IMGs = random.sample(sorted([img for img in HAVE_VESSELs if img not in VAL_IMGs]), COUNT_TEST)

print("test: ", [os.path.basename(img) for img in TEST_IMGs])

TRAIN_IMGs = [img for img in (HAVE_VESSELs if VESSELs_REQUIRED else IMGs) if img not in VAL_IMGs and img not in TEST_IMGs]

In [None]:
import pickle

with open(f"info_{DATASET_NAME}.pkl", "wb") as file:
    pickle.dump({
        k: info[k].__dict__ if info[k] is not None else None for k in info
    }, file)

### Extract spacings and shapes

In [None]:
from utils import load_info_from_checkpoint

info = load_info_from_checkpoint(f"info_{DATASET_NAME}.pkl")

In [None]:
from utils import extract_info_and_masks, loop_nii, display_info, get_target_spacing

In [None]:
info = {
    "train": extract_info_and_masks(path_rule=PATH_RULE),
    "val": extract_info_and_masks(path_rule=PATH_RULE),
    "test": extract_info_and_masks(path_rule=PATH_RULE),
}

loop_nii(TRAIN_IMGs, info["train"]) # ~1.30 min
loop_nii(VAL_IMGs, info["val"])
loop_nii(TEST_IMGs, info["test"]) # ~20 sec

In [None]:
display_info(info["train"], info["test"])

In [None]:
SPACING = get_target_spacing(
    info["train"].spacings, #+ info["val"].spacings,
    info["train"].shapesAfterCropping(),# + info["val"].shapesAfterCropping()
)

SPACING

In [None]:
import pickle

with open(f"info_{DATASET_NAME}.pkl", "wb") as file:
    pickle.dump({
        k: info[k].__dict__ if info[k] is not None else None for k in info
    }, file)

### Crop, Metadata, Resize, Empty Slices Removal, Standardization

In [None]:
from utils import load_info_from_checkpoint

info = load_info_from_checkpoint(f"info_{DATASET_NAME}.pkl")

In [None]:
from utils import get_target_spacing

SPACING = get_target_spacing(
    info["train"].spacings, #+ info["val"].spacings,
    info["train"].shapesAfterCropping(),# + info["val"].shapesAfterCropping()
)

SPACING

In [None]:
import os

from utils import print_bold, preprocessing_loop

In [None]:
print_bold("Training")
preprocessing_loop(
    info["train"],
    os.path.join(OUT_PATH, "numpy", "train"),
    target_spacing=SPACING,
    discard_n_slices=5,
    join_vessel_and_brain=True
)

# print_bold("Validation")
# preprocessing_loop(
#     info["val"],
#     os.path.join(OUT_PATH, "numpy", "val"),
#     target_spacing=SPACING
# )

print_bold("Testing")
preprocessing_loop(
    info["test"],
    os.path.join(OUT_PATH, "numpy", "test"),
    target_spacing=SPACING,
    discard_n_slices=5,
    join_vessel_and_brain=True
)

In [None]:
import pickle

with open(f"info_{DATASET_NAME}.pkl", "wb") as file:
    pickle.dump({
        k: info[k].__dict__ if info[k] is not None else None for k in info
    }, file)