# 0. Imports

In [None]:
from pipeline.proj import load_projection_mat, reformat_sinogram, interpolate_projections, pad_and_reshape, divide_sinogram
from pipeline.aggregate_prj import aggregate_saved_projections
from pipeline.launcher import run_app
from pipeline.apply_model import apply_model_to_projections
from pipeline.infer3d import inference_3d
from pipeline.utils import ensure_dir
from pipeline.config import DEBUG, SCANS, DATA_DIR, data_version, PD_training_app, PD_epochs, PD_network_name, PD_model_name, PD_batch_size, PD_optimizer, ID_training_app, ID_epochs, ID_network_name, ID_model_name, ID_batch_size, ID_optimizer
import torch
import os
# TODO run FDK via: FFrecon_reconFDK(input_mat, output_mat); in file "FFrecon_fullFDK.m"
# TODO add DEBUG mode that prints out helpful info like shapes, etc.
# TODO add some kind of logging for the hyperparameters used in the each run

# 1. Data Preparation: projection interpolation

In [None]:
for patient, scan, scan_type in SCANS:
    # Load the projection data from the matlab files
    odd_index, angles, prj = load_projection_mat(patient, scan, scan_type)

    # Flip and permute to get it in the right format
    prj_gcbct, angles1 = reformat_sinogram(prj, angles)

    # Simulate ngCBCT projections
    prj_ngcbct_li = interpolate_projections(prj_gcbct, odd_index)

    # Split the projections into two halves so they are good dimensions for the CNN
    combined_gcbct = divide_sinogram(pad_and_reshape(prj_gcbct), v_dim=512 if scan_type == "HF" else 256)
    combined_ngcbct = divide_sinogram(pad_and_reshape(prj_ngcbct_li), v_dim=512 if scan_type == "HF" else 256)

    # Ensure the output directories exist
    g_dir = os.path.join(DATA_DIR, 'gated')
    ng_dir = os.path.join(DATA_DIR, 'ng')
    ensure_dir(g_dir)
    ensure_dir(ng_dir)
    
    # Save the projections
    torch.save(combined_gcbct, os.path.join(g_dir, f'{scan_type}_p{patient}_{scan}.pt')) # e.g., HF_p01_01.pt
    torch.save(combined_ngcbct, os.path.join(ng_dir, f'{scan_type}_p{patient}_{scan}.pt'))

# 2. Aggregate projections for train/val/test

In [None]:
# Aggregate and save HF projection data sets
aggregate_saved_projections('HF', 'train')
aggregate_saved_projections('HF', 'validation')
aggregate_saved_projections('HF', 'test')

# Aggregate and save FF projection data sets
aggregate_saved_projections('FF', 'train')
aggregate_saved_projections('FF', 'validation')
aggregate_saved_projections('FF', 'test')

# 3. Training PD CNN

In [None]:
run_app(PD_training_app, [f'--epoch={PD_epochs}', f'--network={PD_network_name}', f'--model_name={PD_model_name}', f'--data_ver={data_version}', f'--optimizer={PD_optimizer}', '--shuffle=True', f'--DEBUG={DEBUG}'])

# 5. Apply model to projections and reconstruction

In [None]:
prj_mix = apply_model_to_projections(patient_id, scan_id, 'HF', data_version, model_name, v_dim, odd_index, angles, prj_gcbct_tensor, prj_ngcbct_li_tensor)

# 3. Aggregate CT volumes

In [None]:
aggregate_ct_volumes(data_version, 'train', scan_type=0, augment=False)
aggregate_ct_volumes(data_version, 'train', scan_type=0, augment=True)
# repeat for validation/test

# 6. Inference on test scans for full 3D

In [None]:
vol = inference_3d(patient_id, scan_id, 'HF', data_version, model_name, 'tumor_location_panc.pt')