# 0. Imports

In [None]:
from .config import CUDA_DEVICE, DEBUG, SCANS, DATA_DIR, RESULT_DIR, data_version, PD_training_app, PD_epochs, PD_network_name, PD_model_name, PD_batch_size, PD_optimizer, PD_num_workers, ID_training_app, ID_epochs, ID_network_name, ID_model_name, ID_batch_size, ID_optimizer, ID_num_workers
from .proj import load_projection_mat, reformat_sinogram, interpolate_projections, pad_and_reshape, divide_sinogram
from .aggregate_prj import aggregate_saved_projections
from .launcher import run_app
from .apply_model import apply_model_to_projections, load_model
from .infer3d import inference_3d
from .utils import ensure_dir
import torch
import scipy.io
import os
# TODO run FDK via: FFrecon_reconFDK(input_mat, output_mat); in file "FFrecon_fullFDK.m"
# TODO add DEBUG mode that prints out helpful info like shapes, etc. (and be sure to set logging level so 'info' shows up too)
# TODO add some kind of logging for the hyperparameters used in the each run
# TODO go through the training code and make sure it is consistent with new pipeline
# TODO save everything as numpy arrays instead of torch tensors, and then convert to torch tensors when needed
# TODO add more options for training like learning rate, etc.
# TODO the user needs to pick which scans go in the training, val, and test sets

# 1. Data Preparation: projection interpolation

In [None]:
for patient, scan, scan_type in SCANS:
    # Load the projection data from the matlab files
    odd_index, angles, prj = load_projection_mat(patient, scan, scan_type)

    # Flip and permute to get it in the right format
    prj_gcbct, angles1 = reformat_sinogram(prj, angles)

    # Simulate ngCBCT projections
    prj_ngcbct_li = interpolate_projections(prj_gcbct, odd_index)

    # Split the projections into two halves so they are good dimensions for the CNN
    combined_gcbct = divide_sinogram(pad_and_reshape(prj_gcbct), v_dim=512 if scan_type == "HF" else 256)
    combined_ngcbct = divide_sinogram(pad_and_reshape(prj_ngcbct_li), v_dim=512 if scan_type == "HF" else 256)

    # Ensure the output directories exist
    g_dir = os.path.join(DATA_DIR, 'gated')
    ng_dir = os.path.join(DATA_DIR, 'ng')
    ensure_dir(g_dir)
    ensure_dir(ng_dir)
    
    # Save the projections
    torch.save(combined_gcbct, os.path.join(g_dir, f'{scan_type}_p{patient}_{scan}.pt')) # e.g., HF_p01_01.pt
    torch.save(combined_ngcbct, os.path.join(ng_dir, f'{scan_type}_p{patient}_{scan}.pt'))

# Free up memory
del odd_index, angles, prj, prj_gcbct, angles1, prj_ngcbct_li, combined_gcbct, combined_ngcbct

# 2. Aggregate projections for train/val/test

In [None]:
# Directory for aggregated data saving
agg_dir = os.path.join(DATA_DIR, "agg")
ensure_dir(agg_dir)

# Aggregate and save projection data sets
for scan_type in ['HF', 'FF']:
    for sample in ['train', 'validation', 'test']:
        prj_gcbct, prj_ngcbct = aggregate_saved_projections(scan_type, sample)
        torch.save(prj_gcbct, os.path.join(agg_dir, f"{scan_type}_{sample}_gated.pt"))
        torch.save(prj_ngcbct, os.path.join(agg_dir, f"{scan_type}_{sample}_ng.pt"))

# Free up memory
del prj_gcbct, prj_ngcbct

# 3. Training PD CNN

In [None]:
run_app(PD_training_app, [f'--epoch={PD_epochs}', f'--network={PD_network_name}', f'--model_name={PD_model_name}', f'--data_ver={data_version}', f'--optimizer={PD_optimizer}', '--shuffle=True', f'--DEBUG={DEBUG}', f'--batch_size={PD_batch_size}'], f'--num_workers={PD_num_workers}')

# 4. Apply PD model to all nonstop-gated sinograms

In [None]:
# Load the trained PD model onto the GPU
PD_model = load_model(PD_network_name, PD_model_name, device=torch.device(CUDA_DEVICE))

for patient, scan, scan_type in SCANS:
    # Get the matlab dicts for the ground truth and CNN projections
    g_mat, cnn_mat = apply_model_to_projections(patient, scan, scan_type, PD_model)

    # Save the ground truth and CNN projections
    scipy.io.savemat(os.path.join(RESULT_DIR, f'{scan_type}_p{patient}_{scan}_gated.mat'), g_mat)
    scipy.io.savemat(os.path.join(RESULT_DIR, f'{scan_type}_p{patient}_{scan}_ng.mat'), cnn_mat)

# Free up memory
del PD_model, g_mat, cnn_mat

# 5. TODO: FDK

# 6. Aggregate CT volumes

In [None]:
aggregate_ct_volumes(data_version, 'train', scan_type=0, augment=False)
aggregate_ct_volumes(data_version, 'train', scan_type=0, augment=True)
# repeat for validation/test

# 7. Train ID CNN

In [None]:
# TODO

# 8. Inference on test scans for full 3D

In [None]:
vol = inference_3d(patient_id, scan_id, 'HF', data_version, model_name, 'tumor_location_panc.pt')