In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import torch
import scipy
import numpy as np

sys.path.append(os.path.abspath('..'))

from models.scot import ScOT

# FPD Calculation

In this notebook, we demonstrate how to calculate the Fréchet Poseidon Distance (FPD) between two sets of solutions using the pre-trained PDE foundation model, Poseidon-B. The FPD is a metric to compare the distribution of solutions generated by different models. It is inspired by the Fréchet Inception Distance (FID) but uses the activations of the Poseidon-B model.

In [3]:
device = 'cuda:0'
poseidon = ScOT.from_pretrained("camlab-ethz/Poseidon-B").to(device)
poseidon.eval()
print(f'Poseidon-B loaded with {sum(p.numel() for p in poseidon.parameters())} parameters')

Poseidon-B loaded with 157729988 parameters


In [4]:
@torch.no_grad()
def get_activations(x):
    assert x.dim() == 3
    zeros = torch.zeros_like(x)
    pixels = torch.stack([x, zeros, zeros, zeros], dim=1)
    hidden = poseidon(
        pixels, time=torch.tensor(0, device=x.device), output_hidden_states=True
    ).hidden_states[-1].mean(1)
    return hidden.cpu().numpy()


def get_data_stats(x, batch_size=128):
    if x.dim() == 3:  # (B, s, s)
        act = np.concatenate([get_activations(x_) for x_ in x.split(batch_size, dim=0)], axis=0)
        mu = np.mean(act, axis=0)
        sigma = np.cov(act, rowvar=False)
        return mu, sigma
    elif x.dim() == 4:  # (B, s, s, t)
        B, s, _, T = x.size()
        x = x.permute(0, 3, 1, 2).reshape(B * T, s, s)
        act = np.concatenate([get_activations(x_) for x_ in x.split(batch_size, dim=0)], axis=0)
        act = act.reshape(B, T, -1)
        mu = np.mean(act, axis=0)
        sigma = np.stack([np.cov(act[:, i], rowvar=False) for i in range(T)], axis=0)
        return mu, sigma
    else:
        raise ValueError(f'Unsupported dim {x.dim()}!')


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Stable version by Dougal J. Sutherland.

    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert (
            mu1.shape == mu2.shape
    ), "Training and test mean vectors have different lengths"
    assert (
            sigma1.shape == sigma2.shape
    ), "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = (
                  "fid calculation produces singular product; "
                  "adding %s to diagonal of cov estimates"
              ) % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


def calculate_frechet_distance_batch(mu1, sigma1, mu2, sigma2, eps=1e-6):
    fpds = np.array([
        calculate_frechet_distance(m1, s1, m2, s2, eps=eps)
        for m1, s1, m2, s2 in zip(mu1, sigma1, mu2, sigma2)
    ])
    return fpds

To calculate FPD between two sets of solutions, first use `get_data_stats` to compute the mean and covariance of the activations of the solutions. Then use `calculate_frechet_distance` to compute the FPD for 2D data with shape `(n_sample, nx, nt)`, `calculate_frechet_distance_batch` for 3D data with shape `(n_sample, nx, ny, nt)` which returns the per frame FPD.

In [5]:
gt = torch.load('../gen/stokes_ic/gt.pt').to(device)
mu_gt, sigma_gt = get_data_stats(gt, batch_size=128)
gen = torch.load('../gen/stokes_ic/eci.pt').to(device)
mu_gen, sigma_gen = get_data_stats(gen, batch_size=128)
fpd = calculate_frechet_distance(mu_gt, sigma_gt, mu_gen, sigma_gen)
print(f'FPD: {fpd:.6f}')

FPD: 0.075787


In [6]:
gt = torch.load('../gen/ns/gt.pt').to(device)
mu_gt, sigma_gt = get_data_stats(gt, batch_size=128)
gen = torch.load('../gen/ns/eci.pt').to(device)
mu_gen, sigma_gen = get_data_stats(gen, batch_size=128)
fpds = calculate_frechet_distance_batch(mu_gt, sigma_gt, mu_gen, sigma_gen)
print(f'FPD: {fpds.mean():.6f}')

FPD: 1.131029
