In [11]:
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision.utils import make_grid
import torch.optim as opt
import numpy as np
import torchvision
import torch.nn.functional as func
import torchvision.transforms as tf_transforms
from scipy import linalg
from torchvision import models

In [12]:
class InceptionV3FeatureExtractor(nn.Module):
    DEFAULT_BLOCK_INDEX = 3
    FEATURE_DIM_TO_BLOCK = {
        64: 0,
        192: 1,
        768: 2,
        2048: 3
    }

    def __init__(self,
                 outputs=[DEFAULT_BLOCK_INDEX],
                 resize=True,
                 normalize=True,
                 is_grad=False):
        super().__init__()

        self.resize = resize
        self.normalize = normalize
        self.outputs = sorted(outputs)
        self.final_block_index = max(outputs)

        assert self.final_block_index <= 3, \
            'Maximum allowed block index is 3.'

        for param in self.parameters():
            param.is_grad = is_grad
            
        inception_model = models.inception_v3(pretrained=True)
        self.feature_blocks = nn.ModuleList(self._build_blocks(inception_model))

    def _build_blocks(self, inception_model):
        blocks = [
            nn.Sequential(
                inception_model.Conv2d_1a_3x3,
                inception_model.Conv2d_2a_3x3,
                inception_model.Conv2d_2b_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            )
        ]

        if self.final_block_index >= 1:
            blocks.append(nn.Sequential(
                inception_model.Conv2d_3b_1x1,
                inception_model.Conv2d_4a_3x3,
                nn.MaxPool2d(kernel_size=3, stride=2)
            ))

        if self.final_block_index >= 2:
            blocks.append(nn.Sequential(
                inception_model.Mixed_5b,
                inception_model.Mixed_5c,
                inception_model.Mixed_5d,
                inception_model.Mixed_6a,
                inception_model.Mixed_6b,
                inception_model.Mixed_6c,
                inception_model.Mixed_6d,
                inception_model.Mixed_6e,
            ))

        if self.final_block_index >= 3:
            blocks.append(nn.Sequential(
                inception_model.Mixed_7a,
                inception_model.Mixed_7b,
                inception_model.Mixed_7c,
                nn.AdaptiveAvgPool2d(output_size=(1, 1))
            ))
        return blocks

    def forward(self, input_images):
        features = []
        x = input_images

        if self.resize:
            x = func.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        if self.normalize:
            x = 2 * x - 1

        for block_idx, block in enumerate(self.feature_blocks):
            x = block(x)
            if block_idx in self.outputs:
                features.append(x)
            if block_idx == self.final_block_index:
                break

        return features

In [20]:
def find_image_mu_sigma(images, inception_model, use_cuda=False):
    inception_model.eval()
    activations=np.empty((len(images), 2048))
    
    if use_cuda:
        images=images.cuda()
    prediction = inception_model(images)[0]
    if prediction.size(2) != 1 or prediction.size(3) != 1:
        prediction = func.adaptive_avg_pool2d(prediction, output_size=(1, 1))

    activations= prediction.cpu().data.numpy().reshape(prediction.size(0), -1)
    return np.mean(activations, axis=0), np.cov(activations, rowvar=False)

def find_for_real_fake_images(real, fake, model, use_cuda=False):
    r_mean, r_covariance = find_image_mu_sigma(real, model, use_cuda=use_cuda)
    f_mean, f_covariance = find_image_mu_sigma(fake, model, use_cuda=use_cuda)
    return r_mean, r_covariance, f_mean, f_covariance

In [21]:
def find_distance_frechet(mu1, sigma1, mu2, sigma2, epsilon=1e-6):
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, 'Means of real and fake distributions are different'
    assert sigma1.shape == sigma2.shape, 'Dimensions of real and fake covariances are different'

    mean_diff = mu1 - mu2

    sigma_product_sqrt, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(sigma_product_sqrt).all():
        print(f"FID calculation gives singular product; Adding {epsilon} to diagonal of covariance estimates")
        offset_identity = np.eye(sigma1.shape[0]) * epsilon
        sigma_product_sqrt = linalg.sqrtm((sigma1 + offset_identity).dot(sigma2 + offset_identity))

    if np.iscomplexobj(sigma_product_sqrt):
        if not np.allclose(np.diagonal(sigma_product_sqrt).imag, 0, atol=1e-3):
            max_imag = np.max(np.abs(sigma_product_sqrt.imag))
            raise ValueError(f'Imaginary component {max_imag}')
        sigma_product_sqrt = sigma_product_sqrt.real

    trace_sigma_product_sqrt = np.trace(sigma_product_sqrt)

    frechet_dist = (mean_diff.dot(mean_diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * trace_sigma_product_sqrt)
    return frechet_dist

In [22]:
def calculate_fid_score(real_images, fake_images):
    inception_block_idx = InceptionV3FeatureExtractor.FEATURE_DIM_TO_BLOCK[2048]
    inception_model = InceptionV3FeatureExtractor([inception_block_idx])
    inception_model = inception_model.cuda()
    r_mu, r_sigma, f_mu, f_sigma = find_for_real_fake_images(real_images, fake_images, inception_model, use_cuda=True)
    fid_score = find_distance_frechet(r_mu, r_sigma, f_mu, f_sigma)
    return fid_score