In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["JAX_LOG_COMPILES"] = "0"

import sys
import jax.numpy as np
from jax import vmap

import numpy as onp
import pandas as pd
import matplotlib
matplotlib.use('Agg')  # disables GUI backend
import matplotlib.pyplot as plt

from jax import jit, random
from inrmri.new_radon import get_weight_freqs, make_forward_radon_operator
from inrmri.basic_nn import weighted_loss 
from inrmri.utils import denoise_loss_batch, l1_loss, tikhonov_loss, tv_loss, create_center_mask
import optax 
from inrmri.advanced_training import OptimizerWithExtraState, train_with_updates_ms_nspokeswise_select
import itertools
from inrmri.utils_rdls import get_shedule, get_predim_direct_ms
from inrmri.utils_rdls import create_folder, save_frames_as_gif_with_pillow
from inrmri.utils_rdls import get_varying_keys, config_to_foldername
from inrmri.utils_rdls import plot_curves, plot_curves_and_mins, plot_multi_axis, multiple_images_visualization
from inrmri.dip import MS_TD_DIP_Net, multi_slice_circle_generator
from inrmri.utils_rdls import safe_normalize, get_center, pad_axis_to_length
from inrmri.utils import clear_jax_memory # parse_slices_list
import ast

In [2]:
# volunteer parameters
volunteer = 'MP'
dataset = 'DATA_0.55T'
total_slices = 8
base_path = "C:\\Users\\Legion\\Documents\\Repositories\\NF-cMRI-rafa\\datasets\\"

In [3]:
# --- DATA ---
base_folder           = base_path + dataset + '\\' + volunteer + '\\'
train_data_folder     = base_folder + 'traindata\\'
model_path            = base_folder + 'stDIP\\'
create_folder(model_path, reset=False)

# static parameters
dummy_iterations = 0 # iterations that are not shown in metrics
num_frames       = 30
saturation       = 0.3
n_coils          = 15

# training parameters
slices_list      = [1,2,3,4,5,6,7,8]
n_slices         = len(slices_list)
s_idxs           = [s_idx - 1 for s_idx in slices_list]
perform_plots    = True
val_frames       = np.array([0]  , dtype=np.int32 )

In [4]:
# --- PARMETERS ---
key = random.PRNGKey(0)
key_net, key_params, key_train, key_latent = random.split(key,4) # keys for reproducibility
key_eval_list = random.split(key, total_slices)

experiment_params = {              
    'N':                    [256],                 # elements of the readout, lenght of the spoke
    'mapnet_layers':        [16],                  # MapNet: layers
    'cnn_latent_shape':     [8],                   # Generatior: input shape, is the trained latent representation by MapNet
    'levels':               [4],                   # Generatior: levels of up sampling
    'features':             [16],                  # Generatior: features
    'iter':                 [500],                # number of iterations
    'bs':                   [1],                   # Number of cardiac phases per iteration (real batch size: bs * nspokes * nslices)
    'addConst':             [False],               # if add a constant to the fixed manifold
    'str_filter':           ['ramp'],              # frequency weighting
    'denoise_type':         ['tv'],                # Regularization ('tv', 'l1', 'tikhonov')
    'lambda':               [0],                   # lambda for regularization
    'lr_schedule':          ['constant_schedule'], # learning rate scheduler
    'lr_init_value':        [5e-3],                # learninig rate start (or constant value depending on the scheduler)
    'lr_end':               [1e-3],                # learning rate end (optional depending on the scheduler)
    'lr_transition_steps':  [3000],           # how many steps to decay over
    'lr_decay_rate':        [0.90],           # (parameter for 'exponential_decay' scheduler)
    'lr_power':             [20],             # (parameter for 'polynomial_schedule' scheduler)
    'metric_step':          [20],             # step of iterations to compute metrics
    'window_size':          [5],              # number of elements in windows to compute window metrics (variance)
    'nspokes':              [1],              # number of spokes used of each frame in the nspokes-wise training
    'select_by':           ['loss', 'mean_var'],  # The used metric for selecting the best parameters
    'nr_iqm':               [True],           # No-reference image quality metric (WMV)
    'fr_iqm':               [True],           # Full-reference image quality metric (PSNR, SSIM, AP)
    'debug':                [False],          # If debug, multiple shapes and time should be printed
    'data_ponderator':      [1],              # currently the data is normalized, so this ponderator would ponderate a normalized data
    'latent_r':             [1],              # radius of the latent representation
    'latent_z':             [10],             # the latent representation creates circles between [-Z_max, Z_max]
}
# Generate list of parameter combinations as dictionaries
keys = list(experiment_params.keys())
combinations = [dict(zip(keys, values)) for values in itertools.product(*experiment_params.values())]
# Identify the varying keys
varying_keys = get_varying_keys(combinations)
experiment_name = '_'.join(varying_keys)
# plot variables
xlim_botton = dummy_iterations
xlim_top    = int(np.max(np.array(experiment_params['iter'])))

In [5]:
# --- create the multislice dataset ---
Y_data_list      =   []
X_data_list      =   []
csm_list         =   []
spclim_list      =   []
hollow_mask_list =   []
recon_fs_list    =   []
exp_folder_path_list  =   []
for  slice_num in slices_list:
    dataset_name = 'slice_' + str(slice_num) + '_' + str(total_slices) +'_nbins' + str(num_frames)
    path_file = train_data_folder + dataset_name + '.npz'
    data = onp.load(    path_file         )
    y_data_item = data['Y_data']
    x_data_item = data['X_data']
    csm_item    = data['csm']
    spclim_item = data['spclim']
    y_data_item = data['Y_data']
    if dataset == 'DATA_0.55T':
        hollow_mask_item  = data['hollow_mask']    
    else:
        hollow_mask_item  = data['hollow_mask_computed']   
    # Check coils
    y_data_item         = pad_axis_to_length(y_data_item, 1, n_coils, pad_value=0+0j)
    csm_item            = pad_axis_to_length(csm_item, 0, n_coils, pad_value=0+0j)
    # append items in lists
    Y_data_list.append(       y_data_item    )
    X_data_list.append(       x_data_item    )  
    csm_list.append(          csm_item      )
    spclim_list.append(       spclim_item   ) 
    hollow_mask_list.append(  hollow_mask_item   ) 
    recon_fs    =  data['recon_fs']
    recon_fs    = get_center(recon_fs) 
    recon_fs    = safe_normalize(recon_fs) 
    recon_fs_list.append(   recon_fs[:,:,val_frames]   )
    save_folder  = model_path + dataset_name.replace(".", "_") + '\\'
    create_folder(save_folder, reset=False )
    save_frames_as_gif_with_pillow(save_folder, recon_fs, filename='recon_fs', vmax=1, saturation=saturation, fps=30)
    exp_folder_path = save_folder + experiment_name + '\\'
    exp_folder_path_list.append(  exp_folder_path  )
    create_folder(exp_folder_path, reset=False)
hollow_mask_array = np.stack(hollow_mask_list, axis=0)

## --- Radon operator list ---
radon_operator_list = []
for jj in range(len(csm_list)):
    radon_operator_list.append( make_forward_radon_operator(csm_list[jj], spclim_list[jj]) )

In [6]:
# --- Combinations LOOP ---
training_names_list = []
experiment_records = {}
for s_idx in s_idxs:
    experiment_records[s_idx] = []
log_list = []

for i, h_params in enumerate(combinations):
    # add information
    h_params['NFRAMES']    = num_frames
    h_params['n_slices']   = n_slices
    h_params['val_frames'] = val_frames
    # training name
    training_name = config_to_foldername(h_params,varying_keys)
    training_names_list.append(  training_name  )
    training_folder_list = []
    for s_idx in s_idxs:
        training_folder = exp_folder_path_list[s_idx] + training_name + '/'
        create_folder(training_folder, reset=False)
        training_folder_list.append(  training_folder  )
    # model config
    CONFIG_NET = {
        'mapnet_layers':    [h_params['mapnet_layers'], h_params['mapnet_layers'],],
        'cnn_latent_shape': (h_params['cnn_latent_shape'],h_params['cnn_latent_shape']),
        'levels':            h_params['levels'],
        'features':          h_params['features']
    }
    # data ponderation
    Y_data_list_current     =  []
    recon_fs_list_current   =  []
    for i in range(len(Y_data_list)):
        Y_data_list_current.append(     Y_data_list[i]    *  h_params['data_ponderator']   )
        recon_fs_list_current.append(   recon_fs_list[i]  ) #*  h_params['data_ponderator']   )

    # define the network
    net = MS_TD_DIP_Net(
        nframes = h_params['NFRAMES'],
        addConst = h_params['addConst'], 
        key_latent = key_latent,
        n_slices = h_params['n_slices'],
        latent_generator=multi_slice_circle_generator,
        imshape = [h_params['N'],h_params['N']],
        radius = h_params['latent_r'],
        z_min = -h_params['latent_z'],
        z_max = h_params['latent_z'],
        **CONFIG_NET
    )
    params = net.init_params(key_params) 
    WEIGHT_FREQS = get_weight_freqs(  h_params['N'], h_params['str_filter']  )
    weights        = (1. + WEIGHT_FREQS)[None, None, :]
    schedule       = get_shedule(h_params)
    optimizer      = OptimizerWithExtraState(optax.adam(learning_rate=schedule))
    lambda_denoise_reg = h_params['lambda']
    center_mask    = create_center_mask((h_params['N'],h_params['N']))
    if h_params['denoise_type'] == 'tv':
        den_loss = tv_loss
    elif h_params['denoise_type'] == 'tikhonov':
        den_loss = tikhonov_loss
    elif h_params['denoise_type'] == 'l1':
        den_loss = l1_loss
    else:
        raise ValueError(f"Unknown denoise_type: {h_params['denoise_type']}")

    def loss(params, X_list, Y_list, index_frames, key):
        n_slices = len(X_list)
        total_loss = 0.0
        def per_spoke_loss(im, x_spoke, y_spoke, radon_operator):
            # (im): (1, px, py), (x_spoke): scalar, (y_spoke): (cmap, nx)
            alphas_frame = x_spoke[0:1]  # shape (1,)
            y_data = y_spoke[..., 0]     # shape (cmap, nx)
            pred_kspace = radon_operator(im, alphas_frame)
            loss_val = weighted_loss(pred_kspace, y_data, weights)
            return loss_val
    
        def per_frame_loss(im, X_frame, Y_frame, radon_operator):
            # X_frame: (nspokes, features), Y_frame: (nspokes, cmap, nx, 1)
            spoke_loss_fn = lambda x_sp, y_sp: per_spoke_loss(im, x_sp, y_sp, radon_operator)
            loss_vals = vmap(spoke_loss_fn)(X_frame, Y_frame)
            return np.sum(loss_vals)
    
        for slice_idx in range(n_slices):
            X = X_list[slice_idx]              # (frames, spokes, features)
            Y = Y_list[slice_idx]              # (frames, spokes, cmap, nx, 1)
            radon_operator = radon_operator_list[slice_idx]
    
            ims, updates = net.train_forward_pass(params, key, index_frames, slice_idx)  # (frames, px, py)
    
            def frame_loss_fn(im, Xf, Yf):
                im_exp = im[None, ...]  # (1, px, py)
                frame_loss = per_frame_loss(im_exp, Xf, Yf, radon_operator)
                frame_loss += lambda_denoise_reg * denoise_loss_batch(im_exp[0, :, :][..., None], den_loss, center_mask)
                return frame_loss
            
            total_loss += np.sum(vmap(frame_loss_fn)(ims, X, Y))
        total_loss /= n_slices
        return total_loss, updates

    # this function uses the model to reconstruct frames for metrc computations and visualizations during training
    def recon_cine(params, t_idx, s_idx, hollow_mask, key):
        # Forward pass — should already output JAX arrays
        ims, _ = net.train_forward_pass(params, key, t_idx, s_idx)     # (frames, px, py)
        # Efficient axis permutation — avoid moveaxis in JAX hot loops
        ims = ims.transpose(1, 2, 0)  # (px, py, frames)
        # Apply mask — (1 - hollow_mask) is broadcasted safely
        ims = ims * (1.0 - hollow_mask)[..., None]
        # Crop center — ensure JAX-native `get_center`
        ims = get_center(ims)  # Output: (px//2, py//2, frames)
        # Normalize — assume JAX-safe operation
        ims = safe_normalize(ims)
        return ims
        
    # training
    loss_fn = jit(loss)
    recon_fn = jit(recon_cine)
    if h_params['select_by'] == 'mean_var':
        results = train_with_updates_ms_nspokeswise_select(loss_fn, X_data_list, Y_data_list_current, params, optimizer, key_train, h_params, recon_cine=recon_fn, hollow_mask_array=hollow_mask_array, val_slices=s_idxs, reference_list = recon_fs_list_current, debug=h_params['debug'])
    elif h_params['select_by'] == 'loss':
        results = train_with_updates_ms_nspokeswise_select(loss_fn, X_data_list, Y_data_list_current, params, optimizer, key_train, h_params, debug=h_params['debug'])
    # --- memory management (post-training)---
    del optimizer, schedule, WEIGHT_FREQS, weights
    clear_jax_memory()
    # save records
    for s_idx in s_idxs:
        record = {key: h_params[key] for key in varying_keys}
        record['duration [s]']  = results['time_s']
        record['duration [min]']  = results['time_s'] / 60
        record['duration']      = results['time_str']
        # First dictionary-level metrics
        for var_name in ['it', 'loss', 'mean_var', 'time']:
            if var_name in results.get('best', {}):
                record[var_name] = results['best'][var_name]
        
        # Second-level metrics indexed by s_idx
        best = results.get('best', {})
        if isinstance(best, dict) and s_idx in best:
            sub_metrics = best[s_idx]
            if isinstance(sub_metrics, dict):
                for var_name in ['var', 'ssim', 'psnr', 'ap']:
                    if var_name in sub_metrics:
                        record[var_name] = sub_metrics[var_name]
        record['experiment_folder'] = exp_folder_path_list[s_idx]  # optional: keep the folder name
        record['training_name']     = training_name
        record['training_folder']   = exp_folder_path_list[s_idx] + training_name + '/'
        experiment_records[s_idx].append(record)
    # log 
    if perform_plots:
        log_list.append(  results['log_queue']  )
    # save best image
    for s_idx in s_idxs:
        best_recon   =  get_predim_direct_ms(net, hollow_mask_array[s_idx], key_eval_list[s_idx], h_params, s_idx, results['best']['params'] )
        np.savez(training_folder_list[s_idx] + 'best_recon' + '.npz', **{'best_recon':best_recon,  **h_params, **record})
        np.savez(training_folder_list[s_idx] + 'log_queue'  + '.npz', **results['log_queue'][s_idx])
        save_frames_as_gif_with_pillow(training_folder_list[s_idx], best_recon, filename='best_recon', vmax=1, saturation=saturation, fps=num_frames)
        if perform_plots:
            log_queue = results.get('log_queue', {})
        
            def safe_get_metric(metric):
                return (
                    s_idx in log_queue and 
                    isinstance(log_queue[s_idx], dict) and 
                    metric in log_queue[s_idx]
                )
        
            # NR-IQM plots and images
            if h_params.get('nr_iqm', False):
                if 'it' in log_queue and safe_get_metric('var'):
                    plot_multi_axis(
                        log_queue[s_idx]['it'],
                        [log_queue[s_idx]['var']],
                        ['VAR'],
                        xlim=(xlim_botton, h_params['iter']),
                        save_path=training_folder_list[s_idx] + 'nr_iqm_curves.png'
                    )
        
                if 'best' in results and 'it' in results['best']:
                    multiple_images_visualization(
                        [recon_fs, best_recon],
                        ['GT', 'VAR (it' + str(results['best']['it']) + ')'],
                        frame=0,
                        save_path=training_folder_list[s_idx] + 'nr_iqm_images.png',
                        saturation=saturation
                    )
        
            # FR-IQM curves
            if h_params.get('fr_iqm', False):
                metrics = ['psnr', 'ssim', 'ap']
                available_metrics = [m for m in metrics if safe_get_metric(m)]
        
                if 'it' in log_queue and available_metrics:
                    plot_multi_axis(
                        log_queue[s_idx]['it'],
                        [log_queue[s_idx][m] for m in available_metrics],
                        [m.upper() for m in available_metrics],
                        xlim=(xlim_botton, h_params['iter']),
                        save_path=training_folder_list[s_idx] + 'fr_iqm_curves.png'
                    )
    # --- memory management (combinatory-level) ---
    del net, results, best_recon
    clear_jax_memory()
    plt.close('all')


[TRACING] Recompiling tDIP...


train iter:   0%|          | 0/501 [00:00<?, ?it/s]

[TRACING] Recompiling tDIP...
256


train iter:   0%|          | 1/501 [00:24<3:22:51, 24.34s/it]

[TRACING] Recompiling tDIP...


train iter: 100%|██████████| 501/501 [07:57<00:00,  1.05it/s]


[TRACING] Recompiling tDIP...


  plt.show()


[TRACING] Recompiling tDIP...


train iter:   0%|          | 0/501 [00:00<?, ?it/s]

[TRACING] Recompiling tDIP...
256
[TRACING] Recompiling tDIP...


train iter: 100%|██████████| 501/501 [08:59<00:00,  1.08s/it]


[TRACING] Recompiling tDIP...


In [12]:
for s_idx in s_idxs:
    # save dataset
    df = pd.DataFrame(experiment_records[s_idx])
    df.to_csv(exp_folder_path_list[s_idx] + "experiment_records.csv", index=False, sep=';')
    if perform_plots:
        log_list_slice = []
        for com_idx in range(len(log_list)):
            item = log_list[com_idx][s_idx]
            log_list_slice.append(item)
    
        def variable_exists(log_list_slice, var_name):
            return all(var_name in item for item in log_list_slice)
    
        # Plot 'loss' curve if available
        if variable_exists(log_list_slice, 'loss'):
            image_path = exp_folder_path_list[s_idx] + 'curves_loss.png'
            plot_curves_and_mins(log_list_slice, training_names_list, 'loss', save_path=image_path, xlim=(xlim_botton, xlim_top))
    
        # Plot 'var' curve if required and available
        if h_params['nr_iqm'] and variable_exists(log_list_slice, 'var'):
            image_path = exp_folder_path_list[s_idx] + 'curves_var.png'
            plot_curves_and_mins(log_list_slice, training_names_list, 'var', save_path=image_path, xlim=(xlim_botton, xlim_top))
    
        # Plot 'ssim', 'psnr', 'ap' if required and available
        if h_params['fr_iqm']:
            for r_iqm in ['ssim', 'psnr', 'ap']:
                if variable_exists(log_list_slice, r_iqm):
                    image_path = exp_folder_path_list[s_idx] + f'curves_{r_iqm}.png'
                    plot_curves(log_list_slice, training_names_list, r_iqm, save_path=image_path, xlim=(xlim_botton, xlim_top))

  plt.show()


## Post-processing

In [13]:
# after all, create a overall dataset
metrics_for_analysis = ['var', 'ssim', 'psnr', 'ap']
metrics_list_for_analysis = ['log_var', 'log_ssim', 'log_psnr', 'log_ap', 'log_loss', 'log_time']
stay_the_same = ['duration [s]', 'duration [min]', 'duration', 'it', 'log_it', 'training_name']
variables = varying_keys

In [14]:
def process_group(group):
    result = {}

    # 1. Scalar metrics → compute mean and std
    for metric in metrics_for_analysis:
        result[f'{metric}_mean'] = group[metric].mean()
        result[f'{metric}_std']  = group[metric].std()

    # 2. List-valued metrics → parse, stack, compute mean/std per position
    for col in metrics_list_for_analysis:
        try:
            lists = group[col].apply(ast.literal_eval).tolist()
            arr = np.array(lists, dtype=np.float32)
            result[f'{col}_mean'] = arr.mean(axis=0).tolist()
            result[f'{col}_std']  = arr.std(axis=0).tolist()
        except Exception as e:
            print(f"Error processing {col}: {e}")
            result[f'{col}_mean'] = np.nan
            result[f'{col}_std']  = np.nan

    # 3. Keep the first value of "stay_the_same" columns
    for col in stay_the_same:
        result[col] = group[col].iloc[0]

    return pd.Series(result)

In [15]:
# create general dataset
count = 0
for i in slices_list:
    experiment_path = model_path + 'slice_'+str(i)+'_'+str(total_slices)+'_nbins30' + '/'+experiment_name   
    csv_path = experiment_path + "/experiment_records.csv"
    if count == 0:
        df_slices = pd.read_csv(csv_path, delimiter=';')
        df_slices['slice'] = i
        count+=1
    else:
        df = pd.read_csv(csv_path, delimiter=';')
        df['slice'] = i
        df_slices = pd.concat([df_slices, df])

In [16]:
df_total = df_slices.reset_index(drop=True)
for index, row in df_total.iterrows():
    log_path = row['training_folder'] + 'log_queue.npz'
    log_queue = np.load(log_path)
    for key  in log_queue.keys():    
        df_total.at[index, 'log_'+key] = str(log_queue[key].tolist())
df_total.to_csv(model_path + experiment_name + ".csv" , index=False, sep=';')
summary_df = df_total.groupby(variables).apply(process_group).reset_index()
summary_df.to_csv(model_path + experiment_name + "_summary.csv" , index=False, sep=';')

Error processing log_var: malformed node or string: nan
Error processing log_ssim: malformed node or string: nan
Error processing log_psnr: malformed node or string: nan
Error processing log_ap: malformed node or string: nan


  summary_df = df_total.groupby(variables).apply(process_group).reset_index()
