In [None]:
%load_ext autoreload
%autoreload 2
%env WANDB_SILENT=true

In [7]:
import math
import torch
from torchvision.transforms import Resize
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, LearnedPerceptualImagePatchSimilarity
import utils
from models import SkipNet, ProgNet, UNet, UNetMod
from denoisers import DIP, DIP_MWV, DIP_TV, DDIP, GaussianBlur, DDIP_P, DIP_Noise, DDIP_AE, DDIP_Dual

In [None]:
# clean = utils.load_celeba(30)
# clean = utils.load_images('./data/CBSD68/', Resize((256, 256)))[10:30]
clean = utils.load_images('./data/Set14/', Resize((256, 256)))
# clean = utils.load_image("./data/set5/butterfly.png")
noisy = torch.stack([utils.get_noisy_image(img, 10) for img in clean])

utils.plot_row([clean[0], noisy[0]])

In [9]:
def cos(step, start=0, end=1):
    if step < start:
        return 0
    
    if step > end:
        return 1
    
    step = (step - start) / (end - start)
    return math.cos((math.pi / 2) * (1 - step))**2

In [None]:
skipnet_none = SkipNet(3, [8, 16, 32, 64, 128], [0, 0, 0, 0, 0], label="none")
skipnet_full = SkipNet(3, [8, 16, 32, 64, 128], [4, 4, 4, 4, 4], label="full")
skipnet_late = SkipNet(3, [8, 16, 32, 64, 128], [0, 0, 0, 4, 4], label="late")
skipnet_early = SkipNet(3, [8, 16, 32, 64, 128], [4, 4, 0, 0, 0], label="early")

prognet_fade = ProgNet(label="fade")
prognet_full = ProgNet(skip_schedules=[lambda x: 1 for _ in range(4)], label="full")
prognet_none = ProgNet(skip_schedules=[lambda x: 0 for _ in range(4)], label="full")
prognet_late = ProgNet(skip_schedules=[lambda x: 1 if i > 1 else 0 for i in range(4)], label="full")
prognet_gradual0 = ProgNet(skip_schedules=[lambda x, i=i: cos(x, 0.2*i) for i in range(4)][::-1], label="gradual0")
prognet_gradual1 = ProgNet(skip_schedules=[lambda x, i=i: cos(x, 0.2*(i + 1)) for i in range(4)][::-1], label="gradual1")
prognet_step = ProgNet(skip_schedules=[lambda x, i=i: 1 if x > 0.2*(i + 1) else 0 for i in range(4)][::-1], label="step")

unet = UNet(hidden_ch=8, n_layers=4)
unetmod = UNetMod(hidden_ch=8, n_layers=4)

denoisers = [DDIP(skipnet_full, cos_offset=0.1), DDIP(unetmod, cos_offset=0.1), DDIP(unetmod, cos_offset=0.1)]

results = {}
for i, denoiser in enumerate(denoisers):
    outputs = []
    for j, (x_hat, x) in enumerate(zip(noisy, clean)):
        options = {
            "mode": "wandb",
            "metrics": ["psnr", "ssim", "lpips"],
            "save_images": False,
            "config": {
                "project": "zero-shot-das-denoising",
                "entity": "jmaen-team",
                "group": "architecture comparison 5",
                "dataset": "set14-10",
                "denoiser_id": i,
                "data_id": j,
            }
        }
        output = denoiser.denoise(x_hat.unsqueeze(0), x.unsqueeze(0), options)
        output = output.detach().cpu()
        outputs.append(output)

    results[denoiser.key()] = torch.cat(outputs)

In [None]:
for result in zip(noisy, *results.values(), clean):
    utils.plot_row(result, ["Noisy", *list(results.keys()), "Clean"])

In [None]:
metrics = {
    "PSNR": PeakSignalNoiseRatio(reduction=None, data_range=1, dim=[1, 2, 3]),
    "SSIM": StructuralSimilarityIndexMeasure(reduction=None, data_range=1),
    "LPIPS": LearnedPerceptualImagePatchSimilarity(),
}

for key, metric in metrics.items():
	print(key)
	print("----------")
	print(f"Noisy:".ljust(60), f"{metric(noisy, clean).mean()}")
	for name, xs in results.items():
		s = metric(xs, clean)
		print(f"{name}:".ljust(60), f"{s.mean()} ({s.var()})\t{s.tolist()}")
	print("\n")