In [None]:
from seasalt.noise_to_salt.noise_to_salt import (
    NoiseDetector,
    Desnoiser,
    weighted_mean_conv,
    calculate_wa_kernel,
    train_noise_detector,
    train_denoiser,
    noise_adder,
    data_folder,
)
from functools import partial
from torch.nn import functional as F
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import numpy as np

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
torch.manual_seed(101)
np.random.seed(101)
device

In [None]:
transform = transforms.Compose(
    [
        transforms.Grayscale(),
        transforms.CenterCrop(320),
        transforms.ToTensor(),
    ]
)

wa_kerne = calculate_wa_kernel(27)


def collate_images(noise_type, batch) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    images = [item[0] for item in batch]
    max_height = max(img.shape[1] for img in images)
    max_width = max(img.shape[2] for img in images)
    padded_images = [
        F.pad(img, (0, max_width - img.shape[2], 0, max_height - img.shape[1]))
        for img in images
    ]
    stacked_images = torch.stack(padded_images)
    noisy_images, masks = noise_adder(
        stacked_images,
        (0.4 - 0.95) * torch.rand(stacked_images.shape) + 0.95,
        noise_type,
    )
    return noisy_images, masks, stacked_images


dataset = datasets.ImageFolder(root=str(data_folder), transform=transform)
lengths = [round(len(dataset) * 0.8), round(len(dataset) * 0.2)]
train_dataset, val_dataset = random_split(dataset, lengths)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=partial(collate_images, "sap"),
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=partial(collate_images, "sap"),
)

In [None]:
# model = NoiseDetector()
# train_noise_detector(
#     model,
#     1e-3,
#     train_dataloader,
#     val_dataloader,
#     device,
#     "noise_detector",
#     100,
#     True,
# )

In [None]:
model = Desnoiser()
model = model.to(device)
train_denoiser(
    model,
    1e-3,
    train_dataloader,
    val_dataloader,
    device,
    "denoise_me_daddy_9510",
    100,
    True,
)