In [1]:
import os
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
import sys
sys.path.append('/Users/nicolasguigui/gits/geomstats')

In [None]:
import pandas as pd
from time import time
from nilearn.connectome.connectivity_matrices import _geometric_mean
from pyriemann.utils.mean import mean_riemann

In [None]:
import geomstats.backend as gs
from geomstats.geometry.spd_matrices import SPDMatrices, SPDMetricAffine, SPDMetricBuresWasserstein
from geomstats.learning.frechet_mean import FrechetMean

Let's use geomstats to sample random matrices

In [None]:
n_points = 100
dim = 10

space = SPDMatrices(dim)
data = space.random_point(n_samples=n_points)

We now use geomstats' affine-invariant (AI) metric to compute the Fréchet mean

In [None]:
metric = SPDMetricAffine(dim)
mean = FrechetMean(
    metric=metric, method='adaptive',
    max_iter=1000, verbose=True, epsilon=1e-12, lr=1.)
mean.fit(data)
geomstats_mean = mean.estimate_

The same can be done with Nilearn and PyRiemann

In [None]:
nilearn_mean = _geometric_mean(data, max_iter=1000, tol=1e-12)
pyriemann_mean = mean_riemann(data, maxiter=1000, tol=1e-12)

We can also compute the mean with the Bures-Wasserstein distance as follows

In [None]:
metric_bw = SPDMetricBuresWasserstein(dim)
mean = FrechetMean(
    metric=metric_bw, method='adaptive',
    max_iter=1000, verbose=True, epsilon=1e-12, lr=1.)
mean.fit(data)
geomstats_mean_bw = mean.estimate_

We can compare the estimtates (in AI distance)

In [None]:
distances = metric.dist_pairwise(
    gs.stack([geomstats_mean, nilearn_mean, pyriemann_mean, geomstats_mean_bw]))
print(distances)

Let's benchmark for many dimensions/number of points (this may take a while)

In [None]:
gs.random.seed(0)
df = pd.DataFrame(
    columns=['dim', 'n_points', 'geomstats adaptive', 'geomstats default',
             'nilearn', 'pyriemann'])

for dim in [3, 5, 10, 20, 30, 40, 50, 60]:
    for n_points in [10, 100, 500]:
        dico = {
            'dim': dim, 'n_points': n_points}
        space = SPDMatrices(dim)
        data = space.random_point(n_samples=n_points)

        metric = SPDMetricAffine(dim)
        mean = FrechetMean(
            metric=metric, method='adaptive',
            max_iter=1000, verbose=False, epsilon=1e-12, lr=1.)
        s = time()
        mean.fit(data)
        dico['geomstats adaptive'] = time() - s
        mean_adaptive = mean.estimate_

        mean = FrechetMean(
            metric=metric, method='default', point_type='matrix',
            max_iter=1000, verbose=False, epsilon=1e-12, lr=1.)
        s = time()
        mean.fit(data)
        dico['geomstats default'] = time() - s
        mean_default = mean.estimate_

        s = time()
        nilearn_mean = _geometric_mean(data, max_iter=1000, tol=1e-12)
        dico['nilearn'] = time() - s
        s = time()
        pyriemann_mean = mean_riemann(data, maxiter=1000, tol=1e-12)
        dico['pyriemann'] = time() - s

        df = df.append(dico, ignore_index=True)

index = pd.MultiIndex.from_product(
    [[3, 5, 10, 20, 30, 40, 50, 60, 100], [10, 100, 500]],
    names=['dim', 'n_points'])
df.index = index

In [None]:
df