## Example Usage
Simple example of how to get the FLS of a generative model for CIFAR10 (e.g. in this example RandomGAN).

We assume the generative model returns images in range $[-1, 1]$ with shape $[B, C, W, H]$

In [3]:
import torch
import torch.nn as nn
import torchvision

class RandomGAN(nn.Module):
    def __init__(self):
        super(RandomGAN, self).__init__()
        
    def forward(self, x):
        return torch.randn((128, 3, 32, 32))

GAN = RandomGAN()

# Create a no-argument function that returns batches of images
def generate_imgs():
    x = torch.randn((128, 100))
    return GAN(x)

In [5]:
from fls.features.InceptionFeatureExtractor import InceptionFeatureExtractor # or DINOv2FeatureExtractor/CLIPFeatureExtractor
from fls.metrics.FLS import FLS

# Save path determines where features are cached (useful for train/test sets)
feature_extractor = InceptionFeatureExtractor(save_path="data/features")

# FLS needs 3 sets of samples: train, test and gen
train_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True
)
train_dataset.name = "CIFAR10_train" # Dataset needs a name to cache features

test_dataset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True
)
test_dataset.name = "CIFAR10_test"

train_feat = feature_extractor.get_all_features(train_dataset)
test_feat = feature_extractor.get_all_features(test_dataset)

# For this example, we use RandomGAN
gen_feat = feature_extractor.get_gen_features(generate_imgs, size=10000)

# 1.322 is a dataset specific constant
cifar_fls = FLS("", 1.322).compute_metric(train_feat, test_feat, gen_feat)
print(f"Random GAN FLS: {cifar_fls}")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 31348001.74it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


100%|██████████| 79/79 [00:13<00:00,  5.98it/s]  


Random GAN FLS: 59.4827917098999
