-
Notifications
You must be signed in to change notification settings - Fork 3
/
fad_score.py
81 lines (71 loc) · 3.68 KB
/
fad_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_fad_score.ipynb.
# %% auto 0
__all__ = ['read_embeddings', 'calc_mu_sigma', 'calc_score', 'main']
# %% ../nbs/03_fad_score.ipynb 5
import torch
import argparse
from .sqrtm import sqrtm
from aeiou.core import fast_scandir
# %% ../nbs/03_fad_score.ipynb 6
def read_embeddings(emb_path='real_emb_clap/', debug=False):
"reads any .pt files in emb_path and concatenates them into one tensor"
if debug: print("searching in ",emb_path)
_, file_list = fast_scandir(emb_path, ['pt'])
if file_list == []:
_, file_list = fast_scandir('/fsx/shawley/code/fad_pytorch/'+emb_path, ['pt']) # yea, cheap hack just for my testing in nbs/ dir
assert file_list != []
embeddings = []
for file_path in file_list:
emb_batch = torch.load(file_path, map_location='cpu')
embeddings.append(emb_batch)
return torch.cat(embeddings, dim=0)
# %% ../nbs/03_fad_score.ipynb 8
def calc_mu_sigma(emb):
"calculates mean and covariance matrix of batched embeddings"
mu = torch.mean(emb, axis=0)
sigma = torch.cov(emb.T)
return mu, sigma
# %% ../nbs/03_fad_score.ipynb 10
def calc_score(real_emb_path, # where real embeddings are stored
fake_emb_path, # where fake embeddings are stored
method='maji', # sqrtm calc method: 'maji'|'li'
debug=False
):
print(f"Calculating FAD score for files in {real_emb_path}/ vs. {fake_emb_path}/")
emb_real = read_embeddings(emb_path=real_emb_path, debug=debug)
emb_fake = read_embeddings(emb_path=fake_emb_path, debug=debug)
if debug: print(emb_real.shape, emb_fake.shape)
mu_real, sigma_real = calc_mu_sigma(emb_real)
mu_fake, sigma_fake = calc_mu_sigma(emb_fake)
if debug:
print("mu_real.shape, sigma_real.shape =",mu_real.shape, sigma_real.shape)
print("mu_fake.shape, sigma_fake.shape =",mu_fake.shape, sigma_fake.shape)
mu_diff = mu_real - mu_fake
if debug:
print("mu_diff = ",mu_diff)
score1 = mu_diff.dot(mu_diff)
print("score1: mu_diff.dot(mu_diff) = ",score1)
score2 = torch.trace(sigma_real)
print("score2: torch.trace(sigma_real) = ", score2)
score3 = torch.trace(sigma_fake)
print("score3: torch.trace(sigma_fake) = ",score3)
score_p = sqrtm( torch.matmul( sigma_real, sigma_fake) )
print("score_p.shape (matmul) = ",score_p.shape)
score4 = -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake) , method=method ) ) )
print("score4 (-2*tr(sqrtm(matmul(sigma_r sigma_f)))) = ",score4)
score = score1 + score2 + score3 + score4
score = mu_diff.dot(mu_diff) + torch.trace(sigma_real) + torch.trace(sigma_fake) -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake), method=method ) ) )
return score
# %% ../nbs/03_fad_score.ipynb 16
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('real_emb_path', help='Path of files of embeddings of real data', default='real_emb_clap/')
parser.add_argument('fake_emb_path', help='Path of files of embeddings of fake data', default='fake_emb_clap/')
parser.add_argument('-d','--debug', action='store_true', help='Enable debugging')
parser.add_argument('-m','--method', default='maji', help='Method for sqrtm calculation: "maji" or "li" ')
args = parser.parse_args()
score = calc_score( args.real_emb_path, args.fake_emb_path, method=args.method, debug=args.debug )
print("FAD score = ",score.cpu().numpy())
# %% ../nbs/03_fad_score.ipynb 17
if __name__ == '__main__' and "get_ipython" not in dir():
main()