In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging
import sys
import pickle
from pathlib import Path

import torch
import wandb
import optuna
from optuna.trial import TrialState

from config import Environment, TrainConfig
from denoising.train import prepare_training
from denoising.utils import seed_everything
from denoising.models.utils import count_parameters

In [3]:
CWD = Path.cwd()
env = Environment(_env_file=CWD / '../env')
wandb.login(key=env.wandb_api_key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/d.nesterov/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdmitrylala[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [7]:
def define_train_cfg(trial) -> TrainConfig:
    n_layers = trial.suggest_int('n_layers', 2, 10, log=True)
    hidden_channels = trial.suggest_categorical('hidden_channels', [4, 8, 16, 32])
    n_modes = trial.suggest_categorical('n_modes', [16])
    lifting_channel_ratio = trial.suggest_categorical('lifting_channel_ratio', [2, 4, 8, 16, 32])
    projection_channel_ratio = trial.suggest_categorical('projection_channel_ratio', [2, 4, 8, 16, 32])
    lr = trial.suggest_float('lr', 1e-4, 1e-3, log=True)

    print(
        f'Running with {n_layers=}, {hidden_channels=}, {n_modes=}, {lifting_channel_ratio=}, {projection_channel_ratio=} {lr=}'
    )

    cfg = TrainConfig(
        # Datasets params
        train_dset='sidd_train',
        test_dset='sidd_test',
        train_batch_size=16,
        test_batch_size=32,
        # Model params
        name_model='sidd-hno',
        cfg_fno={
            'n_modes': (n_modes, n_modes),
            'in_channels': 3,
            'hidden_channels': hidden_channels,
            'lifting_channel_ratio': lifting_channel_ratio,
            'projection_channel_ratio': projection_channel_ratio,
            'out_channels': 3,
            'factorization': 'dense',
            'n_layers': n_layers,
            'rank': 0.42,
            'spectral': 'hartley',
        },
        # Run params
        random_seed=42,
        device='cuda:2',
        run_name='Run optuna',
        # Train params
        n_epochs=1,
        lr=lr,
        verbose=True,
    )

    return cfg


def objective(trial):
    cfg = define_train_cfg(trial)
    trainer, train_kwargs, _ = prepare_training(env, cfg)

    if count_parameters(trainer.model) > 5_000_000:
        print('Pruned by model params')
        raise optuna.exceptions.TrialPruned

    seed_everything(cfg.random_seed)

    trial_obj = 'test_h1'
    metrics = trainer.train(trial=trial, trial_obj=trial_obj, **train_kwargs)

    return float(metrics[trial_obj])

In [8]:
run_name = 'hno-sidd'
study_name = f'{run_name}-optuna'
sampler_path = Path(f'./{run_name}-sampler')

In [None]:
# init trials storage and sampler pickle
optuna.logging.get_logger('optuna').addHandler(logging.StreamHandler(sys.stdout))
storage_name = 'sqlite:///{}.db'.format(study_name)
restored_sampler = None
if sampler_path.exists():
    print(f'Restore sampler from path: {sampler_path}')
    restored_sampler = pickle.load(Path.open(sampler_path, 'rb'))

# create new study or restore
study = optuna.create_study(
    study_name=study_name,
    storage=storage_name,
    direction='minimize',
    sampler=restored_sampler,
    load_if_exists=True,
)

if not sampler_path.exists():
    print(f'Caching sampler in: {sampler_path}')
    pickle.dump(study.sampler, Path.open(sampler_path, 'wb'))

# run optimization
study.optimize(objective, n_trials=100, timeout=12000)

[I 2025-05-10 16:20:45,492] Using an existing study with name 'hno-sidd-optuna' instead of creating a new one.


Restore sampler from path: hno-sidd-sampler
Using an existing study with name 'hno-sidd-optuna' instead of creating a new one.
Using an existing study with name 'hno-sidd-optuna' instead of creating a new one.
Running with n_layers=8, hidden_channels=4, n_modes=16, lifting_channel_ratio=8, projection_channel_ratio=32 lr=0.0003537913969983926
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 6704  in dataset mri_gt_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 1676  in dataset mri_gt_val          with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample si

[I 2025-05-10 16:29:17,733] Trial 1 finished with value: 2.717987060546875 and parameters: {'n_layers': 8, 'hidden_channels': 4, 'n_modes': 16, 'lifting_channel_ratio': 8, 'projection_channel_ratio': 32, 'lr': 0.0003537913969983926}. Best is trial 1 with value: 2.717987060546875.


Saved training state to ckpt
Trial 1 finished with value: 2.717987060546875 and parameters: {'n_layers': 8, 'hidden_channels': 4, 'n_modes': 16, 'lifting_channel_ratio': 8, 'projection_channel_ratio': 32, 'lr': 0.0003537913969983926}. Best is trial 1 with value: 2.717987060546875.
Trial 1 finished with value: 2.717987060546875 and parameters: {'n_layers': 8, 'hidden_channels': 4, 'n_modes': 16, 'lifting_channel_ratio': 8, 'projection_channel_ratio': 32, 'lr': 0.0003537913969983926}. Best is trial 1 with value: 2.717987060546875.
Running with n_layers=3, hidden_channels=16, n_modes=16, lifting_channel_ratio=2, projection_channel_ratio=8 lr=0.0003695278044736918
Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 6704  in dataset mri_gt_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 1676  in dataset mri_

In [7]:
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print('Study statistics: ')
print('  Number of finished trials: ', len(study.trials))
print('  Number of pruned trials: ', len(pruned_trials))
print('  Number of complete trials: ', len(complete_trials))

print('Best trial:')
trial = study.best_trial

print('  Value: ', trial.value)

print('  Params: ')
for key, value in trial.params.items():
    print('    {}: {}'.format(key, value))

Study statistics: 
  Number of finished trials:  132
  Number of pruned trials:  54
  Number of complete trials:  77
Best trial:
  Value:  0.3016343414783478
  Params: 
    n_layers: 4
    hidden_channels: 32
    n_modes: 16
    lifting_channel_ratio: 16
    projection_channel_ratio: 2
    lr: 0.005083610315160994


In [31]:
# HNO on mri gt v2
# Study statistics: 
#   Number of finished trials:  132
#   Number of pruned trials:  54
#   Number of complete trials:  77
# Best trial:
#   Value:  0.3016343414783478
#   Params: 
#     n_layers: 4
#     hidden_channels: 32
#     n_modes: 16
#     lifting_channel_ratio: 16
#     projection_channel_ratio: 2
#     lr: 0.005083610315160994

# FNO on mri gt
# Study statistics: 
#   Number of finished trials:  91
#   Number of pruned trials:  49
#   Number of complete trials:  41
# Best trial:
#   Value:  0.29339370131492615
#   Params: 
#     n_layers: 3
#     hidden_channels: 32
#     n_modes: 32
#     lifting_channel_ratio: 32
#     projection_channel_ratio: 4
#     lr: 0.005342701181994739

# HNO on mri gt
# Study statistics: 
#   Number of finished trials:  32
#   Number of pruned trials:  10
#   Number of complete trials:  21
# Best trial:
#   Value:  0.3213663101196289
#   Params: 
#     n_layers: 6
#     hidden_channels: 8
#     n_modes: 16
#     lifting_channel_ratio: 8
#     projection_channel_ratio: 4
#     lr: 0.004377779250650843

# FNO
# Study statistics: 
#   Number of finished trials:  29
#   Number of pruned trials:  9
#   Number of complete trials:  20
# Best trial:
#   Value:  0.08233946561813354
#   Params: 
#     n_layers: 15
#     hidden_channels: 16
#     n_modes: 32
#     lifting_channel_ratio: 32
#     projection_channel_ratio: 2
#     lr: 0.006055187761870968

# HNO
# Study statistics: 
#   Number of finished trials:  55
#   Number of pruned trials:  31
#   Number of complete trials:  24
# Best trial:
#   Value:  0.08180370926856995
#   Params: 
#     n_layers: 10
#     hidden_channels: 16
#     n_modes: 16
#     lifting_channel_ratio: 32
#     projection_channel_ratio: 8
#     lr: 0.00433647012426727

In [16]:
# HNO-v2, test_h1
# Study statistics:
#   Number of finished trials:  25
#   Number of pruned trials:  16
#   Number of complete trials:  9
# Best trial:
#   Value:  0.10859087109565735
#   Params:
#     n_layers: 2
#     hidden_channels: 41
#     n_modes: 8
#     lifting_channel_ratio: 6
#     projection_channel_ratio: 32
#     lr: 0.0026019787737744096


# HNO-v2, test_l2

# Study statistics:
#   Number of finished trials:  31
#   Number of pruned trials:  13
#   Number of complete trials:  17
# Best trial:
#   Value:  0.043848007917404175
#   Params:
#     n_layers: 3
#     hidden_channels: 49
#     n_modes: 16
#     lifting_channel_ratio: 6
#     projection_channel_ratio: 12
#     lr: 0.0074820412780186325