In [None]:
%load_ext autoreload

%autoreload 2

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Test Model

In [None]:
import os
from pathlib import Path

import omegaconf
import pyrootutils
import torch
import torchshow as ts
from icecream import ic
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import functional as TF

from gaussian_denoiser import data, dataset, dncnn, transforms, utils

In [None]:
root = pyrootutils.setup_root(
    search_from=".",
    indicator="pyproject.toml",
    project_root_env_var=True,
    dotenv=True,
    pythonpath=False,
    cwd=True,
)

PROJECT_ROOT = os.getenv("PROJECT_ROOT")

## Functions

In [None]:
def find_all_ckpt_files(directory: Path) -> list[Path]:
    # Find all files with .ckpt extension recursively
    ckpt_files = list(directory.rglob("*.ckpt"))
    return ckpt_files


def get_ckpt(ckpt_list: list[Path]):
    if ckpt_list:
        print("Found checkpoint files:")
        for file in ckpt_list:
            print(file)
    else:
        print("No checkpoint files found.")

    return ckpt_list[0]

## Parameters

In [None]:
MODEL_PATH = Path(PROJECT_ROOT).joinpath("logs/train/CDnCNN-B_2024-08-13_21-00-53")

CFG_PATH = MODEL_PATH.joinpath(".hydra/config.yaml")

TEST_DATA = Path(PROJECT_ROOT)

DEVICE = "cpu"

In [None]:
cfg = omegaconf.OmegaConf.load(CFG_PATH)

ic(omegaconf.OmegaConf.to_container(cfg))

In [None]:
torch.manual_seed(123)

## Data

In [None]:
PATH = Path(PROJECT_ROOT).joinpath(cfg.datasets.cbsd68.original_path)

In [None]:
ds_test = dataset.ImageFolderDataset(path=PATH)


ds_test_patch = data.ImagePatchDenoiseDataset(
    ds=ds_test,
    transform=lambda x: x,
    patch_size=cfg.experiment.data.patch_size,
    noise_transform=transforms.AWGNOnlyTransform(min_variance=15, max_variance=15),
)

ds_test_denoise = data.ImageDenoiseDataset(
    ds_test, transforms.AWGNOnlyTransform(min_variance=15, max_variance=15)
)

In [None]:
type(ds_test[0])
type(ds_test_denoise[0])

In [None]:
dl = DataLoader(ds_test_denoise, batch_size=1, shuffle=False, num_workers=1)
dl_patch = DataLoader(ds_test_patch, batch_size=128, shuffle=False, num_workers=4)

Data with pre-computed noisy images:

In [None]:
PATH_ORIGINAL = Path(PROJECT_ROOT).joinpath("data/test/cbsd68/original_png")
PATH_NOISY = Path(PROJECT_ROOT).joinpath("data/test/cbsd68/noisy25")

In [None]:
ds_precomputed = dataset.PreComputedTestDataset(PATH_ORIGINAL, PATH_NOISY)

## Model

In [None]:
ckpt_path_list = find_all_ckpt_files(MODEL_PATH)
ckpt_path = get_ckpt(ckpt_path_list)

In [None]:
model = dncnn.DnCNNModule.load_from_checkpoint(ckpt_path)

model.eval()
model.freeze()

model.to("cpu")

Patch based

In [None]:
from torchmetrics import image
from tqdm.notebook import tqdm

psnr = image.PeakSignalNoiseRatio((0, 1), dim=(1, 2, 3), reduction="elementwise_mean")
ssim = image.StructuralSimilarityIndexMeasure(data_range=1.0)

with torch.no_grad():
    for original_image, noisy_image, delta_noise in tqdm(dl_patch):
        noise_estimate = model(noisy_image).cpu()
        denoised_image = noisy_image - noise_estimate
        denoised_image = torch.clip(denoised_image, 0, 1.0)
        psnr.update(denoised_image, original_image)
        ssim.update(denoised_image, original_image)

In [None]:
ts.show(x_denoised)
ts.show(x)

In [None]:
psnr.compute()
ssim.compute()

Image based

In [None]:
from torchmetrics import image
from tqdm.notebook import tqdm

psnr = image.PeakSignalNoiseRatio((0, 1), dim=(1, 2, 3), reduction="elementwise_mean")
ssim = image.StructuralSimilarityIndexMeasure(data_range=1.0)

with torch.no_grad():
    for original_image, noisy_image, delta_noise in tqdm(dl):
        noise_estimate = model(noisy_image).cpu()
        denoised_image = noisy_image - noise_estimate
        denoised_image = torch.clip(denoised_image, 0, 1.0)
        psnr.update(denoised_image, original_image)
        ssim.update(denoised_image, original_image)

In [None]:
psnr.compute()

In [None]:
ssim.compute()

In [None]:
TF.to_pil_image(original_image.squeeze())

In [None]:
TF.to_pil_image(noisy_image.squeeze())

In [None]:
TF.to_pil_image(denoised_image.squeeze())

In [None]:
ts.show(denoised_image - original_image)

In [None]:
ts.show(delta_noise)

## Evaluate Pre-Computed Testset

In [None]:
psnr = image.PeakSignalNoiseRatio((0, 1), dim=(1, 2, 3), reduction="elementwise_mean")
ssim = image.StructuralSimilarityIndexMeasure(data_range=1.0)

dl_precomputed = DataLoader(ds_precomputed, batch_size=1, shuffle=False, num_workers=1)

with torch.no_grad():
    for original_image, noisy_image in tqdm(dl_precomputed):
        noise_estimate = model(noisy_image).cpu()
        denoised_image = noisy_image - noise_estimate
        denoised_image = torch.clip(denoised_image, 0, 1.0)
        psnr.update(denoised_image, original_image)
        ssim.update(denoised_image, original_image)

In [None]:
ts.show(original_image)
ts.show(noisy_image)
ts.show(denoised_image)
ts.show(original_image - denoised_image)

In [None]:
psnr.compute()
ssim.compute()