In [3]:
from datasets import load_dataset

dataset_name = "dim/nfs_pix2pix_1920_1080_v5"
dataset = load_dataset(
    dataset_name,
    cache_dir="/code/dataset/nfs_pix2pix_1920_1080_v5",
)
dataset = dataset["train"]

In [4]:
import random

test_images_ids = list(range(0, len(dataset), 20))
rng = random.Random(42)
amount = min(100, len(test_images_ids))
selected_ids = rng.sample(test_images_ids, amount)
# selected_ids

### LPIPS

In [5]:
import lpips
import torch
from tqdm import tqdm
from torchvision import transforms

resolution = 512
valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5],
            [0.5],
        ),
    ]
)
loss_fn_vgg = lpips.LPIPS(net="vgg").requires_grad_(False).cuda()
total_loss = 0.0
with torch.no_grad():
    for num in tqdm(selected_ids):
        item_1 = valid_transforms(dataset[num]["input_image"].convert("RGB")).cuda()
        # item_1 = valid_transforms(dataset[num]["edited_image"].convert("RGB")).cuda()
        item_2 = valid_transforms(dataset[num]["edited_image"].convert("RGB")).cuda()

        d = loss_fn_vgg(item_1, item_2).item()
        # print(d)
        total_loss += d
total_loss / len(selected_ids)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /opt/conda/lib/python3.11/site-packages/lpips/weights/v0.1/vgg.pth


100%|██████████| 43/43 [00:05<00:00,  7.58it/s]


0.35777404765750087

## SSIM, MSE

In [6]:
import numpy as np

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error


ssim_preds = []
mse_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute ssim"):
        original = valid_transforms(dataset[num]["input_image"].convert("RGB")).numpy()
        # item_2 = valid_transforms(dataset[num]["input_image"].convert("RGB")).numpy()
        generated = valid_transforms(
            dataset[num]["edited_image"].convert("RGB")
        ).numpy()
        ssim_res = ssim(
            original,
            generated,
            data_range=generated.max() - generated.min(),
            channel_axis=0,
        )
        mse_res = mean_squared_error(original, generated)
        ssim_preds.append(ssim_res)
        mse_preds.append(mse_res)
np.mean(ssim_preds), np.mean(mse_preds)

compute ssim: 100%|██████████| 43/43 [00:05<00:00,  7.25it/s]


(np.float32(0.55576855), np.float64(0.05032231584813293))

In [7]:
import torch
import piqa

ssim = piqa.SSIM().cuda()

valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
    ]
)

ssim_preds = []
mse_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute ssim"):
        original = (
            valid_transforms(dataset[num]["input_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        # item_2 = valid_transforms(dataset[num]["input_image"].convert("RGB")).numpy()
        generated = (
            valid_transforms(dataset[num]["edited_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        ssim_res = ssim(
            original,
            generated,
        ).item()
        ssim_preds.append(ssim_res)

np.mean(ssim_preds)

compute ssim: 100%|██████████| 43/43 [00:04<00:00, 10.28it/s]


np.float64(0.6675837677578593)

### Dists

In [8]:
from torch import rand
from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity

valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
    ]
)

dists = DeepImageStructureAndTextureSimilarity().cuda()
dists_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute dists"):
        original = (
            valid_transforms(dataset[num]["input_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        # generated = (
        #     valid_transforms(dataset[num]["input_image"].convert("RGB"))
        #     .cuda()
        #     .unsqueeze(0)
        # )
        generated = (
            valid_transforms(dataset[num]["edited_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        dists_res = dists(
            generated,
            original,
        ).item()
        dists_preds.append(dists_res)

np.round(np.mean(dists_preds), 4)

compute dists: 100%|██████████| 43/43 [00:29<00:00,  1.45it/s]


np.float64(0.1629)

### psnr

In [66]:
from torchmetrics.image import PeakSignalNoiseRatio

psnr = PeakSignalNoiseRatio(data_range=1.0).cuda()

valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
        transforms.ToTensor(),
    ]
)
psnr_preds = []
with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute psnr"):
        original = (
            valid_transforms(dataset[num]["input_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        # generated = (
        #     valid_transforms(dataset[num]["input_image"].convert("RGB"))
        #     .cuda()
        #     .unsqueeze(0)
        # )
        generated = (
            valid_transforms(dataset[num]["edited_image"].convert("RGB"))
            .cuda()
            .unsqueeze(0)
        )
        psnr_res = psnr(
            generated,
            original,
        ).item()
        psnr.update(generated, original)
        psnr_preds.append(psnr_res)

np.round(np.mean(psnr_preds), 4), psnr.compute().item()

compute psnr: 100%|██████████| 43/43 [00:04<00:00,  9.88it/s]


(np.float64(19.2376), 19.00299072265625)

### FID

In [55]:
from torch import rand
from torchmetrics.image.fid import FrechetInceptionDistance
import torch

fid = FrechetInceptionDistance(
    feature=2048,
).cuda()

valid_transforms = transforms.Compose(
    [
        transforms.Resize(
            resolution,
            interpolation=transforms.InterpolationMode.LANCZOS,
        ),
        transforms.CenterCrop(resolution),
    ]
)

with torch.no_grad():
    for num in tqdm(selected_ids, desc="compute fid"):
        original = (
            torch.tensor(
                np.array(valid_transforms(dataset[num]["input_image"].convert("RGB"))),
                dtype=torch.uint8,
            )
            .cuda()
            .permute((2, 0, 1))
            .unsqueeze(0)
        )
        # generated = (
        #     torch.tensor(np.array(dataset[num]["input_image"].convert("RGB")))
        #     .cuda()
        #     .permute((2, 0, 1))
        #     .unsqueeze(0)
        # )
        generated = (
            torch.tensor(
                np.array(valid_transforms(dataset[num]["edited_image"].convert("RGB"))),
                dtype=torch.uint8,
            )
            .cuda()
            .permute((2, 0, 1))
            .unsqueeze(0)
        )
        fid.update(original, real=True)
        fid.update(generated, real=False)
    final_fid = fid.compute()

np.round(final_fid.item(), 4)

compute fid: 100%|██████████| 43/43 [00:03<00:00, 11.51it/s]


np.float64(83.6904)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import lpips
import piqa
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision import transforms
from tqdm import tqdm
from typing import List
from PIL import Image


class ImageEvaluator:
    def __init__(
        self, metrics_list: List[str], device: str = "cuda", resolution: int = 512
    ):
        """
        ВАЖНО, при использовании усреднении через средства питона, накапливаются ошибки
        и результат отличается если сначала все сложить через либу, а затем вызвать метрику compute.
        ключевые различия можно обнаружить на метрике psnr и fid
        Инициализация класса для оценки качества изображений.

        Args:
            metrics_list: Список метрик ('lpips', 'mse', 'ssim', 'dists', 'psnr', 'fid').
            device: Устройство для вычислений ('cuda' или 'cpu').
            resolution: Разрешение для ресайза и кропа изображений.
        """
        self.metrics_list = [m.lower() for m in metrics_list]
        self.device = device
        self.resolution = resolution

        # --- Определение трансформаций ---

        # 1. Базовая трансформация: [0, 1] (для SSIM, PSNR, MSE, DISTS)
        self.base_transform = transforms.Compose(
            [
                transforms.Resize(
                    resolution, interpolation=transforms.InterpolationMode.LANCZOS
                ),
                transforms.CenterCrop(resolution),
                transforms.ToTensor(),
            ]
        )

        # 2. Трансформация для LPIPS: [-1, 1]
        self.lpips_transform = transforms.Compose(
            [
                transforms.Resize(
                    resolution, interpolation=transforms.InterpolationMode.LANCZOS
                ),
                transforms.CenterCrop(resolution),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.5, 0.5, 0.5),
                    (0.5, 0.5, 0.5),
                ),
            ]
        )

        # 3. Трансформация для FID: [0, 255] uint8 tensor
        # FID в torchmetrics ожидает (N, 3, H, W) типа uint8
        self.fid_resize_crop = transforms.Compose(
            [
                transforms.Resize(
                    resolution, interpolation=transforms.InterpolationMode.LANCZOS
                ),
                transforms.CenterCrop(resolution),
            ]
        )

        # --- Инициализация моделей ---
        self._init_models()

    def _init_models(self):
        """Загрузка необходимых моделей в память."""

        if "lpips" in self.metrics_list:
            print("Initializing LPIPS...")
            self.loss_fn_lpips = (
                lpips.LPIPS(net="vgg").requires_grad_(False).to(self.device)
            )

        if "ssim" in self.metrics_list:
            print("Initializing SSIM...")
            self.loss_fn_ssim = piqa.SSIM().to(self.device)

        if "dists" in self.metrics_list:
            print("Initializing DISTS...")
            self.loss_fn_dists = DeepImageStructureAndTextureSimilarity().to(
                self.device
            )

        if "psnr" in self.metrics_list:
            print("Initializing PSNR...")
            self.loss_fn_psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)

        if "fid" in self.metrics_list:
            print("Initializing FID...")
            self.loss_fn_fid = FrechetInceptionDistance(feature=2048).to(self.device)

    def _prepare_fid_tensor(self, pil_img: Image.Image) -> torch.Tensor:
        """Конвертирует PIL в тензор uint8 (C, H, W) для FID."""
        img = self.fid_resize_crop(pil_img)
        # Convert to numpy (H, W, C) then to tensor (C, H, W)
        array = np.array(img)
        tensor = torch.tensor(array, dtype=torch.uint8).to(self.device).permute(2, 0, 1)
        return tensor.unsqueeze(0)

    def evaluate(
        self,
        originals: List[Image.Image],
        generated: List[Image.Image],
    ) -> dict:
        """
        Вычисляет метрики для двух списков изображений.

        Args:
            originals: Список оригинальных PIL изображений.
            generated: Список сгенерированных PIL изображений.

        Returns:
            Словарь с усредненными значениями метрик.
        """
        assert len(originals) == len(
            generated
        ), "Списки изображений должны быть одной длины"

        # Хранилище результатов
        results = {k: [] for k in self.metrics_list if k != "fid"}

        # Сброс состояния FID перед новым расчетом
        if "fid" in self.metrics_list:
            self.loss_fn_fid.reset()

        with torch.no_grad():
            for orig_pil, gen_pil in tqdm(
                zip(originals, generated), total=len(originals), desc="Evaluating"
            ):

                # --- Подготовка данных ---

                # Для LPIPS (нормализация [-1, 1])
                if "lpips" in self.metrics_list:
                    orig_lpips = (
                        self.lpips_transform(orig_pil.convert("RGB"))
                        .to(self.device)
                        .unsqueeze(0)
                    )
                    gen_lpips = (
                        self.lpips_transform(gen_pil.convert("RGB"))
                        .to(self.device)
                        .unsqueeze(0)
                    )

                    val = self.loss_fn_lpips(orig_lpips, gen_lpips).item()
                    results["lpips"].append(val)

                # Для SSIM, DISTS, PSNR, MSE (диапазон [0, 1])
                need_base = any(
                    m in self.metrics_list for m in ["ssim", "dists", "psnr", "mse"]
                )
                if need_base:
                    orig_base = (
                        self.base_transform(orig_pil.convert("RGB"))
                        .to(self.device)
                        .unsqueeze(0)
                    )
                    gen_base = (
                        self.base_transform(gen_pil.convert("RGB"))
                        .to(self.device)
                        .unsqueeze(0)
                    )

                    if "ssim" in self.metrics_list:
                        # piqa.SSIM ожидает (N, C, H, W) в [0, 1]
                        val = self.loss_fn_ssim(orig_base, gen_base).item()
                        results["ssim"].append(val)

                    if "dists" in self.metrics_list:
                        # DISTS(preds, target)
                        val = self.loss_fn_dists(gen_base, orig_base).item()
                        results["dists"].append(val)

                    if "psnr" in self.metrics_list:
                        # PSNR(preds, target)
                        val = self.loss_fn_psnr(gen_base, orig_base).item()
                        results["psnr"].append(val)

                    if "mse" in self.metrics_list:
                        val = F.mse_loss(gen_base, orig_base).item()
                        results["mse"].append(val)

                # Для FID (накопление статистики)
                if "fid" in self.metrics_list:
                    orig_fid = self._prepare_fid_tensor(orig_pil.convert("RGB"))
                    gen_fid = self._prepare_fid_tensor(gen_pil.convert("RGB"))

                    self.loss_fn_fid.update(orig_fid, real=True)
                    self.loss_fn_fid.update(gen_fid, real=False)

        # --- Агрегация результатов ---
        final_metrics = {}

        # Среднее для поэлементных метрик
        for name, values in results.items():
            if values:
                final_metrics[name] = float(np.mean(values))

        # Вычисление итогового FID
        if "fid" in self.metrics_list:
            print("Computing FID score...")
            fid_score = self.loss_fn_fid.compute()
            final_metrics["fid"] = float(fid_score.item())

        # Округление
        for k, v in final_metrics.items():
            final_metrics[k] = round(v, 4)

        return final_metrics


image_evaluator = ImageEvaluator(
    metrics_list=[
        "lpips",
        "mse",
        "ssim",
        "dists",
        "psnr",
        "fid",
    ],
    device="cuda",
    resolution=512,
)

originals = [dataset[num]["input_image"] for num in selected_ids]
generated = [dataset[num]["edited_image"] for num in selected_ids]
image_evaluator.evaluate(
    originals=originals,
    generated=generated,
)

Initializing LPIPS...
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /opt/conda/lib/python3.11/site-packages/lpips/weights/v0.1/vgg.pth
Initializing SSIM...
Initializing DISTS...
Initializing PSNR...
Initializing FID...


Evaluating: 100%|██████████| 43/43 [00:31<00:00,  1.34it/s]


Computing FID score...


{'lpips': 0.3578,
 'mse': 0.0126,
 'ssim': 0.6676,
 'dists': 0.1629,
 'psnr': 19.2376,
 'fid': 83.6904}

In [None]:
{
    "lpips": 0.3578,
    "mse": 0.0126,
    "ssim": 0.6676,
    "dists": 0.1629,
    "psnr": 19.2376,
    "fid": 83.6904,
}

In [64]:
import torch
import torch.nn.functional as F
import numpy as np
import lpips
import piqa
from torch.utils.data import Dataset, DataLoader
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision import transforms
from tqdm import tqdm
from typing import List, Dict
from PIL import Image


# --- DATASET (Остается прежним, он эффективен) ---


# --- EVALUATOR ---
class ImageEvaluator:
    class EvaluationDataset(Dataset):
        def __init__(
            self,
            originals: List[Image.Image],
            generated: List[Image.Image],
            metrics_list: List[str],
            resolution: int,
        ):
            self.originals = originals
            self.generated = generated
            self.metrics_list = metrics_list
            self.resolution = resolution

            self.common_transform = transforms.Compose(
                [
                    transforms.Resize(
                        resolution, interpolation=transforms.InterpolationMode.LANCZOS
                    ),
                    transforms.CenterCrop(resolution),
                ]
            )
            self.to_tensor = transforms.ToTensor()
            self.normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

        def __len__(self):
            return len(self.originals)

        def _process_image(self, pil_img):
            output = {}
            img_resized = self.common_transform(pil_img.convert("RGB"))

            need_base = any(
                m in self.metrics_list for m in ["ssim", "dists", "psnr", "mse"]
            )
            need_lpips = "lpips" in self.metrics_list
            need_fid = "fid" in self.metrics_list

            if need_base or need_lpips:
                tensor_base = self.to_tensor(img_resized)
                if need_base:
                    output["base"] = tensor_base
                if need_lpips:
                    output["lpips"] = self.normalize(tensor_base)

            if need_fid:
                arr = np.array(img_resized)
                tensor_fid = (
                    torch.from_numpy(arr).permute(2, 0, 1).to(dtype=torch.uint8)
                )
                output["fid"] = tensor_fid

            return output

        def __getitem__(self, idx):
            return {
                "orig": self._process_image(self.originals[idx]),
                "gen": self._process_image(self.generated[idx]),
            }

    def __init__(
        self,
        metrics_list: List[str],
        device: str = "cuda",
        resolution: int = 512,
        num_workers: int = 4,
    ):
        self.metrics_list = [m.lower() for m in metrics_list]
        self.device = device
        self.resolution = resolution
        self.num_workers = num_workers
        self._init_models()

    def _init_models(self):
        # 1. LPIPS (нет параметра reduction в init, делаем sum при вызове)
        if "lpips" in self.metrics_list:
            print("Initializing LPIPS...")
            self.loss_fn_lpips = (
                lpips.LPIPS(net="vgg").requires_grad_(False).to(self.device)
            )

        # 2. SSIM (piqa поддерживает reduction='sum')
        if "ssim" in self.metrics_list:
            print("Initializing SSIM...")
            self.loss_fn_ssim = piqa.SSIM(reduction="sum").to(self.device)

        # 3. TorchMetrics (PSNR, DISTS, FID) - они сами умеют накапливать состояние через update()
        if "dists" in self.metrics_list:
            print("Initializing DISTS...")
            self.metric_dists = DeepImageStructureAndTextureSimilarity().to(self.device)

        if "psnr" in self.metrics_list:
            print("Initializing PSNR...")
            self.metric_psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)

        if "fid" in self.metrics_list:
            print("Initializing FID...")
            self.metric_fid = FrechetInceptionDistance(feature=2048).to(self.device)

    def evaluate(
        self,
        originals: List[Image.Image],
        generated: List[Image.Image],
        batch_size: int = 16,
    ) -> Dict[str, float]:

        n_samples = len(originals)
        dataset = self.EvaluationDataset(
            originals, generated, self.metrics_list, self.resolution
        )
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
        )

        # Сброс метрик torchmetrics
        if "dists" in self.metrics_list:
            self.metric_dists.reset()
        if "psnr" in self.metrics_list:
            self.metric_psnr.reset()
        if "fid" in self.metrics_list:
            self.metric_fid.reset()

        # Аккумуляторы для метрик, которые считаем вручную (LPIPS, MSE, SSIM)
        manual_sums = {
            k: 0.0 for k in ["lpips", "mse", "ssim"] if k in self.metrics_list
        }

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):

                # --- LPIPS ---
                if "lpips" in self.metrics_list:
                    orig_l = batch["orig"]["lpips"].to(self.device, non_blocking=True)
                    gen_l = batch["gen"]["lpips"].to(self.device, non_blocking=True)
                    # LPIPS возвращает (B, 1, 1, 1), суммируем
                    manual_sums["lpips"] += (
                        self.loss_fn_lpips(orig_l, gen_l).sum().item()
                    )

                # --- Base Transforms (SSIM, MSE, DISTS, PSNR) ---
                need_base = any(
                    m in self.metrics_list for m in ["ssim", "dists", "psnr", "mse"]
                )
                if need_base:
                    orig_b = batch["orig"]["base"].to(self.device, non_blocking=True)
                    gen_b = batch["gen"]["base"].to(self.device, non_blocking=True)

                    # SSIM (piqa c reduction='sum')
                    if "ssim" in self.metrics_list:
                        manual_sums["ssim"] += self.loss_fn_ssim(orig_b, gen_b).item()

                    # MSE (ручной расчет суммы средних ошибок)
                    if "mse" in self.metrics_list:
                        # (gen - orig)^2 -> mean по пикселям -> sum по батчу
                        # Это дает нам сумму MSE каждой картинки
                        batch_mse_sum = (
                            F.mse_loss(gen_b, orig_b, reduction="none")
                            .mean(dim=[1, 2, 3])
                            .sum()
                            .item()
                        )
                        manual_sums["mse"] += batch_mse_sum

                    # DISTS (torchmetrics update)
                    if "dists" in self.metrics_list:
                        self.metric_dists.update(gen_b, orig_b)

                    # PSNR (torchmetrics update)
                    if "psnr" in self.metrics_list:
                        self.metric_psnr.update(gen_b, orig_b)

                # --- FID (torchmetrics update) ---
                if "fid" in self.metrics_list:
                    orig_f = batch["orig"]["fid"].to(self.device, non_blocking=True)
                    gen_f = batch["gen"]["fid"].to(self.device, non_blocking=True)
                    self.metric_fid.update(orig_f, real=True)
                    self.metric_fid.update(gen_f, real=False)

        # --- Сборка результатов ---
        final_metrics = {}

        # 1. Метрики с ручным суммированием делим на кол-во сэмплов
        for k, v in manual_sums.items():
            final_metrics[k] = v / n_samples

        # 2. Метрики torchmetrics вычисляют результат сами
        if "dists" in self.metrics_list:
            final_metrics["dists"] = float(self.metric_dists.compute().item())

        if "psnr" in self.metrics_list:
            final_metrics["psnr"] = float(self.metric_psnr.compute().item())

        if "fid" in self.metrics_list:
            print("Computing FID score...")
            final_metrics["fid"] = float(self.metric_fid.compute().item())

        # Округление
        for k, v in final_metrics.items():
            final_metrics[k] = round(v, 4)

        return final_metrics


evaluator = ImageEvaluator(
    metrics_list=[
        "lpips",
        "mse",
        "ssim",
        "dists",
        "psnr",
        "fid",
    ],
    device="cuda",
    num_workers=4,  # Используйте 0 для отладки, >0 для скорости (например, cpu_count() / 2)
)


results = evaluator.evaluate(originals, generated, batch_size=16)
print(results)

Initializing LPIPS...
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /opt/conda/lib/python3.11/site-packages/lpips/weights/v0.1/vgg.pth
Initializing SSIM...
Initializing DISTS...
Initializing PSNR...
Initializing FID...


Evaluating: 100%|██████████| 3/3 [00:03<00:00,  1.12s/it]


Computing FID score...
{'lpips': 0.3578, 'mse': 0.0126, 'ssim': 0.6676, 'dists': 0.1629, 'psnr': 19.003, 'fid': 83.7086}


In [60]:
{
    "lpips": 0.3578,
    "mse": 0.0126,
    "ssim": 0.6676,
    "dists": 0.1629,
    "psnr": 19.2376,
    "fid": 83.6904,
}

{'lpips': 0.3578,
 'mse': 0.0126,
 'ssim': 0.6676,
 'dists': 0.1629,
 'psnr': 19.2376,
 'fid': 83.6904}

In [1]:
{
    "lpips": 0.3578,
    "mse": 0.0126,
    "ssim": 0.6676,
    "dists": 0.1629,
    "psnr": 19.2376,
    "fid": 83.6904,
    **{
		"test": 123
	}
}

{'lpips': 0.3578,
 'mse': 0.0126,
 'ssim': 0.6676,
 'dists': 0.1629,
 'psnr': 19.2376,
 'fid': 83.6904,
 'test': 123}