In [1]:
import jax 
import jax.numpy as jnp

In [2]:
minkowski_metric = jnp.diag(jnp.array([-1, 1, 1, 1]))



In [3]:
def metric(coords, met_type='minkowski'):
    if met_type == 'minkowski':
        return minkowski_metric
    if met_type == 'sphere':
        return jnp.diag(jnp.array([1, jnp.sin(coords[0])**2]))
    return minkowski_metric

In [4]:
pd_metric = jax.jacfwd(metric) # derivative on the last dimension

In [5]:
metric(jnp.array([0., 0, 0, 0]))

DeviceArray([[-1,  0,  0,  0],
             [ 0,  1,  0,  0],
             [ 0,  0,  1,  0],
             [ 0,  0,  0,  1]], dtype=int32)

In [6]:
metric(jnp.array([jnp.pi/3, 0]), met_type='sphere').shape

(2, 2)

In [7]:
pd_metric(jnp.array([0., 0, 0, 0, 0]), met_type='sphere').shape

(2, 2, 5)

In [8]:
def christoffel(coords, met_type='minkowski'):
    met = metric(coords, met_type=met_type)
    inv_met = jnp.linalg.inv(met)
    partial_derivs = jnp.einsum('mns -> smn', pd_metric(coords, met_type=met_type))
    sum_partial_derivs = partial_derivs + jnp.einsum('nrm -> mnr', partial_derivs) - jnp.einsum('rmn -> mnr', partial_derivs)
    christ = 0.5 * jnp.einsum('sr, mnr -> smn', inv_met, sum_partial_derivs)
    return christ

In [9]:
pd_christoffel = jax.jacfwd(christoffel)

In [10]:
christoffel(jnp.array([0., 0, 0, 0])).shape

(4, 4, 4)

In [11]:
christoffel(jnp.array([jnp.pi/3, 0]), met_type='sphere')[0, 1, 1]

DeviceArray(-0.4330127, dtype=float32)

In [12]:
pd_christoffel(jnp.array([0., 0, 0, 0])).shape

(4, 4, 4, 4)

In [19]:
def riemann_curvature(coords, met_type='minkowski'):
    christ = christoffel(coords, met_type=met_type)
    pd_christ = jnp.einsum('rmns -> srmn', pd_christoffel(coords, met_type=met_type))
    return jnp.einsum('mrns -> rsmn', pd_christ) - jnp.einsum('nrms -> rsmn', pd_christ) + jnp.einsum('rml, lns -> rsmn', christ, christ) - jnp.einsum('rnl, lms -> rsmn', christ, christ)

In [37]:
def ricci_tensor(coords, met_type='minkowski'):
    riemann = riemann_curvature(coords, met_type=met_type)
    return jnp.einsum('rsru -> su', riemann)

def ricci_scalar(coords, met_type='minkowski'):
    return jnp.einsum('mn, mn -> ', jnp.linalg.inv(metric(coords, met_type=met_type)), ricci_tensor(coords, met_type=met_type))

In [38]:
riemann_curvature(jnp.array([0., 0, 0, 0])).shape

(4, 4, 4, 4)

In [39]:
riemann_curvature(jnp.array([jnp.pi/3, 0]), met_type='sphere')[0, 1, 0, 1]

DeviceArray(0.7500001, dtype=float32)

In [40]:
ricci_tensor(jnp.array([jnp.pi/3, 0]), met_type='sphere'), ricci_scalar(jnp.array([jnp.pi/3, 0]), met_type='sphere')

(DeviceArray([[1.0000001, 0.       ],
              [0.       , 0.7500001]], dtype=float32),
 DeviceArray(2.0000002, dtype=float32))