## Imports

In [None]:
import itertools
from pathlib import Path
import yaml

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

import red_psm_utils
import plots
import red_psm_models
import red_psm_train
import red_psm_metrics

## Flags & Hyperparameters

In [None]:
P = 256  # Total number of time instances = P
ANG_PERIOD = None  # Number of distinct views. If no period == None
MOTION = 'piecewise_affine_transform'  # motion types: ['piecewise_affine_transform', 'cardiac', 'LLNL']
OBJ_TYPE = 'walnut'  # available objects: ['walnut', 'material', 'cardiac_rep_sq', 'LLNL']
NUM_INST = P  # Total number of projections for the full projection data
SPATIAL_DIM = 128  # Spatial dim of the recon: d x d, and proj: d, default: 128, LLNL data: 80
VIEW_ANG_SCH = 'bit_reversal'  # ['linear', 'random', 'bit_reversal', 'golden_angle']
PI_SYMM = True  # Exploit projection pi-symm
ang_range = np.pi if PI_SYMM else 2 * np.pi  # the range of view angles
NOISE_STD = 1e-2 / 2  # Measurement noise std
GEN_G_ALL = False  # If True: Generate noisy full set of measurements, False: load existing data
SAVE_NOISY_MEAS = False  # If True: save noisy projs, False: load saved ones
FWD_MODEL_PATH = 'forward_model/'

## Load variables

In [None]:
# Compute view angle scheme
theta_dict = {}
[theta_dict['linear'], theta_dict['random'], theta_dict['bit_reversal'],
 theta_dict['golden_angle']] = red_psm_utils.generate_theta(
    P, ang_range=ang_range, period=ANG_PERIOD)
theta_exp = theta_dict[VIEW_ANG_SCH]
theta_exp_pi_symm = theta_exp + np.pi

# Load phantom
f = red_psm_utils.load_f(OBJ_TYPE, MOTION, SPATIAL_DIM, P)

# Compute subsampled measurements from the true object
g_radon = red_psm_utils.obtain_projections(f, theta_exp, P)
g_radon_pi_symm = red_psm_utils.obtain_projections(f, theta_exp_pi_symm, P)
g_radon_symm_long = red_psm_utils.construct_pi_symm_g(
    g_radon, g_radon_pi_symm, ang_range)

# Compute/Load full set of measurements
PATH = 'data/'
ADD_PATH = '_mean_corr' if OBJ_TYPE == 'material' else ''
if GEN_G_ALL:
    f_pol_radon, f_true_recon_radon = red_psm_utils.generate_f_pol(
        f, P, SPATIAL_DIM, NUM_INST, theta_exp, OBJ_TYPE,
        ANG_PERIOD, PATH, save=True, add_path=ADD_PATH)
else:
    STR_PERIOD = '' if ANG_PERIOD is None else '_' + str(ANG_PERIOD)
    f_pol_radon = np.load(PATH + 'f_full_meas_%s%s_%d%s.npy' %(
        OBJ_TYPE, ADD_PATH, P, STR_PERIOD))

# Compute/Load measurements with AWGN with std=noise_std
if SAVE_NOISY_MEAS:
    g_radon_noisy, g_radon_symm_long_noisy = red_psm_utils.add_meas_noise(
        g_radon, g_radon_symm_long, NOISE_STD, f_pol_radon.max(), ang_range, 
        OBJ_TYPE, P, ANG_PERIOD, PATH, save=SAVE_NOISY_MEAS, add_path=ADD_PATH)    
elif ang_range == np.pi:
    STR_PERIOD = '' if ANG_PERIOD is None else '_' + str(ANG_PERIOD)
    g_radon_symm_long_noisy = np.load(
        PATH + 'g_radon_symm_long_noisy_%s%s_%d%s_noise_std_%.2e.npy' %(
            OBJ_TYPE, ADD_PATH, P, STR_PERIOD, NOISE_STD))
else:
    raise NotImplementedError

# Load forward operator
R, R_cuda = red_psm_utils.load_radon_op(
    PI_SYMM, SPATIAL_DIM, P, period=ANG_PERIOD, path=FWD_MODEL_PATH)

## Display $f$, view angle acquisition scheme $\theta(t)$, and projections $g_t$

In [None]:
plots.display_inputs(f, theta_exp, P, num_frames=8)

# RED-PSM

## Load training parameter configuration

In [None]:
with open("config/red_psm_train_cfg_P%d.yaml" %P, "r") as file:
    params = yaml.load(file, Loader=yaml.FullLoader)
print(params)

## Run RED-PSM with various configurations

In [None]:
print('RED-PSM: Simultaneous PSM updates')
results_per_config = []

# Run RED-PSM for each configuration defined in the yaml file
for tempBasis, K, z_dim, beta, xi, chi, lmbda, pSize, pStride, num_layers in list(
    itertools.product(
        params['temporal_basis'], params['K_sweep'], params['d_sweep'],
        params['beta_sweep'], params['xi_sweep'], params['chi_sweep'], 
        params['lmbda_sweep'], params['pSize_sweep'], params['pStride_sweep'], 
        params['num_layers_sweep'])):
    print('Temp basis type:%s K:%d d:%d beta:%.1e xi:%.1e chi:%.1e lambda:%.1e'
          '\nDenoiser type:%s Denoiser patch size:%s patch stride:%s' %(
              tempBasis, K, z_dim, beta, xi, chi, lmbda,
              params['denoiser_type'], pSize, pStride))
    
    # Load RED denoiser & patchifier if denoiser_type == patch-based
    model_denoiser, patchifier = red_psm_utils.denoising_network_loader(
        'dncnn', denoiser_type=params['denoiser_type'], pSize=pSize,
        pStride=pStride, num_layers=num_layers, spatial_dim=SPATIAL_DIM,
        obj_type=OBJ_TYPE, num_channels=params['num_channels'],
        noise_est_type=params['noise_est_type'],
        epochs=params['denoiser_epochs'])
    
    # Initialize spatial and temporal basis functions, and the object f
    model = red_psm_models.RedPsm(
        P, K, SPATIAL_DIM, z_dim, tempBasis, obj_type=OBJ_TYPE,
        temp_init_type=params['temp_init_type'],
        f_init_type=params['f_init_type'], 
        spatial_init_type=params['spatial_init_type'],
        temporal_mode=params['temporal_mode'],
        noise_std=NOISE_STD, mask=True, rep=params['rep']).cuda()
    
    # Initialize the scaled dual variable gamma
    gamma_est = torch.zeros(model.f_est.shape).cuda()

    # Initialize the optimizer/scheduler for the first primal ADMM step
    print('Optimizing for: Spatial & temporal basis fcts')
    optimizer = optim.Adam([model.spatial_basis_fcts, model.temporal_fcts],
                           lr=params['lr_primal'])
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=params['num_primal_iter']*params['num_epoch'],
        gamma=0.5)
    if params['criterion'] == 'MSE':
        criterion = nn.MSELoss()
    else:
        raise NotImplementedError

    # Initialize recon metrics per epoch
    metrics = {}
    [metrics['PSNR_f'], metrics['PSNR_psm'], metrics['MAE_f'],
     metrics['MAE_psm'], metrics['SSIM_f'], metrics['SSIM_psm'],
     metrics['HFEN_f'], metrics['HFEN_psm']] = [[] for _ in range(8)]
    best_psnr_f = 1e-1
    best_f_est = None

    Nepoch = tqdm(range(params['num_epoch']),
                  desc=('PSNR:%.1e/%.1e SSIM:%.2e/%.2e MAE:%.2e/%.2e '
                        'HFEN:%.2e/%.2e' %tuple([0 for _ in range(8)])),
                  leave=True, ncols=160, colour='green')
    for epoch in Nepoch:
        # ADMM primal PSM step
        [loss_epoch, f_psm_est, g_f_est, spatial_basis_fcts, temporal_fcts,
         psi_mtx] = red_psm_train.learn_psm_bases(
            model, optimizer, scheduler, criterion, theta_exp,
            g_radon_symm_long_noisy, params['num_primal_iter'], P,
            SPATIAL_DIM, R=R_cuda, temporal_mode=params['temporal_mode'],
            gamma_est=gamma_est, beta=beta, xi=xi, chi=chi, rep=params['rep'])
        with torch.no_grad():
            # ADMM primal f step
            for _ in range(params['num_red_iter']):
                model.f_est = red_psm_train.f_update(
                    lmbda, beta, model_denoiser, patchifier, model.f_est,
                    f_psm_est, gamma_est, params['denoiser_type'], P)
            # ADMM dual step
            gamma_est = red_psm_train.dual_variable_update(
                gamma_est, model.f_est, f_psm_est)

        if not epoch % 100:
            plots.plot_psm_basis_fcts(psi_mtx, spatial_basis_fcts, K)

        # Compute reconstruction accuracy metrics
        if epoch % 10 == 0:
            if epoch % 100 == 0 and epoch != 0:
                plots.train_visualization(
                    loss_epoch, model.f_est.detach().cpu(),
                    f_psm_est.detach().cpu(), gamma_est.detach().cpu(), f, P//8)

            # Update accuracy metrics
            metrics, best_psnr_f, best_f_est = red_psm_metrics.update_metrics(
                f[..., ::params['rep']],
                model.f_est.permute(1, 2, 0).detach().cpu().numpy(),
                f_psm_est.permute(1, 2, 0).detach().cpu().numpy(),
                metrics, best_psnr_f, best_f_est)

            Nepoch.set_description(
                'PSNR:%.1e/%.1e SSIM:%.2e/%.2e MAE:%.2e/%.2e HFEN:%.2e/%.2e' %(
                    metrics['PSNR_f'][-1], metrics['PSNR_psm'][-1],
                    metrics['SSIM_f'][-1], metrics['SSIM_psm'][-1],
                    metrics['MAE_f'][-1], metrics['MAE_psm'][-1],
                    metrics['HFEN_f'][-1], metrics['HFEN_psm'][-1]))

    train_params, train_data = {}, {}
    [train_params['K'], train_params['z_dim'], train_params['temporal_basis'],
     train_params['beta'], train_params['lmbda'], train_params['xi'],
     train_params['chi'], train_params['pSize'], train_params['pStride'],
     train_params['num_layers']] = [
        K, z_dim, tempBasis, beta, lmbda, xi, chi, pSize, pStride, num_layers]
    [train_data['f_est'], train_data['spatial_basis_fcts'],
     train_data['temporal_latent_fcts'], train_data['psi_mtx']] = [
        best_f_est, model.spatial_basis_fcts.detach().cpu().clone(),
        model.temporal_fcts.detach().cpu() if model.temporal_fcts is None else None,
        psi_mtx]

    results_per_config.append(
        [max(metrics['PSNR_f']), metrics, loss_epoch, train_params, train_data,
         params['denoiser_type']])
    
    del model_denoiser
    del model

In [None]:
psm_update_type = 'simultaneous_psm_updates'
plots.plot_and_save_red_psm_results(
    params, OBJ_TYPE, results_per_config, P, psm_update_type,
    noise_std=NOISE_STD, ang_period=ANG_PERIOD, save_fig=False)