In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# fad_score

> Produce FAD score based on files of embeddings of real and fake data

$$ FAD = || \mu_r - \mu_f ||^2 + tr\left(\Sigma_r + \Sigma_f - 2 \sqrt{\Sigma_r \Sigma_f}\right)$$

The embeddings are small enough that this can typically be run on a single processor, on a CPU. However, all the supporting code is GPU-friendly if so desired. 

In [None]:
#| default_exp fad_score

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch 
import argparse
from fad_pytorch.sqrtm import sqrtm
from aeiou.core import fast_scandir

In [None]:
#| export
def read_embeddings(emb_path='real_emb_clap/', debug=False):
    "reads any .pt files in emb_path and concatenates them into one tensor"
    if debug: print("searching in ",emb_path) 
    _, file_list = fast_scandir(emb_path, ['pt'])
    if  file_list == []:
        _, file_list = fast_scandir('/fsx/shawley/code/fad_pytorch/'+emb_path, ['pt']) # yea, cheap hack just for my testing in nbs/ dir
    assert file_list != []
    embeddings = []
    for file_path in file_list:
        emb_batch = torch.load(file_path, map_location='cpu') 
        embeddings.append(emb_batch)
    return torch.cat(embeddings, dim=0)

In [None]:
#| eval: false
# lil test of that
e = read_embeddings()
e.shape

torch.Size([256, 512])

In [None]:
#| export 
def calc_mu_sigma(emb): 
    "calculates mean and covariance matrix of batched embeddings"
    mu = torch.mean(emb, axis=0)
    sigma = torch.cov(emb.T)
    return mu, sigma

In [None]:
#| eval: false
# quick test:
x = torch.rand(32,512) 
mu, sigma = calc_mu_sigma(x) 
mu.shape, sigma.shape 

(torch.Size([512]), torch.Size([512, 512]))

In [None]:
#| export
def calc_score(args, debug=False): 
    real_emb_path, fake_emb_path = args.real_emb_path, args.fake_emb_path
    emb_real = read_embeddings(emb_path=real_emb_path, debug=debug)
    emb_fake = read_embeddings(emb_path=fake_emb_path, debug=debug)
    if debug: print(emb_real.shape, emb_fake.shape)
    
    mu_real, sigma_real = calc_mu_sigma(emb_real) 
    mu_fake, sigma_fake = calc_mu_sigma(emb_fake) 
    if debug:
        print("mu_real.shape, sigma_real.shape =",mu_real.shape, sigma_real.shape)
        print("mu_fake.shape, sigma_fake.shape =",mu_fake.shape, sigma_fake.shape)
    
    diff = mu_real - mu_fake
    if debug:
        print("diff = ",diff) 
        score1 = diff.dot(diff)
        print("score1 = ",score1)
        score2 = torch.trace(sigma_real)
        print("score2 = ", score2)
        score3 = torch.trace(sigma_fake)
        print("score3 = ",score3)
        score_p = sqrtm( torch.matmul( sigma_real, sigma_fake) )
        print("score_p.shape = ",score_p.shape) 
        score4 = -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake)  ) ) )
        print("score4 = ",score4) 
        score = score1 + score2 + score3 + score4
    score = diff.dot(diff) + torch.trace(sigma_real) + torch.trace(sigma_fake) -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake)  ) ) )
    return score

In [None]:
#| eval: false
# test the score function
class DictToObject:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            setattr(self, key, value)
args_dict = {'real_emb_path':'real_emb_clap/', 'fake_emb_path':'fake_emb_clap/'} 
score = calc_score( DictToObject(args_dict) )
score

tensor(0.1561)

In [None]:
#| export
def main(): 
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('real_emb_path', help='Path of files of embeddings of real data', default='real_emb_clap/')
    parser.add_argument('fake_emb_path', help='Path of files of embeddings of fake data', default='fake_emb_clap/')
    args = parser.parse_args()
    score( args )

In [None]:
#| export
if __name__ == '__main__' and "get_ipython" not in dir():
    main()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()