In [None]:
%load_ext autoreload

%autoreload 2

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Dataset

In [None]:
import os

import lightning as L
import pyrootutils
import torch
import torchshow as ts

from gaussian_denoiser import data, transforms

In [None]:
root = pyrootutils.setup_root(
    search_from=".",
    indicator="pyproject.toml",
    project_root_env_var=True,
    dotenv=True,
    pythonpath=False,
    cwd=True,
)

PROJECT_ROOT = os.getenv("PROJECT_ROOT")

In [None]:
import omegaconf

cfg = omegaconf.OmegaConf.load(f"{PROJECT_ROOT}/config/train.yaml")

In [None]:
cfg

In [None]:
cfg.data.train_path

In [None]:
L.seed_everything(123, workers=True)

## Test Dataset (Pipeline)

In [None]:
from gaussian_denoiser.data import DenoisingDataModule, PatchDataset

data_module = data.DenoisingDataModule(
    train_path=cfg.data.train_path,
    val_path=cfg.data.val_path,
    test_path=cfg.data.test_path,
    batch_size=16,
    patch_size=cfg.data.patch_size,
    noise_level_interval=cfg.data.noise_level_interval,
    validation_noise_level_interval=cfg.data.validation_noise_level_interval,
)


data_module.setup("fit")

In [None]:
ds = data_module.train_dataset

clean, noisy, noise = ds[0]

noisy.shape
clean.shape
noise.shape

In [None]:
ds_train = data_module.train_dataloader()
clean_batch, noisy_batch, noise_batch = next(iter(ds_train))

In [None]:
torch.min(noisy_batch)
torch.max(noisy_batch)

torch.min(clean_batch)
torch.max(clean_batch)

torch.min(noise_batch)
torch.max(noise_batch)

In [None]:
from torchmetrics.image import PeakSignalNoiseRatio

PeakSignalNoiseRatio()

psnr = PeakSignalNoiseRatio((0, 1), dim=(1, 2, 3), reduction="elementwise_mean")

psnr(noisy_batch, clean_batch).shape
psnr(noisy_batch, clean_batch)
psnr(noisy_batch[0].unsqueeze(0), clean_batch[0].unsqueeze(0)).shape
psnr(noisy_batch[0].unsqueeze(0), clean_batch[0].unsqueeze(0))