In [1]:
# --- Imports ---
import torch
import numpy as np
import inr_sos
import scipy.io as sio
from inr_sos import DATA_DIR, clear_cache, inspect_mat_fileheader, load_ic_batch, load_L_matrix
from inr_sos.utils.data import USDataset
from inr_sos.utils.config import ExperimentConfig
# --- Load Global Data ---
inverse_data_file = DATA_DIR + "/DL-based-SoS/train-VS-8pairs-IC-081225.mat"
param_grid_file =  DATA_DIR  + "/DL-based-SoS/forward_model_lr/grid_parameters.mat"
analytical_data = inr_sos.load_mat(DATA_DIR + "/DL-based-SoS/train_IC_10k_l2rec_l1rec_imcon.mat")
dataset = USDataset(inverse_data_file, param_grid_file)

In [2]:
indices = np.random.choice(len(dataset), size=2, replace=False)

In [3]:
def run_experiment(dataset: USDataset, indices : list[int], config: ExperimentConfig):
    import copy   
    from inr_sos.models.mlp import FourierMLP, ReluMLP
    from inr_sos.models.siren import SirenMLP
    from inr_sos.training.engines import optimize_stochastic_ray_batching, optimize_full_forward_operator, optimize_sequential_views
    from inr_sos.evaluation.pipeline import run_evaluation
    # --- 1. Define the Grid ---
    optimization_engines = {
        "Full_Matrix": optimize_full_forward_operator,
        "Sequential_SGD": optimize_sequential_views,
        "Ray_Batching": optimize_stochastic_ray_batching
    }
    
    model_architectures = {
        "ReluMLP": ReluMLP,
        "FourierMLP": FourierMLP,
        "SirenMLP": SirenMLP
    }
    
    # --- 2. Run the Gauntlet ---
    for method_name, engine_func in optimization_engines.items():
        print(f"\n{'='*60}")
        print(f" LAUNCHING OPTIMIZER: {method_name}")
        print(f"{'='*60}")
        
        for model_name, model_cls in model_architectures.items():
            print(f"\n---> Testing Backbone: {model_name}")
            
            # Clone config so they don't overwrite each other
            cfg_clone = copy.deepcopy(config)
            
            # This is where your W&B architecture magic happens!
            cfg_clone.experiment_group = method_name  
            cfg_clone.model_type = model_name         
            
            # Run the pipeline
            run_evaluation(
                dataset=dataset,
                model_class=model_cls, 
                train_engine=engine_func,
                config=cfg_clone,
                target_indices=indices,
                use_wandb=True # This ensures the loss curves stream to W&B
            )

In [5]:
config = ExperimentConfig(
        project_name="INR-SoS-Recon",
        in_features=2,
        hidden_features=256,
        hidden_layers=3,
        mapping_size=64,
        scale=0.6,        # For Fourier
        omega=30.0,       # For SIREN
        lr=1e-4,
        steps=2000,       # Used by Full Matrix & SGD
        epochs=150,       # Used by Ray Batching
        batch_size=4096,
        tv_weight=0
)

# experimental run 1
run_experiment(dataset,indices, config)


 LAUNCHING OPTIMIZER: Full_Matrix

---> Testing Backbone: ReluMLP


Loss (us^2): 0.0000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:24<00:00, 80.84it/s, method=optimize_full_forward_operator, model=ReluMLP]


0,1
Final CNR,▁
Final MAE,▁
Final SSIM,▁
Learning Rate,█████▇▇▇▇▇▇▆▆▆▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
MSE Loss,███▇▅▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Total Loss,█▇▅▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Final CNR,2.94534
Final MAE,4.22997
Final SSIM,0.56443
Learning Rate,0.0
MSE Loss,0.0
Total Loss,0.0


Loss (us^2): 0.0000: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:25<00:00, 78.90it/s, method=optimize_full_forward_operator, model=ReluMLP]


0,1
Final CNR,▁
Final MAE,▁
Final SSIM,▁
Learning Rate,█████████▇▇▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁
MSE Loss,██▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Total Loss,███▇▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Final CNR,2.03069
Final MAE,3.7383
Final SSIM,0.75418
Learning Rate,0.0
MSE Loss,1e-05
Total Loss,1e-05


Results successfully logged to local database: benchmark_results.csv


Error: You must call wandb.init() before wandb.log()