In [None]:
from information_estimation_metric import InformationEstimationMetric
from torchvision.io import read_image
from torchvision.transforms.functional import resize
import torch
import math
from torchvision.utils import save_image
import os

In [None]:
iem = InformationEstimationMetric('./checkpoints/imagenet_256x256_loguniform_00400000.pth', 'bf16', False).cuda()

## Simple computation of the IEM between two example images

In [None]:
# For our trained diffusion model, the input images must be of size 256x256.
# ref_img = (resize(read_image('./examples/butter_flower.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.
# dist_img = (resize(read_image('./examples/butter_flower.fnoise.2.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.
ref_img = (resize(read_image('./examples/butter_flower.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.
dist_img = (resize(read_image('./examples/butter_flower.fnoise.2.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.

num_gamma = 64
sigma_min = 1.
sigma_max = 1e3
iem_type = 'standard'
seed = 42

with torch.no_grad():
    print(iem(ref_img, dist_img, num_gamma=num_gamma, sigma_min=sigma_min, sigma_max=sigma_max, iem_type=iem_type, seed=seed))

## Minimize or maximize the IEM between an image and its distorted version while keeping the PSNR fixed

In [None]:
from torch import optim
from tqdm.notebook import tqdm

def optimize(diffusion_model, x, target_l2, sigma_min, sigma_max, num_opt_steps, lr=1e-3, minimize=True):
    torch.manual_seed(42)
    delta = x.clone() + torch.randn_like(x) * 0.01
    delta = target_l2 * delta / torch.norm(delta)
    delta = delta.clone().detach()
    delta.requires_grad_(True)
    
    opt = optim.Adam([delta], lr=lr)
    pbar = tqdm(range(num_opt_steps))
    for i in pbar:
        opt.zero_grad()
        gamma_min = 1 / (sigma_max ** 2)
        gamma_max = 1 / (sigma_min ** 2)
        u = torch.rand(1).cuda() * (math.log(gamma_max) - math.log(gamma_min)) + math.log(gamma_min)
        gamma = u.exp()


        sigma = 1 / gamma.sqrt()
        noise = sigma.item() * torch.randn_like(x)
        y_sigma1 = x + noise
        y_sigma2 = (x + delta).clip(-1,1) + noise
        with torch.no_grad():
            est1 = diffusion_model(y_sigma1, sigma)
        est2 = diffusion_model(y_sigma2, sigma)
        diff = (x - est1) - ((x + delta).clip(-1,1) - est2)
        loss = diff.pow(2).mean()
        if not minimize:
            loss = - loss
        loss = gamma * loss + 10 * ((x + delta).clip(-1,1).detach() - (x + delta)).abs().mean()
        loss.backward()
        pbar.set_description(f"loss={loss.item()}")
        opt.step()
        with torch.no_grad():
            delta /= torch.norm(delta.data)
            delta *= target_l2
    return delta

deltas = []
dim = 256 * 256 * 3
psnr_targets = [15, 20, 25, 30, 35]
def psnr_to_l2(psnr):
    return math.sqrt(dim) * 2 * 10 ** (-psnr / 20)
l2_targets = [psnr_to_l2(psnr) for psnr in psnr_targets]
for l2_target in l2_targets:
    delta_min = optimize(iem.diffusion_model, ref_img, l2_target, sigma_min=1e-1, sigma_max=1e3, num_opt_steps=100, lr=5e-3, minimize=True)
    delta_max = optimize(iem.diffusion_model, ref_img, l2_target, sigma_min=1e-1, sigma_max=1e3, num_opt_steps=100, lr=5e-3, minimize=False)
    deltas.append((delta_min, delta_max))

In [None]:
def psnr(x1, x2):
    mse = (x1 - x2).pow(2).mean()
    return 10 * torch.log10(1 / mse)
os.makedirs('./examples_distorted', exist_ok=True)
for i, ((delta_min, delta_max), psnr_target) in enumerate(zip(deltas, psnr_targets)):
    if i == 0:
        save_image(ref_img * 0.5 + 0.5,  os.path.join('./examples_distorted', 'ref_img.png'))
    distorted_delta_min = ((ref_img + delta_min).clip(-1, 1)) * 0.5 + 0.5
    distorted_delta_max = ((ref_img + delta_max).clip(-1, 1)) * 0.5 + 0.5

    print(psnr_target, psnr(distorted_delta_min, ref_img * 0.5 + 0.5), psnr(distorted_delta_max, ref_img * 0.5 + 0.5))
    
    save_image(distorted_delta_min, os.path.join('./examples_distorted', f'distorted_delta_min_psnr={psnr_target}.png'))
    save_image(distorted_delta_max, os.path.join('./examples_distorted', f'distorted_delta_max_psnr={psnr_target}.png'))
    