In [5]:
import numpy as np
import pickle
import shap
import numpy as np
np.float = float

import numpy as np
import pandas as pd
from pymatgen.core import Structure
from pymatgen.util.plotting import pretty_plot
import pickle
from pymatgen.analysis.xas.spectrum import XAS
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit
from scipy import signal

# === Load experimental spectra (CDF features) ===
df_exp_nmc622 = pd.read_json(r"C:\Users\Bill\Downloads\FeatureXAS-main\FeatureXAS-main\dataset\nmc_exp_xas.json").iloc[0:8,:]
X_exp = np.array(df_exp_nmc622["cdf"].to_list())

# === Load separate models ===
ox_model = pickle.load(open(
    r"C:\Users\Bill\Downloads\FeatureXAS-main\FeatureXAS-main\example\oxidation_regressor.pkl", "rb"
))
bl_model = pickle.load(open(
    r"C:\Users\Bill\Downloads\FeatureXAS-main\FeatureXAS-main\example\bondlength_regressor.pkl", "rb"
))

# === Predictions ===
ox_preds = ox_model.predict(X_exp)
bl_preds = bl_model.predict(X_exp)

# === SHAP explainers ===
explainer_ox = shap.TreeExplainer(ox_model)
explainer_bl = shap.TreeExplainer(bl_model)

shap_values_ox = explainer_ox.shap_values(X_exp)
shap_values_bl = explainer_bl.shap_values(X_exp)

# === Energy bins ===
energy_min, energy_max = 8330, 8370
n_bins = X_exp.shape[1]
energy_bins = np.linspace(energy_min, energy_max, n_bins)

# === Helper: top contributors ===
def top_contributors(expected, pred, shap_vals, energy_bins, top_k=20):
    direction = np.sign(pred - expected)
    mask = shap_vals * direction > 0
    shap_filtered = shap_vals * mask
    idx_sorted = np.argsort(-np.abs(shap_filtered))[:top_k]
    return [(float(energy_bins[j]), shap_vals[j]) for j in idx_sorted if shap_filtered[j] != 0]

# === Helper: Random Forest uncertainty ===
def rf_uncertainty(rf_model, X):
    all_preds = np.stack([tree.predict(X) for tree in rf_model.estimators_], axis=0)
    return all_preds.mean(axis=0), all_preds.std(axis=0)

# === Loop over spectra ===
for i in range(len(X_exp)):
    ox_pred = ox_preds[i]
    bl_pred = bl_preds[i]

    # Uncertainty
    _, ox_std = rf_uncertainty(ox_model, X_exp[i].reshape(1,-1))
    _, bl_std = rf_uncertainty(bl_model, X_exp[i].reshape(1,-1))

    # SHAP top features
    ox_exp_val = explainer_ox.expected_value[0]
    top_ox = top_contributors(ox_exp_val, ox_pred, shap_values_ox[i], energy_bins)

    bl_exp_val = explainer_bl.expected_value[0]
    top_bl = top_contributors(bl_exp_val, bl_pred, shap_values_bl[i], energy_bins)

    # === Output ===
    print(f"\n=== Spectrum {i} ===")
    print(f"Predicted Oxidation State: {ox_pred:.3f} ± {ox_std[0]:.3f}")
    print("Top 20 energy bins driving oxidation prediction:")
    for e, s in top_ox:
        print(f"  {e:.2f} eV (SHAP: {s:.4f})")

    print(f"Predicted Bond Length: {bl_pred:.3f} ± {bl_std[0]:.3f} Å")
    print("Top 20 energy bins driving bond length prediction:")
    for e, s in top_bl:
        print(f"  {e:.2f} eV (SHAP: {s:.4f})")



=== Spectrum 0 ===
Predicted Oxidation State: 3.428 ± 0.316
Top 20 energy bins driving oxidation prediction:
  8360.30 eV (SHAP: -0.0293)
  8361.11 eV (SHAP: -0.0227)
  8359.90 eV (SHAP: -0.0224)
  8359.09 eV (SHAP: -0.0224)
  8361.52 eV (SHAP: -0.0217)
  8359.49 eV (SHAP: -0.0203)
  8360.71 eV (SHAP: -0.0200)
  8358.69 eV (SHAP: -0.0190)
  8361.92 eV (SHAP: -0.0155)
  8362.32 eV (SHAP: -0.0148)
  8358.28 eV (SHAP: -0.0118)
  8357.88 eV (SHAP: -0.0075)
  8332.83 eV (SHAP: -0.0071)
  8350.61 eV (SHAP: -0.0066)
  8350.20 eV (SHAP: -0.0055)
  8349.39 eV (SHAP: -0.0055)
  8349.80 eV (SHAP: -0.0054)
  8351.41 eV (SHAP: -0.0053)
  8333.23 eV (SHAP: -0.0051)
  8332.42 eV (SHAP: -0.0050)
Predicted Bond Length: 1.912 ± 0.028 Å
Top 20 energy bins driving bond length prediction:
  8359.90 eV (SHAP: 0.0028)
  8360.30 eV (SHAP: 0.0024)
  8361.11 eV (SHAP: 0.0021)
  8360.71 eV (SHAP: 0.0020)
  8361.92 eV (SHAP: 0.0017)
  8359.49 eV (SHAP: 0.0017)
  8359.09 eV (SHAP: 0.0016)
  8362.32 eV (SHAP: 0.00