In [None]:
%load_ext autoreload
%autoreload 2
%env WANDB_SILENT=true

In [13]:
import math
import torch
from torchvision.transforms import Resize
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, LearnedPerceptualImagePatchSimilarity
import utils
from models import SkipNet, ProgNet, UNet, UNetMod
from denoisers import DIP, DIP_MWV, DIP_TV, DDIP, GaussianBlur, DDIP_P,  DIP2Self, DIP2Self_MWV, SGDIP, SGDIP_MWV

In [None]:
# clean = utils.load_celeba(30)
# clean = utils.load_images('./data/CBSD68/')
# clean = utils.load_images('./data/Set14/')
clean = utils.load_image("./data/set5/butterfly.png")
noisy = torch.stack([utils.get_noisy_image(img, 15) for img in clean])

utils.plot_row([clean[0], noisy[0]])

psnr = PeakSignalNoiseRatio(data_range=2)
print(psnr(noisy, clean))

In [15]:
def cos(step, start=0, end=1):
    if step < start:
        return 0
    
    if step > end:
        return 1
    
    step = (step - start) / (end - start)
    return math.cos((math.pi / 2) * (1 - step))**2

In [None]:
net = UNetMod(hidden_ch=8, n_layers=4)

denoisers = [
    DIP_MWV(net),
    DIP2Self_MWV(net),
    SGDIP_MWV(net),
    DDIP(net, cos_offset=0.1, sqrt=True),
]

results = {}
for i, denoiser in enumerate(denoisers):
    outputs = []
    for j, (x_hat, x) in enumerate(zip(noisy, clean)):
        options = {
            "mode": "local",
            "metrics": ["psnr", "ssim"],
            "save_images": False,
            "config": {
                "project": "zero-shot-das-denoising",
                "entity": "jmaen-team",
                "group": "sgdip",
                "dataset": "set14-15",
                "denoiser_id": i,
                "data_id": j,
            }
        }
        output = denoiser.denoise(x_hat.unsqueeze(0), x.unsqueeze(0), options)
        output = output.detach().cpu()
        outputs.append(output)

    results[f"{i} - {denoiser.key()}"] = torch.cat(outputs)

In [None]:
for result in zip(noisy, *results.values(), clean):
    utils.plot_row(result, ["Noisy", *list(results.keys()), "Clean"])

In [None]:
metrics = {
    "PSNR": PeakSignalNoiseRatio(data_range=2, reduction=None, dim=[1, 2, 3]),
    "SSIM": StructuralSimilarityIndexMeasure(data_range=2, reduction=None),
    # "LPIPS": LearnedPerceptualImagePatchSimilarity(),
}

for key, metric in metrics.items():
	print(key)
	print("----------")
	print(f"Noisy:".ljust(60), f"{metric(noisy, clean).mean()}")
	for name, xs in results.items():
		s = metric(xs, clean)
		print(f"{name}:".ljust(60), f"{s.mean()} ({s.var()})\t{s.tolist()}")
	print("\n")