BTCV classes:  
- spleen
- right kidney
- left kidney
- gallbladder
- esophagus
- liver
- stomach
- aorta
- inferior vena cava
- portal vein and splenic vein
- pancreas
- right adrenal gland
- left adrenal gland

AMOS classes:
- spleen
- right kidney
- left kidney
- gallbladder
- esophagus
- liver
- stomach
- aorta
- inferior vena cava
- pancreas
- right adrenal gland
- left adrenal gland
- duodenum
- bladder
- prostate/uterus

Total Segmentator classes: Includes all the classes from BTCV & AMOS

It is decided to do the intersection of the three datasets, not the union. This implies a reduced number of classes, but also a better generalization of the future model. Therefore, it is necessary to remove the not included masks from the labels of the datasets. Hence, the intersection includes classes:
- spleen
- right kidney
- left kidney
- gallbladder
- esophagus
- liver
- stomach
- aorta
- inferior vena cava
- pancreas
- right adrenal gland
- left adrenal gland

In [1]:
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import shutil
from tqdm import tqdm
import pandas as pd

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


<h2>TotalSeg</h2>

In [None]:
accepted_masks = {'spleen.nii.gz':1, 'kidney_right.nii.gz':2, 'kidney_left.nii.gz':3, 'gallbladder.nii.gz':4, 'esophagus.nii.gz':5, 'liver.nii.gz':6, 'stomach.nii.gz':7, 'aorta.nii.gz': 8, 'inferior_vena_cava.nii.gz':9, 'portal_vein_and_splenic_vein.nii.gz':10, 'pancreas.nii.gz':11, 'adrenal_gland_right.nii.gz':12, 'adrenal_gland_left.nii.gz':13}

In [2]:
metadata = pd.read_csv('/home/nalvarez/data_analysis/totalseg/totalseg_metadata_def.csv')

In [11]:
scans_dir = "/data/TotalSeg_raw/"

scan_names = sorted(os.listdir(scans_dir))

destination_folder = '/data/TotalSeg_raw_combined_masks'

In [None]:
for scan_name in tqdm(scan_names[1:]):
    scan_path = os.path.join(scans_dir, scan_name)
    image_path = os.path.join(scans_dir, scan_name, 'ct.nii.gz')
    combined_segmentation = np.zeros_like(nib.load(image_path).get_fdata()) #Base array where to save all masks
    for filename in os.listdir(os.path.join(scans_dir, scan_name, "segmentations/")): #Iterate through masks of the same scan
        if filename in list(accepted_masks.keys()): #Filter through the relevant masks
            mask = nib.load(os.path.join(scans_dir, scan_name, "segmentations/", filename))
            mask_data = mask.get_fdata()
            if len(np.unique(mask_data)) == 1: continue
            combined_segmentation[mask_data != 0] = accepted_masks[filename] #Check positions where masks are different from zero
    

    class_value_to_find = accepted_masks['liver.nii.gz']
    for slice_index in reversed(range(combined_segmentation.shape[2])):
        current_slice = combined_segmentation[:, :, slice_index]
        if class_value_to_find in current_slice:
            # Create a new NIfTI image containing slices from the first slice to the last found slice
            combined_nii = nib.Nifti1Image(combined_segmentation[:, :, :slice_index + 1], affine=mask.affine)

            ct_scan = nib.load(os.path.join(image_path))
            ct_data = ct_scan.get_fdata()
            new_nifti_image = nib.Nifti1Image(ct_data[:, :, :slice_index + 1], affine=mask.affine)

            # Save created NIfTI images
            os.makedirs(os.path.join(destination_folder, scan_name)) #Create folder to save contents
            nib.save(combined_nii, destination_folder+f'/{scan_name}/{scan_name}_combined_segmentation.nii.gz')  # Save the combined mask as a NIfTI file
            nib.save(new_nifti_image, os.path.join(destination_folder, scan_name, 'ct.nii.gz'))
            break

## AMOS

In [None]:
data_dir = "/data/AMOS/amos22/"
dst_img_dir1 = "/data/merged_dataset/images/"
dst_label_dir1 = "/data/merged_dataset/labels/"

for images_folder, labels_folder in tqdm(zip(['imagesTr', 'imagesVa'], ['labelsTr', 'labelsVa'])):
    scans_dir = os.path.join(data_dir, labels_folder)
    images_dir = os.path.join(data_dir, images_folder)
    scan_names = sorted(os.listdir(scans_dir))
    for scan_name in scan_names:
        scan_path = os.path.join(scans_dir, scan_name)
        image_path = os.path.join(images_dir, scan_name)
        
        mask = nib.load(scan_path)
        mask_data = mask.get_fdata()
        
        combined_segmentation = np.zeros_like(mask_data) #Base array where to save all masks
        for accepted_mask in accepted_masks.values():
            combined_segmentation[mask_data == accepted_mask] = accepted_mask
            
        combined_nii = nib.Nifti1Image(combined_segmentation, affine=mask.affine)
        nib.save(combined_nii, dst_label_dir1+scan_name)
        shutil.copyfile(os.path.join(images_dir, scan_name), dst_img_dir1+scan_name)

<h2>BTCV</h2>

The variable 'btcv_accepted_masks' is needed because the order of accepted classes in BTCV does not adjust to the intersected list above, i.e. the "portal vein and splenic vein" class is in the middle. That is why we need this variable in order to label correctly the classes in the masks. 

In [2]:
btcv_accepted_masks = {1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 11:10, 12:11, 13:12}

In [None]:
data_dir = "/data/BTCV/Abdomen/Task050_BTCV"
dst_img_dir1 = "/data/merged_dataset/images/"
dst_label_dir1 = "/data/merged_dataset/labels/"

for images_folder, labels_folder in zip(['imagesTr'], ['labelsTr']):
    scans_dir = os.path.join(data_dir, labels_folder)
    images_dir = os.path.join(data_dir, images_folder)
    scan_names = sorted(os.listdir(scans_dir))
    for scan_name in tqdm(scan_names):
        scan_path = os.path.join(scans_dir, scan_name)
        image_path = os.path.join(images_dir, scan_name)
        
        mask = nib.load(scan_path)
        mask_data = mask.get_fdata()
        
        combined_segmentation = np.zeros_like(mask_data)
        for accepted_mask_key, accepted_mask_value in btcv_accepted_masks.items():
            combined_segmentation[mask_data == accepted_mask_key] = accepted_mask_value
        
        combined_nii = nib.Nifti1Image(combined_segmentation, affine=mask.affine)
        nib.save(combined_nii, dst_label_dir1+scan_name)
        shutil.copyfile(os.path.join(images_dir, scan_name), dst_img_dir1+scan_name)