In [1]:
import sys
sys.path.append("/home/tmartorella/ddpm")

from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from basicsr.metrics.metric_util import reorder_image, to_y_channel
from ddpm.datasets.celebahq import CelebAHQ
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2

t = transforms.Compose([transforms.ToTensor(), transforms.Resize([256]), transforms.CenterCrop([256, 256])])
dataset = CelebAHQ(root="/mnt/scitas/bastien/CelebAMask-HQ/CelebA-HQ-img/", transform=t)


def to_pil(img):
    return Image.fromarray((img * 255).permute(1, 2, 0).numpy().astype(np.uint8))


def _ssim(img1, img2):
    """Calculate SSIM (structural similarity) for one channel images.

    It is called by func:`calculate_ssim`.

    Args:
        img1 (ndarray): Images with range [0, 255] with order 'HWC'.
        img2 (ndarray): Images with range [0, 255] with order 'HWC'.

    Returns:
        float: ssim result.
    """

    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate SSIM (structural similarity).

    Ref:
    Image quality assessment: From error visibility to structural similarity

    The results are the same as that of the official released MATLAB code in
    https://ece.uwaterloo.ca/~z70wang/research/ssim/.

    For three-channel images, SSIM is calculated for each channel and then
    averaged.

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the SSIM calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: ssim result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    ssims = []
    for i in range(img1.shape[2]):
        ssims.append(_ssim(img1[..., i], img2[..., i]))
    return np.array(ssims).mean()


def get_metrics(img1, img2):
    ssim = calculate_ssim(img1, img2, 0)
    psnr = calculate_psnr(img1, img2, 0)
    return ssim, psnr

  warn(f"Failed to load image Python extension: {e}")


In [82]:
from pathlib import Path
from tqdm.auto import tqdm

files = [f for f in tqdm(list(Path("/home/tmartorella/imgs_for_metrics/fid/").glob("**/*.png"))) if f.is_file()]
files[:20]

  0%|          | 0/14584 [00:00<?, ?it/s]

[PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/corrupted.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/original.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/sde/non_clipped/399_0.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/ode/clipped/200_0.0001_0.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/ode/clipped/200_0.002_0.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/ode/clipped/200_0.000447214_0.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/ode/non_clipped/200_0.0001_0.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/ode/non_clipped/200_0.002_0.png'),
 PosixPath('/home/tmartorella/imgs_for_metrics/fid/2024-05-22_12-40/pixelate/25048/ode/non_clipped/200_

In [73]:
import pandas as pd
info = []
for f in files:
    f_split = str(f).split("/")
    if f.name in ["corrupted.png", "original.png"]:
        info.append({
            "path": f,
            "corruption": f_split[-3],
            "img_id": f_split[-2],
            "filename": f_split[-1],
            "algo": "none",
            "clipping": "none",
            "timestamp": f_split[-4],
        })
    else:
        info.append({
            "path": f,
            "corruption": f_split[-5],
            "img_id": f_split[-4],
            "filename": f_split[-1],
            "algo": f_split[-3],
            "clipping": f_split[-2],
            "timestamp": f_split[-6],
        })

df = pd.DataFrame(info)

In [74]:
df

Unnamed: 0,path,corruption,img_id,filename,algo,clipping,timestamp
0,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,corrupted.png,none,none,2024-05-22_12-40
1,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,original.png,none,none,2024-05-22_12-40
2,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,399_0.png,sde,non_clipped,2024-05-22_12-40
3,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,200_0.0001_0.png,ode,clipped,2024-05-22_12-40
4,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,200_0.002_0.png,ode,clipped,2024-05-22_12-40
...,...,...,...,...,...,...,...
14579,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,200_0.0002_0.png,ode,non_clipped,2024-05-21_23-02
14580,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,200_0.001_0.png,ode,non_clipped,2024-05-21_23-02
14581,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,200_0.000447214_0.png,ode,non_clipped,2024-05-21_23-02
14582,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,reconstruction_ddim.png,reconstruction,clipped,2024-05-21_23-02


In [5]:
df.groupby("corruption")['img_id'].unique().apply(len)

corruption
brownish_noise                 16
caustic_noise                  16
caustic_refraction             16
checkerboard_cutout           206
cocentric_sine_waves           16
fish_eye                       16
gaussian_blur                  16
inverse_sparkles               16
masking_random_color          206
masking_vline_random_color    206
perlin_noise                  206
pinch_and_twirl                16
pixelate                       16
plasma_noise                  206
scatter                        16
water_drop                     16
Name: img_id, dtype: int64

In [15]:
celebahq_fid = FrechetInceptionDistance(reset_real_features=False).to("cuda")

import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from ddpm.datasets.celebahq import CelebAHQ
from torchvision import transforms
t =  transforms.Compose([transforms.ToTensor(), transforms.Resize([256]), transforms.CenterCrop([256, 256])])

class ADataset(Dataset):
    def __init__(self, ds):
        self.ds = ds
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        return self.ds[idx][0]

celeba_dataset = CelebAHQ(root="/mnt/scitas/bastien/CelebAMask-HQ/CelebA-HQ-img/", transform=t)
loader = DataLoader(ADataset(celeba_dataset), batch_size=16, num_workers=4, pin_memory=True)

for batch in tqdm(loader):
    celebahq_fid.update((batch * 255).to(torch.uint8).to("cuda"), real=True)

100%|█████████████████████████████████████████████████████████████████████████████████| 469/469 [11:24<00:00,  1.46s/it]


In [35]:
from tqdm import tqdm
import torch
from torchmetrics.image import (
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
    LearnedPerceptualImagePatchSimilarity
)

def get_metrics(df, fid_metric):
    metrics = []
    
    for g in tqdm(df.groupby(["corruption"])):
        c = g[0]    
        
        psnr_src_metric = PeakSignalNoiseRatio().to("cuda")
        psnr_tgt_metric = PeakSignalNoiseRatio().to("cuda")
        ssim_src_metric = StructuralSimilarityIndexMeasure(data_range=(0, 1)).to("cuda")
        ssim_tgt_metric = StructuralSimilarityIndexMeasure(data_range=(0, 1)).to("cuda")
        lpips_src_metric = LearnedPerceptualImagePatchSimilarity().to("cuda")
        lpips_tgt_metric = LearnedPerceptualImagePatchSimilarity().to("cuda")
        l2_src_metric = []
        l2_tgt_metric = []
        fid_metric.reset()
        
        list_imgs = list(g[1].iterrows())
        for i, img in list_imgs:
            src = t(Image.open(img["path"].parents[0 if img["algo"] == "none" else 2] / "corrupted.png")).to("cuda")
            tgt = t(Image.open(img["path"].parents[0 if img["algo"] == "none" else 2] / "original.png")).to("cuda")
            img = t(Image.open(img["path"])).to("cuda")
            
            psnr_src_metric.update(img, src)
            psnr_tgt_metric.update(img, tgt)
            ssim_src_metric.update(img.unsqueeze(0), src.unsqueeze(0))
            ssim_tgt_metric.update(img.unsqueeze(0), tgt.unsqueeze(0))
            lpips_src_metric.update((img * 2 - 1).unsqueeze(0), (src * 2 - 1).unsqueeze(0))
            lpips_tgt_metric.update((img * 2 - 1).unsqueeze(0), (tgt * 2 - 1).unsqueeze(0))
            l2_src_metric.append(torch.linalg.norm((img - src).reshape(-1)))
            l2_tgt_metric.append(torch.linalg.norm((img - tgt).reshape(-1)))
            fid_metric.update((img * 255).to(torch.uint8).unsqueeze(0), real=False)
    
        psnr_to_source = psnr_src_metric.compute()
        psnr_to_target = psnr_tgt_metric.compute()
        ssim_to_source = ssim_src_metric.compute()
        ssim_to_target = ssim_tgt_metric.compute()
        l2_to_source = torch.stack(l2_src_metric).mean()
        l2_to_target = torch.stack(l2_tgt_metric).mean()
        lpips_to_source = lpips_src_metric.compute()
        lpips_to_target = lpips_tgt_metric.compute()
        fid = fid_metric.compute()
        
        metrics.append({
            "corruption": c[0],
            "psnr_to_source": psnr_to_source.item(),
            "psnr_to_target": psnr_to_target.item(),
            "ssim_to_source": ssim_to_source.item(),
            "ssim_to_target": ssim_to_target.item(),
            "l2_to_source": l2_to_source.item(),
            "l2_to_target": l2_to_target.item(),
            "lpips_to_source": lpips_to_source.item(),
            "lpips_to_target": lpips_to_target.item(),
            "fid": fid.item(),
            "num_imgs": len(list(g[1].iterrows()))
        })
        
    return pd.DataFrame(metrics)

In [84]:
import pandas as pd
info = []
for f in files:
    f_split = str(f).split("/")
    if f.name in ["corrupted.png", "original.png"]:
        info.append({
            "path": f,
            "corruption": f_split[-3],
            "img_id": f_split[-2],
            "filename": f_split[-1],
            "algo": "none",
            "clipping": "none",
            "timestamp": f_split[-4],
        })
    else:
        info.append({
            "path": f,
            "corruption": f_split[-5],
            "img_id": f_split[-4],
            "filename": f_split[-1],
            "algo": f_split[-3],
            "clipping": f_split[-2],
            "timestamp": f_split[-6],
        })

df = pd.DataFrame(info)

In [85]:
df

Unnamed: 0,path,corruption,img_id,filename,algo,clipping,timestamp
0,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,corrupted.png,none,none,2024-05-22_12-40
1,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,original.png,none,none,2024-05-22_12-40
2,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,399_0.png,sde,non_clipped,2024-05-22_12-40
3,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,200_0.0001_0.png,ode,clipped,2024-05-22_12-40
4,/home/tmartorella/imgs_for_metrics/fid/2024-05...,pixelate,25048,200_0.002_0.png,ode,clipped,2024-05-22_12-40
...,...,...,...,...,...,...,...
14579,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,200_0.0002_0.png,ode,non_clipped,2024-05-21_23-02
14580,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,200_0.001_0.png,ode,non_clipped,2024-05-21_23-02
14581,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,200_0.000447214_0.png,ode,non_clipped,2024-05-21_23-02
14582,/home/tmartorella/imgs_for_metrics/fid/2024-05...,plasma_noise,913,reconstruction_ddim.png,reconstruction,clipped,2024-05-21_23-02


In [None]:
dfs = {
    "reconstruction_without_clipping": {
        "df": df[(df["algo"] == "reconstruction") & (df["clipping"] == "non_clipped")].copy().drop(columns=["timestamp"]).copy()
    },
    "reconstruction_with_clipping": {
        "df": df[(df["algo"] == "reconstruction") & (df["clipping"] == "clipped")].copy().drop(columns=["timestamp"]).copy()
    },
    "ode_without_clipping": {
        "df": df[(df["algo"] == "ode") & (df["clipping"] == "non_clipped")].copy().drop(columns=["timestamp"]).copy()
    },
    "ode_with_clipping": {
        "df": df[(df["algo"] == "ode") & (df["clipping"] == "clipped")].copy().drop(columns=["timestamp"]).copy()
    },
    "sde_without_clipping": {
        "df": df[(df["algo"] == "sde") & (df["clipping"] == "non_clipped")].copy().drop(columns=["timestamp"]).copy()
    },
    "original": {
        "df": df[(df["algo"] == "none") & (df["filename"] == "original.png")].copy().drop(columns=["timestamp"]).copy()
    },
    "corrupted": {
        "df": df[(df["algo"] == "none") & (df["filename"] == "corrupted.png")].copy().drop(columns=["timestamp"]).copy()
    },
}

In [86]:
from IPython.display import display, Markdown

for k, v in dfs.items():
    if k not in ["corrupted", "original"]:
        continue
    
    metrics_per_corruption = get_metrics(v["df"].copy(), celebahq_fid)
    display(Markdown(f"## {k}"))
    display(metrics_per_corruption)

    v["metrics_per_corruption"] = metrics_per_corruption

    _fake_df = v["df"].copy()
    _fake_df["corruption"] = "all"

    metrics_aggregated = get_metrics(_fake_df.copy(), celebahq_fid)
    display(Markdown(f"## {k} - aggregated"))
    display(metrics_aggregated)
    
    v["metrics_aggregated"] = metrics_aggregated

  0%|          | 0/16 [00:00<?, ?it/s]

## reconstruction_without_clipping

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,brownish_noise,33.225513,16.662786,0.946103,0.741651,6.936468,64.84568,0.046587,0.326223,196.388412,12
1,caustic_noise,27.140509,17.575453,0.873486,0.740701,15.553489,58.117165,0.141721,0.208787,176.033417,12
2,caustic_refraction,25.192287,19.399538,0.836159,0.615427,20.963127,46.459045,0.184805,0.341567,206.411652,12
3,checkerboard_cutout,29.253408,16.562868,0.890337,0.70501,13.594751,64.467697,0.107694,0.35517,242.501373,201
4,cocentric_sine_waves,21.86768,18.237171,0.833034,0.446332,28.404007,50.151634,0.167484,0.463822,356.27597,12
5,fish_eye,27.003433,17.611645,0.83815,0.419762,17.235472,57.514568,0.184375,0.467431,357.767365,12
6,gaussian_blur,35.28616,21.69207,0.919496,0.538507,7.363235,36.12101,0.17409,0.507168,186.975159,12
7,inverse_sparkles,23.904257,9.511764,0.7406,0.508467,26.953619,145.905685,0.301696,0.434771,274.620697,12
8,masking_random_color,39.966961,11.689943,0.993666,0.237367,4.285335,113.057686,0.001225,1.074974,233.357834,201
9,masking_vline_random_color,38.997715,11.506734,0.996968,0.161408,4.682489,117.437515,0.001743,0.870393,354.384613,201


  0%|          | 0/1 [00:00<?, ?it/s]

## reconstruction_without_clipping - aggregated

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,27.162848,14.31604,0.894763,0.475805,13.408826,78.688309,0.120334,0.589465,147.212692,1137


  0%|          | 0/16 [00:00<?, ?it/s]

## reconstruction_with_clipping

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,brownish_noise,26.121008,16.665329,0.838245,0.694333,20.130707,64.842552,0.17731,0.384999,186.885651,12
1,caustic_noise,24.327139,17.971403,0.784794,0.723403,25.95075,55.543327,0.223206,0.240173,172.621185,12
2,caustic_refraction,23.106516,21.786087,0.720731,0.704153,30.723995,35.611797,0.281053,0.267829,191.160568,12
3,checkerboard_cutout,22.168167,17.770895,0.742535,0.675802,33.903587,55.51778,0.385263,0.332165,109.424927,201
4,cocentric_sine_waves,22.303226,24.638702,0.592351,0.74769,33.737755,25.377188,0.376157,0.198236,183.25766,12
5,fish_eye,23.165741,19.306591,0.738789,0.538276,30.293221,47.491699,0.444426,0.400741,181.379181,12
6,gaussian_blur,28.390585,21.145803,0.922678,0.547009,16.240461,38.49604,0.206648,0.479005,194.682678,12
7,inverse_sparkles,21.53957,9.45681,0.619701,0.497884,36.73344,147.687927,0.449115,0.489915,210.760651,12
8,masking_random_color,17.793606,14.219425,0.806903,0.318623,52.291435,82.99514,0.374036,0.777271,149.777832,201
9,masking_vline_random_color,16.406998,15.527749,0.734294,0.294387,66.852051,73.17041,0.281918,0.664414,298.121735,201


  0%|          | 0/1 [00:00<?, ?it/s]

## reconstruction_with_clipping - aggregated

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,20.053694,16.823309,0.771474,0.551432,39.276142,59.79269,0.31479,0.450601,84.666672,1137


  0%|          | 0/16 [00:00<?, ?it/s]

## ode_without_clipping

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,brownish_noise,19.256689,16.32379,0.512312,0.507242,46.204094,67.459557,0.302242,0.378966,132.249634,36
1,caustic_noise,20.435217,17.078411,0.561114,0.540488,40.873226,61.852879,0.23112,0.241409,126.1931,36
2,caustic_refraction,20.055094,20.439356,0.513504,0.557496,42.88266,41.577591,0.267652,0.227296,138.559952,36
3,checkerboard_cutout,19.384977,15.998348,0.53018,0.495348,46.576077,69.121674,0.422252,0.340637,43.878151,794
4,cocentric_sine_waves,19.683847,21.575821,0.360887,0.576972,45.320652,36.14146,0.407709,0.200878,130.784668,36
5,fish_eye,20.825779,18.482861,0.563595,0.478409,38.942005,52.24892,0.426021,0.294639,128.090317,36
6,gaussian_blur,24.254257,20.281818,0.726637,0.521745,25.678545,42.293194,0.524971,0.267856,125.149765,36
7,inverse_sparkles,20.598839,9.507915,0.507744,0.390023,40.114197,146.104446,0.408331,0.487912,165.143204,36
8,masking_random_color,13.520676,14.728349,0.147915,0.47974,91.878456,80.002449,1.03745,0.443043,42.969093,794
9,masking_vline_random_color,12.519159,15.807421,0.107939,0.426255,104.659729,71.154999,0.80012,0.406525,68.092079,794


  0%|          | 0/1 [00:00<?, ?it/s]

## ode_without_clipping - aggregated

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,16.203596,16.360512,0.395433,0.503081,61.54916,65.140305,0.552198,0.337842,23.434536,4366


  0%|          | 0/16 [00:00<?, ?it/s]

## ode_with_clipping

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,brownish_noise,18.847744,16.28931,0.508399,0.508693,48.950623,67.614166,0.322136,0.374244,130.405304,36
1,caustic_noise,19.622019,16.919168,0.542042,0.540579,45.402393,62.591099,0.256363,0.247011,123.796585,36
2,caustic_refraction,19.303566,20.117157,0.49542,0.560948,47.297813,42.994968,0.311166,0.235693,134.792618,36
3,checkerboard_cutout,18.48344,16.758698,0.518448,0.515241,52.113293,63.029068,0.445382,0.308977,38.901421,790
4,cocentric_sine_waves,19.207561,21.152426,0.339193,0.587879,48.044144,37.888996,0.431545,0.197457,125.790138,36
5,fish_eye,19.8563,18.608702,0.5512,0.493996,44.294979,51.493809,0.457936,0.284566,123.988266,36
6,gaussian_blur,23.287668,19.752218,0.712878,0.497371,29.404999,45.098518,0.528375,0.28695,121.960548,36
7,inverse_sparkles,19.299818,9.330025,0.475586,0.389907,47.576233,149.977783,0.47414,0.496193,149.280975,36
8,masking_random_color,12.755767,16.376781,0.148555,0.506442,100.459541,65.604744,1.072847,0.36733,38.398689,790
9,masking_vline_random_color,12.177476,17.486828,0.11909,0.47675,108.931442,58.136829,0.800525,0.330255,60.602184,790


  0%|          | 0/1 [00:00<?, ?it/s]

## ode_with_clipping - aggregated

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,15.570461,17.242403,0.390493,0.520305,67.111809,58.992222,0.571618,0.304058,21.87541,4346


  0%|          | 0/16 [00:00<?, ?it/s]

## sde_without_clipping

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,brownish_noise,20.146896,16.538826,0.506674,0.513831,43.489246,65.736908,0.298748,0.377522,168.984985,16
1,caustic_noise,21.031092,17.369131,0.552113,0.537869,39.277027,59.904411,0.222925,0.229087,144.125748,16
2,caustic_refraction,20.477188,21.083298,0.512651,0.579308,41.813652,38.916134,0.261094,0.199185,140.547226,16
3,checkerboard_cutout,20.394911,16.165697,0.550498,0.507901,42.09314,67.769615,0.399148,0.327353,80.512245,202
4,cocentric_sine_waves,20.080894,22.496641,0.348283,0.607563,43.814995,33.007401,0.403271,0.165341,147.68988,16
5,fish_eye,21.447824,18.742651,0.567261,0.496325,37.369629,50.877831,0.423532,0.280955,151.167526,16
6,gaussian_blur,25.263248,20.754078,0.728531,0.537031,24.01787,40.370026,0.538773,0.24556,154.54068,16
7,inverse_sparkles,21.012999,9.817159,0.493476,0.407607,39.263115,141.135406,0.414608,0.460137,205.214508,16
8,masking_random_color,13.669659,14.979639,0.149588,0.510177,90.263557,77.853256,1.051342,0.417342,73.719513,202
9,masking_vline_random_color,12.608151,16.284155,0.104408,0.454759,103.668686,67.36145,0.81621,0.368527,96.242264,202


  0%|          | 0/1 [00:00<?, ?it/s]

## sde_without_clipping - aggregated

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,16.683807,16.692589,0.412422,0.523698,57.465794,62.312519,0.533724,0.312979,38.042252,1186


0it [00:00, ?it/s]

## original

0it [00:00, ?it/s]

## original - aggregated

0it [00:00, ?it/s]

## corrupted

0it [00:00, ?it/s]

## corrupted - aggregated

In [91]:
from IPython.display import display, Markdown

for k, v in dfs.items():
    if k not in ["corrupted", "original"]:
        continue
    
    metrics_per_corruption = get_metrics(v["df"].copy(), celebahq_fid)
    display(Markdown(f"## {k}"))
    display(metrics_per_corruption)

    v["metrics_per_corruption"] = metrics_per_corruption

    _fake_df = v["df"].copy()
    _fake_df["corruption"] = "all"

    metrics_aggregated = get_metrics(_fake_df.copy(), celebahq_fid)
    display(Markdown(f"## {k} - aggregated"))
    display(metrics_aggregated)
    
    v["metrics_aggregated"] = metrics_aggregated

  0%|          | 0/16 [00:00<?, ?it/s]

## original

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,brownish_noise,16.859682,inf,0.795461,1.0,63.374878,0.0,0.283796,0.0,136.829468,16
1,caustic_noise,18.234344,inf,0.849742,1.0,54.236118,0.0,0.099865,0.0,136.829468,16
2,caustic_refraction,21.453802,inf,0.757663,1.0,37.165993,0.0,0.192048,0.0,136.829468,16
3,checkerboard_cutout,16.49535,inf,0.780126,0.999997,64.899422,0.0,0.341479,0.0,50.443966,206
4,cocentric_sine_waves,23.022499,inf,0.578347,1.0,31.305069,0.0,0.349037,0.0,136.829468,16
5,fish_eye,18.183718,inf,0.470227,1.0,54.175911,0.0,0.464492,0.0,136.829468,16
6,gaussian_blur,22.162159,inf,0.587452,1.0,34.142033,0.0,0.630953,0.0,136.829468,16
7,inverse_sparkles,9.964433,inf,0.578729,1.0,138.519394,0.0,0.433153,0.0,136.829468,16
8,masking_random_color,11.802941,inf,0.243072,0.999997,111.564552,0.0,1.071532,0.0,50.443966,206
9,masking_vline_random_color,11.63918,inf,0.165897,0.999997,115.654999,0.0,0.870197,0.0,50.443966,206


  0%|          | 0/1 [00:00<?, ?it/s]

## original - aggregated

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,14.746831,inf,0.563867,0.999998,73.49295,0.0,0.502755,0.0,48.817364,1206


  0%|          | 0/16 [00:00<?, ?it/s]

## corrupted

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,brownish_noise,inf,16.859682,1.0,0.795461,0.0,63.374878,0.0,0.283796,169.313644,16
1,caustic_noise,inf,18.234344,1.0,0.849742,0.0,54.236118,0.0,0.099865,154.926849,16
2,caustic_refraction,inf,21.453802,1.0,0.757663,0.0,37.165993,0.0,0.192048,167.762604,16
3,checkerboard_cutout,inf,16.49535,0.999998,0.780126,0.0,64.899422,0.0,0.341479,270.342499,206
4,cocentric_sine_waves,inf,23.022499,1.0,0.578347,0.0,31.305069,0.0,0.349037,337.507294,16
5,fish_eye,inf,18.183718,0.999992,0.470227,0.0,54.175911,0.0,0.464492,408.461365,16
6,gaussian_blur,inf,22.196289,0.999995,0.587452,0.0,34.142033,0.0,0.630953,231.14949,16
7,inverse_sparkles,inf,9.964433,0.999999,0.578729,0.0,138.519394,0.0,0.433153,303.855286,16
8,masking_random_color,inf,11.802941,1.0,0.243072,0.0,111.564552,0.0,1.071532,230.866531,206
9,masking_vline_random_color,inf,11.63918,1.0,0.165897,0.0,115.654999,0.0,0.870197,355.248138,206


  0%|          | 0/1 [00:00<?, ?it/s]

## corrupted - aggregated

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,inf,14.746831,1.0,0.563867,0.0,73.49295,0.0,0.502755,127.835823,1206


In [93]:
from IPython.display import display, Markdown

for k, v in dfs.items():
    # if k not in ["corrupted", "original"]:
    #     continue
    
    # metrics_per_corruption = get_metrics(v["df"].copy(), celebahq_fid)
    # display(Markdown(f"## {k}"))
    # display(metrics_per_corruption)

    # v["metrics_per_corruption"] = metrics_per_corruption

    _fake_df = v["df"].copy()
    _fake_df[_fake_df["corruption"].isin(["checkerboard_cutout", "perlin_noise", "masking_random_color", "masking_vline_random_color", "plasma_noise"])] 
    _fake_df["corruption"] = "all"

    metrics_aggregated = get_metrics(_fake_df.copy(), celebahq_fid)
    display(Markdown(f"## {k} - aggregated (only ablation)"))
    display(metrics_aggregated)
    
    v["metrics_aggregated"] = metrics_aggregated

  0%|          | 0/1 [00:00<?, ?it/s]

## reconstruction_without_clipping - aggregated (only ablation)

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,27.162848,14.31604,0.894763,0.475805,13.408826,78.688309,0.120334,0.589465,147.212692,1137


  0%|          | 0/1 [00:00<?, ?it/s]

## reconstruction_with_clipping - aggregated (only ablation)

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,20.053694,16.823309,0.771474,0.551432,39.276142,59.79269,0.31479,0.450601,84.666672,1137


  0%|          | 0/1 [00:00<?, ?it/s]

## ode_without_clipping - aggregated (only ablation)

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,16.203596,16.360512,0.395433,0.503081,61.54916,65.140305,0.552198,0.337842,23.434536,4366


  0%|          | 0/1 [00:00<?, ?it/s]

## ode_with_clipping - aggregated (only ablation)

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,15.570461,17.242403,0.390493,0.520305,67.111809,58.992222,0.571618,0.304058,21.87541,4346


  0%|          | 0/1 [00:00<?, ?it/s]

## sde_without_clipping - aggregated (only ablation)

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,16.683807,16.692589,0.412422,0.523698,57.465794,62.312519,0.533724,0.312979,38.042252,1186


  0%|          | 0/1 [00:00<?, ?it/s]

## original - aggregated (only ablation)

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,14.746831,inf,0.563867,0.999998,73.49295,0.0,0.502755,0.0,48.817364,1206


  0%|          | 0/1 [00:00<?, ?it/s]

## corrupted - aggregated (only ablation)

Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid,num_imgs
0,all,inf,14.746831,1.0,0.563867,0.0,73.49295,0.0,0.502755,127.835823,1206


In [94]:
import pickle
with open("/home/tmartorella/imgs_for_metrics/metrics_new.pkl", "wb") as metrics_file:
    pickle.dump(dfs, metrics_file)

In [131]:
ode = df[(df["algo"] == "ode") & (df["corruption"].isin(["contrast", "fog"]))].drop(columns=["timestamp"])
ode_results_per_corruption = get_metrics(ode, celebahq_fid)
ode_results_per_corruption

100%|█████████████████████████████████████████████████████████████████████████████████████████| 2/2 [01:06<00:00, 33.48s/it]


Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid
0,contrast,19.456379,14.005204,0.704571,0.466285,37.227726,86.310837,0.566833,0.413938,51.584625
1,fog,18.370485,12.637938,0.599688,0.423423,49.261127,101.670357,0.444414,0.434942,50.019432


In [134]:
sde = df[(df["algo"] == "sde") & (df["corruption"].isin(["contrast", "fog"]))].drop(columns=["timestamp"])
sde_results_per_corruption = get_metrics(sde, celebahq_fid)
sde_results_per_corruption

100%|█████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:44<00:00, 22.06s/it]


Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid
0,contrast,31.507917,12.398882,0.825937,0.3526,9.956234,104.480896,0.535663,0.814965,330.311615
1,fog,26.394199,11.563744,0.707346,0.363064,20.327606,115.654778,0.416586,0.654471,196.135605


In [132]:
# ode = df[(df["algo"] == "ode") & (df["corruption"] == "fog")].drop(columns=["timestamp"])
ode_all = df[(df["algo"] == "ode")].copy()
ode_all["corruption"] = "all"
ode_results_all = get_metrics(ode_all, celebahq_fid)
ode_results_all

100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [04:12<00:00, 252.90s/it]


Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid
0,all,15.8325,14.758712,0.531907,0.51153,55.553547,69.620918,0.544291,0.349669,27.974178


In [133]:
sde_all = df[(df["algo"] == "sde")].copy()
sde_all["corruption"] = "all"
sde_results_all = get_metrics(sde_all, celebahq_fid)
sde_results_all

100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [05:42<00:00, 342.72s/it]


Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid
0,all,16.148903,14.417244,0.490403,0.517865,53.025021,72.267281,0.62709,0.368789,59.296925


In [128]:
# ode = df[(df["algo"] == "ode") & (df["corruption"] == "fog")].drop(columns=["timestamp"])
corrupted_all = df[(df["algo"] == "none") & (df["filename"] == "corrupted.png")].copy()
# corrupted_all["corruption"] = "all"
corrupted_results_all = get_metrics(corrupted_all, celebahq_fid)
corrupted_results_all

100%|█████████████████████████████████████████████████████████████████████████████████████████| 1/1 [01:01<00:00, 61.27s/it]


Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid
0,all,inf,12.579247,1.0,0.428789,0.0,94.061813,0.0,0.683682,113.806267


In [129]:
# ode = df[(df["algo"] == "ode") & (df["corruption"] == "fog")].drop(columns=["timestamp"])
corrupted_all = df[(df["algo"] == "none") & (df["filename"] == "corrupted.png")].copy()
# corrupted_all["corruption"] = "all"
corrupted_results_all = get_metrics(corrupted_all, celebahq_fid)
corrupted_results_all

100%|███████████████████████████████████████████████████████████████████████████████████████| 17/17 [01:41<00:00,  5.94s/it]


Unnamed: 0,corruption,psnr_to_source,psnr_to_target,ssim_to_source,ssim_to_target,l2_to_source,l2_to_target,lpips_to_source,lpips_to_target,fid
0,brightness,inf,9.852933,1.0,0.618916,0.0,140.778976,0.0,0.257829,98.784233
1,contrast,inf,12.546547,1.0,0.436881,0.0,102.690407,0.0,0.747116,112.282906
2,elastic_transform,inf,19.798203,1.0,0.566378,0.0,44.225132,0.0,0.254095,218.191284
3,fog,inf,11.679802,1.0,0.524399,0.0,114.024574,0.0,0.444118,183.971191
4,frost,inf,9.333415,1.0,0.436469,0.0,145.133148,0.0,0.561487,220.817673
5,gaussian_blur,inf,22.107393,1.0,0.618729,0.0,34.12458,0.0,0.609321,159.776321
6,gaussian_noise,inf,10.650517,1.0,0.058202,0.0,130.082214,0.0,1.331478,321.960663
7,glass_blur,inf,22.184576,1.0,0.62247,0.0,33.846889,0.0,0.461552,225.536407
8,impulse_noise,inf,10.400298,1.0,0.068457,0.0,133.768524,0.0,1.335539,288.324066
9,jpeg_compression,inf,25.767984,1.0,0.709763,0.0,22.577734,0.0,0.310088,125.927399


In [None]:
ode_results_per_corruption.to_csv("/home/martorel/dev/unsupervised_image_editing/ode_results_per_corruption_top_16_all.csv")
sde_results_per_corruption.to_csv("/home/martorel/dev/unsupervised_image_editing/sde_results_per_corruption_top_16_all.csv")
ode_results_all.to_csv("/home/martorel/dev/unsupervised_image_editing/ode_results_all_top_16_all.csv")
sde_results_all.to_csv("/home/martorel/dev/unsupervised_image_editing/sde_results_all_top_16_all.csv")