## Example Usage
Simple example of how to get the FLD of a generative model for CIFAR10 (e.g. in this example RandomGAN).

We assume the generative model returns images in range $[0, 1]$ with shape $[B, C, W, H]$

In [8]:
import torch
import torch.nn as nn
import torchvision


class RandomGAN(nn.Module):
    def __init__(self):
        super(RandomGAN, self).__init__()

    def forward(self, x):
        return torch.randn((128, 3, 32, 32)).clip(0, 1)


GAN = RandomGAN()


# Create a no-argument function that returns batches of images
def generate_imgs():
    x = torch.randn((128, 100))
    return GAN(x).cuda()

### Mapping images to features
To work in a more perceptually space, we must first map samples to meaningful features.

In [9]:
from Codes_Vis.utils.fld.fld.features.InceptionFeatureExtractor import (
    InceptionFeatureExtractor,
)  # or DINOv2FeatureExtractor/CLIPFeatureExtractor

feature_extractor = InceptionFeatureExtractor()

# FLD needs 3 sets of samples: train, test and gen
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
train_dataset.name = "CIFAR10_train"  # Dataset needs a name to cache features

test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)
test_dataset.name = "CIFAR10_test"

# get_dataset_features to compute features of a torch.utils.Dataset
train_feat = feature_extractor.get_dataset_features(train_dataset)
test_feat = feature_extractor.get_dataset_features(test_dataset)

# Can get features directly from a model (e.g. RandomGAN) with get_model_features
gen_feat = feature_extractor.get_model_features(generate_imgs, num_samples=10_000)

# If you've already generated images, features can be obtained from a directory as well!
gen_dir_feat = feature_extractor.get_dir_features("/path/to/images", extension="png")

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 79/79 [01:06<00:00,  1.18it/s]  
                  

## Computing metrics
Once the three sets of features are ready, metrics can be computed!

In [10]:
# And many others!
from Codes_Vis.utils.fld import FLD
from Codes_Vis.utils.fld.fld.metrics.FID import FID

# Compute FLD, FID
cifar_fld = FLD().compute_metric(train_feat, test_feat, gen_feat)
cifar_fid = FID().compute_metric(train_feat, None, gen_feat)

print(f"Random GAN FLD: {cifar_fld:.2f}")
print(f"Random GAN FID: {cifar_fid:.2f}")

# Make sure RandomGAN isn't overfitting (more negative is more overfit)
gen_gap = FLD("gap").compute_metric(train_feat, test_feat, gen_feat)
print(f"Random GAN FLD Generalization Gap: {gen_gap:.2f}")

                                               

Random GAN FLD: 55.90
Random GAN FID: 440.23


                                               

Random GAN FLD Generalization Gap: 0.06
