In [None]:
# from brats import AdultGliomaPreTreatmentSegmenter
# from brats.constants import AdultGliomaPreTreatmentAlgorithms
import SimpleITK as sitk
from pathlib import Path

def check_image_information(image_path):

    """Utility function to check image information.
    Args:
        image_path (str): Path to the image file.
    Returns:
        dict: Dictionary containing image size, spacing, origin, and direction.
    """

    if not isinstance(image_path, Path):
        image_path = Path(image_path)
    
    image = sitk.ReadImage(image_path)
    size = image.GetSize()
    spacing = image.GetSpacing()
    origin = image.GetOrigin()
    direction = image.GetDirection()

    print(f"Image: {image_path.name} size is {size}") 
    print(f"Image: {image_path.name} spacing is {spacing}")
    print(f"Image: {image_path.name} origin is {origin}")
    print(f"Image: {image_path.name} direction is {direction}")

    return {"size": size, "spacing": spacing, "origin": origin, "direction": direction}

example_t1c = "/data/glioma_data/skull_stripped_scans/IM0031/IM0031_brainles/raw_bet/IM0031_t1c_bet.nii.gz"
example_t1 = "/data/glioma_data/skull_stripped_scans/IM0031/IM0031_brainles/raw_bet/IM0031_t1_bet.nii.gz"
example_t2 = "/data/glioma_data/skull_stripped_scans/IM0031/IM0031_brainles/raw_bet/IM0031_t2_bet.nii.gz"
example_flair = "/data/glioma_data/skull_stripped_scans/IM0031/IM0031_brainles/raw_bet/IM0031_fla_bet.nii.gz"
example_gt = "/data/glioma_data/skull_stripped_scans/IM0031/IM0031_brainles/MASK.nii.gz"


t1c_image_info = check_image_information(example_t1c)
t1_image_info = check_image_information(example_t1)
t2_image_info = check_image_information(example_t2)
flair_image_info = check_image_information(example_flair)
gt_image_info = check_image_information(example_gt)


Image: IM0031_t1c_bet.nii.gz size is (240, 240, 155)
Image: IM0031_t1c_bet.nii.gz spacing is (1.0, 1.0, 1.0)
Image: IM0031_t1c_bet.nii.gz origin is (0.0, -239.0, 0.0)
Image: IM0031_t1c_bet.nii.gz direction is (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
Image: IM0031_t1_bet.nii.gz size is (240, 240, 155)
Image: IM0031_t1_bet.nii.gz spacing is (1.0, 1.0, 1.0)
Image: IM0031_t1_bet.nii.gz origin is (0.0, -239.0, 0.0)
Image: IM0031_t1_bet.nii.gz direction is (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
Image: IM0031_t2_bet.nii.gz size is (240, 240, 155)
Image: IM0031_t2_bet.nii.gz spacing is (1.0, 1.0, 1.0)
Image: IM0031_t2_bet.nii.gz origin is (0.0, -239.0, 0.0)
Image: IM0031_t2_bet.nii.gz direction is (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
Image: IM0031_fla_bet.nii.gz size is (240, 240, 155)
Image: IM0031_fla_bet.nii.gz spacing is (1.0, 1.0, 1.0)
Image: IM0031_fla_bet.nii.gz origin is (0.0, -239.0, 0.0)
Image: IM0031_fla_bet.nii.gz direction is (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0

In [None]:
# Now we check the EGD example where we feed raw data

egd_example_dir = Path("/data/glioma_data/prognosais_data/EGD-0096")
egd_example_t1 = egd_example_dir.joinpath("T1.nii.gz")
egd_example_t1_c = egd_example_dir.joinpath("T1GD.nii.gz")
egd_example_t2 = egd_example_dir.joinpath("T2.nii.gz")
egd_example_flair = egd_example_dir.joinpath("FLAIR.nii.gz")
egd_example_gt = egd_example_dir.joinpath("MASK.nii.gz")

egd_example_t1_info = check_image_information(egd_example_t1)
egd_example_t1_c_info = check_image_information(egd_example_t1_c)
egd_example_t2_info = check_image_information(egd_example_t2)
egd_example_flair_info = check_image_information(egd_example_flair)
egd_example_gt_info = check_image_information(egd_example_gt)

In [None]:
from pathlib import Path
input_dir = "/data/glioma_data/skull_stripped_scans/"

for patient in Path(input_dir).iterdir():
    preprocessed_dir = patient / (patient.name + "_brainles") / "raw_bet"
    segmentation_dir = patient / (patient.name + "_brainles") / "segmentation"
    segmenter = AdultGliomaPreTreatmentSegmenter(algorithm=AdultGliomaPreTreatmentAlgorithms.BraTS23_3, cuda_devices="0")
    segmenter.infer_single(
        t1c=preprocessed_dir.joinpath(patient.name + "_t1c_bet.nii.gz"),
        t1n=preprocessed_dir.joinpath(patient.name + "_t1_bet.nii.gz"),
        t2f=preprocessed_dir.joinpath(patient.name + "_fla_bet.nii.gz"),
        t2w=preprocessed_dir.joinpath(patient.name + "_t2_bet.nii.gz"),
        output_file=segmentation_dir.joinpath("brats_orchestrator_segmentation_3.nii.gz"),
    )
