In [60]:
import os
import sys
import importlib

import utils
importlib.reload(utils)

import keras_core as kerasjk
import jax
import jax.numpy as jnp
from jax import random, vmap, jit, grad
#assert jax.default_backend() == 'gpu'
import numpy as np
import pandas as pd
import time
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf
#import elegy # pip install elegy. # Trying to do this with keras core instead.
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions

# Check new yearly data load

In [61]:
SLURM_ARRAY_TASK_ID = 9
SLURM_ARRAY_JOB_ID = 0
DEBUG = True

# Version specifications
model_version = 'v3.0' # v2.0 is MSE NN, v3.0 is MAE NN
hmc_version = 'v24.3'
file_version = '2024'

# Select experiment parameters
df = utils.index_mcmc_runs(file_version=file_version)  # List of all experiments (0-209) for '2023', 0-14 for '2024'
print(f'Found {df.shape[0]} combinations to run MCMC on. Performing MCMC on index {SLURM_ARRAY_TASK_ID}.')
df = df.iloc[SLURM_ARRAY_TASK_ID]

# Define parameters for HMC
seed = SLURM_ARRAY_TASK_ID + SLURM_ARRAY_JOB_ID
penalty = 1e6
integrate = False # If False, Chi2 is interpolated. If True, Chi2 is integrated.
par_equals_perr = True # If True, only 3 parameters will be sampled by the HMC and pwr1par==pwr1perr and pwr2par==pwr2perr
constant_vspoles = True # If True, vspoles is fixed to 400.0. If False, vspoles is specified in the data file.
specified_parameters = utils.get_parameters(df.filename_heliosphere, df.interval, constant_vspoles=constant_vspoles)

# Number of parameters for HMC to sample. 5 normally, 3 if par_equals_perr=True
if par_equals_perr:
    num_params = 3
else:
    num_params = 5

# Load observation data and define logprob. 
if file_version == '2023': 
    data_path = f'../data/oct2022/{df.experiment_name}/{df.experiment_name}_{df.interval}.dat'  # This data is the same.
elif file_version == '2024': 
    year = 2000 + SLURM_ARRAY_TASK_ID # assumes only negative intervals. If otherwise, fix this
    data_path = f'../data/2024/yearly/{year}.dat'
else:
    raise ValueError(f"Invalid file_version {file_version}. Must be '2023' or '2024'.")

model_path = f'../models/model_{model_version}_{df.polarity}.keras'

df.head()

Found 15 combinations to run MCMC on. Performing MCMC on index 9.


interval       2009
alpha         26.26
cmf            4.05
vspoles      712.95
alpha_std      8.35
Name: 13, dtype: object

In [62]:
dataset_ams = np.loadtxt(data_path, usecols=(0,1,2,3)) # Rigidity1, Rigidity2, Flux, Error, dataset (only if yearly dataset)
r1, r2 = dataset_ams[:,0], dataset_ams[:,1]

if 'yearly' in data_path:
    # Need to sort yearly datasets by r1
    sort_indices = np.argsort(r1)
    dataset_ams = dataset_ams[sort_indices, :]
    r1, r2 = dataset_ams[:,0], dataset_ams[:,1]

print(sort_indices)
print(dataset_ams[:,0])
print(dataset_ams[:,1])

[ 0 14 15 16 17 18 19 20 21 22 23 24 25 26 27  1 28 29  2 30  3 31 32  4
 33 34  5 35 36  6 37 38  7 39 40  8 41 42  9 43 44 10 45 46 11 47 48 12
 49 50 13 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91]
[ 0.28  0.4   0.43  0.46  0.47  0.49  0.51  0.55  0.57  0.59  0.63  0.66
  0.68  0.71  0.75  0.76  0.79  0.82  0.83  0.87  0.9   0.9   0.94  0.98
  1.    1.04  1.06  1.09  1.14  1.16  1.19  1.24  1.26  1.31  1.37  1.38
  1.43  1.5   1.51  1.57  1.65  1.66  1.72  1.8   1.83  1.89  1.98  2.02
  2.08  2.16  2.23  2.27  2.38  2.49  2.6   2.73  2.86  3.    3.13  3.29
  3.43  3.6   3.76  3.95  4.13  4.33  4.53  4.97  5.45  5.98  6.55  7.18
  7.87  8.63  9.46 10.38 11.38 12.47 13.68 15.   16.45 18.03 19.78 21.68
 23.77 26.07 28.58 31.34 34.37 37.68 41.32 45.3 ]
[ 0.32  0.43  0.46  0.47  0.49  0.51  0.55  0.57  0.59  0.63  0.66  0.68
  0.71  0.75  0.79  0.83  0.82  0.87  0.9   0.9   0.98  0.94  1.    1.06
  1.04  1.09

In [63]:
bins = np.concatenate([r1[:], r2[-1:]])
observed = dataset_ams[:,2]   # Observed Flux
uncertainty = dataset_ams[:,3]
assert len(bins) == len(observed)+1
bin_midpoints = (bins[:-1] * bins[1:]) ** 0.5  # Geometric mean seemed to work better in exp.

print(f'Bins: {bins}')
print(f'Bin midpoints: {bin_midpoints}')
print(f'Observed: {observed}')
print(f'Uncertainty: {uncertainty}')

Bins: [ 0.28  0.4   0.43  0.46  0.47  0.49  0.51  0.55  0.57  0.59  0.63  0.66
  0.68  0.71  0.75  0.76  0.79  0.82  0.83  0.87  0.9   0.9   0.94  0.98
  1.    1.04  1.06  1.09  1.14  1.16  1.19  1.24  1.26  1.31  1.37  1.38
  1.43  1.5   1.51  1.57  1.65  1.66  1.72  1.8   1.83  1.89  1.98  2.02
  2.08  2.16  2.23  2.27  2.38  2.49  2.6   2.73  2.86  3.    3.13  3.29
  3.43  3.6   3.76  3.95  4.13  4.33  4.53  4.97  5.45  5.98  6.55  7.18
  7.87  8.63  9.46 10.38 11.38 12.47 13.68 15.   16.45 18.03 19.78 21.68
 23.77 26.07 28.58 31.34 34.37 37.68 41.32 45.3  49.67]
Bin midpoints: [ 0.33466401  0.41472883  0.44474712  0.46497312  0.47989582  0.49989999
  0.52962251  0.55991071  0.57991379  0.60967204  0.64482556  0.66992537
  0.69483811  0.72972598  0.75498344  0.77485483  0.80486024  0.82498485
  0.84976467  0.88487287  0.9         0.91978258  0.95979164  0.98994949
  1.0198039   1.04995238  1.07489534  1.1147197   1.14995652  1.17490425
  1.21474277  1.24996     1.28475679  1.3396641

In [64]:
target_log_prob = utils.define_log_prob(model_path, data_path, specified_parameters, penalty=penalty, integrate=integrate, par_equals_perr=par_equals_perr)

2024-08-13 20:31:06.787437: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


# Check older data

In [49]:
SLURM_ARRAY_TASK_ID = 0
SLURM_ARRAY_JOB_ID = 0
DEBUG = True

# Version specifications
model_version = 'v3.0' # v2.0 is MSE NN, v3.0 is MAE NN
hmc_version = 'v24.3'
file_version = '2023'

# Select experiment parameters
df = utils.index_mcmc_runs(file_version=file_version)  # List of all experiments (0-209) for '2023', 0-14 for '2024'
print(f'Found {df.shape[0]} combinations to run MCMC on. Performing MCMC on index {SLURM_ARRAY_TASK_ID}.')
df = df.iloc[SLURM_ARRAY_TASK_ID]

# Load observation data and define logprob. 
if file_version == '2023': 
    data_path = f'../data/oct2022/{df.experiment_name}/{df.experiment_name}_{df.interval}.dat'  # This data is the same.
elif file_version == '2024': 
    year = 2000 + SLURM_ARRAY_TASK_ID # assumes only negative intervals. If otherwise, fix this
    data_path = f'../data/2024/yearly/{year}.dat'
else:
    raise ValueError(f"Invalid file_version {file_version}. Must be '2023' or '2024'.")

model_path = f'../models/model_{model_version}_{df.polarity}.keras'

df.head()

Found 133 combinations to run MCMC on. Performing MCMC on index 0.


interval     20110520-20110610
alpha                    51.49
cmf                       4.85
vspoles                 632.52
alpha_std                10.69
Name: 0, dtype: object

In [50]:
""" Load AMS data from Claudio. Each file contains measurements over a certain time interval. 
Args:
    filename = Filename of observations.
                Original dataset was '../data/BR2461.dat'
                New datasets are in '../data/oct2022/'
                New yearly datasets are in '../data/2024/yearly'
"""
dataset_ams = np.loadtxt(data_path, usecols=(0,1,2,3)) # Rigidity1, Rigidity2, Flux, Error, dataset (only if yearly dataset)
r1, r2 = dataset_ams[:,0], dataset_ams[:,1]

if 'yearly' in data_path:
    # Need to sort yearly datasets by r1
    sort_indices = np.argsort(r1)
    dataset_ams = dataset_ams[sort_indices, :]

bins = np.concatenate([r1[:], r2[-1:]])
observed = dataset_ams[:,2]   # Observed Flux
uncertainty = dataset_ams[:,3]
assert len(bins) == len(observed)+1

In [51]:
print(f"Loaded dataset from {data_path}.")
print(f'Bins: {bins}')
print(f'Observed: {observed}')
print(f'Uncertainty: {uncertainty}')

Loaded dataset from ../data/oct2022/AMS02_H-PRL2021/AMS02_H-PRL2021_20110520-20110610.dat.
Bins: [  1.     1.16   1.33   1.51   1.71   1.92   2.15   2.4    2.67   2.97
   3.29   3.64   4.02   4.43   4.88   5.37   5.9    6.47   7.09   7.76
   8.48   9.26  10.1   11.    13.    16.6   22.8   33.5   48.5   69.7
 100.  ]
Observed: [9.542576e+02 9.411921e+02 8.769211e+02 8.003528e+02 7.088688e+02
 6.185521e+02 5.311145e+02 4.512981e+02 3.792883e+02 3.166651e+02
 2.633877e+02 2.184992e+02 1.800327e+02 1.470573e+02 1.199892e+02
 9.722954e+01 7.856188e+01 6.343624e+01 5.116549e+01 4.125726e+01
 3.316556e+01 2.672519e+01 2.151074e+01 1.559146e+01 9.113706e+00
 4.317174e+00 1.667450e+00 5.843280e-01 2.086225e-01 7.601720e-02]
Uncertainty: [2.811600e+01 2.148905e+01 1.643800e+01 1.293201e+01 1.028839e+01
 8.293686e+00 6.719937e+00 5.448942e+00 4.419263e+00 3.583056e+00
 2.932180e+00 2.403835e+00 1.970323e+00 1.589138e+00 1.290966e+00
 1.053403e+00 8.565820e-01 6.939785e-01 5.606656e-01 4.578981e-0