In [1]:
!pip install torch-fidelity



In [46]:
from torchmetrics.image.fid import FrechetInceptionDistance
import torch
import numpy as np
import tensorflow as tf
import os
import cv2

def get_fid(images1, images2):

    images1 = torch.tensor(images1, dtype=torch.uint8)
    images2 = torch.tensor(images2, dtype=torch.uint8)

    # add batch dimension
    images1 = images1.unsqueeze(0)
    images2 = images2.unsqueeze(0)

    # batch_size*C*H*W -> batch_size*H*W*C
    images1 = images1.permute(0, 3, 1, 2)
    images2 = images2.permute(0, 3, 1, 2)

    # https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html
    # float scalar tensor with mean FID value over samples
    # we keep batchsize 2 as minimum images required are 2 or more
    images1 = images1.repeat(2, 1, 1, 1)
    images2 = images2.repeat(2, 1, 1, 1)

    fid = FrechetInceptionDistance(feature=64)
    fid.update(images1, real=True)
    fid.update(images2, real=False)
    score = fid.compute()
    return score

In [47]:
images1 = cv2.imread('./reference.jpeg')
images2 = cv2.imread('./lion-sunglasses-our.png')

In [48]:
x = get_fid(images1, images2)

tensor(1.7544)


In [49]:
print(x)

tensor(1.7544)


In [50]:
images3 = cv2.imread('./lion-sunglasses-baseline.png')

In [51]:
y = get_fid(images1, images3)

tensor(5.7031)


In [52]:
i4 = cv2.imread('./dog-sunglasses-baseline.png')

In [53]:
i5 = cv2.imread('./dog-sunglasses-our.png')


In [54]:
z = get_fid(images1, i4)

tensor(30.9550)


In [55]:
a = get_fid(images1, i5)

tensor(3.1471)


In [56]:
i6 = cv2.imread('./horse-running-baseline.png')

In [57]:
i7 = cv2.imread('./horse-running-our.png')

In [58]:
b = get_fid(images1, i6)

tensor(40.7026)


In [59]:
c = get_fid(images1, i7)

tensor(5.7273)
