In [None]:
import xai
import torch
import shap
import numpy as np
np.set_printoptions(precision=4, suppress=True)

In [None]:
X,Y = shap.datasets.california()
N = len(X)
X = X.to_numpy(np.float32)
Y = Y.reshape((-1,1)).astype(np.float32)
indices = torch.randperm(N)
train_portion = int(N*0.8)

X_train, Y_train = X[indices[:train_portion]], Y[indices[:train_portion]]
X_val, Y_val = X[indices[train_portion:]], Y[indices[train_portion:]]

In [None]:
n1 = xai.Network.dense(
    input_dim=(8,),
    output_dim=(5,)
)


n2 = xai.Network.dense(
    input_dim=(5,),
    output_dim=(1,)
)

n3 = n1 + n2

n3.adam().fit(
    X_train=X_train,
    Y_train=Y_train,
    X_val=X_val,
    Y_val=Y_val,
    batch_size=256,
    epochs=10_000,
    loss_criterion="MSELoss",
    early_stop_count=1000,
    verbose=True,
    info="Train california housing (both networks)"
).plot_loss()

In [None]:
n1_explainer = n1.explainer("exact", X_val)
n2_explainer = n2.explainer("exact", n1(X_val).output())
n3_explainer = n3.explainer("exact", X_val)

In [None]:
x1 = X_val[570]
x2 = n1(x1).output()
x3 = x1

In [None]:
def norm(array):
    return array/np.max(np.abs(array))

In [None]:
x1_explanation = n1_explainer.explain(x1)[0]
x1_shap = x1_explanation.shap_values
x1_base = x1_explanation.base_values
x1_shap, x1_base

In [None]:
x2_explanation = n2_explainer.explain(x2)[0]
x2_shap = x2_explanation.shap_values
x2_base = x2_explanation.base_values
x2_shap, x2_base

In [None]:
x3_explanation = n3_explainer.explain(x3)[0]
x3_shap = x3_explanation.shap_values
x3_base = x3_explanation.base_values
x3_shap, x3_base

In [None]:
x3_shap.sum() + x3_base

In [None]:
combined_shap = np.zeros((8,1), dtype=np.float32)
# We want 10 shap values for every 784 pixel, i.e. of shape: (10, 784).

# sum(s1) = (y1 - E(x))
# sum(s2) = (y2 - E(y1i))
# sum(s3) = (y2 - E(y2i))

# y1 = sum(s1) + E(y1i)
# y2 = sum(s2) + E(y2i)
# y2 = sum(s3) + E(y2i)

for feature in range(8):
    for latent in range(5):
        x1_shap_norm = np.max(np.abs(x1_shap[feature]))
        shap1 = x1_shap[feature,latent]
        for cls in range(1):
            shap2 = x2_shap[latent,cls]
            x2_shap_norm = np.max(np.abs(x2_shap[latent]))
            combined_shap[feature,cls] += (shap1/x1_shap_norm)*(shap2/x2_shap_norm)
# Done!

combined_shap