# GAN Evaluation - Frechet Inception Distance (FID) Score

## Basic [Concept](https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/)

In [1]:
import numpy as np
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy.random import random

from scipy.linalg import sqrtm

In [9]:
def calculate_fid_score(fvector1, fvector2):
    # calculate mean and covariance statistics
    mean1, sigma1 = fvector1.mean(axis=0), cov(fvector1, rowvar=False)
    mean2, sigma2 = fvector2.mean(axis=0), cov(fvector2, rowvar=False)
    
    # calculate sum squared difference between two mean vectors
    ssdiff = np.sum((mean1 - mean2) ** 2.0)
    
    # calculate sqrt of product between two covariance metrics
    covmean = sqrtm(sigma1.dot(sigma2))
    
    # some elements in the resulting matrix may be imaginary
    # check and correct imaginary numbers from sqrt
    if iscomplexobj(covmean):
        covmean = covmean.real
        
    # calculate score
    fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [10]:
# define two collections of activations
fvector1 = random(10*2048)
fvector1 = fvector1.reshape((10,2048))
fvector2 = random(10*2048)
fvector2 = fvector2.reshape((10,2048))

In [11]:
# fid between act1 and act1
fid = calculate_fid_score(fvector1, fvector1)
print('FID (identical feature vectors): %.3f' % fid)

FID (identical feature vectors): -0.000


In [13]:
# fid between act1 and act2
fid = calculate_fid_score(fvector1, fvector2)
print('FID (different feature vectors): %.3f' % fid)

FID (different feature vectors): 354.907


## FID in Pytorch

In [5]:
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data

from torchvision.models.inception import inception_v3

In [None]:
def calculate_fid_score(images, batch_size=32, resize=False, cuda=False):
    pass

In [11]:
import torchvision.datasets as dset
import torchvision.transforms as transforms

cifar = dset.CIFAR10(root='data/', download=True,
                         transform=transforms.Compose([
                             transforms.Resize(32),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                         ])
)

Files already downloaded and verified


In [None]:
class IgnoreLabelDataset(torch.utils.data.Dataset):
    def __init__(self, orig):
        self.orig = orig

    def __getitem__(self, index):
        return self.orig[index][0]

    def __len__(self):
        return len(self.orig)

print("Calculating Inception Score...")
fid_score = calculate_fid_score(IgnoreLabelDataset(cifar), batch_size=32, resize=True, cuda=True)
print("FID score:",fid_score)

---