In [1]:
#| default_exp fid

A challenge experimenting was that the generated images looked good, hence it was
easy to convince ourselves we're improving.
But there is no metric that indicates that these generated images would look to a human as pictures of clothes.
Only a person can do that.
There are some useful metrics which give an approximation, but are not a replacement for humans.
We will see *FID* the most common metric, and another metric called *KID*.

# FID

In [2]:
#|export
import pickle,gzip,math,os,time,shutil,torch,random
import fastcore.all as fc,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager
from scipy import linalg

from fastcore.foundation import L
import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torch.optim import lr_scheduler
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder

from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *
from miniai.resnet import *
from miniai.augment import *
from miniai.accel import *

In [3]:
from fastcore.test import test_close
from torch import distributions

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'

import logging
logging.disable(logging.WARNING)

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

## Classifier

In [4]:
xl,yl = 'image','label'
name = "fashion_mnist"
bs = 512

@inplace
def transformi(b): b[xl] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in b[xl]]

dsd = load_dataset(name)
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)

  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
b = xb,yb = next(iter(dls.train))

We demonstrate them using a saved model from notebook 17_, and getting the FID for it. 
In notebook 14 we created a `summary()` that shows the different blocks of the model.
There are various different output shapes, e.g., in the first block 
it's a batch size of 1024, 16 channels 28x28 and then we had 32 channels 14x14, ....
Before the final linear layer we had 1024 batches and 512 Channels, with no height and width.
The idea of FID and KID is that the distribution of these 512 channels for a real image has a particular 
signature, i.e., looks a particular way.
We're going to take our samples, run it through a model that's learned to predict
e.g., "fashioned glasses", and we're going to grab the "GlobalAvgPooling" layer, 
and then we're going to average it across a batch right to get 512 numbers, that 
represent the mean of each of those channels.
Those channels might represent "features", for example, "does it have a pointed collar",
"does it have smooth fabric", "does it have sharp heels", etc.
We could recognize that something is probably not a normal fashion image if it 
has sharp heels and flowing fabric.
There are certain sets of means of these activations that don't make sense.<br>
This is not a metric for an individual image, but for a set of images.
We generate fashion images and ask do they look like a bunch of fashion images.
If we look at maybe X% have this feature and have that feature.
Looking at those means is like comparing the distribution within all these images generated,
do they roughly have the same amount, sharp colors as those.

Let's start at that level which is this `feats.mean()`. 
We take our samples and we pass them through a pre-trained model that has learned 
to predict what "type of fashion" something is.
We trained some of those in the 14_ notebook, specifically we trained 
a 20 epoch one in the data augmentation section, which had a 94.3% accuracy.
If we pass our samples through this model we expect to get some useful features.
This was a bit complicated because this model was trained using data that had gone through 
a transformation of subtracting the mean and dividing by the standard deviation,
and that is not what we're creating in our samples.
Most of the diffusion models samples tend to be between -1 and 1.
JH added a new section to the bottom of the 14_ notebook which replaces the transform with something
that goes from -1 to 1 and just creates those data loaders and
then trains  something that can classify fashion, and saved this as `data_aug2.pkl`.
It is the same as before but it's a fashion classifier where the inputs are expected to be between -1 and 1.
BUT, our image samples are NOT between -1 and 1.
In notebook 17_ ddpm2 we use `TF.totensor()`, that makes images that are between 0 and 1. 
This seems to be "a bug", that the images go between 0 and 1, so we'll look at fixing that in a moment.
For now we're just trying to get the FID of our existing model.
We take the output of our model and we need to multiply by 2, 
so that it will be between 0 and 2 and subtract 1, to change our samples to be between -1 and 1.
Now we can pass them through our pre-trained fashion classifier.
How do we get the output of that pooling layer which is what we want?

To flex our Pytorch muscles lets show a couple of ways to do it.
Lets load the `data_aug2.pkl` model.

In [6]:
cbs = [DeviceCB(), MixedPrecision()]
model = torch.load('models/data_aug2.pkl')
learn = Learner(model, dls, F.cross_entropy, cbs=cbs, opt_func=None)

We can use a hook, using `HooksCallback` to create a function `append_outp`
which just depends on the output.

In [None]:
def append_outp(hook, mod, inp, outp):
    if not hasattr(hook,'outp'): hook.outp = []
    hook.outp.append(to_cpu(outp))

As the model is a Sequential, we go through the layers, to find the layer we want, '6'.

In [None]:
hcb = HooksCallback(append_outp, mods=[learn.model[6]], on_valid=True)

Once we've hooked that we can pass that as a callback.
It is a bit "weird" calling `fit()` because `train=False`, we just want 
to make one batch go through and grab the outputs in our hook's `outp`.

In [None]:
learn.fit(1, train=False, cbs=[hcb])

We can then grab a few (64) of those to have a look, it is a 64 by 512 set of features.

In [None]:
feats = hcb.hooks[0].outp[0].float()[:64]
feats.shape

Another way to do it.
Sequential Modules are Python collections that have an API that they're expected to support.
There is a delete "something" call for a collection, like a list.
We can delete the last 2 layers and be left with just the other layers.
That means we can just call `capture_preds`. 
We delete layers eight and seven, call `capture_preds` and this is
going to give us the entire 10 000 images in the test set.

In [None]:
del(learn.model[8])
del(learn.model[7])

In [None]:
feats,y = learn.capture_preds()
feats = feats.float()
feats.shape,y

## Calc FID

In [None]:
betamin,betamax,n_steps = 0.0001,0.02,1000
beta = torch.linspace(betamin, betamax, n_steps)
alpha = 1.-beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()

In [None]:
def noisify(x0, ᾱ):
    device = x0.device
    n = len(x0)
    t = torch.randint(0, n_steps, (n,), dtype=torch.long)
    ε = torch.randn(x0.shape, device=device)
    ᾱ_t = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
    xt = ᾱ_t.sqrt()*x0 + (1-ᾱ_t).sqrt()*ε
    return (xt, t.to(device)), ε

def collate_ddpm(b): return noisify(default_collate(b)[xl], alphabar)
def dl_ddpm(ds): return DataLoader(ds, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)

In [None]:
dls2 = DataLoaders(dl_ddpm(tds['train']), dl_ddpm(tds['test']))

In [None]:
from diffusers import UNet2DModel

class UNet(UNet2DModel):
    def forward(self, x): return super().forward(*x).sample

We demonstrate them using the model we changed in the last lesson, DDPMv2, in notebook 17_, where
we trained with mixed precision and saved it as `fashion_ddpm_mp.pkl`.
We're going to try to get the FID for a model we've already trained.
We get the trained samples model `smodel`, 
with `torch.load`, and then `cuda()` to move it to the GPU.

In [None]:
smodel = torch.load('models/fashion_ddpm_mp.pkl').cuda()

`sample()` is a copy of dbpm from the last time.
We're going to sample from that model, and try to calculate the FID score, which 
indicates how similar are the sample to real images.
For the sample images we look at some statistics of some of the activations.

In [None]:
@torch.no_grad()
def sample(model, sz, alpha, alphabar, sigma, n_steps):
    device = next(model.parameters()).device
    x_t = torch.randn(sz, device=device)
    preds = []
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
        z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(device)
        ᾱ_t1 = alphabar[t-1]  if t > 0 else torch.tensor(1)
        b̄_t = 1 - alphabar[t]
        b̄_t1 = 1 - ᾱ_t1
        x_0_hat = ((x_t - b̄_t.sqrt() * model((x_t, t_batch)))/alphabar[t].sqrt())
        x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
        preds.append(x_0_hat.cpu())
    return preds

In [None]:
%%time
samples = sample(smodel, (256, 1, 32, 32), alpha, alphabar, sigma, n_steps)

In [None]:
s = samples[-1]*2-1

In [None]:
show_images(s[:16], imsize=1.5)

We create a new data loader which contains no training batches, 
it contains one validation batch which contains the samples.
It doesn't matter what the dependent variable is so we just put in the same dependent variable 
that we already had `yb`.
We use that to extract some features from a model.

In [None]:
clearn = TrainLearner(model, DataLoaders([],[(s,yb)]), loss_func=fc.noop, cbs=[DeviceCB()], opt_func=None)
feats2,y2 = clearn.capture_preds()
feats2 = feats2.float().squeeze()
feats2.shape

In [None]:
means = feats.mean(0)
means.shape

In [None]:
covs = feats.T.cov()
covs.shape

### Comparing Matrices

There are ways of comparing two sets of data to evaluate if they are from the same distribution.
The FID metric evaluates if they have similar covariance matrices and similar mean vectors.
`_calc_stats` returns the means and the covariance metrics.
We call `_calc_stats` for both the sample features and the actual dataset or the test dataset.
Given those `s1,s2` features we can calculate the "Fresher Inception Distance" (FID). 
Since we multiply together the two covariance matrices the result is going to be bigger, 
thus we need to scale it down.
When working with scalars, after we multiply two of them, we take the square root to scale back down to
the original scale.
To "renormalize" these matrices we've got to take the matrix square root.
We are going to slightly cheat. 
We've used the float square root from Python's standard library, and it isn't interesting.
To calculate the float square root the classic way is to use Newton's method,
by solving a*a=x, and solve it using the derivative, and taking a step along the derivative a bunch of times.
We can do the same to calculate the matrix square root using the Newton method, but 
for matrices it is slightly more complicated, see `_sqrtm_newton_schultz()`.

In [None]:
#|export
def _sqrtm_newton_schulz(mat, num_iters=100):
    mat_nrm = mat.norm()
    mat = mat.double()
    Y = mat/mat_nrm
    n = len(mat)
    I = torch.eye(n, n).to(mat)
    Z = torch.eye(n, n).to(mat)

    for i in range(num_iters):
        T = (3*I - Z@Y)/2
        Y,Z = Y@T,T@Z
        res = Y*mat_nrm.sqrt()
        if ((mat-(res@res)).norm()/mat_nrm).abs()<=1e-6: break
    return res

Since we have implemented it from scratch we can use the one from scipy, `linalg.sqrtm()`,
to give us a measure of similarity between the two covariance matrices.
And here's the measure of similarity between the two mean matrices,
just the sum of squared errors, and it's just normalizing.
The "trace" is the sum of the diagonal elements.
We need to add traces and subtract two times the trace.
The result is the Fresher Inception Distance, a number which represents how similar a set of samples 
are to some real image data.
The name "Fresher Inception Distance" is peculiar, it has nothing to do with Inception, but as
people use the famous Inception model (Imagenet winning model from Google brain).
Inception is not a good model to use for this, it just happens to be the one which the original paper used, so everybody now uses it to compare results.
 
We're going to get a more accurate metric by using a model that is good at recognizing fashion MNIST.
It's better to use a model that we have trained on our data and we know it's good at that.
It is not a "FID".

In [None]:
#|export
def _calc_stats(feats):
    feats = feats.squeeze()
    return feats.mean(0),feats.T.cov()

def _calc_fid(m1,c1,m2,c2):
#     csr = _sqrtm_newton_schulz(c1@c2)
    csr = tensor(linalg.sqrtm(c1@c2, 256).real)
    return (((m1-m2)**2).sum() + c1.trace() + c2.trace() - 2*csr.trace()).item()

In [None]:
s1,s2 = _calc_stats(feats),_calc_stats(feats2)

In [None]:
np.isnan(s1[0].data).any(), np.isnan(s2[0].data).any(), np.isnan(s1[1].data).any(), np.isnan(s2[1].data).any()

In [None]:
_calc_fid(*s1, *s2)

two caveats of FID: 
1) it is dependent on the number of samples used, more/less accurate with more/less samples. 
It is biased, so with less samples it's too high, papers need to report how many samples they used.
2) Because of using the Inception Network, all images are at a size 299x299, the size that the Inception model was trained.

Applying this Inception Network for measuring the distance means resizing the images to 299x299 which
may not make sense, eg Fashion MNIST is 28x28, resize it to 299 .. :)
Also for larger, e.g.  512x512 or 1024x1024 images, shrink them to 299x299, loosing a lot of detail.
It is a problem for some of these latest papers, the FID scores and how they're comparing them,
and then visually they are better images but the FID score doesn't capture that, because it is using shrinked images.

In [None]:
#|export
def _squared_mmd(x, y):
    def k(a,b): return (a@b.transpose(-2,-1)/a.shape[-1]+1)**3
    m,n = x.shape[-2],y.shape[-2]
    kxx,kyy,kxy = k(x,x), k(y,y), k(x,y)
    kxx_sum = kxx.sum([-1,-2])-kxx.diagonal(0,-1,-2).sum(-1)
    kyy_sum = kyy.sum([-1,-2])-kyy.diagonal(0,-1,-2).sum(-1)
    kxy_sum = kxy.sum([-1,-2])
    return kxx_sum/m/(m-1) + kyy_sum/n/(n-1) - kxy_sum*2/m/n

### KID
The **KID (Kernel Inception Distance)** metric compares two distributions in a way that is not biased,
so it's not necessarily higher or lower if you use more or less samples.
It is simpler to calculate than the FID.
We create a set of groups or partitions, go through each of those partitions
and grab a few `x`'s at a time and a few of `y`'s at a time.
Then we calculate the **MMD**, which does a matrix product, we take the cube of it.
`k` is the kernel and we do that for the first sample bytes 
compared to itself the second compared to itself and the first compared to the second.
We then normalize them in various ways and add to the two with themselves together 
and subtract the with the other one.

In [None]:
#|export
def _calc_kid(x, y, maxs=50):
    xs,ys = x.shape[0],y.shape[0]
    n = max(math.ceil(min(xs/maxs, ys/maxs)), 4)
    mmd = 0.
    for i in range(n):
        cur_x = x[round(i*xs/n) : round((i+1)*xs/n)]
        cur_y = y[round(i*ys/n) : round((i+1)*ys/n)]
        mmd += _squared_mmd(cur_x, cur_y)
    return (mmd/n).item()

In [None]:
_calc_kid(feats, feats2)

KID does not use the stats (the means and covariance Matrix), it uses the features directly.
The final result is the mean of this calculated across different little batches.
It gives us a measure of the similarity of these two distributions.
JH was unsure why more people weren't using KID, since it doesn't have a bias problem.
After using it for a while, the reason is that it has a very high variance,
i.e., when we call it multiple times with just samples with different random seeds 
we get very different values, hence, not useful. 

We don't have a good unbiased metric.
Even if we did it would only tell how similar distributions are to each other, 
it doesn't tell us whether they look good.
That is why all good papers have a section on human testing.
Still, this fit is useful for comparing fashion images. 
Humans are good at looking at faces at a reasonably high resolution, 
but we're not good at looking at 28x28 fashion images.
So it's particularly helpful for stuff that our brains aren't good at.

## FID class

We wrap this up into a `ImageEval` class.
We are going to pass in a pre-trained `model` for a classifier, `dls` data loaders,  
which we're going to use to calculate the real images.

We call `capture_preds` to get the features for the real images, and then we can also calculate the stats for the real images.
We call `_calc_fid` passing in the `stats` for the real images and `_calc_stats` for the features from our samples, `samp`.
Where the features are given by `get_feats()`, we pass in `samp`, any random `y` value is fine, we 
have a single tensor `tensor([0])`, and call `capture_preds`.

In [None]:
#|export
class ImageEval:
    def __init__(self, model, dls, cbs=None):
        self.learn = TrainLearner(model, dls, loss_func=fc.noop, cbs=cbs, opt_func=None)
        self.feats = self.learn.capture_preds()[0].float().cpu().squeeze()
        self.stats = _calc_stats(self.feats)

    def get_feats(self, samp):
        self.learn.dls = DataLoaders([],[(samp, tensor([0]))])
        return self.learn.capture_preds()[0].float().cpu().squeeze()

    def fid(self, samp): return _calc_fid(*self.stats, *_calc_stats(self.get_feats(samp)))
    def kid(self, samp): return _calc_kid(self.feats, self.get_feats(samp))

We can now create an `ImageEval` object `ie` passing in our classifier, data loaders with the real data,
and any other callbacks we want.

In [None]:
ie = ImageEval(model, learn.dls, cbs=[DeviceCB()])

We call `ie.fid`, and 33.9 is the Fid per these samples.

In [None]:
%%time
ie.fid(s)

Kid is on a very different scale, e.g. only 0.05, generally much smaller than FIDs.
We are mainly going to be looking at FIDs.

In [None]:
%%time
ie.kid(s)

Here's what happens if we call FID on sample 0, then sample 50,  etc.
all the way up to 900, and then we also do samples 975 990 and 999.
Over time our samples improved so that's a good test.
It is curious that they stopped improving....
JH has not seen anybody plot this graph before, and it's something to look at because it's telling us
if the sampling is making consistent improvements.

To clarify this is like the predicted denoised sample at the different stages during sampling.
if I was to stop something now and just go straight to the predicted X error what would the FID be?

In [None]:
xs = L.range(0,1000,50)+[975,990,999]
plt.plot(xs, [ie.fid(samples[i].clamp(-0.5,0.5)*2) for i in xs]);

Same for the KID, and the plots look the same.

In [None]:
xs = L.range(0,1000,50)+[975,990,999]
plt.plot(xs, [ie.kid(samples[i].clamp(-0.5,0.5)*2) for i in xs]);

It's a good idea to take the FID of an actual batch of data, to tell us how good we could get.
That's a bit unfair because I think the different sizes our data is 512, our sample is 256,
but anyway there it's it's a pretty huge difference.

In [None]:
ie.fid(xb)

In [None]:
ie.kid(xb)

## Inception

What does it take to get a real FID with an Inception Network.
We are not reimplementing the Inception Network, as it is obsolete, just grab it from Pytorch_fid.

In [None]:
from pytorch_fid.inception import InceptionV3

In [None]:
a = tensor([1,2,3])
a.repeat((3,1))

We `resize_input` to get  3 Channel 299 by 299 images
Created a wrapper `IncepWrap` for an Inception V3 model.
We call `forward` with a batch and just replicates the single channel 3 times to create a 3 
Channel version of a black and white image.
TODO: Flex Pytorch muscles by getting an Inception model working on the Fashion MNIST samples.

In [None]:
class IncepWrap(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = InceptionV3(resize_input=True)
    def forward(self, x): return self.m(x.repeat(1,3,1,1))[0]

In [None]:
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=fc.defaults.cpus)

We pass that to `ImageEval` and it gives 63.8 and on a real batch of data it gets 27.9,
a sign that this is less effective than our real fashion mnist classifier. 
A difference of a ratio of three, our FID for real data using a real classifier was 6.6, that is encouraging.
We now have a FID, more specifically we now have an image eval.

In [None]:
ie = ImageEval(IncepWrap(), dls, cbs=[DeviceCB()])

In [None]:
%%time
ie.fid(s)

In [None]:
ie.fid(xb)

In [None]:
%%time
ie.kid(s)

In [None]:
ie.kid(xb)

J: Other FIDs reported are CIFAR tiny 32 by 32 pixels resized up to 299.
FID is a slightly weird metric, if we saved images as jpegs, and then you load them, the FID may be twice as bad.
The takeaway it's useful when using the same backbone model, the same approach, same number of samples, then we can compare apples to apples.
<br>
For our own experiments these metrics are good but they may not be to compare to other models, then it is best to rely on human studies. 
This is useful for for us all the time, we're going to use the same set number of samples and we're going to use the same fashion MNIST specific classifier.

## Export -

In [None]:
import nbdev; nbdev.nbdev_export()