In [None]:
import os
import numpy as np
import scipy
from PIL import Image
import torch
from torchvision.models import inception_v3
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [None]:
transform = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(image_name)

        if self.transform:
            image = self.transform(image)

        return image

In [None]:
real_data = 'real_images/' # path to the real images folder
generated_data = 'generated_images/' # path to the generated images folder

custom_dataset_real = CustomImageDataset(root_dir=real_data, transform=transform)
custom_dataset_generated = CustomImageDataset(root_dir=generated_data, transform=transform)

In [None]:
data_loader_real = DataLoader(custom_dataset_real, batch_size=32, shuffle=True)
data_loader_generated = DataLoader(custom_dataset_generated, batch_size=32, shuffle=True)

In [None]:
inception_model = inception_v3(weights = 'Inception_V3_Weights.IMAGENET1K_V1').to('cpu') # loading pretrained InceptionV3 model
inception_model = inception_model.eval() # Evaluation mode
inception_model.fc = torch.nn.Identity() # replacing fully connected layer with identity layer

In [None]:
generated_features_list = []
real_features_list = []
with torch.no_grad():
    for real, generated in zip(data_loader_real, data_loader_generated): # iterate over real and generated features

        real_features = inception_model(real).detach() # extract real image features
        real_features_list.append(real_features) # append features to the list

        generated_features = inception_model(generated).detach() # extract generated image features
        generated_features_list.append(generated_features) # append features to the list

generated_features_all = torch.cat(generated_features_list)
real_features_all = torch.cat(real_features_list)

4it [00:22,  5.50s/it]


In [None]:
mu_generated = generated_features_all.mean(0)
mu_real = real_features_all.mean(0)
sigma_generated = torch.Tensor(np.cov(generated_features_all.detach().numpy(), rowvar=False))
sigma_real = torch.Tensor(np.cov(real_features_all.detach().numpy(), rowvar=False))

In [None]:
def matrix_sqrt(x): # returns the square root of a matrix
    y = x.cpu().detach().numpy()
    y = scipy.linalg.sqrtm(y)
    return torch.Tensor(y.real, device=x.device)

In [None]:
fid = ((mu_generated - mu_real)**2).sum() + sigma_generated.trace() + sigma_real.trace() - 2*torch.trace(matrix_sqrt(sigma_generated @ sigma_real))
print(fid)