# Autodiff testing

Formulas:

$$
\Gamma^\beta_{\mu \nu} = \frac{1}{2} g^{\beta \alpha} 
\left(\frac{\partial g_{\mu \alpha}}{\partial x^\nu} +
\frac{\partial g_{\alpha \nu}}{\partial x^\mu} -
\frac{\partial g_{\mu \nu}}{\partial x^\alpha}\right)
$$

In [1]:
import numpy as np

In [2]:
import jax

In [97]:
class Metric:
    """
    Utility class for storing a 4D
    spacetime metric as a
    matrix of functions
    """
    
    def __init__(self, name):
        self.name = name
        self.components = [[lambda x0, x1, x2, x3: 0.0 for i in range(4)] for j in range(4)]
        self.derivatives = [[[lambda x0, x1, x2, x3: 0.0 for i in range(4)] for j in range(4)] for k in range(4)]
    
    # Note: all added components must
    # have functions that take in floats
    # and return floats
    def add_component(self, i, j, func):
        # Add the component's value
        self.components[i][j] = func
        # Add the component's derivative
        for k in range(4):
            self.derivatives[i][j][k] = jax.grad(func, k)
    
    # Returns a function corresponding
    # to the g_ij element of the metric
    def component(self, i, j):
        return self.components[i][j]
    
    # Returns a function corresponding
    # to the derivative of the g_ij element
    # of the metric with respect to x^k
    def diff(self, i, j, k):
        return self.derivatives[i][j][k]
    
    def __repr__(self):
        return f"Metric({self.name})"
    
    def __str__(self):
        return f"Metric({self.name})"

In [98]:
class InverseMetric:
    """
    Utility class for storing a 4D
    inverse spacetime metric as a
    matrix of functions
    """
    
    def __init__(self, name):
        self.name = name
        self.components = [[lambda x0, x1, x2, x3: 0 for i in range(4)] for j in range(4)]
    
    # Note: all added components must
    # have functions that take in floats
    # and return floats
    def add_component(self, i, j, func):
        self.components[i][j] = func
    
    # Returns a function corresponding
    # to the g^ij element of the metric
    def component(self, i, j):
        return self.components[i][j]
    
    def __repr__(self):
        return f"InverseMetric({self.name})"
    
    def __str__(self):
        return f"InverseMetric({self.name})"

In [231]:
class Christoffel:
    """
    Utility class for storing the Christoffel
    symbols of a given spacetime metric
    """
    
    def __init__(self, m, inv_m, debug=False):
        if m.name != inv_m.name:
            raise ValueError(f"Your metric ({metric.name}) and inverse metric \
                    ({inverse_metric.name}) do not match.")
        self.name = m.name + "_christoffels"
        self.metric = m
        self.inverse_metric = inv_m
    
    def component(self, beta, mu, nu):
        inv_m = self.inverse_metric
        m = self.metric
        # We have to explicitly define each lambda function as functions of (x0, x1, x2, x3)
        # in order to pass Python's syntax checking
        symbol = lambda alpha, x0, x1, x2, x3: (1/2) * inv_m.component(beta, alpha)(x0, x1, x2, x3) \
                             * (m.diff(mu, alpha, nu)(x0, x1, x2, x3) \
                                 + m.diff(alpha, nu, mu)(x0, x1, x2, x3) \
                                 - m.diff(mu, nu, alpha)(x0, x1, x2, x3))
        # We sum over alpha in accordance with the Einstein summation convention
        return lambda x0, x1, x2, x3: sum(symbol(alpha, x0, x1, x2, x3) for alpha in range(4))
    
    def matrix(self, x0, x1, x2, x3):
        Gamma = np.zeros((4, 4, 4))
        for i in range(4):
            for j in range(4):
                for k in range(4):
                    Gamma[i][j][k] += self.component(i, j, k)(x0, x1, x2, x3)
        return Gamma
    
    # TODO: add derivatives of Christoffel symbols later for Ricci/Riemann tensor

## Testing on Minkowski spherical metric

In [80]:
mink = Metric("minkowski_spherical")

In [81]:
mink

Metric(name=minkowski_spherical)

In [82]:
mink.add_component(0, 0, lambda t, r, theta, phi: -1.0)

In [83]:
mink.add_component(1, 1, lambda t, r, theta, phi: 1.0)

In [84]:
mink.add_component(2, 2, lambda t, r, theta, phi: r ** 2)

In [85]:
mink.add_component(3, 3, lambda t, r, theta, phi: r ** 2 * (jax.numpy.sin(theta)) ** 2)

In [86]:
mink_inv = InverseMetric("minkowski_spherical")

In [91]:
mink_inv

InverseMetric(name=minkowski_spherical)

In [92]:
mink_inv.add_component(0, 0, lambda t, r, theta, phi: -1.0)

In [93]:
mink_inv.add_component(1, 1, lambda t, r, theta, phi: 1.0)

In [94]:
mink_inv.add_component(2, 2, lambda t, r, theta, phi: 1.0 / (r ** 2))

In [95]:
mink_inv.add_component(3, 3, lambda t, r, theta, phi: 1.0 / (r ** 2 * (jax.numpy.sin(theta)) ** 2))

Test:

$$
\frac{\partial}{\partial r}r^2 \sin^2 \theta  = 2r \sin^2 \theta \, \bigg |_{(0, 1, \pi / 2, 0)} = 2
$$

In [74]:
t0 = 0.0
r0 = 1.0
theta0 = np.pi / 2
phi0 = 0.0

In [159]:
mink.diff(3, 3, 1)(t0, r0, theta0, phi0) # should be 2

Array(2., dtype=float32, weak_type=True)

In [76]:
mink.diff(0, 0, 1)(t0, r0, theta0, phi0) # should be 0

Array(0., dtype=float32, weak_type=True)

In [158]:
mink.diff(3, 3, 2)(t0, r0, theta0, phi0) # should be 0 (or basically 0)

Array(-8.742278e-08, dtype=float32, weak_type=True)

Test christoffel symbols:

In [232]:
mink_ch = Christoffel(mink, mink_inv, debug=True)

In [233]:
mink_ch.component(1, 2, 2)(t0, r0, theta0, phi0) # should be -1

Array(-1., dtype=float32, weak_type=True)

In [237]:
mink_ch.component(3, 1, 3)(t0, r0, theta0, phi0) # should be 1

Array(1., dtype=float32, weak_type=True)