# **Implementation of Frechet Inception Distance (FID) and Kernel Inception Distance (KID)**

In [1]:
import pickle,gzip,math,os,time,shutil,torch,random
import fastcore.all as fc
import matplotlib as mpl, matplotlib.pyplot as plt
import numpy as np
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
from scipy import linalg

from fastcore.foundation import L
import torchvision.transforms.functional as TF, torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
from miniai.accel import *

In [2]:
from fastcore.test import test_close
from torch import distributions

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'

import logging
logging.disable(logging.WARNING)

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

## **Evaluating Generative Performance**

We're reaching a point where our generated images are getting good enough to actually bias our own perception of what a _good image_ actually looks like. We need to be able to measure the quality of these generated images against a real life benchmark. The benchmark is an actual human being's comparison of generated images and a set of real images using the good old "mark-one eyeball" approach.

While the research community is constantly finding new ways to mathematically measure the difference between the two, the most commonly used metric is called the [Frechet Inception Distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_inception_distance).

> The Frechet Inception Distance score, or FID for short, is a metric that calculates the distance between feature vectors calculated for real and generated images.
>
>The score summarizes how similar the two groups are in terms of statistics on computer vision features of the raw images calculated using the inception v3 model used for image classification. Lower scores indicate the two groups of images are more similar, or have more similar statistics, with a perfect score being 0.0 indicating that the two groups of images are identical.
>
>The FID score is used to evaluate the quality of images generated by generative adversarial networks, and lower scores have been shown to correlate well with higher quality images.

## **Setup CLassifier**

To implement FID, we will use our existing model `fashion_ddpm_mp` which was trained using mixed precision. Here we will refer to this model as the `sample model` or `smodel`.

In [7]:
xl, yl = 'image', 'label'
name = "fashion_mnist"
bs = 512

# `*2-1` ensures the image range is normalized to (-1, 1) from NB 14 - Section AUGMENT 2
# This was to move away from subtracting the mean and dividing the standard deviation
@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]

dsd = load_dataset(name)
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)

In [6]:
b = xb, yb = next(iter(dls.train))

In [None]:
cbs = [DeviceCB(), MixedPrecision()]
model = torch.load('models/data_aug2.pkl')
learn = Learner(model, dls, F.cross_entropy, cbs=cbs, opt_func=None)

In [None]:
def append_outp(hook, mode, inp, outp):
    if not hasattr(hook, 'outp'): hook.outp = []
    hook.outp.append(to_cpu(outp))

In [None]:
hcb = HooksCallback(append_outp, mods=[learn.model[6]], on_valid=True)

In [None]:
learn.fit(1, train=False, cbs=[hcb])

In [None]:
feats = hcb.hooks[0].outp[0].float()[:64]
feats.shape

In [None]:
del(learn.model[8])
del(learn.model[7])

In [None]:
feats, y = learn.capture_preds()
feats = feats.float()
feats.shape, y

## **Calculate FID**

In [None]:
beta_min, beta_max, n_steps = 0.0001, 0.02, 1000
beta = torch.linspace(beta_min, beta_max, n_steps)
alpha = 1.-beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()

In [None]:
# Carrying over the noisify function from before
def noisify(x0, ᾱ):
    device = x0.device
    n = len(x0)
    t = torch.randint(0, n_steps, (n,), dtype=torch.long)
    ε = torch.randn(x0.shape, device=device)
    ᾱ_t = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
    xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
    return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[xl], alphabar)

def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)

In [None]:
dls2 = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

In [None]:
from diffusers import UNet2DModel

class UNet(UNet2DModel):
    def forward(self, x): return super().forward(*x).sample

In [None]:
# Load sample model
smodel = torch.load('models/fashion_ddpm_mp.pkl').cuda()

In [None]:
@torch.no_grad()
def sample(model, sz, alpha, alphabar, sigma, n_steps):
    device = next(model.parameters()).device
    x_t = torch.randn(sz, device=device)
    preds = []
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
        z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(device)
        ᾱ_t1 = alphabar[t-1]  if t > 0 else torch.tensor(1)
        b̄_t = 1 - alphabar[t]
        b̄_t1 = 1 - ᾱ_t1
        x_0_hat = ((x_t - b̄_t.sqrt() * model((x_t, t_batch)))/alphabar[t].sqrt())
        x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
        preds.append(x_0_hat.cpu())
    return preds

In [None]:
%%time
samples = sample(smodel, (256, 1, 32, 32), alpha, alphabar, sigma, n_steps)

In [None]:
s = samples[-1]*2-1

In [None]:
show_images(s[:16], imsize=1.5)

Taking 4 of the previously trained model's generated samples, we look at some statistics of select activations. Recall that we created `summary()` for the `TrainLearner`, which returns various output shapes of each layer of our model.

![title](imgs/table_summary.png)

We will take our samples and run them through the model which is pretrained to predict fashion classes. Afterwards, the layer titled `GlobalAvgPool`, will be extracted and the mean of each of the channel across the batch size will calculated.

Channels would contain different feature characteristics from the fashion dataset, so the mean would be representative of the distributions of these characteristics.

In [None]:
# Create a new DataLoader which contains no training batches. It does contain one
# validation batch with the samples from above.
clearn = TrainLearner(model, DataLoaders([], [(s, yb)]), loss_func=fc.noop, cbs=[DeviceCB()], opt_func=None)
feats2, y2 = clearn.capture_preds()
feats2 = feats2.float().squeeze()
feats2.shape

In [None]:
# 
means = feats.mean(0)
means.shape

In [None]:
covs = feats.T.cov()
covs.shape