In [1]:
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/cluster/shared/software/libs/cuda/11a"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import spax

# Fisher information
Let $p(\mathbf{d} \in D | \boldsymbol{\theta})$ be a pdf parameterized with $\boldsymbol{\theta}$. Fisher matrix is then defined as:
$$\mathcal{F}_{ij} = \operatorname{E}\left[\left.\left(\frac{\partial}{\partial\theta_i} \log p(D|\boldsymbol\theta)\right)\left(\frac{\partial}{\partial\theta_j} \log p(D|\boldsymbol\theta)\right)\right|\boldsymbol\theta\right] = -\operatorname{E}\left[\left.\frac{\partial^2}{\partial\theta_i\, \partial\theta_j} \log p(D|\boldsymbol\theta)\right|\boldsymbol\theta\right]\,.$$

## Gaussian-distributed data
In the simplest case, we can assume multivariate gaussian as the underlying pdf, $\log p(\mathbf{d}|\boldsymbol\theta) = - 1/2\,  (\mathbf{d} - \boldsymbol\mu(\boldsymbol\theta))^\textsf{T} \, \Sigma^{-1} \, (\mathbf{d} - \boldsymbol\mu(\boldsymbol\theta))$, where we assumed that covariance matrix doesn't depend on $\boldsymbol\theta$.

Now from equation above it follows: $$ \mathcal{F}_{ij} = \frac{\partial\boldsymbol\mu^\textsf{T}}{\partial\theta_i}\Sigma^{-1}\frac{\partial\boldsymbol\mu}{\partial\theta_j}$$
If $X$ is $n \times m$ matrix, where $n$ represents dimensionality of the data and $m$ number of samples, then covariance matrix is simply $$\Sigma = \frac{1}{m-1} X \, X^T \, ,$$
with the assumption that the mean was removed, i.e. `X -= X.mean(axis = 1)`.

On the other hand, if $X_i^{+}$ and $X_i^{-}$ are $n \times m'$ matrices representing a set of data points with distance $\Delta\theta_i$ apart, then:
$$\frac{\partial\boldsymbol\mu}{\partial\theta_i} \approx \frac{1}{m'}\sum_k \frac{\mathbf{d}_{ki}^{+} - \mathbf{d}_{ki}^{-}}{\Delta\theta_i} \, ,$$
where $\mathbf{d}_{ki}^{+, -}$ are columns of $X_i^{+, -}$.

## Need for compression
In the case $n \ge m$, covariance matrix is non invertible and some data compression is needed. Here we are implementing a simple PCA compression of order $N$: $ \widetilde{X}_N = U_N^T \, X$, where $\sim$ denotes compressed space and $U_N$ is $n \times N$ rotation matrix.

Covariance matrix can now be written as
$$ \widetilde{\Sigma}_N = \frac{1}{m-1} \widetilde{X}_N \, \widetilde{X}_N^T = \frac{1}{m-1} U_N^T \, X \, X^T U_N = \sigma_N^2 \, ,$$
where $\sigma_N^2$ is diagonal matrix containing first $N$ principal components of the covariance matrix.

If we denote $J \equiv \partial\boldsymbol\mu / \partial\boldsymbol\theta$, one can easily show $\widetilde{J}_N = U_N^T J$. Therefore, if we start with $F = J^T \Sigma^{-1} J$, after PCA compression one has:
$$ F_N = J^T \, U_N  \, \sigma_N^{-2} \, U_N^T J  \, .$$

### Extras
$$ X_N = U_N \, U_N^T \, X$$
$$ \Sigma_N \equiv \frac{1}{m-1} \, X_N \, X_N^T = \frac{1}{m-1} \, U_N \, U_N^T \, X \, X^T \, U_N \, U_N^T = U_N \, \sigma_N^2 \, U_N^T $$
$$ \Sigma_N^{-1} \equiv U_N \, \sigma_N^{-2} \, U_N^T $$
With this in mind and the fact that $U_N^T \, U_N = I$, we can see that 
$$ \Sigma_N^{-1} \Sigma_N \Sigma_N^{-1} = \Sigma_N^{-1} \, ,$$ 
$$ \Sigma_N \Sigma_N^{-1} \Sigma_N = \Sigma_N  \, ,$$ 
i.e. it represents Moore-Penrose inverse. Moreover,
$$ F_N = J^T \, \Sigma_N^{-1} \, J \, .$$

## Testing the code

In [5]:
#constructing the data, here a 16-dim multivariate gaussian
N_dim, N_samples = 16, 10000
random_sample = np.random.normal(0, 1, size = (N_dim, N_samples))
sigmas = np.arange(1, N_dim + 1)[:, np.newaxis]
data =  random_sample * sigmas
δθ = sigmas * 0.01
derivative = np.empty((N_dim, 2, N_dim, N_samples))
derivative[:, 0, ...], derivative[:, 1, ...] = random_sample * (sigmas - δθ / 2), random_sample * (sigmas + δθ / 2)

Fisher = spax.Fisher()
Fisher.fit(data, derivative, δθ, batch_size = 4)
F = []
for n in range(1, N_dim):
    F.append(Fisher.compute(N = n))
print(F)

[DeviceArray(-0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(-0., dtype=float32), DeviceArray(-0., dtype=float32)]
