### Imports

In [None]:
import jax.numpy as jnp
import numpy as np

# load in population-error package
from population_error import error_statistics_from_weights

### Arrays of VT weights and event weights

For each found VT injection $\theta_j$ and each sample from the hyperposterior $\Lambda_n$, `vt_weights` array should be $p(\theta_j | \Lambda_n) / p(\theta_j | {\rm draw})$ weights. The shape is `vt_weights.shape = (Nsamp, Nfound)`

Similarly, for each posterior sample, for the $i^{\rm th}$ event, the $j^{\rm th}$ posterior sample $\theta_{ij}$, `event_weights` should be $p(\theta_{ij} | \Lambda_n) / \pi(\theta_{ij}|{\rm PE})$ are the weights. The shape is `event_weights.shape = (Nsamp, Nobs, NPE)`.

Finally, `total_samples` is the total number of injections, $N_{\rm inj}$.

In [2]:
# just random arrays for simplicity. Often this approach runs into memory errors
vt_weights = jnp.array(np.random.uniform(size=(500, 100_000)))
event_weights = jnp.array(np.random.uniform(size=(500, 100, 1000)))
total_samples = 5e6

### Compute error statistics

In [3]:
statistics = error_statistics_from_weights(vt_weights, event_weights, total_samples, include_likelihood_correction=True)
print(statistics)

Running for 500 iterations: 100%|██████████| 500/500 [00:10<00:00, 45.96it/s]
Running for 500 iterations: 100%|██████████| 500/500 [00:10<00:00, 45.94it/s]


{'error_statistic': 0.14175374994676887, 'precision_statistic': 0.14175374564949028, 'accuracy_statistic': 4.297278581024669e-09}
