In [1]:
import numpy as np
import pandas as pd

import torch

from geomstats.geometry.spd_matrices import SPDMatrices, SPDBuresWassersteinMetric

spd = SPDMatrices(3)
spd.equip_with_metric(SPDBuresWassersteinMetric)


INFO: Using numpy backend


<font size=5> Bures Metric
    
<font size=5> $\mathcal{B}(\Sigma_x, \Sigma_y) = tr\bigg(\Sigma_x + \Sigma_y - 2\big(\Sigma_x^{\frac{1}{2}} \Sigma_y \Sigma_x^{\frac{1}{2}}\big)^{\frac{1}{2}}\bigg)$
    
Most other methods use $sqrtm$ function to calculate the matrix square roots. This is slow, and produces imaginary numbers if the covariance matrix is not full rank. Instead, we use the fact that the symmetric matrices are orthogonally parallelizable to calculate faster matrix square roots that are always real. 

In [2]:
def my_squared_bures(sigma_x,sigma_y):
    
    Lx, Qx = np.linalg.eigh(sigma_x)
    sigma_x_sqrt = Qx @ np.diag(np.sqrt(Lx*(Lx>0))) @ Qx.T

    Lc,Qc = np.linalg.eigh(sigma_x_sqrt@sigma_y@sigma_x_sqrt)
    cross_term = Qc @ np.diag(np.sqrt((Lc*(Lc>0)))) @ Qc.T

    return np.trace(sigma_x + sigma_y - 2*cross_term)


In [3]:
errors=[]

for i in range(1000):
    sigma_x = spd.random_point()
    sigma_y = spd.random_point()
    
    errors.append(my_squared_bures(sigma_x,sigma_y)-spd.metric.squared_dist(sigma_x,sigma_y))
    
max(errors)

3.419486915845482e-14

In [4]:
%%timeit
my_squared_bures(sigma_x,sigma_y)

80.7 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
%%timeit
spd.metric.squared_dist(sigma_x,sigma_y)

283 µs ± 433 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
def my_squared_bures_torch(sigma_x,sigma_y):
    
    Lx, Qx = torch.linalg.eigh(sigma_x)
    sigma_x_sqrt = Qx @ torch.diag_embed(torch.sqrt(Lx*(Lx>0))) @ Qx.mH

    Lc,Qc = torch.linalg.eigh(sigma_x_sqrt@sigma_y.unsqueeze(1)@sigma_x_sqrt.unsqueeze(0))
    cross_term = Qc @ torch.diag_embed(torch.sqrt((Lc*(Lc>0)))) @ Qc.mH

    return torch.einsum('ijkk -> ij', sigma_x.unsqueeze(0) + sigma_y.unsqueeze(1) - 2*cross_term).T

In [7]:
N=2000

sigma_x = torch.tensor(spd.random_point(N))
sigma_y = torch.tensor(spd.random_point(N))

sigma_x2 = sigma_x.to('cuda')
sigma_y2 = sigma_y.to('cuda')

In [8]:
%%timeit
my_squared_bures_torch(sigma_x,sigma_y)

6.75 s ± 75.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
def bures_distance_matrix2(Sigma_x,Sigma_y):
    Lx, Qx = torch.linalg.eigh(Sigma_x)
    Sigma_x_sqrt = Qx @ torch.diag_embed(torch.sqrt(Lx*(Lx>0))) @ Qx.mH
    
    cross_term = Sigma_x_sqrt @ Sigma_y.unsqueeze(1) @ Sigma_x_sqrt.unsqueeze(0)
    
    #Split data up for gpu memory
    N=cross_term.shape[0]
    M=int(1000000/cross_term.shape[1])
    
    for i in np.arange(0,N,M):
        Lc,Qc = torch.linalg.eigh(cross_term[i:(i+M)])
        cross_term[i:i+M] = Qc @ torch.diag_embed(torch.sqrt((Lc*(Lc>0)))) @ Qc.mH #This is actually the square root of the cross term - but because it is such a large array, we just write over the original variable to save space
        
    return torch.einsum('ijkk -> ij', Sigma_x.unsqueeze(0) + Sigma_y.unsqueeze(1) - 2*cross_term).T #the transpose here should be fixed

In [10]:
%%timeit
bures_distance_matrix2(sigma_x2,sigma_y2)

3.09 s ± 38.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
