In [None]:
import tensorflow as tf
config_tf = tf.ConfigProto()
config_tf.gpu_options.allow_growth = True
config_tf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.Session(config=config_tf)

import json
# from keras.models import model_from_json

import sys
from pathlib import Path

import matplotlib.pyplot as plt
# import mpl_scatter_density
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import permutation_importance
from keras.utils import multi_gpu_model
from keras.wrappers.scikit_learn import  KerasRegressor
import keras.backend as K
from importlib import reload
from scipy.stats import median_abs_deviation
import seaborn as sns
import shap

In [None]:
params = {
    "legend.fontsize": "x-large",
    "axes.labelsize": "xx-large",
    "axes.titlesize": "x-large",
    "xtick.labelsize": "x-large",
    "ytick.labelsize": "x-large",
    "figure.facecolor": "w",
    "xtick.top": True,
    "ytick.right": True,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "font.family": "serif",
    "mathtext.fontset": "dejavuserif",
}
plt.rcParams.update(params)

In [None]:
# Path where your software library is saved
# Clone the latest version of morphCaps branch from github
path_photoz = '/home/bid13/code/photozCapsNet'

sys.path.insert(1, path_photoz)
path_photoz = Path(path_photoz)

In [None]:
from encapzulate.data_loader.data_loader import load_data
from encapzulate.utils.fileio import load_model, load_config
from encapzulate.utils import metrics
from encapzulate.utils.utils import import_model
from encapzulate.utils.metrics import Metrics, probs_to_redshifts, bins_to_redshifts

# from encapzulate.utils.plots import better_step
reload(metrics)

In [None]:
# Parameters for the exploration
run_name = "paper1_regression_80perc_0" #"morphCapsDeep_multi_15" # 
checkpoint_eval = 100

In [None]:
#Create and set different paths
# path_output = "/data/bid13/photoZ/results"
path_output = "/home/bid13/code/photozCapsNet/results"
path_output = Path(path_output)
path_results = path_output / run_name.split("_")[0] / run_name / "results" 
path_config =  path_results / "config.yml"

In [None]:
config  = load_config(path_config)
scale= config['image_scale']

In [None]:
# this wont be needed in future
config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)
# model = multi_gpu_model(eval_model,gpus=2)

In [None]:
data = np.load("./z_phot/all_paper1_regression_80perc_0_100.npz", allow_pickle=True)

In [None]:
cat_test = data["cat_test"]
y_caps_all = data["y_caps_all_test"]
y_prob = data["y_prob_test"]
morpho = np.argmax(y_prob, axis =-1)
caps_test = y_caps_all[range(len(y_caps_all)),morpho,:]
z_spec_test = data["z_spec_test"]
z_phot_test = data["z_phot_test"]

In [None]:
y_caps_all = data["y_caps_all_train"]
y_prob = data["y_prob_train"]
morpho = np.argmax(y_prob, axis =-1)
caps_train = y_caps_all[range(len(y_caps_all)),morpho,:]

# SHAP explainer

### DEEP SHAP

In [None]:
# select a set of background examples to take an expectation over
background = caps_train[np.random.choice(caps_train.shape[0], 1000, replace=False)]

# explain predictions of the model on four images
e = shap.DeepExplainer(redshift_model, background)
# ...or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)
shap_values = e.shap_values(caps_test)



In [None]:
names = ["Dim " + s for s in np.arange(1, config["dim_capsule"] + 1).astype(str)]
explainer = shap.Explanation(shap_values[0], data=caps_test, feature_names=names)

In [None]:
cmap = sns.color_palette("flare", as_cmap=True)

In [None]:
shap.plots.beeswarm(explainer, max_display=16, color=cmap)

### Gradient Explainer

In [None]:
e = shap.GradientExplainer(
    redshift_model,
    caps_test,
    batch_size=4096,
    local_smoothing=0,  # std dev of smoothing noise
)

In [None]:
shap_values = e.shap_values(caps_test)

In [None]:
names = ["Dim " + s for s in np.arange(1, config["dim_capsule"] + 1).astype(str)]
explainer = shap.Explanation(shap_values[0], data=caps_test, feature_names=names)
cmap = sns.color_palette("flare", as_cmap=True)


In [None]:
# fig,ax=plt.subplots(1,1,figsize=(10,20))
shap.plots.beeswarm(explainer, max_display=16, color=cmap, color_bar_label="Dimension Value",show=False)
plt.savefig("./figs/shap_feature_importance.pdf",bbox_inches="tight")

In [None]:
mean = np.mean(np.abs(explainer.values), axis=0)
std = np.std(np.abs(explainer.values), axis=0)/np.sqrt(len(explainer.values))
order = np.argsort(mean)[::-1]

In [None]:
print(mean[order])
print(std[order])

In [None]:
plt.errorbar(mean[order],np.arange(16),xerr=std[order],fmt="o")

In [None]:
# clustering = shap.utils.hclust(caps_test, z_spec_test)

In [None]:
# shap.plots.bar(explainer, clustering=clustering)

In [None]:
# names = ["Dim " + s for s in np.arange(1, config["dim_capsule"] + 1).astype(str)]
# explainer = shap.Explanation(shap_values[0], data=caps_test, feature_names=names)
# cmap = sns.color_palette("flare", as_cmap=True)
# shap.plots.beeswarm(
#     explainer, max_display=16, color=cmap, clustering=clustering, cluster_threshold=0.5
# )

In [None]:
def logistic_trans(x,xmin=0,xmax=0.4):
    return np.log((x-xmin)/(xmax-x))

def logistic_trans_inv(x,xmin=0,xmax=0.4):
    return (np.exp(x)*xmax + xmin)/(np.exp(x)+1)

In [None]:
def base_model():
    model = redshift_model
    model= multi_gpu_model(redshift_model)
    model.compile(loss='mse', optimizer = 'adam')
    return model

In [None]:
sklearn_model = KerasRegressor(build_fn=base_model, batch_size=2048, verbose=0)
sklearn_model.model = base_model()
zz = logistic_trans_inv(sklearn_model.predict(caps_test))

In [None]:
def feature_importance_scorer(estimator, X, z_spec):
    z_phot = estimator.predict(X)
    z_phot = logistic_trans_inv(z_phot, 0, 0.4)
    error = (z_phot - z_spec) / (1 + z_spec)
    sigma_nmad = 1.4826 * np.median(np.abs(error - np.median(error)))
    return -1 * sigma_nmad

In [None]:
base_level = np.abs(feature_importance_scorer(sklearn_model, caps_test, z_spec_test))

In [None]:
result = permutation_importance(sklearn_model, caps_test, z_spec_test, n_repeats=100, n_jobs=1, scoring = feature_importance_scorer)

# Plot the Permutation feature Importances
$\sigma_{NMAD}$ is used as the error metric for the feature importance

In [None]:
from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
from matplotlib.ticker import AutoLocator, NullFormatter, NullLocator, ScalarFormatter


class PowerLawScale(mscale.ScaleBase):
    """ Custom class defining a Power law scaler for the axes
    """

    name = "power_law"

    def __init__(self, axis, *, gamma, **kwargs):
        super().__init__(axis)
        self.gamma = gamma

    def set_default_locators_and_formatters(self, axis):
        """
        Default
        """
        axis.set_major_locator(AutoLocator())
        axis.set_major_formatter(ScalarFormatter())
        axis.set_minor_locator(NullLocator())
        axis.set_minor_formatter(NullFormatter())

    def limit_range_for_scale(self, vmin, vmax, minpos):
    
        return vmin, vmax

    def get_transform(self):
        """Set the actual transform for the axis coordinates.
 
        """
        return self.PowerLawTransform(self.gamma)

    class PowerLawTransform(mtransforms.Transform):
        input_dims = output_dims = 1
        def __init__(self, gamma):
            mtransforms.Transform.__init__(self)
            self.gamma = gamma

        def transform_non_affine(self, a):
            return np.sign(a)*np.power(np.abs(a), self.gamma)
#             return np.power(a, self.gamma)

        def inverted(self):
            return PowerLawScale.InvertedPowerLawTransform( self.gamma)

    class InvertedPowerLawTransform(mtransforms.Transform):
        input_dims = output_dims = 1
        def __init__(self, gamma):
            mtransforms.Transform.__init__(self)
            self.gamma = gamma

        def transform_non_affine(self, a):
            return np.sign(a)*np.power(np.abs(a), 1/self.gamma)

        def inverted(self):
            return PowerLawScale.PowerLawTransform(self.gamma)



mscale.register_scale(PowerLawScale)

In [None]:
importances = result.importances / base_level
median = np.median(importances, axis=-1)
sorted_idx = median.argsort()
spread = np.percentile(importances, [16, 84], axis=-1)
spread[0, :] = median - spread[0, :]
spread[1, :] = spread[1, :] - median
names = np.arange(1, config["dim_capsule"] + 1).astype(str)

fig, ax = plt.subplots(figsize=(20, 10))


ax.errorbar(
    median[sorted_idx],
    names[sorted_idx],
    xerr=spread[:, sorted_idx],
    fmt="o",
    markersize=15,
    elinewidth=2,
    capsize=10,
    capthick=2,
    ls="",
)

ax.tick_params(axis="both", which="major", labelsize=25)
ax.tick_params(axis="both", which="minor", labelsize=25)
ax.set_ylabel("Capsule Dimension", fontsize=40)
ax.set_xlabel("Permutation Feature Importance", fontsize=40)
ax.grid(ls="--")
ax.set_xscale("power_law", gamma=0.3)
xticklabels = [0,0.01,0.1, 0.2,0.5,1,2,3,4]
ax.set_xticks(xticklabels)
ax.set_xticklabels([str(i) for i in xticklabels])
fig.savefig("./figs/permutation_feature_importance.pdf", dpi=300, bbox_inches="tight")

In [None]:
rand = np.random.default_rng()
perm = rand.permutation(len(caps_test))
permuted = caps_test.copy()
permuted[:,7]=permuted[perm,7]

In [None]:
pred = sklearn_model.predict(permuted)

In [None]:
from encapzulate.utils.metrics import Metrics

In [None]:
met = Metrics(pred,z_spec_test,-4,1,0.05)

In [None]:
met.full_diagnostic()