## Example: Making figures for Shapley Analysis
- Assumes that RF model has already been trained and exists under "FI_GEO_RFdat_AIMFAHR_forest.pkl"
- Assumes that data for model is curated and stored under "FI_GEO_RFdat_AIMFAHR_data.h5"

Both were created from Kyle Murphy's notebook; the notebook was modifed slightly to pickle the trained RF model and save the training/test/out-of-sample data to file.

See Murphy+2025 for RF model: https://doi.org/10.1029/2024SW003928

In [1]:
import numpy as np
import numpy.random as rand
import pandas as pd
import shap_analysis as shpa
import time as pytime
import matplotlib.pyplot as plt
from matplotlib import gridspec

In [2]:
import fasttreeshap as fts
from fasttreeshap.plots import beeswarm, waterfall, bar
import shap

In [4]:
NJOBS = 8 # change based on your computer

In [5]:
# hack to save minimal viable explanation, since it takes a while to make
def save_explanation(fname, expln):
    # hack to avoid pickling and associated errors
    dta = expln.data.astype(float)
    np.savez(fname, values=expln.values, 
             base_values=expln.base_values, data=dta,
             feature_names=expln.feature_names, 
             output_names=expln.output_names, compute_time=expln.compute_time)

def load_explanation(fname):
    dat_dict = np.load(fname)
    return fts.Explanation(values = dat_dict["values"], base_values=dat_dict["base_values"], 
                           data=dat_dict["data"], feature_names=dat_dict["feature_names"], 
                           output_names=dat_dict["output_names"], 
                           compute_time=dat_dict["compute_time"])

In [6]:
# hack to blend multiple explanations together
def stack_expln(expln1, expln2):
    return fts.Explanation(values = np.vstack([expln1.values, expln2.values]), 
                             base_values = np.vstack([expln1.base_values, expln2.base_values]), 
                             data = np.vstack([expln1.data, expln2.data]),
                             feature_names=expln1.feature_names,
                            )

In [7]:
# load in needed RF, shap_explainer, and dataset
rf, explainer = shpa.load_tools("FI_GEO_RFdat_AIMFAHR_forest.pkl", n_jobs=NJOBS) #uncomment if you need to run model/explainer
fgeo_col  = shpa.fgeo_col # data columns used in best-performing Murphy RF model
full_data = shpa.load_data(all_cols=True, option="test_d") # loads Grace B test data

In [8]:
# sort into more specific datasets
storm_data = full_data[full_data['storm'] == 1]
quiet_data = full_data[full_data['storm'] == -1]

recovery_data = storm_data[storm_data["storm phase"] == 2]
mainphase_data = storm_data[storm_data["storm phase"] == 1]

sorted_data = full_data.sort_values("DateTime")

In [9]:
pb = shpa.PolarBear(n_lat = 12, n_mlt = 24)

## Summary of data

In [10]:
full_data.shape

(271092, 14)

In [25]:
full_data.describe()

Unnamed: 0,1300_02,43000_09,85550_13,94400_18,SYM_H index,AE,SatLat,cos_SatMagLT,sin_SatMagLT,400kmDensity,DateTime,storm,storm phase,400kmDensity_pred
count,271092.0,271092.0,271092.0,271092.0,271092.0,271092.0,271092.0,271092.0,271092.0,271092.0,271092,271092.0,271092.0,271092.0
mean,7.187411,9.859163,9.434794,9.207575,-10.990863,171.890853,-0.151048,-0.000522,0.000385,1.308555,2007-06-11 12:51:30.882431488,0.10381,0.436759,1.308646
min,6.395051,9.715523,9.369749,9.147455,-469.0,1.0,-89.03936,-1.0,-1.0,0.0023,2003-01-01 00:30:00,-1.0,-1.0,0.066742
25%,6.826294,9.768229,9.392063,9.163332,-17.0,38.0,-45.065897,-0.695579,-0.719182,0.475864,2005-03-08 14:01:15,-1.0,-1.0,0.500467
50%,7.195337,9.832408,9.419841,9.189895,-8.0,85.0,-0.12629,-0.001545,-0.000995,0.907904,2007-05-09 07:35:00,1.0,1.0,0.935888
75%,7.536688,9.937945,9.472079,9.242422,-1.0,228.0,45.01862,0.69313,0.719841,1.73887,2009-07-05 17:37:30,1.0,2.0,1.767111
max,9.525276,26.979264,9.709964,31.735991,97.0,3529.0,89.03638,1.0,1.0,23.6808,2012-06-30 23:55:00,1.0,2.0,16.188876
std,0.394671,0.186431,0.048081,0.214995,18.554718,209.125574,52.010454,0.700119,0.714028,1.213174,,0.994599,1.344655,1.141864


In [186]:
storm_data.shape

(149617, 14)

In [26]:
storm_data.describe()

Unnamed: 0,1300_02,43000_09,85550_13,94400_18,SYM_H index,AE,SatLat,cos_SatMagLT,sin_SatMagLT,400kmDensity,DateTime,storm,storm phase,400kmDensity_pred
count,149617.0,149617.0,149617.0,149617.0,149617.0,149617.0,149617.0,149617.0,149617.0,149617.0,149617,149617.0,149617.0,149617.0
mean,7.269417,9.880465,9.444359,9.218007,-15.78846,226.35115,-0.175444,-0.001048,0.001413,1.537729,2006-10-22 07:55:03.559087616,1.0,1.603274,1.534656
min,6.449483,9.73036,9.376063,9.148476,-469.0,1.0,-89.03829,-1.0,-1.0,0.003382,2003-01-01 00:30:00,1.0,1.0,0.071081
25%,6.9671,9.791568,9.401458,9.171335,-24.0,52.0,-45.07429,-0.684184,-0.729076,0.621265,2004-07-20 04:40:00,1.0,1.0,0.646251
50%,7.342521,9.86943,9.436674,9.204232,-13.0,136.0,-0.17662,-0.001178,0.00199,1.146216,2006-06-16 19:05:00,1.0,2.0,1.18395
75%,7.590228,9.957018,9.482035,9.253163,-4.0,322.0,45.00366,0.682388,0.731258,2.082349,2008-08-10 10:05:00,1.0,2.0,2.103849
max,9.525276,26.979264,9.709964,31.735991,97.0,3529.0,89.03152,1.0,1.0,23.6808,2012-06-22 03:40:00,1.0,2.0,16.188876
std,0.380843,0.234827,0.048743,0.286155,22.304878,242.094674,51.982336,0.695019,0.718994,1.288445,,0.0,0.48922,1.200584


In [187]:
quiet_data.shape

(121475, 14)

In [27]:
quiet_data.describe()

Unnamed: 0,1300_02,43000_09,85550_13,94400_18,SYM_H index,AE,SatLat,cos_SatMagLT,sin_SatMagLT,400kmDensity,DateTime,storm,storm phase,400kmDensity_pred
count,121475.0,121475.0,121475.0,121475.0,121475.0,121475.0,121475.0,121475.0,121475.0,121475.0,121475,121475.0,121475.0,121475.0
mean,7.086407,9.832926,9.423013,9.194726,-5.081811,104.813789,-0.121001,0.000125,-0.000882,1.026288,2008-03-23 12:52:39.106811904,-1.0,-1.0,1.030277
min,6.395051,9.715523,9.369749,9.147455,-81.0,2.0,-89.03936,-1.0,-1.0,0.0023,2003-01-05 07:30:00,-1.0,-1.0,0.066742
25%,6.745853,9.752132,9.385549,9.157656,-10.0,29.0,-45.058305,-0.708826,-0.706903,0.362907,2006-03-28 12:47:30,-1.0,-1.0,0.390073
50%,7.040624,9.806662,9.408161,9.178545,-4.0,54.0,-0.0878,-0.002068,-0.003351,0.679256,2008-07-07 17:25:00,-1.0,-1.0,0.693351
75%,7.430072,9.909953,9.455955,9.226672,1.0,126.0,45.03114,0.708133,0.704482,1.276441,2009-12-10 05:37:30,-1.0,-1.0,1.289399
max,8.489174,10.141226,9.609864,9.370111,55.0,1991.0,89.03638,1.0,1.0,10.0752,2012-06-30 23:55:00,-1.0,-1.0,8.354893
std,0.387929,0.091646,0.044511,0.044726,9.606497,131.372336,52.045264,0.706352,0.707866,1.046669,,0.0,0.0,0.997017


In [28]:
mainphase_data.describe()

Unnamed: 0,1300_02,43000_09,85550_13,94400_18,SYM_H index,AE,SatLat,cos_SatMagLT,sin_SatMagLT,400kmDensity,DateTime,storm,storm phase,400kmDensity_pred
count,59357.0,59357.0,59357.0,59357.0,59357.0,59357.0,59357.0,59357.0,59357.0,59357.0,59357,59357.0,59357.0,59357.0
mean,7.291208,9.882281,9.446647,9.216659,-8.327661,219.015584,-0.178185,-0.002439,-0.002205,1.533563,2006-10-11 10:18:27.549235968,1.0,1.0,1.5224
min,6.450426,9.73036,9.376622,9.148476,-469.0,1.0,-89.03777,-1.0,-1.0,0.00902,2003-01-01 05:05:00,1.0,1.0,0.071081
25%,6.999895,9.793853,9.402021,9.172416,-15.0,44.0,-45.28952,-0.691708,-0.728395,0.589399,2004-05-28 03:45:00,1.0,1.0,0.621803
50%,7.367244,9.881559,9.443245,9.211713,-5.0,107.0,-0.08739,-0.003508,-0.002906,1.126245,2006-05-17 22:15:00,1.0,1.0,1.159533
75%,7.60426,9.963187,9.485718,9.256993,3.0,307.0,45.03942,0.685405,0.722883,2.096799,2008-09-30 12:05:00,1.0,1.0,2.089964
max,8.981911,10.350459,9.648032,9.460077,97.0,3529.0,89.02986,1.0,1.0,23.6808,2012-06-17 12:30:00,1.0,1.0,16.188876
std,0.375779,0.098472,0.048865,0.048023,21.863546,260.385554,52.089084,0.696866,0.717206,1.329464,,0.0,0.0,1.233295


In [29]:
recovery_data.describe()

Unnamed: 0,1300_02,43000_09,85550_13,94400_18,SYM_H index,AE,SatLat,cos_SatMagLT,sin_SatMagLT,400kmDensity,DateTime,storm,storm phase,400kmDensity_pred
count,90260.0,90260.0,90260.0,90260.0,90260.0,90260.0,90260.0,90260.0,90260.0,90260.0,90260,90260.0,90260.0,90260.0
mean,7.255086,9.879271,9.442855,9.218894,-20.694848,231.175183,-0.17364,-0.000133,0.003792,1.540468,2006-10-29 11:57:29.734101504,1.0,2.0,1.542716
min,6.449483,9.731566,9.376063,9.149251,-426.0,1.0,-89.03829,-1.0,-1.0,0.003382,2003-01-01 00:30:00,1.0,2.0,0.097776
25%,6.944725,9.790106,9.401117,9.170291,-27.0,61.0,-45.042732,-0.680404,-0.72943,0.640422,2004-07-29 14:12:30,1.0,2.0,0.663336
50%,7.310173,9.86297,9.433119,9.201211,-17.0,153.0,-0.24153,0.000497,0.006021,1.159604,2006-07-06 15:07:30,1.0,2.0,1.19668
75%,7.572505,9.95242,9.479378,9.250661,-9.0,330.0,44.975295,0.680126,0.736615,2.071376,2008-07-12 23:56:15,1.0,2.0,2.113241
max,9.525276,26.979264,9.709964,31.735991,64.0,2966.0,89.03152,1.0,1.0,21.39224,2012-06-22 03:40:00,1.0,2.0,15.355804
std,0.383465,0.291595,0.048605,0.366355,21.205026,229.145755,51.912306,0.693805,0.720161,1.260742,,0.0,0.0,1.178516


## Overall Look
1. Randomly sample storm times and non-storm times
2. run all events through explainer
3. make beeswarm and bar plots for all classes

In [10]:
# rng sampler
gen = rand.default_rng(693993)

In [11]:
NUM_SAMP = 2000
storm_sample = gen.choice(range(storm_data.shape[0]), size = NUM_SAMP, replace=False)
quiet_sample = gen.choice(range(quiet_data.shape[0]), size = NUM_SAMP, replace=False)

In [12]:
storm_samp_data = storm_data.iloc[storm_sample]
quiet_samp_data = quiet_data.iloc[quiet_sample]

In [13]:
recovery_sample = gen.choice(range(recovery_data.shape[0]), size = NUM_SAMP, replace=False)
mainphase_sample = gen.choice(range(mainphase_data.shape[0]), size = NUM_SAMP, replace=False)

In [14]:
mainphase_samp_data = mainphase_data.iloc[mainphase_sample]
recovery_samp_data = recovery_data.iloc[recovery_sample]

### Get shapley values
WARNING: this will take a while on a laptop, depending on the NUM_SAMP. For 2000 pts on my laptop (NJOBS=8), it took about 25-30 mins, depending on multitasking.

In [12]:
"""
print("start shap storm")
tfirst = pytime.perf_counter()
storm_shap = explainer(storm_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("storm1", storm_shap)
"""

start shap storm
Time taken: 1739.8983653000032


In [19]:
"""
print("start shap quiet")
tfirst = pytime.perf_counter()
quiet_shap = explainer(quiet_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("quiet1", quiet_shap)
"""

start shap quiet
Time taken: 1506.1464909000206


In [154]:
"""
print("start shap main")
tfirst = pytime.perf_counter()
mainphase_shap = explainer(mainphase_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("mainphase1", mainphase_shap)
"""

start shap main
Time taken: 1556.1542722999584


In [155]:
"""
print("start shap recovery")
tfirst = pytime.perf_counter()
recovery_shap = explainer(recovery_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("recovery1", recovery_shap)
"""

start shap main
Time taken: 1481.5337197000626


In [156]:
# alternatively, if you have already calculated shap values
storm_shap = load_explanation("storm1.npz")

In [157]:
recovery_shap = load_explanation("recovery1.npz")

In [159]:
quiet_shap = load_explanation("quiet1.npz")

In [18]:
mainphase_shap = load_explanation("mainphase1.npz")

In [160]:
total_shap = stack_expln(storm_shap, quiet_shap)

### Make plots for global shap values

#### beeswarm

In [13]:
beeswarm(storm_shap)

No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored


In [18]:
f = plt.gcf()
f.savefig("beeswarmStorm.png", bbox_inches="tight")
f.clf()

In [20]:
beeswarm(quiet_shap)

No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored


In [21]:
f = plt.gcf()
f.savefig("beeswarmQuiet.png", bbox_inches="tight")
f.clf()

In [156]:
beeswarm(recovery_shap)

No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored


In [157]:
f = plt.gcf()
f.savefig("beeswarmRecover.png", bbox_inches="tight")
f.clf()

In [18]:
beeswarm(mainphase_shap)

NameError: name 'mainphase_shap' is not defined

In [159]:
f = plt.gcf()
f.savefig("beeswarmMainPhase.png", bbox_inches="tight")
f.clf()

In [146]:
beeswarm(total_shap)

In [147]:
f = plt.gcf()
f.savefig("beeswarmTotal.png", bbox_inches="tight")
f.clf()

In [137]:
beeswarm(storm_shap+quiet_shap) # this is the wrong method

No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored


In [138]:
f = plt.gcf()
f.savefig("beeswarmBadTotal.png", bbox_inches="tight")
f.clf()

#### bar

In [106]:
fts.plots.bar(storm_shap)

In [107]:
f = plt.gcf()
f.savefig("barStorm.png", bbox_inches="tight")
f.clf()

In [108]:
fts.plots.bar(quiet_shap)

In [109]:
f = plt.gcf()
f.savefig("barQuiet.png", bbox_inches="tight")
f.clf()

In [148]:
fts.plots.bar(total_shap)

In [149]:
f = plt.gcf()
f.savefig("barTotal.png", bbox_inches="tight")
f.clf()

In [215]:
fts.plots.bar(recovery_shap)

In [216]:
f = plt.gcf()
f.savefig("barRecovery.png", bbox_inches="tight")
f.clf()

In [217]:
fts.plots.bar(mainphase_shap)

In [218]:
f = plt.gcf()
f.savefig("barMainPhase.png", bbox_inches="tight")
f.clf()

In [255]:
fts.plots.heatmap(storm_shap)

In [256]:
f = plt.gcf()
f.savefig("heatmapStorm.png", bbox_inches="tight")
f.clf()

#### Clustering analysis
Check redundancies of input features via clustering analysis

In [21]:
storm_samp_data[fgeo_col].values.shape

(2000, 9)

In [20]:
tfirst = pytime.perf_counter()
storm_clustering = shap.utils.hclust(storm_samp_data[fgeo_col].values, storm_samp_data["400kmDensity_pred"].values, metric="cosine")
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")

Time taken: 0.007312900037504733


Ignoring the y argument passed to shap.utils.hclust since the given clustering metric is not based on label fitting!


In [22]:
tfirst = pytime.perf_counter()
storm_clustering_xg = shap.utils.hclust(storm_samp_data[fgeo_col].values, storm_samp_data["400kmDensity_pred"].values, metric="auto")
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")

Time taken: 0.8806039001792669


In [52]:
fts.plots.bar(storm_shap, clustering=storm_clustering_xg)

In [53]:
f = plt.gcf()
f.savefig("barStorm_clust_xg.png", bbox_inches="tight")
f.clf()

In [23]:
total_samp_data = pd.concat([storm_samp_data, quiet_samp_data])

In [24]:
tfirst = pytime.perf_counter()
total_clustering_xg = shap.utils.hclust(total_samp_data[fgeo_col].values, total_samp_data["400kmDensity_pred"].values, metric="auto")
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")

Time taken: 0.8806074999738485


In [50]:
fts.plots.bar(total_shap, clustering=total_clustering_xg)

In [51]:
f = plt.gcf()
f.savefig("barTotal_clust_xg.png", bbox_inches="tight")
f.clf()

In [21]:
fgeo_col

['1300_02',
 '43000_09',
 '85550_13',
 '94400_18',
 'SYM_H index',
 'AE',
 'SatLat',
 'cos_SatMagLT',
 'sin_SatMagLT']

## Single Events
The events we are looking at are semi-randomly chosen: 2 during a storm, one during a non-storm. One has a "high" density prediction, one "medium", and one "low".

1. "High": storm_data index 4803
2. "Medium": storm_data index 4804
3. "Low": quiet_data index 604

In [11]:
high_evt = storm_data.iloc[4803]
med_evt = storm_data.iloc[4804]
low_evt = quiet_data.iloc[604]

In [18]:
high_evt

1300_02                         7.778282
43000_09                       10.102009
85550_13                         9.56579
94400_18                        9.332493
SYM_H index                        -25.0
AE                                 112.0
SatLat                          43.13684
cos_SatMagLT                   -0.815219
sin_SatMagLT                    0.579153
400kmDensity                    4.485137
DateTime             2003-03-07 05:15:00
storm                                  1
storm phase                            2
400kmDensity_pred               4.847418
Name: 62847, dtype: object

In [29]:
print("MLT high: ", 24.*np.arctan2(high_evt["sin_SatMagLT"], high_evt["cos_SatMagLT"])/(2*np.pi))

MLT high:  9.6406


In [15]:
med_evt

1300_02                         7.391199
43000_09                        9.927624
85550_13                        9.469308
94400_18                        9.234825
SYM_H index                        -44.0
AE                                 169.0
SatLat                         -14.01763
cos_SatMagLT                   -0.051369
sin_SatMagLT                     0.99868
400kmDensity                    1.664955
DateTime             2005-01-23 07:45:00
storm                                  1
storm phase                            2
400kmDensity_pred               1.621514
Name: 261021, dtype: object

In [30]:
print("MLT med: ",24.*np.arctan2(med_evt["sin_SatMagLT"], med_evt["cos_SatMagLT"])/(2*np.pi))

MLT med:  6.1963


In [17]:
low_evt

1300_02                         7.160848
43000_09                        9.815979
85550_13                        9.411937
94400_18                        9.186509
SYM_H index                         -1.0
AE                                 220.0
SatLat                          38.91635
cos_SatMagLT                   -0.694483
sin_SatMagLT                     0.71951
400kmDensity                    0.944242
DateTime             2010-03-29 16:05:00
storm                                 -1
storm phase                           -1
400kmDensity_pred               0.888769
Name: 805729, dtype: object

In [32]:
print("MLT low: ", 24.*np.arctan2(low_evt["sin_SatMagLT"], low_evt["cos_SatMagLT"])/(2*np.pi))

MLT low:  8.9324


### individual shap values and waterfall

In [12]:
high_shaps = explainer(high_evt[fgeo_col].to_frame().transpose())

In [13]:
# example waterfall plot
# copy shenangians needed since single event data somehow becomes a 2D array
from copy import deepcopy
high_shaps2 = deepcopy(high_shaps)
high_shaps2.base_values = high_shaps2.base_values[0][0]
high_shaps2.values = high_shaps2.values.squeeze()
high_shaps2.data = high_shaps2.data.squeeze()

In [None]:
fts.plots.waterfall(high_shaps2)

In [99]:
f = plt.gcf()
f.savefig("waterfallHighDen.png", bbox_inches="tight")
f.clf()

In [20]:
med_shaps = explainer(med_evt[fgeo_col].to_frame().transpose())

In [21]:
med_shaps2 = deepcopy(med_shaps)
med_shaps2.base_values = med_shaps2.base_values[0][0]
med_shaps2.values = med_shaps2.values.squeeze()
med_shaps2.data = med_shaps2.data.squeeze()

In [102]:
fts.plots.waterfall(med_shaps2)

In [103]:
f = plt.gcf()
f.savefig("waterfallMedDen.png", bbox_inches="tight")
f.clf()

In [22]:
low_shaps = explainer(low_evt[fgeo_col].to_frame().transpose())

In [23]:
low_shaps2 = deepcopy(low_shaps)
low_shaps2.base_values = low_shaps2.base_values[0][0]
low_shaps2.values = low_shaps2.values.squeeze()
low_shaps2.data = low_shaps2.data.squeeze()

In [None]:
fts.plots.waterfall(low_shaps2)

In [78]:
f = plt.gcf()
f.savefig("waterfallLowDen.png", bbox_inches="tight")
f.clf()

### Expand from single point to 2D Hemisphere Grid

In [16]:
# now expand to a 2D grid
high_north_qq = high_evt["SatLat"] >= 0.
high_grid, high_sat_lat_msh = pb.make_grid(high_evt, north=high_north_qq)

In [45]:
# Uncomment to make shap vals for high grid
"""
print("start event mapping")
tfirst = pytime.perf_counter()
high_shap_vals = explainer(high_grid)
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("high_evt_shap", high_shap_vals)
"""

start event mapping
Time taken: 214.44375979993492


In [19]:
# if already done shap calculation
high_shap_vals = load_explanation("high_evt_shap.npz")

In [21]:
# now expand to a 2D grid
med_north_qq = med_evt["SatLat"] >= 0.
med_grid, med_sat_lat_msh = pb.make_grid(med_evt, north=med_north_qq)

In [82]:
#Uncomment to get SHAP values on grid for medium event
"""
print("start event mapping")
tfirst = pytime.perf_counter()
med_shap_vals = explainer(med_grid)
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("med_evt_shap", med_shap_vals)
"""

start event mapping
Time taken: 198.80716650001705


In [23]:
med_shap_vals = load_explanation("med_evt_shap.npz")

In [24]:
# now expand to a 2D grid
low_north_qq = low_evt["SatLat"] >= 0.
low_grid, low_sat_lat_msh = pb.make_grid(low_evt, north=low_north_qq)

In [83]:
# Uncomment to generate shap_values on Grid for low event
"""
print("start event mapping")
tfirst = pytime.perf_counter()
low_shap_vals = explainer(low_grid)
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("low_evt_shap", low_shap_vals)
"""

start event mapping
Time taken: 217.43469040002674


In [26]:
low_shap_vals = load_explanation("low_evt_shap.npz")

#### Plot density predictions for hemisphere

In [17]:
figP = plt.Figure()
axP = figP.add_subplot(111, projection="polar")

img = pb.plot_density(axP, rf, high_evt, high_grid, high_sat_lat_msh, high_north_qq, vr=[0.5,6])

figP.colorbar(img)
figP.tight_layout()

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.


In [86]:
figP.savefig("highDen.png")

In [18]:
# plot geo den plot
pb.full_den_movie_plot(high_evt, rf, out_prefix="highGEO", vr = [0.5,6.5])

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.


In [91]:
figP = plt.Figure()
axP = figP.add_subplot(111, projection="polar")

img = pb.plot_density(axP, rf, med_evt, med_grid, med_sat_lat_msh, med_north_qq, vr=[0.5,6])

figP.colorbar(img)
figP.tight_layout()

FixedFormatter should only be used together with FixedLocator


In [92]:
figP.savefig("medDen.png")

In [22]:
# plot geo den plot
pb.full_den_movie_plot(med_evt, rf, out_prefix="medGEO", vr = [0.5,6.5])

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.


In [18]:
figP = plt.Figure()
axP = figP.add_subplot(111, projection="polar")

img = pb.plot_density(axP, rf, low_evt, low_grid, low_sat_lat_msh, low_north_qq, vr=[0.5,6])

figP.colorbar(img)
figP.tight_layout()

NameError: name 'low_grid' is not defined

In [89]:
figP.savefig("lowDen.png")

In [25]:
# plot geo den plot
pb.full_den_movie_plot(low_evt, rf, out_prefix="lowGEO", vr = [0.5,2.5])

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.


#### Beeswarm for hemisphere

In [None]:
beeswarm(high_shap_vals, show=False)

In [None]:
f = plt.gcf()
f.savefig("high_evt_beeswarm.png", bbox_inches="tight")
f.clf()

#### Shap Values plots

In [30]:
high_shap_vals.feature_names

['1300_02',
 '43000_09',
 '85550_13',
 '94400_18',
 'SYM_H index',
 'AE',
 'SatLat',
 'cos_SatMagLT',
 'sin_SatMagLT']

In [27]:
figP = plt.Figure()

for nme in high_shap_vals.feature_names:
    i = shpa._fgeo_col_dict[nme]
    figP.clf()
    shp = high_shap_vals.values[:,i]
    min_shap = shp.min()
    max_shap = shp.max()
    #min_shap = min(0, shp.min())
    #max_shap = max(0, shp.max())
    ax = figP.add_subplot(111, projection="polar")
    img = pb.plot_shap(ax, high_evt, shp, nme, high_sat_lat_msh, high_north_qq, [min_shap, max_shap])
    figP.colorbar(img)
    ax.set_title(f"Factor: {nme}", pad=30, size=14)
    #ax.text(0.06, 0.06, f"{date}", transform=figP.transFigure, size=13)
    figP.tight_layout()
    figP.savefig("High"+f"_{nme[:5]}_shap.png")

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after se

In [28]:
fig = plt.Figure(figsize=(12,8))
gs = fig.add_gridspec(2,3)
ax0 = fig.add_subplot(gs[0,0], projection="polar")
ax1 = fig.add_subplot(gs[0,1], projection="polar")
ax2 = fig.add_subplot(gs[0,2], projection="polar")
ax3 = fig.add_subplot(gs[1,0], projection="polar")
ax4 = fig.add_subplot(gs[1,1], projection="polar")
ax5 = fig.add_subplot(gs[1,2], projection="polar")

for nme,axis in zip(["43000_09", "85550_13","1300_02","cos_SatMagLT", "sin_SatMagLT","SYM_H index"], [ax0, ax1, ax2, ax3, ax4, ax5]):
    i = shpa._fgeo_col_dict[nme]
    shp = high_shap_vals.values[:,i]
    min_shap = shp.min()
    max_shap = shp.max()
    #min_shap = min(0, shp.min())
    #max_shap = max(0, shp.max())
    img = pb.plot_shap(axis, high_evt, shp, nme, high_sat_lat_msh, high_north_qq, [min_shap, max_shap])
    fig.colorbar(img, ax=axis, fraction=0.046, pad=0.06)
    axis.set_title(f"Factor: {nme}", pad=25, size=16)
    #ax.text(0.06, 0.06, f"{date}", transform=figP.transFigure, size=13)
    fig.tight_layout()
    fig.savefig("Hex_High_Shap.png")

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.


In [39]:
fig = plt.Figure(figsize=(12,8))
gs = fig.add_gridspec(2,3)
ax0 = fig.add_subplot(gs[0,0], projection="polar")
ax1 = fig.add_subplot(gs[0,1], projection="polar")
ax2 = fig.add_subplot(gs[0,2], projection="polar")
ax3 = fig.add_subplot(gs[1,0], projection="polar")
ax4 = fig.add_subplot(gs[1,1], projection="polar")
ax5 = fig.add_subplot(gs[1,2], projection="polar")

for nme,axis in zip(["43000_09", "85550_13","1300_02","cos_SatMagLT", "sin_SatMagLT","SYM_H index"], [ax0, ax1, ax2, ax3, ax4, ax5]):
    i = shpa._fgeo_col_dict[nme]
    shp = med_shap_vals.values[:,i]
    min_shap = shp.min()
    max_shap = shp.max()
    #min_shap = min(0, shp.min())
    #max_shap = max(0, shp.max())
    img = pb.plot_shap(axis, med_evt, shp, nme, med_sat_lat_msh, med_north_qq, [min_shap, max_shap])
    fig.colorbar(img, ax=axis, fraction=0.046, pad=0.06)
    axis.set_title(f"Factor: {nme}", pad=25, size=16)
    #ax.text(0.06, 0.06, f"{date}", transform=figP.transFigure, size=13)
    fig.tight_layout()
    fig.savefig("Hex_Med_Shap.png")

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.


In [36]:
fig = plt.Figure(figsize=(12,8))
gs = fig.add_gridspec(2,3)
ax0 = fig.add_subplot(gs[0,0], projection="polar")
ax1 = fig.add_subplot(gs[0,1], projection="polar")
ax2 = fig.add_subplot(gs[0,2], projection="polar")
ax3 = fig.add_subplot(gs[1,0], projection="polar")
ax4 = fig.add_subplot(gs[1,1], projection="polar")
ax5 = fig.add_subplot(gs[1,2], projection="polar")

for nme,axis in zip(["43000_09", "85550_13","1300_02","cos_SatMagLT", "sin_SatMagLT","SYM_H index"], [ax0, ax1, ax2, ax3, ax4, ax5]):
    i = shpa._fgeo_col_dict[nme]
    shp = low_shap_vals.values[:,i]
    min_shap = shp.min()
    max_shap = shp.max()
    #min_shap = min(0, shp.min())
    #max_shap = max(0, shp.max())
    img = pb.plot_shap(axis, low_evt, shp, nme, low_sat_lat_msh, low_north_qq, [min_shap, max_shap])
    fig.colorbar(img, ax=axis, fraction=0.046, pad=0.06)
    axis.set_title(f"Factor: {nme}", pad=25, size=16)
    #ax.text(0.06, 0.06, f"{date}", transform=figP.transFigure, size=13)
    fig.tight_layout()
    fig.savefig("Hex_Low_Shap.png")

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
