In [None]:
import torch
import numpy as np
import random

In [None]:
from seasalt.salt_net import (
    train_denoiser,
    train_noise_detector,
    HybridModel,
    NoiseType,
    get_test_dataloader,
    get_train_dataloader,
    get_tensor_board_dataset,
    train_hybrid_model,
)

In [None]:
torch.manual_seed(101)
np.random.seed(101)
random.seed(101)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

In [None]:
noise_type = NoiseType.RANDOM
min_noise = 0.4
max_noise = 0.95
batch_size = 16
train_dataloader = get_train_dataloader(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=batch_size
)
val_dataloader = get_test_dataloader(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=batch_size
)
tb_dataloader = get_tensor_board_dataset(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=8
)

In [None]:
denoiser_model = HybridModel(
    denoiser_weights_path="./models/pytorch_best_frank_model_88.h5",
    detector_weights_path="./models/pytorch_noise_detector_prob_se_"
    "uneq_trick_low_lr_more_data_dropout_40.h5",
)
denoiser_model.to(device)
train_hybrid_model(
    denoiser_model,
    1e-3,
    train_dataloader,
    val_dataloader,
    device,
    "new_gausian_90",
    100,
    tb_dataloader,
)