#The FID implementation is referenced friom https://github.com/hukkelas/pytorch-frechet-inception-distance

In [None]:
import torch
from torch import nn
from torchvision.models import inception_v3
import cv2
import multiprocessing
import numpy as np
import glob
import os
from scipy import linalg
import random

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [3]:
class PartialInceptionNetwork(nn.Module):

    def __init__(self, transform_input=True):
        super().__init__()
        self.inception_network = inception_v3(pretrained=True)
        self.inception_network.Mixed_7c.register_forward_hook(self.output_hook)
        self.transform_input = transform_input

    def output_hook(self, module, input, output):
        # N x 2048 x 8 x 8
        self.mixed_7c_output = output

    def forward(self, x):
        """
        Args:
            x: shape (N, 3, 299, 299) dtype: torch.float32 in range 0-1
        Returns:
            inception activations: torch.tensor, shape: (N, 2048), dtype: torch.float32
        """
        assert x.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" +\
                                             ", but got {}".format(x.shape)
        x = x * 2 -1 # Normalize to [-1, 1]

        # Trigger output hook
        self.inception_network(x)

        # Output: N x 2048 x 1 x 1 
        activations = self.mixed_7c_output
        activations = torch.nn.functional.adaptive_avg_pool2d(activations, (1,1))
        activations = activations.view(x.shape[0], 2048)
        return activations


def get_activations(images, batch_size):
    """
    Calculates activations for last pool layer for all iamges
    --
        Images: torch.array shape: (N, 3, 299, 299), dtype: torch.float32
        batch size: batch size used for inception network
    --
    Returns: np array shape: (N, 2048), dtype: np.float32
    """
    assert images.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" +\
                                              ", but got {}".format(images.shape)

    num_images = images.shape[0]
    inception_network = PartialInceptionNetwork()
    inception_network = inception_network.to(device)
    inception_network.eval()
    n_batches = int(np.ceil(num_images  / batch_size))
    inception_activations = np.zeros((num_images, 2048), dtype=np.float32)
    for batch_idx in range(n_batches):
        start_idx = batch_size * batch_idx
        end_idx = batch_size * (batch_idx + 1)

        ims = images[start_idx:end_idx]
        ims = ims.to(device)
        activations = inception_network(ims)
        activations = activations.detach().cpu().numpy()
        assert activations.shape == (ims.shape[0], 2048), "Expexted output shape to be: {}, but was: {}".format((ims.shape[0], 2048), activations.shape)
        inception_activations[start_idx:end_idx, :] = activations
    return inception_activations



def calculate_activation_statistics(images, batch_size):
    """Calculates the statistics used by FID
    Args:
        images: torch.tensor, shape: (N, 3, H, W), dtype: torch.float32 in range 0 - 1
        batch_size: batch size to use to calculate inception scores
    Returns:
        mu:     mean over all activations from the last pool layer of the inception model
        sigma:  covariance matrix over all activations from the last pool layer 
                of the inception model.

    """
    act = get_activations(images, batch_size)
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma


In [9]:
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
            
    Stable version by Dougal J. Sutherland.

    Params:
    -- mu1 : Numpy array containing the activations of the pool_3 layer of the
             inception net ( like returned by the function 'get_predictions')
             for generated samples.
    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted
               on an representive data set.
    -- sigma1: The covariance matrix over activations of the pool_3 layer for
               generated samples.
    -- sigma2: The covariance matrix over activations of the pool_3 layer,
               precalcualted on an representive data set.

    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2
    # product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps

        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


def preprocess_image(im):
    """Resizes and shifts the dynamic range of image to 0-1
    Args:
        im: np.array, shape: (H, W, 3), dtype: float32 between 0-1 or np.uint8
    Return:
        im: torch.tensor, shape: (3, 299, 299), dtype: torch.float32 between 0-1
    """
    assert im.shape[2] == 3
    assert len(im.shape) == 3
    if im.dtype == np.uint8:
        im = im.astype(np.float32) / 255
    im = cv2.resize(im, (299, 299))
    im = np.rollaxis(im, axis=2)
    im = torch.from_numpy(im)
    assert im.max() <= 1.0
    assert im.min() >= 0.0
    assert im.dtype == torch.float32
    assert im.shape == (3, 299, 299)

    return im


# def preprocess_images(images, use_multiprocessing):
#     """Resizes and shifts the dynamic range of image to 0-1
#     Args:
#         images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8
#         use_multiprocessing: If multiprocessing should be used to pre-process the images
#     Return:
#         final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1
#     """
#     if use_multiprocessing:
#         with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
#             jobs = []
#             for im in images:
#                 job = pool.apply_async(preprocess_image, (im,))
#                 jobs.append(job)
#             final_images = torch.zeros(images.shape[0], 3, 299, 299)
#             for idx, job in enumerate(jobs):
#                 im = job.get()
#                 final_images[idx] = im#job.get()
#     else:
#         final_images = torch.stack([preprocess_image(im) for im in images], dim=0)
#     assert final_images.shape == (images.shape[0], 3, 299, 299)
#     assert final_images.max() <= 1.0
#     assert final_images.min() >= 0.0
#     assert final_images.dtype == torch.float32
#     return final_images

def preprocess_images(images):
    """Resizes and shifts the dynamic range of image to 0-1
    Args:
        images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8
    Return:
        final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1
    """
    final_images = torch.stack([preprocess_image(im) for im in images], dim=0)
    
    assert final_images.shape == (images.shape[0], 3, 299, 299)
    assert final_images.max() <= 1.0
    assert final_images.min() >= 0.0
    assert final_images.dtype == torch.float32
    
    return final_images

def calculate_fid(images1, images2, batch_size):
    """ Calculate FID between images1 and images2
    Args:
        images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        use_multiprocessing: If multiprocessing should be used to pre-process the images
        batch size: batch size used for inception network
    Returns:
        FID (scalar)
    """
    images1 = preprocess_images(images1)
    images2 = preprocess_images(images2)
    mu1, sigma1 = calculate_activation_statistics(images1, batch_size)
    mu2, sigma2 = calculate_activation_statistics(images2, batch_size)
    fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    return fid

def resize_and_crop_images(images, image_size=64):
    upscale_size = round(image_size * 5 / 4)
    final_images = []

    for img in images:
        # Resize (upscale) the image
        resized_img = cv2.resize(img, (upscale_size, upscale_size), interpolation=cv2.INTER_CUBIC)

        # Randomly crop a region of the image
        x = random.randint(0, max(0, upscale_size - image_size))
        y = random.randint(0, max(0, upscale_size - image_size))
        cropped_img = resized_img[y:y+image_size, x:x+image_size]

        # Append the processed image
        final_images.append(cropped_img)

    return np.array(final_images)



# def load_images(path):
#     """ 
#     Loads all .jpg images from a given path.
#     Warnings: Expects all images to be of the same dtype and shape.
#     Args:
#         path: relative path to directory
#     Returns:
#         final_images: np.array of image dtype and shape.
#     """
#     if not os.path.exists(path):
#         raise ValueError(f"Path {path} does not exist")
    
#     image_extensions = ["jpg", "png"]
#     image_paths = []
#     for ext in image_extensions:
#         image_paths.extend(glob.glob(os.path.join(path, f"*.{ext}")))



#     if not image_paths:
#         raise ValueError(f"No JPG images found in directory {path}")

#     # Initialize an empty list for storing image arrays
#     images = []

#     for impath in image_paths:
#         im = cv2.imread(impath, cv2.IMREAD_COLOR)
#         if im is None:
#             continue  # Skip if the image can't be read

#         im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB

#         # Append to the list if the image is valid
#         images.append(im)

#     if not images:
#         raise ValueError("No valid images found in the directory")

#     # Check if all images have the same shape and dtype
#     first_image_shape = images[0].shape
#     first_image_dtype = images[0].dtype

#     if not all(im.shape == first_image_shape and im.dtype == first_image_dtype for im in images):
#         raise ValueError("Not all images have the same shape and dtype")

#     # Convert list of images to a numpy array
#     final_images = np.array(images)

#     return final_images


import glob
import os
import cv2
import numpy as np
import random

def load_images(path):
    """ 
    Loads a maximum of 4000 .png or .jpg images from a given path.
    Randomly samples 4000 images if more are available.
    Warnings: Expects all images to be of same dtype and shape.
    Args:
        path: relative path to directory
    Returns:
        final_images: np.array of image dtype and shape.
    """
    image_paths = []
    image_extensions = ["png", "jpg"]
    for ext in image_extensions:
        print("Looking for images in", os.path.join(path, f"*.{ext}"))
        image_paths.extend(glob.glob(os.path.join(path, f"*.{ext}")))

    # Randomly sample 4000 images if more are available
    if len(image_paths) > 1500:
        image_paths = random.sample(image_paths, 1500)

    if not image_paths:
        return np.array([])  # Return an empty array if no images found

    first_image = cv2.imread(image_paths[0])
    W, H = first_image.shape[:2]
    final_images = np.zeros((len(image_paths), H, W, 3), dtype=first_image.dtype)

    for idx, impath in enumerate(image_paths):
        im = cv2.imread(impath)
        im = im[:, :, ::-1] # Convert from BGR to RGB
        assert im.dtype == final_images.dtype
        final_images[idx] = im

    return final_images


In [10]:
real_images = load_images('fid_real_images/')
fake_images = load_images('single/')
real_images = resize_and_crop_images(real_images)

Looking for images in fid_real_images/*.png
Looking for images in fid_real_images/*.jpg
Looking for images in single/*.png
Looking for images in single/*.jpg


In [11]:
fid_value = calculate_fid(real_images, fake_images, 80)

In [12]:
fid_value

53.24366777965821