In [None]:
import torch
import numpy as np

In [None]:
from seasalt.salt_net import (
    train_denoiser,
    train_noise_detector,
    Desnoiser,
    NoiseDetector,
    NoiseType,
    get_test_dataloader,
    get_train_dataloader,
    get_tensor_board_dataset,
)

In [None]:
torch.manual_seed(101)
np.random.seed(101)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

In [None]:
noise_type = NoiseType.SAP
min_noise = 0.5
max_noise = 0.5
batch_size = 128
train_dataloader = get_train_dataloader(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=batch_size
)
val_dataloader = get_test_dataloader(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=batch_size
)
tb_dataloader = get_tensor_board_dataset(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=8
)

In [None]:
noise_detecor_model = NoiseDetector()
noise_detecor_model = noise_detecor_model.to(device)
train_noise_detector(
    noise_detecor_model,
    1e-3,
    train_dataloader,
    val_dataloader,
    device,
    "noise_detector_1",
    50,
    tb_dataloader,
)

In [None]:
denoiser_model = Desnoiser()
denoiser_model = denoiser_model.to(device)
train_denoiser(
    denoiser_model,
    1e-3,
    train_dataloader,
    val_dataloader,
    device,
    "denoiser_1",
    50,
    tb_dataloader,
)