In [None]:
%load_ext autoreload

%autoreload 2

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Inference Tests

In [None]:
import os
from pathlib import Path

import omegaconf
import pyrootutils
import torch
import torchshow as ts
from PIL import Image
from torchvision.transforms import functional as TF

from gaussian_denoiser import data, dncnn, transforms, utils

In [None]:
root = pyrootutils.setup_root(
    search_from=".",
    indicator="pyproject.toml",
    project_root_env_var=True,
    dotenv=True,
    pythonpath=False,
    cwd=True,
)

PROJECT_ROOT = os.getenv("PROJECT_ROOT")

## Functions

In [None]:
def find_all_ckpt_files(directory: Path) -> list[Path]:
    # Find all files with .ckpt extension recursively
    ckpt_files = list(directory.rglob("*.ckpt"))
    return ckpt_files


def get_ckpt(ckpt_list: list[Path]):
    if ckpt_list:
        print("Found checkpoint files:")
        for file in ckpt_list:
            print(file)
    else:
        print("No checkpoint files found.")

    return ckpt_list[0]

## Parameters

In [None]:
MODEL_PATH = Path(PROJECT_ROOT).joinpath("logs/train/2024-08-12_11-42-38")

CFG_PATH = MODEL_PATH.joinpath(".hydra/config.yaml")
TEST_IMAGE = "docs/cherry.jpg"

DEVICE = "cpu"

In [None]:
cfg = omegaconf.OmegaConf.load(CFG_PATH)

# Load Image

In [None]:
test_image_path = Path(PROJECT_ROOT).joinpath(TEST_IMAGE)

test_image = Image.open(test_image_path)

test_image.thumbnail((1024, 1024))
test_image
test_image.size

## Preprocess Image

In [None]:
image_tensor = TF.to_tensor(test_image)

In [None]:
image_tensor_patches, padding = utils.patchify(image_tensor, patch_size=cfg.data.patch_size)

In [None]:
image_tensor_patches.shape

In [None]:
image_tensor_rec = utils.depatchify(
    image_tensor_patches, image_tensor.shape, patch_size=cfg.data.patch_size, padding=padding
)

In [None]:
ts.show(image_tensor)

In [None]:
ts.show(image_tensor_patches)

In [None]:
ts.show(image_tensor_rec)

In [None]:
torch.testing.assert_close(image_tensor_rec, image_tensor)

In [None]:
image_tensor.shape

In [None]:
image_tensor_rec.shape

## Load Model

In [None]:
ckpt_path_list = find_all_ckpt_files(MODEL_PATH)
ckpt_path = get_ckpt(ckpt_path_list)

In [None]:
model = dncnn.DnCNNModule.load_from_checkpoint(ckpt_path)

model.eval()
model.freeze()

model.to("cpu")

## Test inference on full image and patches

Full images

In [None]:
image_tensor_batch = TF.to_tensor(test_image).unsqueeze(0)
image_tensor_batch.shape

In [None]:
model.device

In [None]:
with torch.no_grad():
    tensor_noise_estimate = model(image_tensor_batch).cpu()
    tensor_denoised = image_tensor_batch - tensor_noise_estimate
    tensor_denoised = torch.clip(tensor_denoised, 0, 1.0).squeeze()

In [None]:
TF.to_pil_image(tensor_noise_estimate.squeeze(0))

ts.show(tensor_noise_estimate)
tensor_noise_estimate.min()
tensor_noise_estimate.max()

In [None]:
import torchshow as ts

ts.show(image_tensor)
ts.show(tensor_denoised)

Patches

In [None]:
with torch.no_grad():
    tensor_noise_estimate_patches = model(image_tensor_patches).cpu()
    tensor_denoised_patches = image_tensor_patches - tensor_noise_estimate_patches
    tensor_denoised_patches = torch.clip(tensor_denoised_patches, 0, 1.0)

tensor_denoised_patches = utils.depatchify(
    tensor_denoised_patches,
    original_size=image_tensor.squeeze(0).shape,
    patch_size=cfg.data.patch_size,
    padding=padding,
)
tensor_denoised_patches.shape

In [None]:
ts.show(image_tensor)
ts.show(tensor_denoised_patches)

In [None]:
torch.testing.assert_close(tensor_denoised, tensor_denoised_patches)

In [None]:
((image_tensor - tensor_denoised) ** 2).mean()

In [None]:
from torchmetrics.image.psnr import PeakSignalNoiseRatio

psnr = PeakSignalNoiseRatio((0, 1))

psnr(image_tensor, tensor_denoised)

In [None]:
image_denoised = TF.to_pil_image(tensor_denoised.squeeze(0))

In [None]:
image_denoised

## Test with some Noise

In [None]:
x = TF.to_tensor(test_image)
x_noisy = torch.clip(torch.rand_like(x) * 50 / 255.0 + x, 0, 1)

In [None]:
x_noisy.shape

In [None]:
model.eval()
with torch.no_grad():
    noise_estimate = model(x_noisy.unsqueeze(0)).cpu()
    x_denoised = x_noisy - noise_estimate.squeeze(0)
    x_denoised = torch.clip(x_denoised, 0, 1.0).squeeze()

In [None]:
ts.show(x)
ts.show(x_noisy)
ts.show(x - x_noisy)
ts.show(x_denoised)