In [1]:
from seasalt.noise_to_salt.noise_to_salt import (
    NoiseDetector,
    weighted_mean_conv,
    train_model,
    noise_adder,
    data_folder,
)
from functools import partial
from torch.nn import functional as F
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import numpy as np

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
torch.manual_seed(101)
np.random.seed(101)
device

device(type='cuda')

In [3]:
transform = transforms.Compose(
    [
        transforms.Grayscale(),
        transforms.ToTensor(),
    ]
)


def collate_pad_to_biggest(noise_parameter, noise_type, batch):
    images = [item[0] for item in batch]
    max_height = max(img.shape[1] for img in images)
    max_width = max(img.shape[2] for img in images)
    padded_images = [
        F.pad(img, (0, max_width - img.shape[2], 0, max_height - img.shape[1]))
        for img in images
    ]
    stacked_images = torch.stack(padded_images)
    noisy_images, masks = noise_adder(stacked_images, noise_parameter, noise_type)
    return noisy_images, masks


dataset = datasets.ImageFolder(root=str(data_folder), transform=transform)
lengths = [round(len(dataset) * 0.8), round(len(dataset) * 0.2)]
train_dataset, val_dataset = random_split(dataset, lengths)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=partial(collate_pad_to_biggest, 0.2, "gaussian"),
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=partial(collate_pad_to_biggest, 0.2, "gaussian"),
)

In [4]:
model = NoiseDetector().cuda()
model = model.to(device)

In [5]:
train_model(
    model,
    1e-1,
    train_dataloader,
    val_dataloader,
    device,
    "efteta7",
    100,
)

Output()

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 7.79 GiB of which 17.75 MiB is free. Process 844025 has 498.00 MiB memory in use. Including non-PyTorch memory, this process has 7.13 GiB memory in use. Of the allocated memory 5.99 GiB is allocated by PyTorch, and 137.39 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF