### Imports and logging

In [None]:
import logging

logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [None]:
import os
import torch

# We set up CUDA first to ensure it is configured correctly
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
if torch.cuda.is_available():
    CUDA_DEVICE = torch.device("cuda:0")
    logger.info(f"CUDA is available. Using device: {CUDA_DEVICE}")
else:
    logger.error("CUDA is not available. Please check your PyTorch installation. Using CPU instead...this will be slow.")
    CUDA_DEVICE = torch.device("cpu")

In [None]:
from pipeline.proj import load_projection_mat, reformat_sinogram, interpolate_projections, pad_and_reshape, divide_sinogram
from pipeline.aggregate_prj import aggregate_saved_projections
from pipeline.aggregate_ct import aggregate_saved_recons
from pipeline.apply_model import apply_model_to_projections, load_model, apply_model_to_recons
from pipeline.utils import ensure_dir, read_scans_agg_file
from pipeline.paths import Directories, Files
import scipy.io
import matplotlib.pyplot as plt
import numpy as np
import yaml
import importlib
import copy
from tqdm import tqdm
import gc
import tigre.utilities.gpu as gpu
from pipeline.FDK_half.FDK_half import FDKHalf
from pipeline.utils import get_geometry
import time

### Pick scan

In [None]:
PHASE = '7'
DATA_VERSION = '13'
PATIENT = '01'
SCAN = '01'
SCAN_TYPE = 'HF'
PD_MODEL_VERSION = 'MK7_07'
PD_NETWORK_NAME = 'IResNet'
ID_MODEL_VERSION = 'MK7_07'
ID_NETWORK_NAME = 'IResNet'

### File system

In [None]:
# Base directory
WORK_ROOT = "D:/NoahSilverberg/ngCBCT"

# NSG_CBCT Path where the raw matlab data is stored
NSG_CBCT_PATH = "D:/MitchellYu/NSG_CBCT"

# Directory with all files specific to this phase/data version
PHASE_DATAVER_DIR = os.path.join(
    WORK_ROOT, f"phase{PHASE}", f"DS{DATA_VERSION}"
)

DIRECTORIES = Directories(
    mat_projections_dir=os.path.join(NSG_CBCT_PATH, "data/prj/HF/mat"),
    pt_projections_dir=os.path.join(WORK_ROOT, "prj_pt"),
    projections_aggregate_dir=os.path.join(PHASE_DATAVER_DIR, "aggregates", "projections"),
    projections_model_dir=os.path.join(PHASE_DATAVER_DIR, "models", "projections"),
    projections_results_dir=os.path.join(PHASE_DATAVER_DIR, "results", "projections"),
    projections_gated_dir=os.path.join(WORK_ROOT, "gated", "prj_mat"),
    reconstructions_dir=os.path.join(PHASE_DATAVER_DIR, "reconstructions"),
    reconstructions_gated_dir=os.path.join(WORK_ROOT, "gated", "fdk_recon"),
    images_aggregate_dir=os.path.join(PHASE_DATAVER_DIR, "aggregates", "images"),
    images_model_dir=os.path.join(PHASE_DATAVER_DIR, "models", "images"),
    images_results_dir=os.path.join(PHASE_DATAVER_DIR, "results", "images"),
)

FILES = Files(DIRECTORIES)

### Run

In [None]:
start_time = time.perf_counter()

In [None]:
### Load PD model

model_path = FILES.get_model_filepath(PD_MODEL_VERSION, "PROJ")
PD_model = load_model(PD_NETWORK_NAME, model_path, CUDA_DEVICE)

pd_model_load_time = time.perf_counter()

In [None]:
### Load scans and apply PD model

mat_path = FILES.get_projection_mat_filepath(PATIENT, SCAN, SCAN_TYPE)
gated_pt_path = FILES.get_projection_pt_filepath(PATIENT, SCAN, SCAN_TYPE, gated=True)
ng_pt_path = FILES.get_projection_pt_filepath(PATIENT, SCAN, SCAN_TYPE, gated=False)

_, cnn_mat = apply_model_to_projections(PD_model, SCAN_TYPE, mat_path, gated_pt_path, ng_pt_path, CUDA_DEVICE)

pd_model_apply_time = time.perf_counter()

In [None]:
### FDK

if SCAN_TYPE == 'HF':
    fdk = FDKHalf()(cnn_mat['prj'], get_geometry(), cnn_mat['angles'].flatten(), filter="hann", parker=True)
    fdk = torch.from_numpy(fdk).to(CUDA_DEVICE)
else:
    raise NotImplementedError

fdk_time = time.perf_counter()

In [None]:
### Load ID model

model_path = FILES.get_model_filepath(ID_MODEL_VERSION, 'IMAG')

ID_model = load_model(ID_NETWORK_NAME, model_path, CUDA_DEVICE)

id_model_load_time = time.perf_counter()

In [None]:
### Pass FDK through ID model

ng_pt_path = FILES.get_recon_filepath(PD_MODEL_VERSION, PATIENT, SCAN, SCAN_TYPE, gated=False)

apply_model_to_recons(ID_model, ng_pt_path, CUDA_DEVICE)

id_model_apply_time = time.perf_counter()

In [None]:
### Print timings
steps = [
    ("Load PD model", pd_model_load_time - start_time),
    ("Apply PD model", pd_model_apply_time - pd_model_load_time),
    ("FDK", fdk_time - pd_model_apply_time),
    ("Load ID model", id_model_load_time - fdk_time),
    ("Apply ID model", id_model_apply_time - id_model_load_time),
]

cumulative_time = 0
print(f"{'Step':<20}|{'Time (s)':<15}|{'Cumulative Time (s)':<20}")
print("-" * 55)
for step, duration in steps:
    cumulative_time += duration
    print(f"{step:<20}|{duration:<15.2f}|{cumulative_time:<20.2f}")