Skip to content

Commit

Permalink
make pytorch-fid package optional
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 17, 2020
1 parent 5d9691f commit 25ccfcf
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ Notes:

Thanks to <a href="https://github.com/GetsEclectic">GetsEclectic</a>, you can now calculate the FID score periodically! Again, made super simple with one extra argument, as shown below.

Firstly, install the `pytorch_fid` package

```bash
$ pip install pytorch-fid
```

Followed by

```bash
$ stylegan2_pytorch --data ./data --calculate-fid-every 5000
```
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
'torch',
'torchvision',
'pillow',
'pytorch-fid',
'vector-quantize-pytorch>=0.1.0'
],
classifiers=[
Expand Down
3 changes: 1 addition & 2 deletions stylegan2_pytorch/stylegan2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
from stylegan2_pytorch.version import __version__
from stylegan2_pytorch.diff_augment import DiffAugment

from pytorch_fid import fid_score

from vector_quantize_pytorch import VectorQuantize
from linear_attention_transformer import ImageLinearAttention

Expand Down Expand Up @@ -1102,6 +1100,7 @@ def tile(a, dim, n_tile):

@torch.no_grad()
def calculate_fid(self, num_batches):
from pytorch_fid import fid_score
torch.cuda.empty_cache()

real_path = str(self.results_dir / self.name / 'fid_real') + '/'
Expand Down
2 changes: 1 addition & 1 deletion stylegan2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.4'
__version__ = '1.5.5'

0 comments on commit 25ccfcf

Please sign in to comment.