In [None]:
%load_ext autoreload

%autoreload 2

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Export Model using Torch Script

In [None]:
import os
from pathlib import Path

import omegaconf
import pyrootutils
import torch
from PIL import Image
from torchvision.transforms import functional as TF

from gaussian_denoiser import data, dncnn, transforms, utils

In [None]:
root = pyrootutils.setup_root(
    search_from=".",
    indicator="pyproject.toml",
    project_root_env_var=True,
    dotenv=True,
    pythonpath=False,
    cwd=True,
)

PROJECT_ROOT = os.getenv("PROJECT_ROOT")

## Parameters

In [None]:
MODEL_PATH = Path(PROJECT_ROOT).joinpath("logs/train/CDnCNN-B_2024-08-13_21-00-53")

CFG_PATH = MODEL_PATH.joinpath(".hydra/config.yaml")

TEST_DATA = Path(PROJECT_ROOT)

DEVICE = "cpu"

SAVE_PATH = "models/test.pt"
TEST_IMAGE = "docs/cactus.jpg"

DEVICE = "cpu"

In [None]:
cfg = omegaconf.OmegaConf.load(CFG_PATH)
PATCH_SIZE = cfg.experiment.data.patch_size

## Export Model

In [None]:
ckpt_path_list = utils.find_all_ckpt_files(MODEL_PATH)
ckpt_path = utils.get_ckpt(ckpt_path_list)

In [None]:
model = dncnn.DnCNNModule.load_from_checkpoint(ckpt_path)

In [None]:
compiled_model = model.to_torchscript()

In [None]:
save_path = Path(PROJECT_ROOT).joinpath(SAVE_PATH)
compiled_model.save(save_path)

## Load Model

In [None]:
loaded_model = torch.jit.load(SAVE_PATH)

In [None]:
test_image_path = Path(PROJECT_ROOT).joinpath(TEST_IMAGE)

test_image = Image.open(test_image_path)

test_image.thumbnail((1024, 1024))
test_image
test_image.size

In [None]:
loaded_model = loaded_model.to(DEVICE).eval()
model = model.to("cpu").eval()
with torch.no_grad():
    x = TF.to_tensor(test_image).unsqueeze(0)
    outputs = loaded_model(x)
    output_orig = model(x)

In [None]:
torch.testing.assert_close(outputs, output_orig)

In [None]:
torch.min(outputs)
torch.max(outputs)
TF.to_pil_image(outputs.squeeze())

In [None]:
x_denoised = torch.clip(x - outputs, 0, 1)

In [None]:
TF.to_pil_image(x_denoised.squeeze(0))
TF.to_pil_image(x.squeeze(0))

### Patchify

In [None]:
from gaussian_denoiser import utils

In [None]:
x = TF.to_tensor(test_image)
patches, patchify_padding = utils.patchify(x, PATCH_SIZE)
with torch.no_grad():
    noise_patches = loaded_model(patches)

noise = utils.depatchify(
    noise_patches, original_size=x.shape, patch_size=PATCH_SIZE, padding=patchify_padding
)

In [None]:
TF.to_pil_image(noise)

In [None]:
import torchshow as ts

ts.show(noise)

In [None]:
x_rec = torch.clip(x - noise, 0, 1)
image_rec = TF.to_pil_image(x_rec)
image_rec

In [None]:
torch.testing.assert_close(x_rec, x_denoised.squeeze(0))