In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
import sys
import pickle
from pathlib import Path

import torch
import wandb
import optuna
from optuna.trial import TrialState

from config import Environment, TrainConfig
from denoising.train import prepare_training
from denoising.utils import seed_everything
from denoising.models.utils import count_parameters

In [3]:
CWD = Path.cwd()
env = Environment(_env_file=CWD / '../env')
wandb.login(key=env.wandb_api_key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/d.nesterov/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdmitrylala[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [6]:
def define_train_cfg(trial) -> TrainConfig:
    n_layers = trial.suggest_int('n_layers', 1, 20, log=True)
    hidden_channels = trial.suggest_int('hidden_channels', 4, 64, log=True)
    n_modes = trial.suggest_int('n_modes', 4, 32, log=True)
    lifting_channel_ratio = trial.suggest_int('lifting_channel_ratio', 2, 32, log=True)
    projection_channel_ratio = trial.suggest_int('projection_channel_ratio', 2, 32, log=True)
    lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)

    print(
        f'Running with {n_layers=}, {hidden_channels=}, {n_modes=}, {lifting_channel_ratio=}, {projection_channel_ratio=} {lr=}'
    )

    cfg = TrainConfig(
        # Datasets params
        train_dset='mri_pm_train',
        test_dset='mri_pm_test',
        train_batch_size=32,
        test_batch_size=32,
        # Model params
        name_model='mri-hno-v2',
        cfg_fno={
            'n_modes': (n_modes, n_modes),
            'in_channels': 1,
            'hidden_channels': hidden_channels,
            'lifting_channel_ratio': lifting_channel_ratio,
            'projection_channel_ratio': projection_channel_ratio,
            'out_channels': 1,
            'factorization': 'dense',
            'n_layers': n_layers,
            'rank': 0.42,
            'spectral': 'hartley',
        },
        # Run params
        random_seed=42,
        device='cuda:2',
        run_name='Run optuna',
        # Train params
        n_epochs=3,
        lr=lr,
        verbose=True,
    )

    return cfg


def objective(trial):
    cfg = define_train_cfg(trial)
    trainer, train_kwargs, _ = prepare_training(env, cfg)

    if count_parameters(trainer.model) > 10_000_000:
        print('Pruned by model params')
        raise optuna.exceptions.TrialPruned

    seed_everything(cfg.random_seed)

    trial_obj = 'test_h1'
    metrics = trainer.train(trial=trial, trial_obj=trial_obj, **train_kwargs)

    return float(metrics[trial_obj])

In [7]:
run_name = 'hno-v2-h1'
study_name = f'{run_name}-optuna'
sampler_path = Path(f'./{run_name}-sampler')

In [None]:
# init trials storage and sampler pickle
optuna.logging.get_logger('optuna').addHandler(logging.StreamHandler(sys.stdout))
storage_name = 'sqlite:///{}.db'.format(study_name)
restored_sampler = None
if sampler_path.exists():
    print(f'Restore sampler from path: {sampler_path}')
    restored_sampler = pickle.load(Path.open(sampler_path, 'rb'))

# create new study or restore
study = optuna.create_study(
    study_name=study_name,
    storage=storage_name,
    direction='minimize',
    sampler=restored_sampler,
    load_if_exists=True,
)

if not sampler_path.exists():
    print(f'Caching sampler in: {sampler_path}')
    pickle.dump(study.sampler, Path.open(sampler_path, 'wb'))

# run optimization
study.optimize(objective, n_trials=1000, timeout=600)

[I 2025-04-18 11:16:07,158] Using an existing study with name 'hno-v2-h1-optuna' instead of creating a new one.


Restore sampler from path: hno-v2-h1-sampler
Using an existing study with name 'hno-v2-h1-optuna' instead of creating a new one.
Using an existing study with name 'hno-v2-h1-optuna' instead of creating a new one.
Using an existing study with name 'hno-v2-h1-optuna' instead of creating a new one.
Running with n_layers=2, hidden_channels=43, n_modes=15, lifting_channel_ratio=10, projection_channel_ratio=5 lr=0.011927374642042675
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.

[I 2025-04-18 11:16:45,942] Trial 16 pruned. 


Trial 16 pruned. 
Trial 16 pruned. 
Trial 16 pruned. 
Running with n_layers=3, hidden_channels=22, n_modes=9, lifting_channel_ratio=6, projection_channel_ratio=5 lr=0.0022351622959278783
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])
torch.Size([32, 1, 145, 145]) torch.Size([32, 1, 145, 145])
Loaded  model mri-fno-neuralop with n_parameters = 2010449
Loaded 

[I 2025-04-18 11:17:13,850] Trial 17 pruned. 


Trial 17 pruned. 
Trial 17 pruned. 
Trial 17 pruned. 
Running with n_layers=4, hidden_channels=9, n_modes=4, lifting_channel_ratio=3, projection_channel_ratio=2 lr=1.9468162041961396e-05
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])
torch.Size([32, 1, 145, 145]) torch.Size([32, 1, 145, 145])
Loaded  model mri-fno-neuralop with n_parameters = 2010449
Loaded 

[I 2025-04-18 11:17:33,474] Trial 18 pruned. 


Trial 18 pruned. 
Trial 18 pruned. 
Trial 18 pruned. 
Running with n_layers=1, hidden_channels=34, n_modes=19, lifting_channel_ratio=9, projection_channel_ratio=11 lr=0.0004347370541490722
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])
torch.Size([32, 1, 145, 145]) torch.Size([32, 1, 145, 145])
Loaded  model mri-fno-neuralop with n_parameters = 2010449
Loade

[I 2025-04-18 11:17:59,635] Trial 19 pruned. 


Trial 19 pruned. 
Trial 19 pruned. 
Trial 19 pruned. 
Running with n_layers=4, hidden_channels=20, n_modes=6, lifting_channel_ratio=5, projection_channel_ratio=3 lr=0.008872729287776196
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])
torch.Size([32, 1, 145, 145]) torch.Size([32, 1, 145, 145])
Loaded  model mri-fno-neuralop with n_parameters = 2010449
Loaded  

[I 2025-04-18 11:19:27,225] Trial 20 finished with value: 0.10982754081487656 and parameters: {'n_layers': 4, 'hidden_channels': 20, 'n_modes': 6, 'lifting_channel_ratio': 5, 'projection_channel_ratio': 3, 'lr': 0.008872729287776196}. Best is trial 13 with value: 0.10859087109565735.


Saved training state to ckpt
Trial 20 finished with value: 0.10982754081487656 and parameters: {'n_layers': 4, 'hidden_channels': 20, 'n_modes': 6, 'lifting_channel_ratio': 5, 'projection_channel_ratio': 3, 'lr': 0.008872729287776196}. Best is trial 13 with value: 0.10859087109565735.
Trial 20 finished with value: 0.10982754081487656 and parameters: {'n_layers': 4, 'hidden_channels': 20, 'n_modes': 6, 'lifting_channel_ratio': 5, 'projection_channel_ratio': 3, 'lr': 0.008872729287776196}. Best is trial 13 with value: 0.10859087109565735.
Trial 20 finished with value: 0.10982754081487656 and parameters: {'n_layers': 4, 'hidden_channels': 20, 'n_modes': 6, 'lifting_channel_ratio': 5, 'projection_channel_ratio': 3, 'lr': 0.008872729287776196}. Best is trial 13 with value: 0.10859087109565735.
Running with n_layers=2, hidden_channels=52, n_modes=9, lifting_channel_ratio=3, projection_channel_ratio=13 lr=0.0291840183541753
Got n_samples = 8380  in dataset mri_pm_train        with sample size

[I 2025-04-18 11:20:57,470] Trial 21 pruned. 


Trial 21 pruned. 
Trial 21 pruned. 
Trial 21 pruned. 
Running with n_layers=2, hidden_channels=4, n_modes=8, lifting_channel_ratio=4, projection_channel_ratio=8 lr=0.02104741326193117
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])
torch.Size([32, 1, 145, 145]) torch.Size([32, 1, 145, 145])
Loaded  model mri-fno-neuralop with n_parameters = 2010449
Loaded  mo

[I 2025-04-18 11:21:11,984] Trial 22 pruned. 


Trial 22 pruned. 
Trial 22 pruned. 
Trial 22 pruned. 
Running with n_layers=2, hidden_channels=50, n_modes=11, lifting_channel_ratio=4, projection_channel_ratio=23 lr=0.08808860574819805
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])
torch.Size([32, 1, 145, 145]) torch.Size([32, 1, 145, 145])
Loaded  model mri-fno-neuralop with n_parameters = 2010449
Loaded 

[I 2025-04-18 11:22:06,678] Trial 23 pruned. 


Trial 23 pruned. 
Trial 23 pruned. 
Trial 23 pruned. 
Running with n_layers=4, hidden_channels=63, n_modes=13, lifting_channel_ratio=6, projection_channel_ratio=13 lr=0.0028160287651627877
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])
torch.Size([32, 1, 145, 145]) torch.Size([32, 1, 145, 145])
Loaded  model mri-fno-neuralop with n_parameters = 2010449
Loade

In [11]:
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print('Study statistics: ')
print('  Number of finished trials: ', len(study.trials))
print('  Number of pruned trials: ', len(pruned_trials))
print('  Number of complete trials: ', len(complete_trials))

print('Best trial:')
trial = study.best_trial

print('  Value: ', trial.value)

print('  Params: ')
for key, value in trial.params.items():
    print('    {}: {}'.format(key, value))

Study statistics: 
  Number of finished trials:  16
  Number of pruned trials:  9
  Number of complete trials:  7
Best trial:
  Value:  0.10859087109565735
  Params: 
    n_layers: 2
    hidden_channels: 41
    n_modes: 8
    lifting_channel_ratio: 6
    projection_channel_ratio: 32
    lr: 0.0026019787737744096


In [16]:
# HNO-v2, test_l2

# Study statistics:
#   Number of finished trials:  31
#   Number of pruned trials:  13
#   Number of complete trials:  17
# Best trial:
#   Value:  0.043848007917404175
#   Params:
#     n_layers: 3
#     hidden_channels: 49
#     n_modes: 16
#     lifting_channel_ratio: 6
#     projection_channel_ratio: 12
#     lr: 0.0074820412780186325