In [1]:
import numpy as np
import pandas as pd

import torch

from geomstats.geometry.spd_matrices import SPDMatrices, SPDBuresWassersteinMetric

def bures_distance_matrix2(Sigma_x,Sigma_y):
    Lx, Qx = torch.linalg.eigh(Sigma_x)
    Sigma_x_sqrt = Qx @ torch.diag_embed(torch.sqrt(Lx*(Lx>0))) @ Qx.mH
    
    cross_term = torch.matmul(torch.matmul(Sigma_x_sqrt, Sigma_y.unsqueeze(1)),Sigma_x_sqrt.unsqueeze(0))
    
    N=cross_term.shape[0]
    M=int(1000000/cross_term.shape[1])
    
    for i in np.arange(0,N,M):
        Lc,Qc = torch.linalg.eigh(cross_term[i:(i+M)])
        cross_term[i:i+M] = Qc @ torch.diag_embed(torch.sqrt((Lc*(Lc>0)))) @ Qc.mH

    return torch.einsum('ijkk -> ij', Sigma_x.unsqueeze(0) + Sigma_y.unsqueeze(1) - 2*cross_term).T #the transpose here should be fixed


INFO: Using numpy backend


In [2]:
N= 100

spd = SPDMatrices(3)
spd.equip_with_metric(SPDBuresWassersteinMetric)

Sigma_x = spd.random_point(N)
Sigma_y = spd.random_point(N)



In [3]:
D0 = np.array([[spd.metric.squared_dist(sigma0,sigma1) for sigma0 in Sigma_y] for sigma1 in Sigma_x])

In [4]:
D1 = bures_distance_matrix2(torch.tensor(Sigma_x), torch.tensor(Sigma_y))

In [5]:
torch.max(D1-D0)

tensor(2.6645e-14, dtype=torch.float64)

<font size= "5"> - We see that the largest difference in the 10000 distance calculations is on the order of (e-14)
    
Next, we'll try real data

In [6]:
df = pd.read_pickle("data_files/dtmri_dataframe_11_2.pkl")
df = df[df['labels']!='#NULL!']

In [7]:
mu0 = df.at[0,'Cingulum_Frontal_Parahippocampal_L']
mu1 = df.at[2,'Cingulum_Frontal_Parahippocampal_L']

In [8]:
Sigma_x = spd.projection(mu0.covariances_)
Sigma_y = spd.projection(mu1.covariances_)

In [9]:
D1 = bures_distance_matrix2(torch.tensor(Sigma_x), torch.tensor(Sigma_y))

In [10]:
spd.belongs(Sigma_x).all()
spd.belongs(Sigma_y).all()

True

In [11]:
spd.metric.squared_dist(Sigma_x[0],Sigma_y[0])

array(0.00108953-7.41880324e-12j)

In [12]:
D1[0,0].cpu().numpy()

array(0.00108953)

<font size= "5"> - Get the same (real) answer. Why does their matrix square root have an imaginary value?

Because the matrix is ill-conditioned (I think), and Geomstats uses sqrtm, a more general (and slower) matrix squareroot function than is necessary. We choose instead to use the fact that our matrices are symmetric, and thus orthogonally diagonalizable, and thus
    
$\Sigma = PDP^{-1} \implies \Sigma^{\frac{1}{2}} = PD^{\frac{1}{2}}P^{-1}$, since
    
$PD^{\frac{1}{2}}P^{-1} PD^{\frac{1}{2}}P^{-1} = PDP^{-1}$
    

In [13]:
D0 = np.array([[np.real(spd.metric.squared_dist(sigma0,sigma1)) for sigma0 in Sigma_y] for sigma1 in Sigma_x])



Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to find a square root.
Failed to 

KeyboardInterrupt: 