In [None]:
%load_ext autoreload
%autoreload 2
%env WANDB_SILENT=true

In [44]:
import torch
from torchvision.transforms import Resize
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, LearnedPerceptualImagePatchSimilarity
import utils
from models import UNet, SkipNet, ProgNet
from denoisers import DIP, DIP_MWV, DIP_TV, DDIP, GaussianBlur, SelfDIP, DDIP_MWV, DDIP_P, DIP_P

In [None]:
# clean = utils.load_celeba(1)
# clean = utils.load_images('./data/CBSD68/', Resize((256, 256)))[:30]
clean = utils.load_images('./data/Set14/', Resize((256, 256)))
# clean = utils.load_image("./data/set5/butterfly.png")
noisy = torch.stack([utils.get_noisy_image(img, 10) for img in clean])

utils.plot_row([clean[0], noisy[0]])

In [None]:
skipnet_none = SkipNet(3, [8, 16, 32, 64, 128], [0, 0, 0, 0, 0], label="none")
skipnet_full = SkipNet(3, [8, 16, 32, 64, 128], [4, 4, 4, 4, 4], label="full")
skipnet_late = SkipNet(3, [8, 16, 32, 64, 128], [0, 4, 4, 4, 4], label="late")
skipnet_early = SkipNet(3, [8, 16, 32, 64, 128], [4, 4, 0, 0, 0], label="early")
unet = UNet(hidden_ch=8, n_layers=5)
prognet = ProgNet()

print(skipnet_full)

denoisers = [DIP_MWV(skipnet_late), DIP(skipnet_full), DDIP(skipnet_full), DIP_P(prognet), DDIP_P(prognet)]

results = {}
for denoiser in denoisers:
    outputs = []
    for i, (x_hat, x) in enumerate(zip(noisy, clean)):
        options = {
            "mode": "wandb",
            "metrics": ["psnr", "ssim", "lpips"],
            "config": {
                "project": "zero-shot-das-denoising",
                "entity": "jmaen-team", 
                "group": "prognet2",
                "dataset": "set14-10",
                "id": i,
            }
        }
        output = denoiser.denoise(x_hat.unsqueeze(0), x.unsqueeze(0), options)
        output = output.detach().cpu()
        outputs.append(output)

    results[denoiser.key()] = torch.cat(outputs)

In [None]:
for result in zip(noisy, *results.values(), clean):
    utils.plot_row(result, ["Noisy", *list(results.keys()), "Clean"])

In [None]:
metrics = {
    "PSNR": PeakSignalNoiseRatio(reduction=None, data_range=1, dim=[1, 2, 3]),
    "SSIM": StructuralSimilarityIndexMeasure(reduction=None, data_range=1),
    "LPIPS": LearnedPerceptualImagePatchSimilarity(),
}

for key, metric in metrics.items():
	print(key)
	print("----------")
	print(f"Noisy:".ljust(30), f"{metric(noisy, clean).mean()}")
	for name, xs in results.items():
		s = metric(xs, clean)
		print(f"{name}:".ljust(30), f"{s.mean()}\t{s.tolist()}")
	print("\n")