In [3]:
from datasets import load_dataset

dataset_name = "dim/nfs_pix2pix_1920_1080_v5"
dataset = load_dataset(
    dataset_name,
    cache_dir="/code/dataset/nfs_pix2pix_1920_1080_v5",
)
dataset = dataset["train"]

In [4]:
import random

test_images_ids = list(range(0, len(dataset), 20))
rng = random.Random(42)
amount = min(100, len(test_images_ids))
selected_ids = rng.sample(test_images_ids, amount)
# selected_ids

### LPIPS

In [5]:
import lpips
import torch
from tqdm import tqdm
from torchvision import transforms

resolution = 512
valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5],
            [0.5],
        ),
    ]
)
loss_fn_vgg = lpips.LPIPS(net="vgg").requires_grad_(False).cuda()
total_loss = 0.0
with torch.no_grad():
    for num in tqdm(selected_ids):
        item_1 = valid_transforms(dataset[num]["input_image"].convert("RGB")).cuda()
        # item_1 = valid_transforms(dataset[num]["edited_image"].convert("RGB")).cuda()
        item_2 = valid_transforms(dataset[num]["edited_image"].convert("RGB")).cuda()

        d = loss_fn_vgg(item_1, item_2).item()
        # print(d)
        total_loss += d
total_loss / len(selected_ids)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /opt/conda/lib/python3.11/site-packages/lpips/weights/v0.1/vgg.pth


100%|██████████| 43/43 [00:05<00:00,  7.58it/s]


0.35777404765750087

## SSIM, MSE

In [6]:
import numpy as np

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error


ssim_preds = []
mse_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute ssim"):
        original = valid_transforms(dataset[num]["input_image"].convert("RGB")).numpy()
        # item_2 = valid_transforms(dataset[num]["input_image"].convert("RGB")).numpy()
        generated = valid_transforms(
            dataset[num]["edited_image"].convert("RGB")
        ).numpy()
        ssim_res = ssim(
            original,
            generated,
            data_range=generated.max() - generated.min(),
            channel_axis=0,
        )
        mse_res = mean_squared_error(original, generated)
        ssim_preds.append(ssim_res)
        mse_preds.append(mse_res)
np.mean(ssim_preds), np.mean(mse_preds)

compute ssim: 100%|██████████| 43/43 [00:05<00:00,  7.25it/s]


(np.float32(0.55576855), np.float64(0.05032231584813293))

In [7]:
import torch
import piqa

ssim = piqa.SSIM().cuda()

valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
    ]
)

ssim_preds = []
mse_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute ssim"):
        original = (
            valid_transforms(dataset[num]["input_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        # item_2 = valid_transforms(dataset[num]["input_image"].convert("RGB")).numpy()
        generated = (
            valid_transforms(dataset[num]["edited_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        ssim_res = ssim(
            original,
            generated,
        ).item()
        ssim_preds.append(ssim_res)

np.mean(ssim_preds)

compute ssim: 100%|██████████| 43/43 [00:04<00:00, 10.28it/s]


np.float64(0.6675837677578593)

### Dists

In [8]:
from torch import rand
from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity

valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
    ]
)

dists = DeepImageStructureAndTextureSimilarity().cuda()
dists_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute dists"):
        original = (
            valid_transforms(dataset[num]["input_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        # generated = (
        #     valid_transforms(dataset[num]["input_image"].convert("RGB"))
        #     .cuda()
        #     .unsqueeze(0)
        # )
        generated = (
            valid_transforms(dataset[num]["edited_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        dists_res = dists(
            generated,
            original,
        ).item()
        dists_preds.append(dists_res)

np.round(np.mean(dists_preds), 4)

compute dists: 100%|██████████| 43/43 [00:29<00:00,  1.45it/s]


np.float64(0.1629)

### psnr

In [9]:
from torchmetrics.image import PeakSignalNoiseRatio

psnr = PeakSignalNoiseRatio(data_range=1.0).cuda()


psnr_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute psnr"):
        original = (
            valid_transforms(dataset[num]["input_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        # generated = (
        #     valid_transforms(dataset[num]["input_image"].convert("RGB"))
        #     .cuda()
        #     .unsqueeze(0)
        # )
        generated = (
            valid_transforms(dataset[num]["edited_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        psnr_res = psnr(
            generated,
            original,
        ).item()
        psnr_preds.append(psnr_res)

np.round(np.mean(psnr_res), 4)

compute psnr: 100%|██████████| 43/43 [00:04<00:00, 10.46it/s]


np.float64(18.7176)

### FID

In [34]:
from torch import rand
from torchmetrics.image.fid import FrechetInceptionDistance
import torch

fid = FrechetInceptionDistance(
    feature=768,
).cuda()

valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
    ]
)

with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute fid"):
        original = (
            torch.tensor(np.array(dataset[num]["input_image"].convert("RGB")))
            .cuda()
            .permute((2, 0, 1))
            .unsqueeze(0)
        )
        # generated = (
        #     torch.tensor(np.array(dataset[num]["input_image"].convert("RGB")))
        #     .cuda()
        #     .permute((2, 0, 1))
        #     .unsqueeze(0)
        # )
        generated = (
            torch.tensor(np.array(dataset[num]["edited_image"].convert("RGB")))
            .cuda()
            .permute((2, 0, 1))
            .unsqueeze(0)
        )
        fid.update(original, real=True)
        fid.update(generated, real=False)
    final_fid = fid.compute()

np.round(final_fid.item(), 4)

compute fid: 100%|██████████| 43/43 [00:02<00:00, 17.79it/s]


np.float64(0.5267)