In [27]:
import numpy as np
from numpy.linalg import inv, det
from numpy import trace, abs, sqrt, zeros, pi, diag

import torch as th

In [23]:
def vech(A):
    count = 0
    c = A.shape[0]
    v = np.zeros(c * (c + 1) // 2,)
    for j in range(c):
        for i in range(j,c):
            v[count] = A[i,j]
            count += 1
    return v

def uvech(v,n):
    count = 0
    L = np.zeros((n,n))
    for j in range(n):
        for i in range(j,n):
            L[i,j]=v[count]
            count += 1
    return L

In [24]:
def py_matel(n, vechLk, vechLl, Sym):
    
    # initialize arrays
    Lk=zeros((n,n));
    Ll=zeros((n,n));
    Ak=zeros((n,n));
    Al=zeros((n,n));
    Akl=zeros((n,n));
    invAkl=zeros((n,n));
    invAk=zeros((n,n));
    invAl=zeros((n,n));
    

    Lk = uvech(vechLk,n);
    Ll = uvech(vechLl,n);

    
    # apply symmetry projection on Ll
    
    PLl = Sym.T @ Ll;
    
    # build Ak, Al, Akl, invAkl, invAk, invAl

    Ak = Lk@Lk.T;
    Al = PLl@PLl.T;
    Akl = Ak+Al;
    
    invAkl = inv(Akl);
    invAk  = inv(Ak);
    invAl  = inv(Al);
    
    # Overlap: (normalized)
    skl = 2**(3*n/2) * sqrt( (abs(det(Lk))*abs(det(Ll))/det(Akl) )**3 );
    
    #gradient with respect to vechLk
    dsk = vech( 3/2 * skl * (diag(1/diag(Lk)) - 2*invAkl@Lk) )
    dsl = vech( 3/2 * skl * (diag(1/diag(Ll)) - 2*Sym@invAkl@PLl) )

    
    return {'skl':skl, 'dsk':dsk, 'dsl':dsl}


In [25]:
def test_matel():
    n = 3;
    vechLk = np.array([  1.00000039208682, 
              0.02548044275764261, 
              0.3525161612610669,
              1.6669144815242515,
              0.9630555318946559,
              1.8382882034659822 ]);
    
    vechLl = np.array([  1.3353550436464964,
               0.9153272033682132,
               0.7958636766525028,
               1.8326931436447955,
               0.3450426931160630,
               1.8711839323167831 ]);
    
    Sym = np.array([[0,0,1],
                    [0,1,0],
                    [1,0,0]]);
    
    
    matels = py_matel(n, vechLk, vechLl, Sym)
    
    print('skl: ',matels['skl'])
    print('dsk: ',matels['dsk'])
    print('dsl: ',matels['dsl'])
    

In [26]:
test_matel()

skl:  0.5333557299037264
dsk:  [ 0.48982708  0.07856709 -0.05598214  0.11792033 -0.11130223 -0.16323205]
dsl:  [ 0.31975945 -0.06663576 -0.14954734 -0.07506522 -0.0351545  -0.19168515]
