In [3]:
import numpy as np
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

#import tinygp
#from tinygp import GaussianProcess
from tinygp import kernels

import json
import pickle

import ultranest

from scripts import numpy_compile, jax_compile

In [6]:
with open('pitchfork/pitchfork.json', 'r') as fp:
    pitchfork_dict = json.load(fp)

pitchfork = numpy_compile(pitchfork_dict)

jpitchfork = jax_compile(pitchfork_dict)

with open('pitchfork/pitchfork_info.json', 'r') as fp:
    pitchfork_info = json.load(fp)

pitchfork_cov = np.loadtxt('pitchfork/pitchfork_covariance.txt')

In [7]:
pca_comps = np.array(pitchfork_info['custom_objects']['inverse_pca']['pca_comps'])
pca_mean = np.array(pitchfork_info['custom_objects']['inverse_pca']['pca_mean'])

In [8]:
preds = pitchfork.forward_pass(np.array([[0.5,0.5,0.5,0.5,0.5]]))
preds

[array([[0.45746291, 0.55022404, 0.37583447]]),
 array([[ 2.42869059e+00, -2.09924402e-02, -2.91607387e-02,
         -6.30696148e-03, -1.37554752e-03,  1.22286922e-03,
          4.84569803e-03,  2.07647610e-03,  2.68109677e-03,
          9.62887053e-04, -6.12263785e-04,  1.05727241e-03,
          1.32305702e-03, -1.92868063e-03, -1.74395222e-03]])]

In [10]:
def predict(x, emulator, n_min=6, n_max=40, jaxxed=False):
    if jaxxed:
        ## def constants
        L_sun = 3.828e+26
        R_sun = 6.957e+8
        SB_sigma = 5.670374419e-8

        log_inputs_mean = jnp.array(pitchfork_info["data_scaling"]["inp_mean"][0])
        
        log_inputs_std = jnp.array(pitchfork_info["data_scaling"]["inp_std"][0])

        log_outputs_mean = jnp.array(pitchfork_info["data_scaling"]["classical_out_mean"][0] + pitchfork_info["data_scaling"]["astero_out_mean"][0])
        
        log_outputs_std = jnp.array(pitchfork_info["data_scaling"]["classical_out_std"][0] + pitchfork_info["data_scaling"]["astero_out_std"][0])
        
        pca_comps = jnp.array(pitchfork_info['custom_objects']['inverse_pca']['pca_comps'])
        
        pca_mean = jnp.array(pitchfork_info['custom_objects']['inverse_pca']['pca_mean'])        
            
        log_inputs = jnp.log10(x)
        
        standardised_log_inputs = (log_inputs - log_inputs_mean)/log_inputs_std

        preds = jpitchfork.forward_pass(standardised_log_inputs)

        pca_preds = jnp.tensordot(preds[1], pca_comps, 1) + pca_mean
        
        standardised_log_outputs = jnp.concatenate((jnp.array(preds[0]), pca_preds), axis=1)

        log_outputs = (standardised_log_outputs*log_outputs_std) + log_outputs_mean

        outputs = 10**log_outputs

        outputs = outputs.at[:,2].set(log_outputs[:,2]) ##we want star_feh in dex

        teff = jnp.array(((outputs[:,1]*L_sun) / (4*jnp.pi*SB_sigma*((outputs[:,0]*R_sun)**2)))**0.25)
        
        outputs = outputs.at[:,0].set(teff)
        
        outputs = jnp.concatenate((outputs[:,:3], outputs[:,n_min-3:n_max-2]), axis=1)

        return outputs
        
    else:
        ## def constants
        L_sun = 3.828e+26
        R_sun = 6.957e+8
        SB_sigma = 5.670374419e-8
        
        log_inputs_mean = np.array(pitchfork_info["data_scaling"]["inp_mean"][0])
        
        log_inputs_std = np.array(pitchfork_info["data_scaling"]["inp_std"][0])

        log_outputs_mean = np.array(pitchfork_info["data_scaling"]["classical_out_mean"][0] + pitchfork_info["data_scaling"]["astero_out_mean"][0])
        
        log_outputs_std = np.array(pitchfork_info["data_scaling"]["classical_out_std"][0] + pitchfork_info["data_scaling"]["astero_out_std"][0])

        pca_comps = np.array(pitchfork_info['custom_objects']['inverse_pca']['pca_comps'])
        
        pca_mean = np.array(pitchfork_info['custom_objects']['inverse_pca']['pca_mean'])
            
        log_inputs = np.log10(x)
        
        standardised_log_inputs = (log_inputs - log_inputs_mean)/log_inputs_std

        preds = pitchfork.forward_pass(standardised_log_inputs)

        pca_preds = np.tensordot(preds[1], pca_comps, 1) + pca_mean
        
        standardised_log_outputs = np.concatenate((np.array(preds[0]), pca_preds), axis=1)

        log_outputs = (standardised_log_outputs*log_outputs_std) + log_outputs_mean

        outputs = 10**log_outputs

        outputs[:,2] = log_outputs[:,2] ##we want star_feh in dex

        teff = np.array(((outputs[:,1]*L_sun) / (4*np.pi*SB_sigma*((outputs[:,0]*R_sun)**2)))**0.25)
        
        outputs[:,0] = teff
        
        outputs = np.concatenate((np.array(outputs[:,:3]), np.array(outputs[:,n_min-3:n_max-2])), axis=1)

        return outputs
        

In [11]:
predict(np.array([[0.5,0.5,0.5,0.5,0.5]]), pitchfork)

array([[2.90103052e+03, 6.20954018e-02, 1.68595571e+00, 7.08014276e+03,
        3.19778967e+03, 1.91677524e+03, 3.10554663e+03, 4.67799255e+03,
        3.58019538e+03, 2.92512615e+03, 4.12115449e+03, 6.41977217e+03,
        7.11880497e+03, 6.16948384e+03, 5.66950715e+03, 6.24773686e+03,
        7.53749071e+03, 8.78368790e+03, 9.43642634e+03, 9.44136905e+03,
        9.04100965e+03, 8.58677296e+03, 8.17628446e+03, 7.67193348e+03,
        7.00492316e+03, 6.56804817e+03, 6.88645134e+03, 8.24118060e+03,
        1.04124600e+04, 1.23062569e+04, 1.27352167e+04, 1.18730838e+04,
        1.08532563e+04, 1.04488193e+04, 1.07413303e+04, 1.14497843e+04,
        1.22779177e+04, 1.30732578e+04]])

## checking speed, np

In [12]:
single_point = np.full((1,5), 0.5)
million_points = np.full((1_000_000,5), 0.5)

In [13]:
%%time
wtf_single_pred = predict(single_point, pitchfork)

CPU times: user 973 μs, sys: 0 ns, total: 973 μs
Wall time: 816 μs


In [14]:
%%time
wtf_million_preds = predict(million_points, pitchfork)

CPU times: user 1min 42s, sys: 31.6 s, total: 2min 14s
Wall time: 11 s


## checking speed, jax

In [19]:
def jit_predict(x, emulator, n_min=6, n_max=40):
    ## def constants
    L_sun = 3.828e+26
    R_sun = 6.957e+8
    SB_sigma = 5.670374419e-8

    log_inputs_mean = jnp.array(pitchfork_info["data_scaling"]["inp_mean"][0])
    
    log_inputs_std = jnp.array(pitchfork_info["data_scaling"]["inp_std"][0])

    log_outputs_mean = jnp.array(pitchfork_info["data_scaling"]["classical_out_mean"][0] + pitchfork_info["data_scaling"]["astero_out_mean"][0])
    
    log_outputs_std = jnp.array(pitchfork_info["data_scaling"]["classical_out_std"][0] + pitchfork_info["data_scaling"]["astero_out_std"][0])
    
    pca_comps = jnp.array(pitchfork_info['custom_objects']['inverse_pca']['pca_comps'])
    
    pca_mean = jnp.array(pitchfork_info['custom_objects']['inverse_pca']['pca_mean'])        
        
    log_inputs = jnp.log10(x)
    
    standardised_log_inputs = (log_inputs - log_inputs_mean)/log_inputs_std

    preds = jpitchfork.jit_forward_pass(standardised_log_inputs)

    pca_preds = jnp.tensordot(preds[1], pca_comps, 1) + pca_mean
    
    standardised_log_outputs = jnp.concatenate((jnp.array(preds[0]), pca_preds), axis=1)

    log_outputs = (standardised_log_outputs*log_outputs_std) + log_outputs_mean

    outputs = 10**log_outputs

    outputs = outputs.at[:,2].set(log_outputs[:,2]) ##we want star_feh in dex

    teff = jnp.array(((outputs[:,1]*L_sun) / (4*jnp.pi*SB_sigma*((outputs[:,0]*R_sun)**2)))**0.25)
    
    outputs = outputs.at[:,0].set(teff)
    
    outputs = jnp.concatenate((outputs[:,:3], outputs[:,n_min-3:n_max-2]), axis=1)

    return outputs


jpredict = jax.jit(jit_predict, static_argnums=1, static_argnames =['n_min', 'n_max'])

In [20]:
jsingle_point = jnp.array([[1,0.0154,0.2766,1.92,4.6]])
jmillion_points = jnp.full((1000000,5), 0.5)

In [21]:
jsingle_point

Array([[1.    , 0.0154, 0.2766, 1.92  , 4.6   ]], dtype=float64)

In [22]:
j_single_pred = jpredict(jsingle_point, jpitchfork, n_min=6, n_max=28)
j_million_preds = jpredict(jmillion_points,jpitchfork, n_min=6, n_max=28)

In [24]:
%%time
j_single_pred = jpredict(jsingle_point, jpitchfork, n_min=6, n_max=28)

CPU times: user 535 μs, sys: 164 μs, total: 699 μs
Wall time: 352 μs


In [25]:
%%time
j_million_preds = jpredict(jmillion_points, jpitchfork, n_min=6, n_max=28)

CPU times: user 1.15 ms, sys: 745 μs, total: 1.9 ms
Wall time: 847 μs


## solar test

In [26]:
solar_inputs = jnp.array([[1,0.0154,0.2766,1.92,4.6]])
solar_inputs

Array([[1.    , 0.0154, 0.2766, 1.92  , 4.6   ]], dtype=float64)

In [27]:
jit_predict(solar_inputs, jpitchfork, n_min=6, n_max=28)

Array([[5.78467070e+03, 1.11421245e+00, 1.14405595e-02, 9.02033062e+02,
        1.03637525e+03, 1.17022160e+03, 1.30341926e+03, 1.43371597e+03,
        1.56043806e+03, 1.68555188e+03, 1.81102455e+03, 1.93682656e+03,
        2.06201327e+03, 2.18635236e+03, 2.31051803e+03, 2.43528857e+03,
        2.56082442e+03, 2.68677400e+03, 2.81288396e+03, 2.93916225e+03,
        3.06567519e+03, 3.19257919e+03, 3.31996822e+03, 3.44761916e+03,
        3.57521420e+03, 3.70282997e+03]], dtype=float64)

In [28]:
jpredict(solar_inputs, jpitchfork)

Array([[5.78467070e+03, 1.11421245e+00, 1.14405595e-02, 9.02033062e+02,
        1.03637525e+03, 1.17022160e+03, 1.30341926e+03, 1.43371597e+03,
        1.56043806e+03, 1.68555188e+03, 1.81102455e+03, 1.93682656e+03,
        2.06201327e+03, 2.18635236e+03, 2.31051803e+03, 2.43528857e+03,
        2.56082442e+03, 2.68677400e+03, 2.81288396e+03, 2.93916225e+03,
        3.06567519e+03, 3.19257919e+03, 3.31996822e+03, 3.44761916e+03,
        3.57521420e+03, 3.70282997e+03, 3.83077309e+03, 3.95890710e+03,
        4.08657067e+03, 4.21312110e+03, 4.33835058e+03, 4.46238924e+03,
        4.58510386e+03, 4.70585252e+03, 4.82399103e+03, 4.93984082e+03,
        5.05499279e+03, 5.17144098e+03]], dtype=float64)

## nice! with n_min and n_max working too. now to get running inference...