In [1]:
from information_estimation_metric import InformationEstimationMetric
from torchvision.io import read_image
from torchvision.transforms.functional import resize
import torch
import math
from torchvision.utils import save_image
import os
from torch import optim
from tqdm.notebook import tqdm

In [2]:
iem = InformationEstimationMetric('./checkpoints/imagenet_256x256_loguniform_00400000.pth', 'bf16', False).cuda()

# Simple computation of the IEM between two example images

In [3]:
# The input images should be of size 256x256. Change the diffusion model to support other image resolutions.

# ref_img = (resize(read_image('./examples/butter_flower.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.
# dist_img = (resize(read_image('./examples/butter_flower.fnoise.2.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.
ref_img = (resize(read_image('./examples/butter_flower.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.
dist_img = (resize(read_image('./examples/butter_flower.fnoise.2.png'), 256).unsqueeze(0).cuda() / 255.) * 2. - 1.

num_gamma = 64
sigma_min = 1.
sigma_max = 1e3
iem_type = 'standard'
seed = 42

with torch.no_grad():
    # Expects images in the range [-1, 1].
    print(iem(ref_img, dist_img, num_gamma=num_gamma, sigma_min=sigma_min, sigma_max=sigma_max, iem_type=iem_type, seed=seed))

tensor([14.4503], device='cuda:0')


# Maximum differentiation competition

In [4]:
def initialize_delta(target_l2, x):
    torch.manual_seed(42)
    delta = torch.randn_like(x)

    delta = target_l2 * delta / torch.norm(delta)
    delta = delta.clone().detach()
    delta.requires_grad_(True)
    with torch.no_grad():
        delta.add_(x)
        delta.clamp_(-1, 1)
        delta.sub_(x)
    return delta


dim = 256 * 256 * 3
def psnr_to_l2(psnr):
    return math.sqrt(dim) * 2 * 10 ** (-psnr / 20)
def psnr(x1, x2):
    mse = (x1 - x2).pow(2).mean()
    return 10 * torch.log10(1 / mse)

psnr_targets = [10, 15, 20, 25, 30, 35]
l2_targets = [psnr_to_l2(psnr) for psnr in psnr_targets]
lr = 5e-2

## Minimize or maximize the IEM between an image and its distorted version while keeping the PSNR fixed

In [5]:
def optimize_iem(diffusion_model, x, target_l2, sigma_min, sigma_max, num_opt_steps, lr=1e-3, minimize=True):
    torch.manual_seed(42)
    delta = initialize_delta(target_l2, x)
    initial_delta = delta.clone()

    opt = optim.Adam([delta], lr=lr)
    pbar = tqdm(range(num_opt_steps))
    gamma_min = 1 / (sigma_max ** 2)
    gamma_max = 1 / (sigma_min ** 2)
    gammas = torch.logspace(math.log(gamma_min), math.log(gamma_max), num_opt_steps).cuda()

    for i in pbar:
        opt.zero_grad()
        gamma = gammas[i]

        sigma = 1 / gamma.sqrt()
        noise = sigma.item() * torch.randn_like(x)
        y_sigma1 = x + noise
        y_sigma2 = (x + delta) + noise
        with torch.no_grad():
            est1 = diffusion_model(y_sigma1, sigma)
        est2 = diffusion_model(y_sigma2, sigma)
        diff = (x - est1) - ((x + delta) - est2)
        loss = gamma * diff.pow(2).mean()
        if not minimize:
            loss = - loss

        loss.backward()
        pbar.set_description(f"loss={loss.item()}")
        opt.step()
        with torch.no_grad():
            delta /= torch.norm(delta.data)
            delta *= target_l2
        with torch.no_grad():
            delta.add_(x)
            delta.clamp_(-1, 1)
            delta.sub_(x)
    return delta, initial_delta

deltas = []
for l2_target in l2_targets:
    delta_min, initial_delta = optimize_iem(iem.diffusion_model, ref_img, l2_target, sigma_min=sigma_min, sigma_max=sigma_max,
                             num_opt_steps=1000, lr=lr, minimize=True)
    delta_max, _ = optimize_iem(iem.diffusion_model, ref_img, l2_target, sigma_min=sigma_min, sigma_max=sigma_max,
                             num_opt_steps=1000, lr=lr, minimize=False)
    deltas.append((delta_min, delta_max, initial_delta))

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [6]:
folder_name = f'./mad_comp/iem/lr={lr}/'
os.makedirs(folder_name, exist_ok=True)
for i, ((delta_min, delta_max, initial_delta), psnr_target) in enumerate(zip(deltas, psnr_targets)):
        if i == 0:
            save_image(ref_img * 0.5 + 0.5,
                       os.path.join(folder_name, 'ref_img.png'))
        distorted_delta_min = ((ref_img + delta_min).clip(-1, 1)) * 0.5 + 0.5
        distorted_delta_max = ((ref_img + delta_max).clip(-1, 1)) * 0.5 + 0.5
        distorted_initial_delta = ((ref_img + initial_delta).clip(-1, 1)) * 0.5 + 0.5

        print(psnr_target, psnr(distorted_delta_min, ref_img * 0.5 + 0.5),
              psnr(distorted_delta_max, ref_img * 0.5 + 0.5), psnr(distorted_initial_delta, ref_img * 0.5 + 0.5))

        save_image(distorted_delta_min, os.path.join(folder_name,
                                                     f'distorted_delta_min_psnr={psnr_target}.png'))
        save_image(distorted_delta_max, os.path.join(folder_name,
                                                     f'distorted_delta_max_psnr={psnr_target}.png'))
        save_image(distorted_initial_delta, os.path.join(folder_name,
                                                     f'distorted_initial_delta_psnr={psnr_target}.png'))

10 tensor(10.0018, device='cuda:0', grad_fn=<MulBackward0>) tensor(10.3763, device='cuda:0', grad_fn=<MulBackward0>) tensor(12.2435, device='cuda:0', grad_fn=<MulBackward0>)
15 tensor(15.0017, device='cuda:0', grad_fn=<MulBackward0>) tensor(15.4222, device='cuda:0', grad_fn=<MulBackward0>) tensor(16.5815, device='cuda:0', grad_fn=<MulBackward0>)
20 tensor(20.0020, device='cuda:0', grad_fn=<MulBackward0>) tensor(20.4564, device='cuda:0', grad_fn=<MulBackward0>) tensor(21.1622, device='cuda:0', grad_fn=<MulBackward0>)
25 tensor(25.0106, device='cuda:0', grad_fn=<MulBackward0>) tensor(25.5256, device='cuda:0', grad_fn=<MulBackward0>) tensor(25.9441, device='cuda:0', grad_fn=<MulBackward0>)
30 tensor(30.0294, device='cuda:0', grad_fn=<MulBackward0>) tensor(30.5171, device='cuda:0', grad_fn=<MulBackward0>) tensor(30.8342, device='cuda:0', grad_fn=<MulBackward0>)
35 tensor(35.0893, device='cuda:0', grad_fn=<MulBackward0>) tensor(35.4313, device='cuda:0', grad_fn=<MulBackward0>) tensor(35.749

## Other metrics

In [7]:
from pyiqa import create_metric
def optimize_pyiqa(x, target_l2, num_opt_steps, lr=1e-3, minimize=True):
    torch.manual_seed(43)
    delta = initialize_delta(target_l2, x)

    opt = optim.Adam([delta], lr=lr)
    pbar = tqdm(range(num_opt_steps))
    for i in pbar:
        opt.zero_grad()
        if (x + delta).max() > 1 or (x + delta).min() < -1:
            raise Exception()
        loss = metric((x + delta) * 0.5 + 0.5, x * 0.5 + 0.5).mean()

        if not minimize:
            loss = - loss
        if not metric.lower_better:
            loss = - loss
        loss.backward()
        pbar.set_description(f"metric_name={metric_name};loss={loss.item()}")
        opt.step()
        with torch.no_grad():
            delta /= torch.norm(delta.data)
            delta *= target_l2
        with torch.no_grad():
            delta.add_(x)
            delta.clamp_(-1, 1)
            delta.sub_(x)
    return delta

# ['dists', 'lpips', 'vif', 'topiq_fr', 'fsim', 'ssim', 'nlpd', 'gmsd', 'pieapp', 'mad']
metric_name = 'dists'
metric = create_metric(metric_name, device=torch.device('cuda'), as_loss=True)
folder_name = f'./mad_comp/{metric_name}/lr={lr}/'
os.makedirs(folder_name, exist_ok=True)
deltas = []
for l2_target in l2_targets:
    try:
        delta_min = optimize_pyiqa(ref_img, l2_target, num_opt_steps=1000, lr=lr, minimize=True)
        delta_max = optimize_pyiqa(ref_img, l2_target, num_opt_steps=1000, lr=lr, minimize=False)
        deltas.append((delta_min, delta_max))
    except Exception as e:
        pass

for i, ((delta_min, delta_max), psnr_target) in enumerate(zip(deltas, psnr_targets)):
    if i == 0:
        save_image(ref_img * 0.5 + 0.5,
                   os.path.join(folder_name, 'ref_img.png'))
    distorted_delta_min = ((ref_img + delta_min)) * 0.5 + 0.5
    distorted_delta_max = ((ref_img + delta_max)) * 0.5 + 0.5

    print(psnr_target, psnr(distorted_delta_min, ref_img * 0.5 + 0.5),
          psnr(distorted_delta_max, ref_img * 0.5 + 0.5))

    save_image(distorted_delta_min, os.path.join(folder_name,
                                                 f'distorted_delta_min_psnr={psnr_target}.png'))
    save_image(distorted_delta_max, os.path.join(folder_name,
                                                 f'distorted_delta_max_psnr={psnr_target}.png'))

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /dev/shm/.cache-gohayon/torch/hub/checkpoints/vgg16-397923af.pth


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 528M/528M [00:11<00:00, 47.2MB/s]


Downloading: "https://huggingface.co/chaofengc/IQA-PyTorch-Weights/resolve/main/DISTS_weights-f5e65c96.pth" to /dev/shm/.cache-gohayon/torch/hub/pyiqa/DISTS_weights-f5e65c96.pth



100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12.0k/12.0k [00:00<00:00, 17.4MB/s]

Loading pretrained model DISTS from /dev/shm/.cache-gohayon/torch/hub/pyiqa/DISTS_weights-f5e65c96.pth





  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

10 tensor(10.0016, device='cuda:0', grad_fn=<MulBackward0>) tensor(10.0065, device='cuda:0', grad_fn=<MulBackward0>)
15 tensor(15.0021, device='cuda:0', grad_fn=<MulBackward0>) tensor(15.0073, device='cuda:0', grad_fn=<MulBackward0>)
20 tensor(20.0048, device='cuda:0', grad_fn=<MulBackward0>) tensor(20.0121, device='cuda:0', grad_fn=<MulBackward0>)
25 tensor(25.0157, device='cuda:0', grad_fn=<MulBackward0>) tensor(25.0165, device='cuda:0', grad_fn=<MulBackward0>)
30 tensor(30.0252, device='cuda:0', grad_fn=<MulBackward0>) tensor(30.0354, device='cuda:0', grad_fn=<MulBackward0>)
35 tensor(35.0775, device='cuda:0', grad_fn=<MulBackward0>) tensor(35.1097, device='cuda:0', grad_fn=<MulBackward0>)
