In [1]:
import sys
sys.path.append('../')
from T2IBenchmark import calculate_fid

List of parametes:

In [2]:
help(calculate_fid)

Help on function calculate_fid in module T2IBenchmark.pipelines:

calculate_fid(input1: Union[str, List[str], T2IBenchmark.loaders.BaseImageLoader, T2IBenchmark.metrics.fid.FIDStats], input2: Union[str, List[str], T2IBenchmark.loaders.BaseImageLoader, T2IBenchmark.metrics.fid.FIDStats], device: torch.device = 'cuda', seed: Union[int, NoneType] = 42, batch_size: int = 128, dataloader_workers: int = 16, verbose: bool = True) -> (<class 'int'>, typing.Tuple[dict, dict])
    Calculate the Frechet Inception Distance (FID) between two sets of images.
    
    Parameters
    ----------
    input1 : Union[str, List[str], BaseImageLoader]
        The first set of images to compute the FID score for. This can either be
        a path to directory, a path to .npz file, a list of image file paths, an instance
        of BaseImageLoader or an instance of FIDStats.
    input2 : Union[str, List[str], BaseImageLoader]
        The second set of images to compute the FID score for. This can either be
  

**Advanced usage:**

In [3]:
fid, fid_data = calculate_fid(
    '../assets/images/cats/', 
    '../assets/images/dogs/',
    seed=111,
    batch_size=2,
    dataloader_workers=8,
    verbose=True
)

Processing: ImageDataset(5 items)


100%|██████████| 3/3 [00:00<00:00,  3.63it/s]


Processing: ImageDataset(5 items)


100%|██████████| 3/3 [00:00<00:00,  3.06it/s]


FID is 278.9133501791375


**Additional data**

`fid_data` contains two dictionaries with calculated InceptionV3 features and statistics for every input. 
Index `0` refers to first input data and index `1` refers to second input data 

Features and stats for `assets/images/cats/`:

In [4]:
print(type(fid_data[0]['features']), type(fid_data[0]['stats']))
print('Features shape:', fid_data[0]['features'].shape) # features have shape (num_images, 2048)
print('FID stats:', fid_data[0]['stats']) 

<class 'numpy.ndarray'> <class 'T2IBenchmark.metrics.fid.FIDStats'>
Features shape: (5, 2048)
FID stats: <T2IBenchmark.metrics.fid.FIDStats object at 0x7f66acb70b50>


You can save FID stats to a `.npz` file and use them later to avoid feature re-calculation 

In [5]:
cats_fid_stats = fid_data[0]['stats']
cats_fid_stats.to_npz('cats_stats.npz')

**FID stats usage**

In [6]:
from T2IBenchmark import FIDStats

# loading saved FID stats
cats_fid_stats = FIDStats.from_npz('cats_stats.npz')

In [7]:
fid, _ = calculate_fid(
    cats_fid_stats, 
    '../assets/images/dogs/',
    seed=111,
    batch_size=2,
    dataloader_workers=8
)

Processing: <T2IBenchmark.metrics.fid.FIDStats object at 0x7f66a922b910>
Processing: ImageDataset(5 items)


100%|██████████| 3/3 [00:00<00:00,  3.14it/s]


FID is 278.9133501791375
