Skip to content

Commit

Permalink
fix fad calculation for newer versions of scipy (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhvng committed Jan 25, 2024
1 parent cf50298 commit 99ab594
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions frechet_audio_distance/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,13 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
diff = mu1 - mu2

# Product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2).astype(complex), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset).astype(complex))

# Numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy==1.23.4
numpy
torch
scipy==1.10.1
scipy
tqdm
soundfile
resampy
Expand Down

0 comments on commit 99ab594

Please sign in to comment.