<a href="https://colab.research.google.com/github/ejnunn/GAN_Research/blob/main/AFID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !rm -r GAN_Research
!git clone --quiet https://github.com/ejnunn/GAN_Research.git
!pip install torchextractor

Collecting torchextractor
  Downloading https://files.pythonhosted.org/packages/cc/94/f14591882d0459a626d6aa8ed3699b08e6b79192c26cae87cbd6081cb835/torchextractor-0.3.0-py3-none-any.whl
Installing collected packages: torchextractor
Successfully installed torchextractor-0.3.0


In [67]:
import torch
import torchvision
import torchextractor as tx
import multiprocessing
import numpy as np
import cv2
from scipy import linalg

In [3]:
original_model = torchvision.models.inception_v3(pretrained=True)

Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-1a9a5a14.pth


HBox(children=(FloatProgress(value=0.0, max=108857766.0), HTML(value='')))




In [77]:
images1 = torch.rand(128, 3, 299, 299)
images2 = torch.rand(128, 3, 299, 299)

fid_value1, fid_value2, fid_value3 = calculate_fid(images1, images2, use_multiprocessing=False, batch_size=1)

In [78]:
print('fid_value1 =', fid_value1)
print('fid_value2 =', fid_value2)
print('fid_value3 =', fid_value3)

fid_value1 = 29066.984375359836
fid_value2 = 16503.98046884275
fid_value3 = 17.05095610770348


# FID Functions

In [63]:
def get_activations(images, batch_size):
    """
    Calculates activations for last pool layer for all iamges
    --
        Images: torch.array shape: (N, 3, 299, 299), dtype: torch.float32
        batch size: batch size used for inception network
    --
    Returns: np array shape: (N, 2048), dtype: np.float32
    """
    assert images.shape[1:] == (3, 299, 299), "Expected input shape to be: (N,3,299,299)" +\
                                              ", but got {}".format(images.shape)

    num_images = images.shape[0]
    original_model = torchvision.models.inception_v3(pretrained=True)
    inception_network = tx.Extractor(original_model, ['maxpool1', 'maxpool2', 'avgpool'])
    inception_network = to_cuda(inception_network)
    inception_network.eval()
    n_batches = int(np.ceil(num_images  / batch_size))
    inception_activations = np.zeros((num_images, 2048), dtype=np.float32)
    for batch_idx in range(n_batches):
        start_idx = batch_size * batch_idx
        end_idx = batch_size * (batch_idx + 1)

        ims = images[start_idx:end_idx]
        ims = to_cuda(ims)
        model_output, features = inception_network(ims)
        act1, act2, act3 = features.values()

        act1 = act1.detach().cpu().numpy().flatten()
        act1 = np.expand_dims(act1, axis=0)
        act2 = act2.detach().cpu().numpy().flatten()
        act2 = np.expand_dims(act2, axis=0)
        act3 = act3.detach().cpu().numpy().squeeze()
        act3 = np.expand_dims(act3, axis=0)
        
        assert act1.shape == (ims.shape[0], 341056), "Expexted output shape to be: {}, but was: {}".format((ims.shape[0], 341056), act1.shape)
        assert act2.shape == (ims.shape[0], 235200), "Expexted output shape to be: {}, but was: {}".format((ims.shape[0], 235200), act2.shape)
        assert act3.shape == (ims.shape[0], 2048), "Expexted output shape to be: {}, but was: {}".format((ims.shape[0], 2048), act3.shape)
        
    return act1, act2, act3

In [65]:
def calculate_activation_statistics(images, batch_size):
    """Calculates the statistics used by FID
    Args:
        images: torch.tensor, shape: (N, 3, H, W), dtype: torch.float32 in range 0 - 1
        batch_size: batch size to use to calculate inception scores
    Returns:
        mu:     mean over all activations from the last pool layer of the inception model
        sigma:  covariance matrix over all activations from the last pool layer 
                of the inception model.
    """
    act1, act2, act3 = get_activations(images, batch_size)
    
    mu1 = np.mean(act1, axis=0)
    mu2 = np.mean(act2, axis=0)
    mu3 = np.mean(act3, axis=0)

    sigma1 = np.cov(act1, rowvar=False)
    sigma2 = np.cov(act2, rowvar=False)
    sigma3 = np.cov(act3, rowvar=False)
    return mu1, mu2, mu3, sigma1, sigma2, sigma3

In [61]:
# Modified from: https://github.com/bioinf-jku/TTUR/blob/master/fid.py
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
            
    Stable version by Dougal J. Sutherland.
    Params:
    -- mu1 : Numpy array containing the activations of the pool_3 layer of the
             inception net ( like returned by the function 'get_predictions')
             for generated samples.
    -- mu2   : The sample mean over activations of the pool_3 layer, precalcualted
               on an representive data set.
    -- sigma1: The covariance matrix over activations of the pool_3 layer for
               generated samples.
    -- sigma2: The covariance matrix over activations of the pool_3 layer,
               precalcualted on an representive data set.
    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2
    # product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
        warnings.warn(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

In [72]:
def calculate_fid(images1, images2, use_multiprocessing, batch_size):
    """ Calculate FID between images1 and images2
    Args:
        images1: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        images2: np.array, shape: (N, H, W, 3), dtype: np.float32 between 0-1 or np.uint8
        use_multiprocessing: If multiprocessing should be used to pre-process the images
        batch size: batch size used for inception network
    Returns:
        FID (scalar)
    """
    # images1 = preprocess_images(images1, use_multiprocessing)
    # images2 = preprocess_images(images2, use_multiprocessing)
    mu11, mu12, mu13, sigma11, sigma12, sigma13 = calculate_activation_statistics(images1, batch_size)
    mu21, mu22, mu23, sigma21, sigma22, sigma23 = calculate_activation_statistics(images2, batch_size)
    fid1 = calculate_frechet_distance(mu11, sigma11, mu21, sigma21)
    fid2 = calculate_frechet_distance(mu12, sigma12, mu22, sigma22)
    fid3 = calculate_frechet_distance(mu13, sigma13, mu23, sigma23)
    return fid1, fid2, fid3

In [10]:
def preprocess_image(im):
    """Resizes and shifts the dynamic range of image to 0-1
    Args:
        im: np.array, shape: (H, W, 3), dtype: float32 between 0-1 or np.uint8
    Return:
        im: torch.tensor, shape: (3, 299, 299), dtype: torch.float32 between 0-1
    """
    assert im.shape[0] == 3
    assert len(im.shape) == 3
    if im.dtype == np.uint8:
        im = im.astype(np.float32) / 255
    im = cv2.resize(im, (299, 299))
    im = np.rollaxis(im, axis=2)
    im = torch.from_numpy(im)
    assert im.max() <= 1.0
    assert im.min() >= 0.0
    assert im.dtype == torch.float32
    assert im.shape == (3, 299, 299)

    return im

In [11]:
def preprocess_images(images, use_multiprocessing):
    """Resizes and shifts the dynamic range of image to 0-1
    Args:
        images: np.array, shape: (N, H, W, 3), dtype: float32 between 0-1 or np.uint8
        use_multiprocessing: If multiprocessing should be used to pre-process the images
    Return:
        final_images: torch.tensor, shape: (N, 3, 299, 299), dtype: torch.float32 between 0-1
    """
    if use_multiprocessing:
        with multiprocessing.Pool(multiprocessing.cpu_count()) as pool:
            jobs = []
            for im in images:
                job = pool.apply_async(preprocess_image, (im,))
                jobs.append(job)
            final_images = torch.zeros(images.shape[0], 3, 299, 299)
            for idx, job in enumerate(jobs):
                im = job.get()
                final_images[idx] = im#job.get()
    else:
        final_images = torch.stack([preprocess_image(im) for im in images], dim=0)
    assert final_images.shape == (images.shape[0], 3, 299, 299)
    assert final_images.max() <= 1.0
    assert final_images.min() >= 0.0
    assert final_images.dtype == torch.float32
    return final_images

In [12]:
def to_cuda(elements):
    """
    Transfers elements to cuda if GPU is available
    Args:
        elements: torch.tensor or torch.nn.module
        --
    Returns:
        elements: same as input on GPU memory, if available
    """
    if torch.cuda.is_available():
        return elements.cuda()
    return elements

# Random code

In [40]:
model = tx.Extractor(original_model, ['maxpool1', 'maxpool2', 'avgpool'])
model.eval()
dummy = torch.rand(1, 3, 299, 299)
model_output, features = model(dummy)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

{'maxpool1': torch.Size([1, 64, 73, 73]), 'maxpool2': torch.Size([1, 192, 35, 35]), 'avgpool': torch.Size([1, 2048, 1, 1])}


In [15]:
for name, f in features.items():
  print(name, f.flatten().shape)

maxpool1 torch.Size([341056])
maxpool2 torch.Size([235200])
avgpool torch.Size([2048])


In [None]:
[x for x in tx.list_module_names(model) if x.find('pool') != -1]