In [1]:
import torch
import numpy as np
import random

In [None]:
from seasalt.salt_net import (
    train_denoiser,
    train_noise_detector,
    DenoiseNet,
    NoiseDetector,
    NoiseType,
    get_test_dataloader,
    get_train_dataloader,
    get_tensor_board_dataset,
)

In [None]:
torch.manual_seed(101)
np.random.seed(101)
random.seed(101)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

device(type='cuda')

In [None]:
noise_type = NoiseType.PROBALISTIC
min_noise = 0.4
max_noise = 0.8
batch_size = 32
train_dataloader = get_train_dataloader(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=batch_size
)
val_dataloader = get_test_dataloader(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=batch_size
)
tb_dataloader = get_tensor_board_dataset(
    noise_type, min_noise=min_noise, max_noise=max_noise, batch_size=8
)

In [None]:
noise_detecor_model = NoiseDetector(squeeze_excitation=True, dropout=True)
# noise_detecor_model = noise_detecor_model.to(device)
# train_noise_detector(
#     noise_detecor_model,
#     1e-4,
#     train_dataloader,
#     val_dataloader,
#     device,
#     "noise_detector_prob_se_uneq_trick_low_lr_more_data_dropout",
#     100,
#     tb_dataloader,
# )

In [None]:
X = next(train_dataloader.__iter__())[0][:1, :]
mask = next(train_dataloader.__iter__())[1][:1, :]
noise_detecor_model.load_state_dict(
    torch.load(
        "./models/pytorch_noise_detector_prob_se_uneq_trick_low_lr_more_data_dropout_31.h5"
    ),
)
noise_detecor_model = noise_detecor_model.eval()
traced_model = torch.jit.trace(noise_detecor_model, X)

In [None]:
traced_model.save("./models/detector.pt")

In [None]:
# denoiser_model = DenoiseNet(
#     output_cnn_depth=20,
#     enable_seconv=True,
#     enable_unet=False,
#     enable_fft=True,
#     enable_unet_post_processing=True,
# )
# denoiser_model = denoiser_model.to(device)
# train_denoiser(
#     denoiser_model,
#     1e-3,
#     train_dataloader,
#     val_dataloader,
#     device,
#     "test",
#     100,
#     tb_dataloader,
# )