In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import neuralop
import torch
import wandb

from denoising import (
    Environment,
    make_model_config,
    ModelRegistry,
)
from config import model_load_configs

ImportError: cannot import name 'builder' from 'google.protobuf.internal' (/home/d.nesterov/denoising-fno/.venv/lib/python3.10/site-packages/google/protobuf/internal/__init__.py)

In [3]:
print('torch ' + torch.__version__)
print('neuralop ' + neuralop.__version__)

torch 2.6.0+cu124
neuralop 1.0.2


In [4]:
# загружаем переменные окружения
env = Environment(_env_file='../env')

In [5]:
wandb.login(key=env.wandb_api_key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/d.nesterov/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdmitrylala[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Инициализация данных

In [6]:
import matplotlib.pyplot as plt


def show_sample(item: dict[str, torch.Tensor]) -> None:
    def prep(v):
        return v.detach().cpu().permute(1, 2, 0)

    x, y = prep(item['x']), prep(item['y'])

    fig = plt.figure(figsize=(7, 7))
    ax = fig.add_subplot(1, 2, 1)
    ax.imshow(x)
    ax.set_title('input x')
    ax = fig.add_subplot(1, 2, 2)
    ax.imshow(y)
    ax.set_title('input y')
    fig.show()

In [7]:
from denoising import (
    make_bsd_dset_config,
    make_fno_dset_config,
    make_load_params,
)
from denoising.data import DatasetRegistry

In [8]:
mri_root = env.data / 'MRI/IXI_0_1/255'
bsd_root = env.data / 'BSDS300-horizontal-synthetic'
sidd_root = env.data / 'SIDD_Small_sRGB_Only'

mri_sketch_load_params = make_load_params('sketch', [145, 145], 'float32')
mri_image_load_params = make_load_params('image', [145, 145], 'float32')
mri_gt_load_params = make_load_params('gt', [255, 255], 'float32')
pm_load_params = [mri_sketch_load_params, mri_image_load_params]
gt_load_params = [mri_sketch_load_params, mri_gt_load_params]

sidd_noisy_load_params = make_load_params('noisy', [512, 512, 3], 'uint8')
sidd_gt_load_params = make_load_params('gt', [512, 512, 3], 'uint8')
sidd_load_params = [sidd_noisy_load_params, sidd_gt_load_params]

datasets_configs = {
    # MRI datasets
    'mri_pm_train': make_fno_dset_config(
        mri_root,
        env.data / 'MRI/lists/IXI_0_1/train_pmLR_gibbsnoiseLR_train.csv',
        pm_load_params,
    ),
    'mri_pm_test': make_fno_dset_config(
        mri_root,
        env.data / 'MRI/lists/IXI_0_1/train_pmLR_gibbsnoiseLR_val.csv',
        pm_load_params,
    ),
    'mri_gt_test': make_fno_dset_config(
        mri_root,
        env.data / 'MRI/lists/IXI_0_1/train_gtLR_gibbsnoiseLR_val.csv',
        gt_load_params,
    ),
    # BSD datasets
    'bsd_synth_0.01_train': make_bsd_dset_config(bsd_root, 0.01, 'train'),
    'bsd_synth_0.01_test': make_bsd_dset_config(bsd_root, 0.01, 'test'),
    # SIDD datasets, patches
    'sidd_train': make_fno_dset_config(
        sidd_root / 'train',
        sidd_root / 'patches_train.csv',
        sidd_load_params,
        normalize=True,
    ),
    'sidd_test': make_fno_dset_config(
        sidd_root / 'val',
        sidd_root / 'patches_val.csv',
        sidd_load_params,
        normalize=True,
    ),
}

In [9]:
dataset_registry = DatasetRegistry()
dataset_registry.load(datasets_configs, verbose=True)
dataset_registry

Got n_samples = 8380  in dataset mri_pm_train        with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_pm_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 2093  in dataset mri_gt_test         with sample size = torch.Size([1, 145, 145])
Got n_samples = 137   in dataset bsd_synth_0.01_train with sample size = torch.Size([1, 321, 481])
Got n_samples = 77    in dataset bsd_synth_0.01_test with sample size = torch.Size([1, 321, 481])
Got n_samples = 12296 in dataset sidd_train          with sample size = torch.Size([3, 512, 512])
Got n_samples = 3008  in dataset sidd_test           with sample size = torch.Size([3, 512, 512])


DatasetRegistry(['mri_pm_train', 'mri_pm_test', 'mri_gt_test', 'bsd_synth_0.01_train', 'bsd_synth_0.01_test', 'sidd_train', 'sidd_test'])

In [10]:
# MRI
train_loader = dataset_registry.make_dl('mri_pm_train', batch_size=64, shuffle=True)
test_loader = dataset_registry.make_dl('mri_pm_test', batch_size=128)

# BSD
# train_loader = dataset_registry.make_dl('bsd_synth_0.01_train', batch_size=32, shuffle=True)
# test_loader = dataset_registry.make_dl('bsd_synth_0.01_test', batch_size=64)

# SIDD patches
# train_loader = dataset_registry.make_dl('sidd_train', batch_size=16, shuffle=True)
# test_loader = dataset_registry.make_dl('sidd_test', batch_size=64)

In [11]:
for batch in train_loader:
    print(batch.x.size(), batch.y.size())
    break

torch.Size([64, 1, 145, 145]) torch.Size([64, 1, 145, 145])


In [12]:
# Uncomment to show sample

# idx = 10
# sample = train_loader.dataset[idx]

# x, y = sample['x'], sample['y']
# print(f'Training sample {idx} has shape: {x.size()}')

# show_sample(sample)

# Инициализация FNO

In [25]:
# Параметры запусков


# run 6, params from paper (best on MRI)
# fno_cfg = {
#     'n_modes': (32, 32),
#     'in_channels': 1,
#     'hidden_channels': 32,
#     'lifting_channel_ratio': 8,
#     'projection_channel_ratio': 2,
#     'out_channels': 1,
#     'factorization': 'tucker',
#     'n_layers': 4,
#     'rank': 0.42,
# }
# optimizer = torch.optim.Adam(train_model.parameters(), lr=1e-3, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

# для BSD меньше lr
# optimizer = torch.optim.Adam(train_model.parameters(), lr=1e-4, weight_decay=1e-5)


# run 5
# fno_cfg = {
#     'n_modes': (16, 16),
#     'in_channels': 1,
#     'hidden_channels': 16,
#     'projection_channel_ratio': 2,
#     'out_channels': 1,
#     'factorization': 'tucker',
#     'n_layers': 32,
#     'rank': 0.42,
#     'positional_embedding': GridEmbedding2D(in_channels=1, grid_boundaries=[[0, 1], [0, 1.5]]),
# }

In [26]:
model_registry = ModelRegistry()
model_registry.load(model_load_configs, random_seed=env.random_seed, device='cpu', verbose=True)
model_registry

Loaded  model mri-fno-neuralop with n_parameters = 1015257
Loaded  model mri-fno-custom   with n_parameters = 1015257
Loaded  model sidd-fno-run2    with n_parameters = 1015899
Loaded  model sidd-fno-run3    with n_parameters = 1015899
Loaded  model sidd-fno-run4    with n_parameters = 1020595
Loaded  model bsd-fno          with n_parameters = 1015257


ModelRegistry(['mri-fno-neuralop', 'mri-fno-custom', 'sidd-fno-run2', 'sidd-fno-run3', 'sidd-fno-run4', 'bsd-fno'])

In [27]:
new_cfg = {
    'n_modes': (32, 32),
    'in_channels': 3,
    'hidden_channels': 16,  # less hidden: 32 -> 16
    'lifting_channel_ratio': 8,
    'projection_channel_ratio': 2,
    'out_channels': 3,
    'factorization': 'tucker',
    'n_layers': 16,  # stack more layers: 4 -> 16
    'rank': 0.42,
}
model_name = 'new-fno'

new_fno_cfg = {
    model_name: make_model_config(new_cfg, None, 'FNO'),
}

model_registry.load(new_fno_cfg, random_seed=env.random_seed, device=env.device, verbose=True)
model = model_registry[model_name]

Created model new-fno          with n_parameters = 1020595


# Обучение

In [29]:
from dataclasses import dataclass, field
from pathlib import Path

from neuralop import H1Loss, LpLoss, Trainer
from torch import nn
from torch.utils.data import DataLoader


@dataclass
class TrainConfig:
    train_loader: DataLoader
    test_loader: DataLoader
    model: nn.Module
    lr: float
    n_epochs: int
    device: str | torch.device

    run_name: str
    save_dir_run: Path
    save_weights_path: Path
    tags: list[str] = field(default_factory=list)


def make_run(cfg: TrainConfig):
    wandb_init_args = dict(
        project='Denoising MRI',
        name=cfg.run_name,
        group='FNO 2025',
        entity='Dmitrylala',
        tags=cfg.tags,
    )

    return wandb.init(**wandb_init_args)


def prepare_training(cfg: TrainConfig) -> tuple:
    trainer_cfg = {
        'model': cfg.model,
        'n_epochs': cfg.n_epochs,
        'device': cfg.device,
        'wandb_log': True,
        'eval_interval': 1,
        'log_output': True,
        'verbose': True,
    }

    trainer = Trainer(**trainer_cfg)
    print(f'Logging to wandb enabled: {trainer.wandb_log}')

    optimizer = torch.optim.Adam(cfg.model.parameters(), lr=cfg.lr, weight_decay=cfg.lr / 10.0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

    l2loss = LpLoss(d=2, p=2)
    h1loss = H1Loss(d=2)

    trainer_cfg = {
        'train_loader': cfg.train_loader,
        'test_loaders': {'test': cfg.test_loader},
        'optimizer': optimizer,
        'scheduler': None,
        'scheduler': scheduler,
        'save_every': 1,
        'save_dir': cfg.save_dir_run,
        'training_loss': h1loss,
        'eval_losses': {'h1': h1loss, 'l2': l2loss},
    }

    return trainer, trainer_cfg

In [30]:
run_idx = 4
save_dir = Path('sidd_patches')

run_name = f'Run {run_idx}, SIDD, more layers'
save_dir_run = save_dir / f'run-{run_idx}'
save_weights_path = save_dir / f'run-{run_idx}-weights.pt'

train_cfg = TrainConfig(
    train_loader=train_loader,
    test_loader=train_loader,
    model=model,
    lr=1e-3,
    n_epochs=50,
    device=env.device,
    run_name=run_name,
    save_dir_run=save_dir_run,
    save_weights_path=save_weights_path,
    tags=['SIDD', 'no augs'],
)

In [18]:
run = make_run(train_cfg)

In [21]:
trainer, trainer_cfg = prepare_training(train_cfg)

Logging to wandb enabled: True


In [None]:
trainer.train(**trainer_cfg)
run.finish()

Training on 12296 samples
Testing on [12296] samples         on resolutions ['test'].
Raw outputs of shape torch.Size([16, 3, 512, 512])


In [27]:
torch.save(train_cfg.model.to('cpu').state_dict(), train_cfg.save_weights_path)