# Setup Logging & Imports

In [None]:
DEBUG = True

### Setting up logging

In [None]:
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("pipeline")

# Show info messages if DEBUG mode is enabled
if DEBUG:
    logger.setLevel(logging.DEBUG)
    logger.debug("DEBUG mode is enabled. Detailed logs will be shown.")
else:
    logger.setLevel(logging.INFO)
    logger.info("DEBUG mode is disabled. Only essential logs will be shown.")

### Imports

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch

# We set up CUDA first to ensure it is configured correctly
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
if torch.cuda.is_available():
    CUDA_DEVICE = torch.device("cuda:0")
    logger.info(f"CUDA is available. Using device: {CUDA_DEVICE}")
else:
    logger.error("CUDA is not available. Please check your PyTorch installation. Using CPU instead...this will be slow.")
    CUDA_DEVICE = torch.device("cpu")

In [None]:
from pipeline.proj import load_projection_mat, reformat_sinogram, interpolate_projections, pad_and_reshape, divide_sinogram, undersample_projections, get_even_index
from pipeline.aggregate_prj import aggregate_saved_projections
from pipeline.aggregate_ct import aggregate_saved_recons
from pipeline.apply_model import apply_model_to_projections, load_model, apply_model_to_recons
from pipeline.utils import read_scans_agg_file, CTorchReconstruct
from pipeline.paths import Directories, Files
import scipy.io
import matlab.engine
import matplotlib.pyplot as plt
import numpy as np
import yaml
import importlib
import copy
from tqdm import tqdm
import gc

# Configuration

In [None]:
# Scans to convert to PyTorch tensors
# Put None if you don't have any scans to convert
# See the README for how to write this file correctly
# NOTE: This will throw an error if the scan has already been converted
#       If you would like to re-convert a scan,
#       you can delete the file manually
SCANS_CONVERT = 'txt_files/scans_convert_to_pt_liver.txt'
# SCANS_CONVERT = None

# Phase of the project (all data, models, etc. will be saved under this phase)
PHASE = "8"

# If this data version already exists in this phase, it will be loaded
# Otherwise it will be created using whatever the most updated data creation script is
DATA_VERSION = '03_liver'
PANCREAS = False # Whether the data is pancreas
LIVER = True # Whether the data is liver


# Scans to use for training, val, and testing
# You should set this even if you are not doing aggregation
# See the README for how to write this file correctly
# NOTE: This will NOT throw an error if there are already aggregated scans
#       it will just give a warning and skip the aggregation step
SCANS_AGG = 'txt_files/scans_to_agg_liver.txt'
# SCANS_AGG = None

# Whether to use even indices, too, for the projections
USE_EVEN_INDICES = False

# Reconstruction method ('CTorch' or 'matlab')
RECON_METHOD = 'matlab'

# Whether to augment the data for the image domain
# This will only be used if you are doing image domain aggregation
AUGMENT_ID = True

# List of yaml files that contain configurations for the pipeline
# Each file should contain the paramters for a specific model/ensemble
CONFIG_FILES = [
    "config.yaml",
]

# Base directory
WORK_ROOT = "D:/NoahSilverberg/ngCBCT"

# NSG_CBCT Path where the raw matlab data is stored
# NSG_CBCT_PATH = "D:/MitchellYu/NSG_CBCT"

# Directory with all files specific to this phase/data version
PHASE_DATAVER_DIR = os.path.join(
    WORK_ROOT, f"phase{PHASE}", f"DS{DATA_VERSION}"
)

DIRECTORIES = Directories(
    mat_projections_dir=os.path.join("H:/Public/Noah/liver_prj_mat"),
    pt_projections_dir=os.path.join("H:/Public/Noah", "prj_pt_liver"),
    # projections_aggregate_dir=os.path.join(PHASE_DATAVER_DIR, "aggregates", "projections"),
    projections_model_dir=os.path.join(f'H:\Public/Noah/phase{PHASE}/DS01_panc', "models", "projections"),
    projections_results_dir=os.path.join(f'H:\Public/Noah/phase{PHASE}/DS{DATA_VERSION}', "results", "projections"),
    projections_gated_dir=os.path.join("H:\Public/Noah", "gated_liver", "prj_mat"),
    reconstructions_dir=os.path.join(f'H:\Public/Noah/phase{PHASE}/DS{DATA_VERSION}', "reconstructions"),
    reconstructions_gated_dir=os.path.join("H:\Public/Noah", "gated_liver", "fdk_recon" if RECON_METHOD == 'matlab' else "ctorch_recon"),
    # pl_reconstructions_dir=os.path.join("H:\Public/Noah", "gated", "pl_recon", 'HF'),
    # images_aggregate_dir=os.path.join(PHASE_DATAVER_DIR, "aggregates", "images"),
    images_model_dir=os.path.join(f'H:\Public/Noah/phase{PHASE}/DS01_panc', "models", "images"),
    images_results_dir=os.path.join(f'H:\Public/Noah/phase{PHASE}/DS{DATA_VERSION}', "results", "images"),
    # error_results_dir= os.path.join(f'H:\Public/Noah/phase{PHASE}/DS{DATA_VERSION}', "results", "error_results"),
)

FILES = Files(DIRECTORIES)

# Data Preparation: projection interpolation

In [None]:
if SCANS_CONVERT is not None:
    # Read the scans to convert file
    with open(SCANS_CONVERT, "r") as f:
        SCANS_CONVERT = []
        for line in f:
            line = line.strip()
            if not line:
                continue
            patient, scan, scan_type = line.split()
            SCANS_CONVERT.append((patient, scan, scan_type))

    logger.debug(f"Loaded scan list for conversion: {SCANS_CONVERT}")

    logger.info("Starting to process projection data...")

    for patient, scan, scan_type in SCANS_CONVERT:
        for odd in [True, False] if USE_EVEN_INDICES else [True]:
            g_path = FILES.get_projection_pt_filepath(patient, scan, scan_type, gated=True)
            ng_path = FILES.get_projection_pt_filepath(patient, scan, scan_type, gated=False, odd=odd)

            # Make sure the files do not already exist
            if os.path.exists(ng_path):
                logger.warning(f"Projection files already exist for patient {patient}, scan {scan}, type {scan_type}. Skipping...")
                continue

            # Load the projection data from the matlab files
            mat_path = FILES.get_projection_mat_filepath(patient, scan, scan_type, PANCREAS, LIVER)
            odd_index, angles, prj = load_projection_mat(mat_path)

            # Log shapes of loaded data
            logger.debug(f'Processing patient {patient}, scan {scan}, type {scan_type}')
            logger.debug(f'Loaded odd_index shape: {odd_index.shape}')
            logger.debug(f'Loaded angles shape: {angles.shape}')
            logger.debug(f'Loaded projection shape: {prj.shape}')

            # Flip and permute to get it in the right format
            prj_gcbct, angles1, _ = reformat_sinogram(prj, angles)

            # Log shapes after reformatting
            logger.debug(f'Reformatted projection shape: {prj_gcbct.shape}')

            # Simulate ngCBCT projections
            prj_ngcbct_li = interpolate_projections(prj_gcbct, odd_index, odd=odd)

            # Log shapes after interpolation
            logger.debug(f'Interpolated ngCBCT projection shape: {prj_ngcbct_li.shape}')

            # Split the projections into two halves so they are good dimensions for the CNN
            patches = 3 if scan_type == 'FF' and prj_gcbct.shape[0] > 520 else 2
            logger.debug(f'Splitting projections into {patches}')
            combined_gcbct = divide_sinogram(pad_and_reshape(prj_gcbct), v_dim=512 if scan_type == "HF" else 256, patches=patches)
            combined_ngcbct = divide_sinogram(pad_and_reshape(prj_ngcbct_li), v_dim=512 if scan_type == "HF" else 256, patches=patches)

            # Log shapes after dividing sinograms
            logger.debug(f'Combined gCBCT shape: {combined_gcbct.shape}')
            logger.debug(f'Combined ngCBCT shape: {combined_ngcbct.shape}')

            logger.debug(f'Saving projections...')
            
            # Save the projections
            torch.save(combined_gcbct, g_path)
            torch.save(combined_ngcbct, ng_path)

            logger.debug(f'Done with patient {patient}, scan {scan}, type {scan_type}\n')

    logger.info("All projections saved successfully.")

    # Free up memory
    try:
        del odd_index, angles, prj, prj_gcbct, angles1, prj_ngcbct_li, combined_gcbct, combined_ngcbct
    except:
        pass
else:
    logger.info("No scans to convert. Skipping projection data processing.")

### DEBUG: Sample projections

In [None]:
if DEBUG and SCANS_CONVERT is not None:
    # Pick the first HF scan and first FF scan
    hf_scan = None
    ff_scan = None
    for patient, scan, scan_type in SCANS_CONVERT:
        if scan_type == "HF":
            hf_scan = (patient, scan, scan_type)
            break
    for patient, scan, scan_type in SCANS_CONVERT:
        if scan_type == "FF":
            ff_scan = (patient, scan, scan_type)
            break

    # Display the first HF scan
    # Show the gated and nonstop-gated on subplots
    if hf_scan:
        if USE_EVEN_INDICES:
            hf_patient, hf_scan_num, hf_scan_type = hf_scan
            g_path = FILES.get_projection_pt_filepath(hf_patient, hf_scan_num, hf_scan_type, gated=True)
            ng_path_odd = FILES.get_projection_pt_filepath(hf_patient, hf_scan_num, hf_scan_type, gated=False, odd=True)
            ng_path_even = FILES.get_projection_pt_filepath(hf_patient, hf_scan_num, hf_scan_type, gated=False, odd=False)
            hf_gated_prj = torch.load(g_path)
            hf_ng_prj_odd = torch.load(ng_path_odd)
            hf_ng_prj_even = torch.load(ng_path_even)
            plt.figure(figsize=(18, 6))
            plt.subplot(1, 3, 1)
            plt.imshow(hf_gated_prj[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Gated Projection - {hf_scan_type} p{hf_patient}_{hf_scan_num}')
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(hf_ng_prj_odd[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Odd Nonstop-Gated Projection - {hf_scan_type} p{hf_patient}_{hf_scan_num}')
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(hf_ng_prj_even[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Even Nonstop-Gated Projection - {hf_scan_type} p{hf_patient}_{hf_scan_num}')
            plt.axis('off')
            plt.tight_layout()
            plt.show()

            # Also plot a difference map between odd/gated and even/gated, but set all non-zero differences to 1
            diff_map_odd = np.abs(hf_ng_prj_odd[0, 0, :, :].cpu().numpy() - hf_gated_prj[0, 0, :, :].cpu().numpy())
            diff_map_even = np.abs(hf_ng_prj_even[0, 0, :, :].cpu().numpy() - hf_gated_prj[0, 0, :, :].cpu().numpy())
            diff_map_odd[diff_map_odd > 0] = 1  # Set all non-zero differences to 1
            diff_map_even[diff_map_even > 0] = 1  # Set all non-zero differences to 1
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(diff_map_odd, cmap='gray')
            plt.title(f'Difference Map (Odd) - {hf_scan_type} p{hf_patient}_{hf_scan_num}')
            plt.axis('off')
            plt.subplot(1, 2, 2)
            plt.imshow(diff_map_even, cmap='gray')
            plt.title(f'Difference Map (Even) - {hf_scan_type} p{hf_patient}_{hf_scan_num}')
            plt.axis('off')
            plt.tight_layout()
            plt.show()


            # Free up memory
            del hf_gated_prj, hf_ng_prj_odd, hf_ng_prj_even
        else:
            hf_patient, hf_scan_num, hf_scan_type = hf_scan
            g_path = FILES.get_projection_pt_filepath(hf_patient, hf_scan_num, hf_scan_type, gated=True)
            ng_path = FILES.get_projection_pt_filepath(hf_patient, hf_scan_num, hf_scan_type, gated=False, odd=False)
            hf_gated_prj = torch.load(g_path)
            hf_ng_prj = torch.load(ng_path)
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(hf_gated_prj[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Gated Projection - {hf_scan_type} p{hf_patient}_{hf_scan_num}')
            plt.axis('off')
            plt.subplot(1, 2, 2)
            plt.imshow(hf_ng_prj[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Nonstop-Gated Projection - {hf_scan_type} p{hf_patient}_{hf_scan_num}')
            plt.axis('off')
            plt.tight_layout()
            plt.show()

            # Free up memory
            del hf_gated_prj, hf_ng_prj

    # Repeat for FF scan
    if ff_scan:
        if USE_EVEN_INDICES:
            ff_patient, ff_scan_num, ff_scan_type = ff_scan
            g_path = FILES.get_projection_pt_filepath(ff_patient, ff_scan_num, ff_scan_type, gated=True)
            ng_path_odd = FILES.get_projection_pt_filepath(ff_patient, ff_scan_num, ff_scan_type, gated=False, odd=True)
            ng_path_even = FILES.get_projection_pt_filepath(ff_patient, ff_scan_num, ff_scan_type, gated=False, odd=False)
            ff_gated_prj = torch.load(g_path)
            ff_ng_prj_odd = torch.load(ng_path_odd)
            ff_ng_prj_even = torch.load(ng_path_even)
            plt.figure(figsize=(18, 6))
            plt.subplot(1, 3, 1)
            plt.imshow(ff_gated_prj[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Gated Projection - {ff_scan_type} p{ff_patient}_{ff_scan_num}')
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(ff_ng_prj_odd[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Odd Nonstop-Gated Projection - {ff_scan_type} p{ff_patient}_{ff_scan_num}')
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(ff_ng_prj_even[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Even Nonstop-Gated Projection - {ff_scan_type} p{ff_patient}_{ff_scan_num}')
            plt.axis('off')
            plt.tight_layout()
            plt.show()

            # Also plot a difference map between odd/gated and even/gated
            diff_map_odd = np.abs(ff_ng_prj_odd[0, 0, :, :].cpu().numpy() - ff_gated_prj[0, 0, :, :].cpu().numpy())
            diff_map_even = np.abs(ff_ng_prj_even[0, 0, :, :].cpu().numpy() - ff_gated_prj[0, 0, :, :].cpu().numpy())
            diff_map_odd[diff_map_odd > 0] = 1  # Set all non-zero differences to 1
            diff_map_even[diff_map_even > 0] = 1  # Set all non-zero differences to 1
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(diff_map_odd, cmap='gray')
            plt.title(f'Difference Map (Odd) - {ff_scan_type} p{ff_patient}_{ff_scan_num}')
            plt.axis('off')
            plt.subplot(1, 2, 2)
            plt.imshow(diff_map_even, cmap='gray')
            plt.title(f'Difference Map (Even) - {ff_scan_type} p{ff_patient}_{ff_scan_num}')
            plt.axis('off')
            plt.tight_layout()
            plt.show()

            # Free up memory
            del ff_gated_prj, ff_ng_prj_odd, ff_ng_prj_even
        else:
            ff_patient, ff_scan_num, ff_scan_type = ff_scan
            g_path = FILES.get_projection_pt_filepath(ff_patient, ff_scan_num, ff_scan_type, gated=True)
            ng_path = FILES.get_projection_pt_filepath(ff_patient, ff_scan_num, ff_scan_type, gated=False, odd=False)
            ff_gated_prj = torch.load(g_path)
            ff_ng_prj = torch.load(ng_path)
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(ff_gated_prj[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Gated Projection - {ff_scan_type} p{ff_patient}_{ff_scan_num}')
            plt.axis('off')
            plt.subplot(1, 2, 2)
            plt.imshow(ff_ng_prj[0, 0, :, :].cpu().numpy(), cmap='gray')
            plt.title(f'Nonstop-Gated Projection - {ff_scan_type} p{ff_patient}_{ff_scan_num}')
            plt.axis('off')
            plt.tight_layout()
            plt.show()

        # Free up memory
        del ff_gated_prj, ff_ng_prj

# Aggregate projections for train/val/test

In [None]:
if SCANS_AGG is not None:
    scans_agg, scan_type = read_scans_agg_file(SCANS_AGG)
    logger.debug(f"Loaded scan list for aggregation: {scans_agg}")

    # Only aggregate projections if they don't already exist
    agg_dir = DIRECTORIES.projections_aggregate_dir
    if agg_dir is None:
        logger.warning("No aggregation directory specified. Skipping projection data aggregation.")
    elif len(os.listdir(agg_dir)) > 0:
        logger.warning(f"Aggregated projection data already exists in {agg_dir}. Skipping...")
    else:
        logger.info("Starting to aggregate projection data...")
        
        # Aggregate and save projection data sets
        for split in ['TRAIN', 'VALIDATION']:
            if len(scans_agg[split]) > 0:
                ng_paths = [FILES.get_projection_pt_filepath(patient, scan, scan_type, gated=False, odd=True) for patient, scan, scan_type in scans_agg[split]]
                if USE_EVEN_INDICES:
                    ng_paths += [FILES.get_projection_pt_filepath(patient, scan, scan_type, gated=False, odd=False) for patient, scan, scan_type in scans_agg[split]]
                ng_agg_path = FILES.get_projections_aggregate_filepath(split, gated=False)
                aggregate_saved_projections(ng_paths, ng_agg_path)
                logger.debug("Done with nonstop-gated...")
                g_paths = [FILES.get_projection_pt_filepath(patient, scan, scan_type, gated=True) for patient, scan, scan_type in scans_agg[split]]
                if USE_EVEN_INDICES:
                    g_paths *= 2
                g_agg_path = FILES.get_projections_aggregate_filepath(split, gated=True)
                aggregate_saved_projections(g_paths, g_agg_path)
                logger.debug("Done with gated...")

                logger.debug(f"Aggregated projections saved for {scan_type} {split}.\n")
            else:
                logger.debug(f"No scans to aggregate for {scan_type} {split}. Skipping aggregation.")

        logger.info("Projection data aggregation completed successfully.")
        logger.info("Aggregated projection data saved in: %s", agg_dir)
else:
    logger.info("No scans to aggregate. Skipping projection data aggregation.")

# Training PD CNN

In [None]:
for config_file in CONFIG_FILES:
    # Load the yaml configuration file
    with open(config_file, "r") as f:
        config = yaml.safe_load(f)
    
    logger.debug(f"Loaded configuration from {config_file}")

    # Skip this config if the user has set PD_training to False
    if not config['PD_settings']['training']:
        logger.info(f"Skipping PD training for {config_file} as PD training is set to False.")
        continue

    # Get the training application
    module_name, class_name = config['PD_settings']['training_app'].rsplit('.', 1)
    module = importlib.import_module("pipeline." + module_name)
    cls = getattr(module, class_name)

    logger.debug(f"Loaded class {class_name} from module {module_name}")

    # Get the model version (for naming purposes)
    model_version = config['PD_settings']['model_version']

    # Get the ensemble size, and loop through it
    ensemble_size = config['PD_settings']['ensemble_size']
    for i in range(ensemble_size):
        # If we are training an ensemble, we add an identifier to the model version
        if ensemble_size > 1:
            # Deepcopy config so we don't affect the original
            cfg = copy.deepcopy(config)
            cfg['PD_settings']['model_version'] = f"{model_version}_{i+1:02}" # e.g., "v1_01"
        else:
            cfg = config

        # Add the data version to the configuration
        cfg['PD_settings']['data_version'] = DATA_VERSION

        checkpoint = FILES.get_model_filepath(model_version=cfg['PD_settings']['model_version'], domain='PROJ', checkpoint=cfg['PD_settings']['start_checkpoint'], ensure_exists=False)
        if cfg['PD_settings']['start_checkpoint'] is not None and os.path.exists(checkpoint):
            checkpoint = torch.load(checkpoint)
            epoch = checkpoint['epoch']
            state_dict = checkpoint['state_dict']
            optimizer = checkpoint['optimizer']

        #     logger.info(f"Resuming training from epoch {epoch} for model version {cfg['PD_settings']['model_version']}...")

        #     # Instantiate with the loaded configuration
        #     instance = cls(cfg, "PROJ", DEBUG, FILES, epoch, optimizer, state_dict)
        # else:
        #     logger.info(f"Starting training from scratch for model version {cfg['PD_settings']['model_version']}...")
        #     instance = cls(cfg, "PROJ", DEBUG, FILES)

        # Instantiate with the loaded configuration
        instance = cls(cfg, "PROJ", DEBUG, FILES)

        logger.info(f"Going to try training the {i + 1}-th model with configuration from {config_file}...")

        # Run the training
        instance.main()

        logger.info(f"Finished training the {i + 1}-th model.\n")

        del instance, cfg
        gc.collect()

    # Free up memory
    del module, cls, config, module_name, class_name

# Apply PD model and FDK to all scans

In [None]:
if SCANS_AGG is None:
    logger.info("Skipping model application as the aggregation scan list is not provided.")
else:
    if RECON_METHOD == 'matlab':
        eng = matlab.engine.start_matlab()

        matlab_script_path = 'D:/NoahSilverberg/CudaRecon'
        cuda_tools = r'D:\NoahSilverberg\CudaTools'
        matlab_functions = r'D:\NoahSilverberg\CommonMatlabFunctions_HZ'

        eng.addpath(cuda_tools, nargout=0)
        eng.addpath(matlab_functions, nargout=0)
        eng.addpath(matlab_script_path, nargout=0)

    # Loop through the configurations again
    for config_file in CONFIG_FILES:
        # Load the yaml configuration file
        with open(config_file, "r") as f:
            config = yaml.safe_load(f)

        logger.debug(f"Loaded configuration from {config_file}")

        # Get the ensemble size, and loop through it
        ensemble_size = config['PD_settings']['ensemble_size']
        for i in range(ensemble_size):
            model_version = config['PD_settings']['model_version']

            # If we are training an ensemble, we add an identifier to the model version
            if ensemble_size > 1:
                model_version = f"{model_version}_{i+1:02}"

            # Load the trained PD model onto the GPU
            model_path = FILES.get_model_filepath(model_version, "PROJ")
            PD_model = load_model(config['PD_settings']['network_name'], config['PD_settings']['network_kwargs'], model_path, CUDA_DEVICE)

            passthrough_count = config['PD_settings']['passthrough_count']

            scans_agg, scan_type = read_scans_agg_file(SCANS_AGG)
            if scan_type != config['PD_settings']['scan_type']:
                raise ValueError(f"Scan type in aggregation file ({scan_type}) does not match scan type in config ({config['PD_settings']['scan_type']}).")
            
            for split in ['TRAIN', 'VALIDATION', 'TEST']:
                for patient, scan, scan_type in tqdm(scans_agg[split], desc=f"Applying model {model_version} to projections split {split}"):
                    # Get the matlab dict for the nonstop-gated projections
                    mat_path = FILES.get_projection_mat_filepath(patient, scan, scan_type, PANCREAS, LIVER)

                    # Get the acquired nonstop-gated projections, indices, and angles from the .mat file
                    odd_index, angles, prj = load_projection_mat(mat_path)

                    # Flip and permute to get it in the right format
                    prj_gcbct_, angles1, flipped = reformat_sinogram(prj, angles)

                    for odd in [True, False] if USE_EVEN_INDICES else [True]:

                        # Simulate ngCBCT projections
                        prj_ngcbct_li_ = interpolate_projections(prj_gcbct_.detach(), odd_index, odd=odd)

                        # Reformat the projections to be in the right shape for the CNN
                        prj_gcbct = pad_and_reshape(prj_gcbct_).detach()
                        prj_ngcbct_li = pad_and_reshape(prj_ngcbct_li_).detach()

                        for passthrough_num in range(passthrough_count):
                        
                            # Save paths for the the ground truth, CNN-processed, and nonstop-gated projections
                            cnn_path = FILES.get_projections_results_filepath(model_version, patient, scan, scan_type, gated=False, odd=odd, passthrough_num=passthrough_num if passthrough_count > 1 else None)
                            g_recon_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type, gated=True, odd=odd)
                            cnn_recon_path = FILES.get_recon_filepath(model_version, patient, scan, scan_type, gated=False, odd=odd, passthrough_num=passthrough_num if passthrough_count > 1 else None)
                            nsg_recon_path = FILES.get_recon_filepath('nsg', patient, scan, scan_type, gated=False, odd=odd)
                            raw_recon_path = FILES.get_recon_filepath('raw', patient, scan, scan_type, gated=False, odd=odd)

                            # For training we only save the reconstructions
                            if split == 'TRAIN'and os.path.exists(cnn_recon_path):
                                logger.info(f"CNN projections and reconstructions already exist for {scan_type} p{patient}_{scan} for model {model_version}. Skipping...")
                                continue

                            # For validation and testing we save both projections and reconstructions
                            if split != 'TRAIN' and os.path.exists(cnn_recon_path): # and os.path.exists(cnn_path):
                                logger.info(f"CNN projections and reconstructions already exist for {scan_type} p{patient}_{scan} for model {model_version}. Skipping...")
                                continue

                            if os.path.exists(cnn_path):
                                cnn_mat = scipy.io.loadmat(cnn_path)
                                if split == 'TRAIN':
                                    os.remove(cnn_path)  # Remove the .mat file to save space

                                logger.debug(f"Loaded existing CNN projections for {scan_type} p{patient}_{scan} from {cnn_path}.")
                            else:
                                g_mat, cnn_mat = apply_model_to_projections(PD_model, scan_type, odd_index, angles, prj_gcbct, prj_ngcbct_li, odd, CUDA_DEVICE, train_at_inference=config['PD_settings']['train_at_inference'], _batch_size=16)
                                logger.debug(f"Applied model {model_version} to projections for {scan_type} p{patient}_{scan}.")

                                # NOTE: Uncomment this is you want to save the CNN output projections
                                #       these take up a lot of space and we don't really need them typically...so I commented it out
                                # if split != 'TRAIN':
                                #     scipy.io.savemat(cnn_path, cnn_mat)
                                #     logger.debug(f"Saved CNN projections for {scan_type} p{patient}_{scan} to {cnn_path}.")

                            # We only need to do & save FDK recons once
                            if not os.path.exists(g_recon_path):   
                                if RECON_METHOD == 'matlab':
                                    prj_ = matlab.single(g_mat['prj'].astype(np.float32))
                                    angles_ = matlab.single(g_mat['angles'].astype(np.float32))
                                    del g_mat
                                    if scan_type == "HF":
                                        g_fdk = np.array(eng.HFrecon_nsFDK(prj_, angles_, nargout=1))
                                    else:
                                        g_fdk = np.array(eng.FFrecon_reconFDK(prj_, angles_, nargout=1))
                                    del prj_, angles_

                                    # Convert to Pytorch tensor
                                    g_fdk = torch.from_numpy(g_fdk).detach()
                                    g_fdk = torch.permute(g_fdk, (2, 0, 1))
                                elif RECON_METHOD == 'CTorch':
                                    g_fdk = CTorchReconstruct(torch.flip(torch.from_numpy(g_mat['prj'].copy()), (0, 2) if flipped else (2,)), angles - np.pi/2, scan_type, CUDA_DEVICE)
                                else:
                                    raise ValueError(f"Unknown reconstruction method: {RECON_METHOD}")

                                if scan_type == "FF":
                                    g_fdk = g_fdk[:, 128:-128, 128:-128].clone()
                                
                                # Save the recon results as .pt
                                torch.save(g_fdk, g_recon_path)
                                del g_fdk
                                logger.debug(f"Saved gated reconstruction for {scan_type} p{patient}_{scan} to {g_recon_path}.")
                            else:
                                logger.debug(f"Gated reconstruction already exists for {scan_type} p{patient}_{scan}.")

                            # We only need to do & save FDK recons once
                            if not os.path.exists(raw_recon_path) and split != 'TRAIN':
                                if not odd:
                                    odd_index = get_even_index(odd_index, prj_gcbct_.shape[0]) # NOTE: this is actually even indices now
                                if RECON_METHOD == 'matlab': 
                                    raw_prj = undersample_projections(prj_gcbct_, odd_index)[0]
                                    raw_prj = raw_prj.contiguous()
                                    prj_ = matlab.single(raw_prj[odd_index.astype(np.int64) - 1].numpy().astype(np.float32))
                                    del raw_prj
                                    angles_ = matlab.single(angles1[odd_index.astype(np.int64) - 1].numpy().astype(np.float32))
                                    if scan_type == "HF":
                                        raw_fdk = np.array(eng.HFrecon_nsFDK(prj_, angles_, nargout=1))
                                    else:
                                        raw_fdk = np.array(eng.FFrecon_reconFDK(prj_, angles_, nargout=1))
                                    del prj_, angles_

                                    # Convert to Pytorch tensor
                                    raw_fdk = torch.from_numpy(raw_fdk).detach()
                                    raw_fdk = torch.permute(raw_fdk, (2, 0, 1))
                                elif RECON_METHOD == 'CTorch':
                                    raw_fdk = CTorchReconstruct(torch.flip(prj_gcbct_[odd_index.astype(np.int64) - 1], (0, 2) if flipped else (2,)), (torch.flip(torch.flip(angles, (0,))[odd_index.astype(np.int64) - 1], (0,)) if flipped else angles[odd_index.astype(np.int64) - 1]) - np.pi/2, scan_type, CUDA_DEVICE)

                                if scan_type == "FF":
                                    raw_fdk = raw_fdk[:, 128:-128, 128:-128].clone()
                                
                                # Save the recon results as .pt
                                torch.save(raw_fdk, raw_recon_path)
                                del raw_fdk
                                logger.debug(f"Saved raw reconstruction for {scan_type} p{patient}_{scan} to {raw_recon_path}.")
                            else:
                                logger.debug(f"Raw reconstruction already exists for {scan_type} p{patient}_{scan}.")

                            # # NOTE: This only needs to be done if you are using an auxiliary model for error prediction -- otherwise it's not needed
                            # # We only need to do & save FDK recons once
                            # if not os.path.exists(nsg_recon_path):
                            #     if RECON_METHOD == 'matlab':
                            #         prj_ = matlab.single(prj_ngcbct_li_.numpy().astype(np.float32))
                            #         angles_ = matlab.single(angles1.numpy().astype(np.float32))
                            #         if scan_type == "HF":
                            #             nsg_fdk = np.array(eng.HFrecon_nsFDK(prj_, angles_, nargout=1))
                            #         else:
                            #             nsg_fdk = np.array(eng.FFrecon_reconFDK(prj_, angles_, nargout=1))
                            #         del prj_, angles_

                            #         # Convert to Pytorch tensor
                            #         nsg_fdk = torch.from_numpy(nsg_fdk).detach()
                            #         nsg_fdk = torch.permute(nsg_fdk, (2, 0, 1))
                            #     elif RECON_METHOD == 'CTorch':
                            #         nsg_fdk = CTorchReconstruct(torch.flip(prj_ngcbct_li_, (0, 2) if flipped else (2,)), angles - np.pi/2, scan_type, CUDA_DEVICE)

                            #     if scan_type == "FF":
                            #         nsg_fdk = nsg_fdk[:, 128:-128, 128:-128].clone()
                                
                            #     # Save the recon results as .pt
                            #     torch.save(nsg_fdk, nsg_recon_path)
                            #     del nsg_fdk
                            #     logger.debug(f"Saved nonstop-gated reconstruction for {scan_type} p{patient}_{scan} to {nsg_recon_path}.")
                            # else:
                            #     logger.debug(f"Nonstop-gated reconstruction already exists for {scan_type} p{patient}_{scan}.")

                            # We only need to do & save FDK recons once
                            if not os.path.exists(cnn_recon_path):
                                if RECON_METHOD == 'matlab':
                                    prj_ = matlab.single(cnn_mat['prj'].astype(np.float32))
                                    angles_ = matlab.single(cnn_mat['angles'].astype(np.float32))
                                    del cnn_mat
                                    if scan_type == "HF":
                                        cnn_fdk = np.array(eng.HFrecon_nsFDK(prj_, angles_, nargout=1))
                                    else:
                                        cnn_fdk = np.array(eng.FFrecon_reconFDK(prj_, angles_, nargout=1))
                                    del prj_, angles_

                                    # Convert to Pytorch tensor
                                    cnn_fdk = torch.from_numpy(cnn_fdk).detach()
                                    cnn_fdk = torch.permute(cnn_fdk, (2, 0, 1))
                                elif RECON_METHOD == 'CTorch':
                                    cnn_fdk = CTorchReconstruct(torch.flip(torch.from_numpy(cnn_mat['prj'].copy()), (0, 2) if flipped else (2,)), angles - np.pi/2, scan_type, CUDA_DEVICE)
                                    del cnn_mat

                                if scan_type == "FF":
                                    cnn_fdk = cnn_fdk[:, 128:-128, 128:-128].clone()

                                # Save the recon results as .pt
                                torch.save(cnn_fdk, cnn_recon_path)
                                logger.debug(f"Saved CNN reconstruction for {scan_type} p{patient}_{scan} to {cnn_recon_path} with shape {cnn_fdk.shape}.")
                                del cnn_fdk
                            else:
                                logger.debug(f"CNN reconstruction already exists for {scan_type} p{patient}_{scan}.")

                            logger.debug(f"Saved projections for {scan_type} p{patient}_{scan}.")

                        del prj_gcbct, prj_ngcbct_li, prj_ngcbct_li_

                    del odd_index, angles, prj, prj_gcbct_, angles1

            # Free up memory
            del PD_model

    if RECON_METHOD == 'matlab':
        eng.quit()

    logger.info("All models applied to projections.")

# Aggregate CT volumes for train/val/test

In [None]:
if SCANS_AGG is not None:
    scans_agg, scan_type = read_scans_agg_file(SCANS_AGG)
    logger.debug(f"Loaded scan list for aggregation: {scans_agg}")

    for config_file in CONFIG_FILES:

        # Load the yaml configuration file
        with open(config_file, "r") as f:
            config = yaml.safe_load(f)
        
        logger.debug(f"Loaded configuration from {config_file}")

        # Skip if they want on-the-fly aggregation
        if config['ID_settings']['augmented_id_training_set']:
            logger.warning(f"Skipping reconstruction data aggregation for {config_file} as on-the-fly aggregation is enabled.")
            continue

        # Get the ensemble size, and loop through it
        ensemble_size = config['ID_settings']['ensemble_size']

        # If the input type is a dict, we assume it is an error-predicting auxiliary model
        # The input type is of the form {'PD': ['v1', 10], 'ID': ['v2', 1]}
        error = isinstance(config['ID_settings']['input_type'], dict)

        # For simplicity, we only allow one ensemble size for error-predicting auxiliary models
        # There's really not much point in using an ensemble for this
        if ensemble_size > 1 and error:
            raise ValueError("For error-predicting auxiliary models, please use an ensemble size of 1.")

        for i in range(ensemble_size):
            # If we are training an ensemble, we add an identifier to the model version
            if ensemble_size > 1 and config['ID_settings']['input_type_match_ensemble']:
                input_type = f"{config['ID_settings']['input_type']}_{i+1:02}" # e.g., "v1_01"
            else:
                input_type = config['ID_settings']['input_type']

            # Only aggregate reconstructions if they don't already exist
            agg_dir = DIRECTORIES.get_images_aggregate_dir(input_type['ID'][0] if error else input_type)
            if len(os.listdir(agg_dir)) > 0:
                logger.warning(f"Aggregated reconstruction data already exists in {os.path.dirname(agg_dir)}. Skipping...")
            else:
                logger.info("Starting to aggregate reconstruction data...")

                # Aggregate and save reconstruction data sets
                for split in ['TRAIN', 'VALIDATION']:
                    ng_agg_path = FILES.get_images_aggregate_filepath(input_type['ID'][0] if error else input_type, split, truth=False, error=error)
                    if len(scans_agg[split]) > 0:
                        if error:
                            if input_type['PD'][1] == 1 and input_type['ID'][1] == 1:
                                ng_paths = []

                                # Get the nonstop-gated reconstruction paths
                                ng_paths.append([FILES.get_recon_filepath('nsg', patient, scan, scan_type, gated=False, odd=True) for patient, scan, scan_type in scans_agg[split]])

                                # Get the PD model reconstruction paths
                                ng_paths.append([FILES.get_recon_filepath(input_type['PD'][0], patient, scan, scan_type, gated=False, odd=True) for patient, scan, scan_type in scans_agg[split]])

                                # Get the ID model reconstruction paths
                                ng_paths.append([FILES.get_images_results_filepath(input_type['ID'][0], patient, scan, odd=True) for patient, scan, scan_type in scans_agg[split]])
                            
                                if USE_EVEN_INDICES:
                                    ng_paths[0] += [FILES.get_recon_filepath('nsg', patient, scan, scan_type, gated=False, odd=False) for patient, scan, scan_type in scans_agg[split]]
                                    ng_paths[1] += [FILES.get_recon_filepath(input_type['PD'][0], patient, scan, scan_type, gated=False, odd=False) for patient, scan, scan_type in scans_agg[split]]
                                    ng_paths[2] += [FILES.get_images_results_filepath(input_type['ID'][0], patient, scan, odd=False) for patient, scan, scan_type in scans_agg[split]]
                            else:
                                logger.warning(f"Skipping aggregation for error-predicting auxiliary model since passthrough count is greater than 1. Please use on-the-fly aggregation during training instead.")
                                continue
                        else:
                            ng_paths = [FILES.get_recon_filepath(input_type, patient, scan, scan_type, gated=False, odd=True) for patient, scan, scan_type in scans_agg[split]]
                            if USE_EVEN_INDICES:
                                ng_paths += [FILES.get_recon_filepath(input_type, patient, scan, scan_type, gated=False, odd=False) for patient, scan, scan_type in scans_agg[split]]
                        
                        aggregate_saved_recons(ng_paths, ng_agg_path, scan_type)
                        logger.debug("Done with nonstop-gated...")

                        g_agg_path = FILES.get_images_aggregate_filepath(input_type['ID'][0] if error else 'fdk', split, truth=True, error=error)
                        if os.path.exists(g_agg_path):
                            logger.warning(f"Gated aggregation file {g_agg_path} already exists. Skipping aggregation for gated data.")
                            continue
                        
                        if error:
                            g_paths = []

                            # Get the ID model reconstruction paths
                            g_paths.append([FILES.get_images_results_filepath(input_type['ID'][0], patient, scan, odd=True) for patient, scan, scan_type in scans_agg[split]])

                            # Get the gated reconstruction paths
                            g_paths.append([FILES.get_recon_filepath(None, patient, scan, scan_type, gated=True, odd=True) for patient, scan, scan_type in scans_agg[split]])

                            if USE_EVEN_INDICES:
                                g_paths[0] += [FILES.get_images_results_filepath(input_type['ID'][0], patient, scan, odd=False) for patient, scan, scan_type in scans_agg[split]]
                                g_paths[1] += [FILES.get_recon_filepath(None, patient, scan, scan_type, gated=True, odd=False) for patient, scan, scan_type in scans_agg[split]]
                        else:
                            g_paths = [FILES.get_recon_filepath(None, patient, scan, scan_type, gated=True) for patient, scan, scan_type in scans_agg[split]]
                            if USE_EVEN_INDICES:
                                g_paths *= 2
                                
                        aggregate_saved_recons(g_paths, g_agg_path, scan_type, compute_errors=error)
                        logger.debug("Done with gated...")

                        logger.debug(f"Aggregated reconstructions saved for {scan_type} {split}.\n")
                    else:
                        logger.debug(f"No scans to aggregate for {scan_type} {split}. Skipping aggregation.")

            logger.info("Reconstruction data aggregation completed successfully.")
            logger.info("Aggregated reconstruction data saved in: %s", agg_dir)
else:
    logger.info("No scans to aggregate. Skipping reconstruction data aggregation.")

# Train ID CNN

In [None]:
for config_file in CONFIG_FILES:
    # Load the yaml configuration file
    with open(config_file, "r") as f:
        config = yaml.safe_load(f)
    
    logger.debug(f"Loaded configuration from {config_file}")

    # Skip this config if the user has set ID_training to False
    if not config['ID_settings']['training']:
        logger.info(f"Skipping ID training for {config_file} as ID training is set to False.")
        continue

    # Get the training application
    module_name, class_name = config['ID_settings']['training_app'].rsplit('.', 1)
    module = importlib.import_module("pipeline." + module_name)
    cls = getattr(module, class_name)

    logger.debug(f"Loaded class {class_name} from module {module_name}")

    # Get the model version (for naming purposes)
    model_version = config['ID_settings']['model_version']

    # Get the ensemble size, and loop through it
    ensemble_size = config['ID_settings']['ensemble_size']
    for i in range(ensemble_size):
        # If we are training an ensemble, we add an identifier to the model version
        if ensemble_size > 1:
            # Deepcopy config so we don't affect the original
            cfg = copy.deepcopy(config)
            cfg['ID_settings']['model_version'] = f"{model_version}_{i+1:02}" # e.g., "v1_01"
            if config['ID_settings']['input_type_match_ensemble']:
                cfg['ID_settings']['input_type'] = f"{cfg['ID_settings']['input_type']}_{i+1:02}"
        else:
            cfg = config

        # Add the data version to the configuration
        cfg['ID_settings']['data_version'] = DATA_VERSION

        # Pass the scan list and augmentation flag to the TrainingApp
        scans_agg, _ = read_scans_agg_file(SCANS_AGG)
        cfg['ID_settings']['augment_id'] = AUGMENT_ID

        checkpoint = FILES.get_model_filepath(model_version=cfg['ID_settings']['model_version'], domain='IMAG', checkpoint=cfg['ID_settings']['start_checkpoint'], ensure_exists=False)
        if cfg['ID_settings']['start_checkpoint'] is not None and os.path.exists(checkpoint):
            checkpoint = torch.load(checkpoint)
            epoch = checkpoint['epoch']
            state_dict = checkpoint['state_dict']
            optimizer = checkpoint['optimizer']

            # Instantiate with the loaded configuration
            instance = cls(cfg, "IMAG", DEBUG, FILES, epoch, optimizer, state_dict)
        else:
            instance = cls(cfg, "IMAG", DEBUG, FILES)

        logger.info(f"Going to try training the {i + 1}-th model with configuration from {config_file}...")

        # Run the training
        instance.main()

        logger.info(f"Finished training the {i + 1}-th model.\n")

        del instance, cfg
        gc.collect()

    # Free up memory
    del module, cls, config, module_name, class_name

# Pass all samples through the ID model

In [None]:
if SCANS_AGG is None:
    logger.info("Skipping model application as the aggregation scan list is not provided.")
else:
    # Loop through the configurations again
    for config_file in CONFIG_FILES:
        # Load the yaml configuration file
        with open(config_file, "r") as f:
            config = yaml.safe_load(f)

        logger.debug(f"Loaded configuration from {config_file}")

        # Check if the model is evidential or error
        is_evidential = config['ID_settings']['is_evidential']
        error = isinstance(config['ID_settings']['input_type'], dict)
        if is_evidential or error:
            # For evidential models (deterministic), passthrough_count should be 1
            # For error-predicting auxiliary models, we enforce that passthrough_count is 1
            #     although they technically don't need to be deterministic. It's kind of unnecessary to use >1.
            if config['ID_settings']['passthrough_count'] > 1:
                raise ValueError("Evidential/error models are deterministic. 'passthrough_count' should be 1.")

        # Get the ensemble size, and loop through it
        ensemble_size = config['ID_settings']['ensemble_size']
        for i in range(ensemble_size):
            model_version = config['ID_settings']['model_version']
            input_type = config['ID_settings']['input_type']

            # If we are training an ensemble, we add an identifier to the model version
            if ensemble_size > 1:
                model_version = f"{model_version}_{i+1:02}"
                if config['ID_settings']['input_type_match_ensemble']:
                    if error:
                        input_type['PD'][0] = f"{input_type['PD'][0]}_{i+1:02}"
                        input_type['ID'][0] = f"{input_type['ID'][0]}_{i+1:02}"
                    else:
                        input_type = f"{input_type}_{i+1:02}"

            # Load the trained ID model onto the GPU
            model_path = FILES.get_model_filepath(model_version, "IMAG")
            ID_model = load_model(config['ID_settings']['network_name'], config['ID_settings']['network_kwargs'], model_path, CUDA_DEVICE)

            passthrough_count = config['ID_settings']['passthrough_count']

            scans_agg, scan_type = read_scans_agg_file(SCANS_AGG, list_=False)
            if scan_type != config['ID_settings']['scan_type']:
                raise ValueError(f"Scan type in aggregation file ({scan_type}) does not match scan type in config ({config['ID_settings']['scan_type']}).")
            
            for patient, scan, scan_type in tqdm(scans_agg["VALIDATION"] + scans_agg["TEST"], desc=f"Applying model {model_version} to projections"):
                for passthrough_num in range(passthrough_count):
                    
                    # Determine input path
                    # For evidential models, which are deterministic, we assume the input should also be
                    # deterministic. We use passthrough 0 from the stochastic predecessor model.
                    if is_evidential:
                        input_passthrough = 0
                    elif error:
                        input_passthrough = 0 if input_type['PD'][1] > 1 else None
                    else:
                        input_passthrough = passthrough_num if passthrough_count > 1 else None

                    for odd in [True, False] if USE_EVEN_INDICES else [True]:

                        if error:                    
                            ng_pt_path = []
                            ng_pt_path.append(FILES.get_recon_filepath('nsg', patient, scan, scan_type, gated=False, odd=odd))
                            ng_pt_path.append(FILES.get_recon_filepath(input_type['PD'][0], patient, scan, scan_type, gated=False, passthrough_num=input_passthrough, odd=odd))
                            ng_pt_path.append(FILES.get_images_results_filepath(input_type['ID'][0], patient, scan, odd=odd))

                            ng_path = FILES.get_error_results_filepath(model_version, patient, scan, passthrough_num=None, odd=odd)
                        else:
                            ng_pt_path = FILES.get_recon_filepath(input_type, patient, scan, scan_type, gated=False, passthrough_num=input_passthrough, odd=odd)
                            
                            ng_path = FILES.get_images_results_filepath(
                                model_version, patient, scan, passthrough_num=passthrough_num if passthrough_count > 1 else None, odd=odd
                            )

                        if os.path.exists(ng_path):
                            logger.info(f"ID CNN results already exist for {scan_type} p{patient}_{scan} for model {model_version} (passthrough {passthrough_num}). Skipping...")
                            continue

                        # Apply model and save results
                        results = apply_model_to_recons(
                            ID_model, ng_pt_path, CUDA_DEVICE,
                            scan_type=scan_type,
                            train_at_inference=config['ID_settings']['train_at_inference'],
                            _batch_size=8,
                        )

                        if is_evidential:
                            gamma, nu, alpha, beta = results
                            # Save all 4 outputs in a dictionary. Squeeze the channel dimension.
                            results_dict = {
                                'gamma': gamma.squeeze(1),
                                'nu': nu.squeeze(1),
                                'alpha': alpha.squeeze(1),
                                'beta': beta.squeeze(1)
                            }
                            torch.save(results_dict, ng_path)
                            del gamma, nu, alpha, beta, results_dict
                        else:
                            # Original behavior for non-evidential models
                            torch.save(results, ng_path)

                        logger.debug(f"Saved results for {scan_type} p{patient}_{scan} (passthrough {passthrough_num}).")
                        del results

            # Free up memory
            del ID_model

    logger.info("All models applied to projections.")
