In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import jit, grad, jacobian, lax

import numpy as np
import matplotlib.pyplot as plt

In [2]:
n = 3

h0 = jnp.array(np.diag(np.sort(np.pi*np.random.uniform(size=n))))
mraw = np.random.normal(size=n**2).reshape((n,n))
m = jnp.array(0.5*(mraw + mraw.T))

print(jnp.mean(jnp.abs(h0 - h0.T)))
print(jnp.mean(jnp.abs(m - m.T)))

0.0
0.0


In [3]:
z = h0 - 0.25*m
evals, evecs = jnp.linalg.eigh(z)
expevals = jnp.exp(evals)

In [4]:
jnp.linalg.norm(z - evecs @ jnp.diag(evals) @ evecs.conj().T)

Array(2.28527812e-15, dtype=float64)

In [5]:
jnp.linalg.norm(evecs @ evecs.conj().T - jnp.eye(n))

Array(1.27859491e-15, dtype=float64)

In [6]:
a = evecs.conj().T @ m @ evecs

In [7]:
evals1, evals2 = jnp.meshgrid(evals, evals, indexing='ij')
mask = jnp.ones((n,n)) - jnp.eye(n)
expevals1, expevals2 = jnp.meshgrid(expevals, expevals, indexing='ij')

In [8]:
diagterm1 = expevals1*jnp.diag(jnp.diag(a*a))
print(diagterm1)

[[ 1.21999498  0.          0.        ]
 [ 0.          1.80893493  0.        ]
 [ 0.          0.         15.95829431]]


In [9]:
testmat3 = np.zeros((n,n))
for i in range(3):
    for k in range(3):
        if i==k:
            testmat3[i,k] = 0.5*expevals[i]*(a[i,i]*a[i,k])*2

print(jnp.linalg.norm(diagterm1-testmat3))

0.0


In [10]:
numer1 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
denom1 = (evals1-evals2)**2 * mask + jnp.eye(n)
frac1 = numer1/denom1 * mask
term1 = frac1*2*(jnp.diag(a)*a).T

In [12]:
testmat = np.zeros((n,n))
testmat2 = np.zeros((n,n))
testmat3 = np.zeros((n,n))
for i in range(3):
    for k in range(3):
        if i != k:
            testmat[i,k] = (-expevals[i]+evals[i]*expevals[i]-evals[k]*expevals[i]+expevals[k])/(evals[i]-evals[k])**2
        testmat2[i,k] = a[i,i]*a[i,k] + a[i,i]*a[i,k]
        testmat3[i,k] = testmat[i,k] * testmat2[i,k]
print(jnp.linalg.norm(term1-testmat3))

4.577566798522237e-16


In [13]:
numer2 = -expevals1 + evals1*expevals2 - evals2*expevals2 + expevals2
denom2 = (evals1-evals2)**2 * mask + jnp.eye(n)
frac2 = numer2/denom2 * mask
term2 = frac2*2*(a*jnp.diag(a))

In [14]:
testmat = np.zeros((n,n))
testmat2 = np.zeros((n,n))
testmat3 = np.zeros((n,n))
for i in range(3):
    for k in range(3):
        if i != k:
            testmat[i,k] = (-expevals[i]+evals[i]*expevals[k]-evals[k]*expevals[k]+expevals[k])/(evals[i]-evals[k])**2
        testmat2[i,k] = a[i,k]*a[k,k] + a[i,k]*a[k,k]
        testmat3[i,k] = testmat[i,k] * testmat2[i,k]
print(jnp.linalg.norm(term2-testmat3))

0.0


In [15]:
paren = 2*jnp.einsum('ij,jk->ijk',a,a)

In [47]:
evals1, evals2, evals3 = jnp.meshgrid(evals, evals, evals, indexing='ij')
expevals1, expevals2, expevals3 = jnp.meshgrid(expevals, expevals, expevals, indexing='ij')

numer3 = evals1*(expevals2-expevals3)-evals2*(expevals1-expevals3)+evals3*(expevals1-expevals2)

mask = np.zeros((3,3,3))
for i in range(3):
    for j in range(3):
        for k in range(3):
            if i!=j and i!=k and j!=k:
                mask[i,j,k] = 1

mask = jnp.array(mask)
# print(mask)

denom3 = (evals1-evals2)*(evals1-evals3)*(evals2-evals3) + jnp.ones((3,3,3)) - mask

term3 = jnp.sum( (numer3/denom3 * mask) * paren, axis=1 )
print(term3)

[[ 0.         -0.50614114 -0.94470172]
 [-0.50614114  0.         -0.34890753]
 [-0.94470172 -0.34890753  0.        ]]


In [60]:
evals1, evals2 = jnp.meshgrid(evals, evals, indexing='ij')
expevals1, expevals2 = jnp.meshgrid(expevals, expevals, indexing='ij')
mask = jnp.ones((n,n)) - jnp.eye(n)

matij = mask*(1.0/((evals1-evals2) + jnp.eye(n)))
matind1 = expevals1 * matij
matind2 = expevals2 * matij

testmatL = np.zeros((n,n))
for i in range(n):
    for j in range(n):
        if i != j:
            testmatL[i,j] = a[i,j]*expevals[i]/(evals[i]-evals[j])

# print(testmatL)
# print(matind2*a)

testmatR = np.zeros((n,n))
for j in range(n):
    for k in range(n):
        if j != k:
            testmatR[j,k] = a[j,k]

# print(testmatR)
# print(matij*a)

testmat = np.zeros((n,n))
for i in range(n):
    for k in range(n):
        for j in range(n):
            if i != k and j != i and j != k:
                testmat[i,k] += (testmatL[i,j]*testmatR[j,k])/(evals[i]-evals[k])

# print(testmat)
# print(matij*( (matind1*a) @ (mask*a) ))


term3new = 2*mask*((matind2*a) @ (matij*a))
term3new -= 2*matij*( (matind1*a) @ (mask*a) )
term3new -= 2*matij*( (mask*a) @ (matind2*a) ) 

print(term3new)
print(term3)


[[ 0.         -0.50614114 -0.94470172]
 [-0.50614114  0.         -0.34890753]
 [-0.94470172 -0.34890753  0.        ]]
[[ 0.         -0.50614114 -0.94470172]
 [-0.50614114  0.         -0.34890753]
 [-0.94470172 -0.34890753  0.        ]]


In [17]:
testmat3 = np.zeros((3,3))
for i in range(3):
    for j in range(3):
        for k in range(3):
            if i!=j and i!=k and j!=k:
                fractermN = evals[i]*(expevals[j]-expevals[k])-evals[j]*(expevals[i]-expevals[k])+evals[k]*(expevals[i]-expevals[j])
                fractermD = (evals[i]-evals[j])*(evals[i]-evals[k])*(evals[j]-evals[k])
                fracterm = fractermN/fractermD
                parenterm = a[i,j]*a[j,k]*2
                testmat3[i,k] += fracterm*parenterm

In [18]:
print(jnp.linalg.norm(term3-testmat3))

1.5700924586837752e-16


In [19]:
newmask = np.zeros((3,3,3))
for i in range(3):
    for j in range(3):
        for k in range(3):
            if i!=j and i==k:
                newmask[i,j,k] = 1

diagnumer2 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
diagdenom2 = (evals1-evals2)**2 + jnp.ones((3,3,3)) - newmask
diagfrac2 = (diagnumer2/diagdenom2)*newmask
diagterm2 = jnp.sum( diagfrac2*paren, axis=1 )
print(diagterm2)

[[0.41328424 0.         0.        ]
 [0.         2.28269261 0.        ]
 [0.         0.         3.03765877]]


In [72]:
numer1 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
np.eye(n) * 2*((numer1/denom1 * a) @ ( a))

Array([[ 0.41328424,  0.        , -0.        ],
       [-0.        ,  2.28269261, -0.        ],
       [ 0.        ,  0.        ,  3.03765877]], dtype=float64)

In [20]:
testmat3 = np.zeros((3,3))
for i in range(3):
    for j in range(3):
        for k in range(3):
            if i!=j and i==k:
                fractermN = -expevals[i] + evals[i]*expevals[i] - evals[j]*expevals[i] + expevals[j]
                fractermD = (evals[i]-evals[j])**2
                fracterm = fractermN/fractermD
                parenterm = 2*a[i,j]*a[j,k]
                testmat3[i,k] += fracterm*parenterm

In [21]:
print(jnp.linalg.norm(diagterm2-testmat3))

0.0


In [22]:
r = evecs @ (term1 - term2 - term3 + diagterm1 + diagterm2) @ evecs.conj().T
print(r)

[[ 1.62698398  0.62982758 -1.28962078]
 [ 0.62982758  5.41068761 -5.11032187]
 [-1.28962078 -5.11032187 17.68318824]]
