# JAXKnife Tutorial

Installation Instructions:
make a new conda environment with python=3.10
pip install jax
pip install jaxlib

In [7]:
# !pip install jax jaxlib

In [6]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import bootstrap as scp_bootstrap

In [29]:
def bootstrap(data,estimator,n_resamples=1000,seed=0):
    """
    Input: 
    data        (array-like)
    estimator   (function)
    n_samples   (int)
    seed        (int)
    """

    key = jax.random.PRNGKey(seed)

    est_arr = np.zeros(n_resamples)
 
    for i in range(n_resamples):
        key, subkey = jax.random.split(key)
        resample = jax.random.choice(subkey,data,(len(data),),replace=True,axis=0)
        
        est_arr[i] = estimator(resample)

    return est_arr

def com1D(data):

    # assume m=1 for all particles

    xs = data[:,0]
    

    com_x = np.sum(xs) / len(xs)
    

    return com_x

np.random.seed(0)
test_data = np.random.uniform(-1,1,size=(100,3))

print(test_data)

com_arr = bootstrap(test_data,com1D,n_resamples=10,seed=0)

print(f'COM median: {np.median(com_arr)}\nCOM stdev: {np.std(com_arr)}')

[[ 0.09762701  0.43037873  0.20552675]
 [ 0.08976637 -0.1526904   0.29178823]
 [-0.12482558  0.783546    0.92732552]
 [-0.23311696  0.58345008  0.05778984]
 [ 0.13608912  0.85119328 -0.85792788]
 [-0.8257414  -0.95956321  0.66523969]
 [ 0.5563135   0.7400243   0.95723668]
 [ 0.59831713 -0.07704128  0.56105835]
 [-0.76345115  0.27984204 -0.71329343]
 [ 0.88933783  0.04369664 -0.17067612]
 [-0.47088878  0.54846738 -0.08769934]
 [ 0.1368679  -0.9624204   0.23527099]
 [ 0.22419145  0.23386799  0.88749616]
 [ 0.3636406  -0.2809842  -0.12593609]
 [ 0.39526239 -0.87954906  0.33353343]
 [ 0.34127574 -0.57923488 -0.7421474 ]
 [-0.3691433  -0.27257846  0.14039354]
 [-0.12279697  0.97674768 -0.79591038]
 [-0.58224649 -0.67738096  0.30621665]
 [-0.49341679 -0.06737845 -0.51114882]
 [-0.68206083 -0.77924972  0.31265918]
 [-0.7236341  -0.60683528 -0.26254966]
 [ 0.64198646 -0.80579745  0.67588981]
 [-0.80780318  0.95291893 -0.0626976 ]
 [ 0.95352218  0.20969104  0.47852716]
 [-0.92162442 -0.43438607