In [28]:
import os
import sys
import importlib

sys.path.append('../')

import utils
importlib.reload(utils)

import keras_core as keras
import jax
import jax.numpy as jnp
from jax import random, vmap, jit, grad
#assert jax.default_backend() == 'gpu'
import numpy as np
import pandas as pd
import time
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf
#import elegy # pip install elegy. # Trying to do this with keras core instead.
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

import preprocess.preprocess
from preprocess.preprocess import INPUTS, PARAMETERS, PARAMETERS_SPECIFIED, RIGIDITY_VALS
from preprocess.preprocess import transform_input, untransform_input
from preprocess.preprocess import PARAMETERS_MIN, PARAMETERS_MAX
from chi2 import CalculateChi2

# Check new yearly data load

In [134]:
# Need version with correct path to the data
def index_mcmc_runs(file_version):
    """Make a list of combinations for which we want to run MCMC."""
    if file_version == '2023':
        experiments = ['AMS02_H-PRL2021', 'PAMELA_H-ApJ2013', 'PAMELA_H-ApJL2018']
        dfs = []
        for experiment_name in experiments:
            filename = f'../../data/2023/{experiment_name}_heliosphere.dat'
            df = utils.index_experiment_files(filename) 
            df['experiment_name'] = experiment_name
            df['filename_heliosphere'] = filename
            dfs.append(df)
        df = pd.concat(dfs, axis=0, ignore_index=0)

    elif file_version == '2024':
        filename = f'../../data/2024/yearly_heliosphere.dat'
        df = utils.read_experiment_summary(filename)
        df['experiment_name'] = 'yearly'
        df['filename_heliosphere'] = filename

    else: raise ValueError(f"Unknown file_version {file_version}. Must be '2023' or '2024'.")

    return df

In [135]:
SLURM_ARRAY_TASK_ID = 12
SLURM_ARRAY_JOB_ID = 0
DEBUG = True

# Version specifications
model_version = 'v3.0' # v2.0 is MSE NN, v3.0 is MAE NN
hmc_version = 'v25.0'
file_version = '2024'

# Select experiment parameters
df = index_mcmc_runs(file_version=file_version)  # List of all experiments (0-209) for '2023', 0-14 for '2024'
print(f'Found {df.shape[0]} combinations to run MCMC on. Performing MCMC on index {SLURM_ARRAY_TASK_ID}.')
df = df.iloc[SLURM_ARRAY_TASK_ID]

# Define parameters for HMC
seed = SLURM_ARRAY_TASK_ID + SLURM_ARRAY_JOB_ID
penalty = 1e6
integrate = False # If False, Chi2 is interpolated. If True, Chi2 is integrated.
par_equals_perr = False # If True, only 3 parameters will be sampled by the HMC and pwr1par==pwr1perr and pwr2par==pwr2perr
constant_vspoles = False # If True, vspoles is fixed to 400.0. If False, vspoles is specified in the data file.
specified_parameters = utils.get_parameters(df.filename_heliosphere, df.interval, constant_vspoles=constant_vspoles)

# Number of parameters for HMC to sample. 5 normally, 3 if par_equals_perr=True
if par_equals_perr:
    num_params = 3
else:
    num_params = 5

# Load observation data and define logprob. 
if file_version == '2023': 
    data_path = f'../../data/oct2022/{df.experiment_name}/{df.experiment_name}_{df.interval}.dat'  # This data is the same.
elif file_version == '2024': 
    year = 2000 + SLURM_ARRAY_TASK_ID # assumes only negative intervals. If otherwise, fix this
    data_path = f'../../data/2024/yearly/{year}.dat'
else:
    raise ValueError(f"Invalid file_version {file_version}. Must be '2023' or '2024'.")

model_path = f'../../models/model_{model_version}_{df.polarity}.keras'

# Define log probability
target_log_prob = utils.define_log_prob(model_path, data_path, specified_parameters, penalty=penalty, integrate=integrate, par_equals_perr=par_equals_perr)

Found 15 combinations to run MCMC on. Performing MCMC on index 12.


In [136]:
bins, observed, uncertainty = utils.load_data_ams(data_path)

print(f'Bins: {bins}')
print(f'Observed: {observed}')

Bins: [  0.28   0.4    0.43   0.46   0.47   0.49   0.51   0.55   0.57   0.59
   0.63   0.66   0.68   0.71   0.75   0.76   0.79   0.82   0.83   0.87
   0.9    0.9    0.94   0.98   1.     1.     1.04   1.06   1.09   1.14
   1.16   1.16   1.19   1.24   1.26   1.31   1.33   1.37   1.38   1.43
   1.5    1.51   1.51   1.57   1.65   1.66   1.71   1.72   1.8    1.83
   1.89   1.92   1.98   2.02   2.08   2.15   2.16   2.23   2.27   2.38
   2.4    2.49   2.6    2.67   2.73   2.86   2.97   3.     3.13   3.29
   3.29   3.43   3.6    3.64   3.76   3.95   4.02   4.13   4.33   4.43
   4.53   4.88   4.97   5.37   5.45   5.9    5.98   6.47   6.55   7.09
   7.18   7.76   7.87   8.48   8.63   9.26   9.46  10.1   10.38  11.
  11.38  12.47  13.    13.68  15.    16.45  16.6   18.03  19.78  21.68
  22.8   23.77  26.07  28.58  31.34  33.5   34.37  37.68  41.32  45.3
  48.5   69.7  100.  ]
Observed: [5.737101e+01 2.579455e+02 2.868582e+02 3.134762e+02 3.262379e+02
 3.710233e+02 4.302186e+02 4.687991e+02 4.8867

## Check that calculated chi2 matches results

In [137]:
results_dir = '../../../results/' + hmc_version + '/'

# Load results_dir/logprobs_{SLURM_ARRAY_TASK_ID}_yearly_{year}.csv, and find the absolute value max logprob value and it's index
logprobs = pd.read_csv(results_dir + f'logprobs_{SLURM_ARRAY_TASK_ID}_yearly_{year}_neg.csv', header=None)
max_logprob = logprobs.max()
max_logprob_idx = logprobs.idxmax()

print(f'Max logprob: {max_logprob[0]} at index {max_logprob_idx[0]}')

Max logprob: -1557.79248046875 at index 95772


In [138]:
# Get samples from results_dir/samples_{SLURM_ARRAY_TASK_ID}_yearly_{year}_neg.csv
samples = pd.read_csv(results_dir + f'samples_{SLURM_ARRAY_TASK_ID}_yearly_{year}_neg.csv', header=None)

# Get the sample at the max logprob index
max_logprob_sample = samples.iloc[max_logprob_idx[0]]

print(f'Max logprob sample: {max_logprob_sample}')

Max logprob sample: 0    444.447693
1      1.575198
2      1.180917
3      0.400036
4      0.400027
Name: 95772, dtype: float64


In [139]:
xs = []
for i in range(5):
    xs.append(max_logprob_sample[i])

xs = np.array(xs)

print(f'xs: {xs}')

# Transform xs
xs = transform_input(xs)

print(f'Transformed xs: {xs}')

xs: [4.44447693e+02 1.57519841e+00 1.18091655e+00 4.00035560e-01
 4.00027394e-01]
Transformed xs: [4.47334666e-01 9.03998778e-01 6.00705037e-01 1.87158585e-05
 1.44180499e-05]


### Break apart define_log_prob function and check each piece

In [140]:
# Load trained NN model that maps 8 parameters to predicted flux at RIGIDITY_VALS.
model = keras.models.load_model(model_path)
model.run_eagerly = True # Settable attribute (in elegy). Required to be true for ppmodel.

# Load observation data from Claudio
bins, observed, uncertainty = utils.load_data_ams(data_path)
#bin_midpoints = (bins[:-1] + bins[1:])/2  # Arithmetic mean
bin_midpoints = (bins[:-1] * bins[1:]) ** 0.5  # Geometric mean seemed to work better in exp.

# Transform input parameters to be in range 0--1.
parameters_specified_transformed = transform_input(jnp.array(specified_parameters))
print(f'specified_parameters: {specified_parameters}')
print(f'parameters_specified_transformed: {parameters_specified_transformed}')

specified_parameters: (67.47, 5.47, 542.65)
parameters_specified_transformed: [0.7937647  0.4242857  0.47550005]


In [141]:
# If par==perr, then only predicting ['cpa', 'pwr1par', 'pwr2par']. Need to create array of ['cpa', 'pwr1par', 'pwr1par', 'pwr2par', 'pwr2par']
if par_equals_perr:
    xs = jnp.array([xs[0], xs[1], xs[1], xs[2], xs[2]])

# Include logprior in loglikelihood. This keeps HMC from going off into no-mans land.
nlogprior = 0.
for i in range(5):
    nlogprior += penalty * jnp.abs((jnp.minimum(0., xs[i]))) # Penalty for being <0
    nlogprior += penalty * jnp.abs((jnp.maximum(1., xs[i]) - 1.))  # Penalty for being >1

# log_prob = -chi2/2.  - nlogprior, so using nlogprior find the chi2 value associated with the max logprob
chi2_max_logprob = -2. * (max_logprob[0] + nlogprior)
print(f'Max logprob Chi2: {chi2_max_logprob}')

Max logprob Chi2: 3115.5849609375


In [142]:
# Batch parameters and predict with model
batch = utils._form_batch(xs, parameters_specified_transformed)
yhat = model(batch)    
yhat = yhat[0,:]  # Remove batch dimension.
yhat = yhat.numpy()

# Prepare yhat and rigidity for integration/interpolation
yhat = utils.untransform_output(yhat.reshape((1,-1))).reshape(-1) # Undo scaling and minmax.
log_yhat = jnp.log(yhat)
log_rigidity = jnp.log(jnp.array(RIGIDITY_VALS))

# Compute chi2
chi2 = CalculateChi2(log_rigidity, log_yhat)
yhat_interp_integrated = []

if integrate:
    # Compute integral for each bin
    for x1, x2 in zip(bins[:-1], bins[1:]):
        integral = chi2.compute_integral(x1, x2) / (x2 - x1)
        yhat_interp_integrated.append(integral)
else:
    # Compute interpolated value for each bin
    for x in bin_midpoints:
        interpolated = chi2.interpolate_model(x)
        yhat_interp_integrated.append(interpolated)

# Compute log prob
chi2 = (((jnp.asarray(yhat_interp_integrated) - observed)/uncertainty)**2)
cum_chi2 = jnp.sum(chi2)
log_prob = -cum_chi2/2.  - nlogprior

In [143]:
xs_untransformed = untransform_input(xs)
print(f"Max logprob model index (0-based): {max_logprob_idx[0]}; logprob = {max_logprob[0]:.3f}, chi2 = {chi2_max_logprob:.3f}; cpa = {xs_untransformed[0]:.2f}, pwr1par = {xs_untransformed[1]:.2f}, pwr1perr = {xs_untransformed[2]:.2f}, pwr2par = {xs_untransformed[3]:.2f}, pwr2perr = {xs_untransformed[4]:.2f}")
print(f"Calculated chi2 and logprob from sample parameters: logprob = {log_prob:.3f}; chi2 = {cum_chi2:.3f}\n")

for i in range(len(yhat_interp_integrated)):
    print(f'rig = {bin_midpoints[i]:.6f} [{bins[i]:.2f}, {bins[i+1]:.2f}]; data = {observed[i]:.6f}',
          f'; unc = {uncertainty[i]:.6f}; mod: {yhat_interp_integrated[i]:.6f}',
          f'; cum_chi2 = {(chi2[:i+1].sum()):.6f}')

Max logprob model index (0-based): 95772; logprob = -1557.792, chi2 = 3115.585; cpa = 444.45, pwr1par = 1.58, pwr1perr = 1.18, pwr2par = 0.40, pwr2perr = 0.40
Calculated chi2 and logprob from sample parameters: logprob = -1557.794; chi2 = 3115.588

rig = 0.334664 [0.28, 0.40]; data = 57.371010 ; unc = 0.591263; mod: 63.503342 ; cum_chi2 = 107.569572
rig = 0.414729 [0.40, 0.43]; data = 257.945500 ; unc = 27.355140; mod: 113.547440 ; cum_chi2 = 135.433655
rig = 0.444747 [0.43, 0.46]; data = 286.858200 ; unc = 30.813750; mod: 135.260269 ; cum_chi2 = 159.638229
rig = 0.464973 [0.46, 0.47]; data = 313.476200 ; unc = 37.132920; mod: 151.193771 ; cum_chi2 = 178.737854
rig = 0.479896 [0.47, 0.49]; data = 326.237900 ; unc = 34.221780; mod: 163.638718 ; cum_chi2 = 201.313049
rig = 0.499900 [0.49, 0.51]; data = 371.023300 ; unc = 34.646030; mod: 179.258759 ; cum_chi2 = 231.948868
rig = 0.529623 [0.51, 0.55]; data = 430.218600 ; unc = 40.724330; mod: 203.897797 ; cum_chi2 = 262.833405
rig = 0.5599

# Check older data

In [51]:
SLURM_ARRAY_TASK_ID = 0
SLURM_ARRAY_JOB_ID = 0
DEBUG = True

# Version specifications
model_version = 'v3.0' # v2.0 is MSE NN, v3.0 is MAE NN
hmc_version = 'v24.3'
file_version = '2023'

# Select experiment parameters
df = index_mcmc_runs(file_version=file_version)  # List of all experiments (0-209) for '2023', 0-14 for '2024'
print(f'Found {df.shape[0]} combinations to run MCMC on. Performing MCMC on index {SLURM_ARRAY_TASK_ID}.')
df = df.iloc[SLURM_ARRAY_TASK_ID]

# Load observation data and define logprob. 
if file_version == '2023': 
    data_path = f'../../data/oct2022/{df.experiment_name}/{df.experiment_name}_{df.interval}.dat'  # This data is the same.
elif file_version == '2024': 
    year = 2000 + SLURM_ARRAY_TASK_ID # assumes only negative intervals. If otherwise, fix this
    data_path = f'../../data/2024/yearly/{year}.dat'
else:
    raise ValueError(f"Invalid file_version {file_version}. Must be '2023' or '2024'.")

model_path = f'../../models/model_{model_version}_{df.polarity}.keras'

df.head()

Found 133 combinations to run MCMC on. Performing MCMC on index 0.


interval     20110520-20110610
alpha                    51.49
cmf                       4.85
vspoles                 632.52
alpha_std                10.69
Name: 0, dtype: object

In [52]:
bins, observed, uncertainty = utils.load_data_ams(data_path)
print(f"Loaded dataset from {data_path}.")
print(f'Bins: {bins}')
print(f'Observed: {observed}')
print(f'Uncertainty: {uncertainty}')

Loaded dataset from ../../data/oct2022/AMS02_H-PRL2021/AMS02_H-PRL2021_20110520-20110610.dat.
Bins: [  1.     1.16   1.33   1.51   1.71   1.92   2.15   2.4    2.67   2.97
   3.29   3.64   4.02   4.43   4.88   5.37   5.9    6.47   7.09   7.76
   8.48   9.26  10.1   11.    13.    16.6   22.8   33.5   48.5   69.7
 100.  ]
Observed: [9.542576e+02 9.411921e+02 8.769211e+02 8.003528e+02 7.088688e+02
 6.185521e+02 5.311145e+02 4.512981e+02 3.792883e+02 3.166651e+02
 2.633877e+02 2.184992e+02 1.800327e+02 1.470573e+02 1.199892e+02
 9.722954e+01 7.856188e+01 6.343624e+01 5.116549e+01 4.125726e+01
 3.316556e+01 2.672519e+01 2.151074e+01 1.559146e+01 9.113706e+00
 4.317174e+00 1.667450e+00 5.843280e-01 2.086225e-01 7.601720e-02]
Uncertainty: [2.811600e+01 2.148905e+01 1.643800e+01 1.293201e+01 1.028839e+01
 8.293686e+00 6.719937e+00 5.448942e+00 4.419263e+00 3.583056e+00
 2.932180e+00 2.403835e+00 1.970323e+00 1.589138e+00 1.290966e+00
 1.053403e+00 8.565820e-01 6.939785e-01 5.606656e-01 4.578981