In [None]:
%load_ext autoreload

%autoreload 2

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Test Models on Pre-Computed Testsets

In [None]:
import os
from pathlib import Path

import omegaconf
import pyrootutils
import torch
import torchshow as ts
from torch.utils.data import DataLoader
from torchmetrics import image
from tqdm.notebook import tqdm

from gaussian_denoiser import dataset, dncnn, utils

In [None]:
root = pyrootutils.setup_root(
    search_from=".",
    indicator="pyproject.toml",
    project_root_env_var=True,
    dotenv=True,
    pythonpath=False,
    cwd=True,
)

PROJECT_ROOT = os.getenv("PROJECT_ROOT")

## Parameters


In [None]:
MODEL_PATH = Path(PROJECT_ROOT).joinpath("logs/train/CDnCNN-B_2024-08-13_21-00-53")

CFG_PATH = MODEL_PATH.joinpath(".hydra/config.yaml")

TEST_DATA = Path(PROJECT_ROOT)

DEVICE = "cpu"

In [None]:
cfg = omegaconf.OmegaConf.load(CFG_PATH)

## Prepare Data

In [None]:
PATH_ORIGINAL = Path(PROJECT_ROOT).joinpath(cfg.datasets.cbsd68.original_path)
PATH_NOISY = Path(PROJECT_ROOT).joinpath(cfg.datasets.cbsd68.noisy25)

In [None]:
ds_precomputed = dataset.PreComputedTestDataset(PATH_ORIGINAL, PATH_NOISY)

## Load model

In [None]:
ckpt_path_list = utils.find_all_ckpt_files(MODEL_PATH)
ckpt_path = utils.get_ckpt(ckpt_path_list)

In [None]:
model = dncnn.DnCNNModule.load_from_checkpoint(ckpt_path)

model.eval()
model.freeze()

model.to("cpu")

## Evaluation

In [None]:
psnr = image.PeakSignalNoiseRatio((0, 1), dim=(1, 2, 3), reduction="elementwise_mean")
ssim = image.StructuralSimilarityIndexMeasure(data_range=1.0)

dl_precomputed = DataLoader(ds_precomputed, batch_size=1, shuffle=False, num_workers=1)

delta_images = list()

with torch.no_grad():
    for original_image, noisy_image in tqdm(dl_precomputed):
        noise_estimate = model(noisy_image).cpu()
        denoised_image = noisy_image - noise_estimate
        denoised_image = torch.clip(denoised_image, 0, 1.0)
        delta_image = denoised_image - original_image
        psnr.update(denoised_image, original_image)
        ssim.update(denoised_image, original_image)
        delta_images.append(delta_image)

In [None]:
psnr.compute()
ssim.compute()

In [None]:
ts.show(original_image)
ts.show(noisy_image)
ts.show(noise_estimate)
ts.show(denoised_image - original_image)

In [None]:
# deltas = torch.concat(delta_images)

ts.show(delta_images[0])
ts.show(delta_images[1])