In [4]:
!pip install torchmetrics torch-fidelity

Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Installing collected packages: torch-fidelity
Successfully installed torch-fidelity-0.3.0
[0m

In [2]:
from torchmetrics.image.fid import FrechetInceptionDistance
import glob
from PIL import Image
import torch
from torchvision import transforms
import numpy as np

In [3]:
def load_image_as_tensor(image_path):
    image = Image.open(image_path).convert('RGB').resize((299, 299)) # Inception's input is 299x299
    image = torch.from_numpy(np.array(image)).permute(2, 0, 1) # HWC -> CHW
    return image

def compute_fids(target_directory, feature_size=2048):
    real_files = sorted(glob.glob(f"small_coco/*.png"))
    real_imgs = [load_image_as_tensor(file) for file in real_files]
    real_imgs = torch.stack(real_imgs)

    fake_files = sorted(glob.glob(f"generated/{target_directory}/*.png"))
    fake_imgs = [load_image_as_tensor(file) for file in fake_files]
    fake_imgs = torch.stack(fake_imgs)

    fid = FrechetInceptionDistance(feature=feature_size)
    fid.update(real_imgs, real=True)
    fid.update(fake_imgs, real=False)
    fid_value = fid.compute().numpy()
    print(target_directory, fid_value)

In [4]:
compute_fids("sd15")
compute_fids("sd21")
compute_fids("sdxl")

sd15 191.24498
sd21 190.45113
sdxl 187.00229
