In [1]:
from seasalt.noise_to_salt.noise_to_salt import (
    NoiseDetector,
    weighted_mean_conv,
    train_model,
    noise_adder,
    data_folder,
)
from functools import partial
from torch.nn import BCELoss
from torch.nn import functional as F
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import numpy as np

In [2]:
device = (
    torch.device("cuda") if torch.cuda.is_available() else torch.device("mps")
)  # torch.device("cpu")
torch.manual_seed(101)
np.random.seed(101)

In [3]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)


def collate_pad_to_biggest(noise_parameter, noise_type, batch):
    images = [item[0] for item in batch]
    max_height = max(img.shape[1] for img in images)
    max_width = max(img.shape[2] for img in images)
    padded_images = [
        F.pad(img, (0, max_width - img.shape[2], 0, max_height - img.shape[1]))
        for img in images
    ]
    stacked_images = torch.stack(padded_images)
    noisy_images, masks = noise_adder(stacked_images, noise_parameter, noise_type)
    return noisy_images, masks


dataset = datasets.ImageFolder(root=str(data_folder), transform=transform)
lengths = [round(len(dataset) * 0.8), round(len(dataset) * 0.2)]
train_dataset, val_dataset = random_split(dataset, lengths)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=partial(collate_pad_to_biggest, 0.2, "gaussian"),
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=partial(collate_pad_to_biggest, 0.2, "gaussian"),
)

In [4]:
model = NoiseDetector().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = BCELoss()
model = model.to(device)

In [5]:
train_model(
    model,
    "gaussian",
    0.2,
    optimizer,
    criterion,
    train_dataloader,
    val_dataloader,
    device,
    "efteta7",
    100,
)

Output()