In [31]:
import xai
import torch
import shap
import numpy as np
np.set_printoptions(precision=4, suppress=True)

In [4]:
X,Y = shap.datasets.california()
N = len(X)
X = X.to_numpy(np.float32)
Y = Y.reshape((-1,1)).astype(np.float32)
indices = torch.randperm(N)
train_portion = int(N*0.8)

X_train, Y_train = X[indices[:train_portion]], Y[indices[:train_portion]]
X_val, Y_val = X[indices[train_portion:]], Y[indices[train_portion:]]

In [40]:
n1 = xai.Network.dense(
    input_dim=(8,),
    output_dim=(5,)
)


n2 = xai.Network.dense(
    input_dim=(5,),
    output_dim=(1,)
)

n3 = n1 + n2

n3.adam().fit(
    X_train=X_train,
    Y_train=Y_train,
    X_val=X_val,
    Y_val=Y_val,
    batch_size=256,
    epochs=10_000,
    loss_criterion="MSELoss",
    early_stop_count=1000,
    verbose=True,
    info="Train california housing (both networks)"
).plot_loss()

Train-loss: 5.545177, Val-loss: 4.296268:   0%|          | 1/10000 [00:00<04:32, 36.72it/s]

Early stopping! Train-loss: 0.491502, Val-loss: 0.546340:  52%|█████▏    | 5183/10000 [00:58<00:54, 87.95it/s]


In [84]:
n1_explainer = n1.explainer("exact", X_val)
n2_explainer = n2.explainer("exact", n1(X_val).output())
n3_explainer = n3.explainer("exact", X_val)

In [85]:
x1 = X_val[570]
x2 = n1(x1).output()
x3 = x1

In [43]:
def norm(array):
    return array/np.max(np.abs(array))

In [86]:
x1_explanation = n1_explainer.explain(x1)[0]
x1_shap = x1_explanation.shap_values
x1_base = x1_explanation.base_values
x1_shap, x1_base

(array([[ 0.8049, -0.367 ,  0.3823,  0.4066,  0.4811],
        [ 0.259 , -0.2552,  0.3578, -0.1323,  0.2671],
        [ 0.1102,  0.3295, -0.0988,  0.8404,  0.0669],
        [ 0.0691, -0.0084,  0.0034,  0.0849,  0.0328],
        [ 2.7651,  6.7562,  4.5705,  3.7003, -7.8579],
        [-0.0856,  0.0226, -0.0439, -0.0867, -0.0649],
        [ 0.1401,  0.1928, -0.0229,  0.6023,  0.0921],
        [-0.0057, -0.0008, -0.0081, -0.0127, -0.0094]]),
 array([[ -8.314 ,  -3.4106, -14.5153,  -1.5314,   4.409 ]]))

In [87]:
x2_explanation = n2_explainer.explain(x2)[0]
x2_shap = x2_explanation.shap_values
x2_base = x2_explanation.base_values
x2_shap, x2_base

(array([[-0.4528],
        [ 0.0152],
        [-0.197 ],
        [-1.0512],
        [ 0.8866]]),
 array([[1.8947]]))

In [99]:
x3_explanation = n3_explainer.explain(x3)[0]
x3_shap = x3_explanation.shap_values
x3_base = x3_explanation.base_values
x3_shap, x3_base

(array([[-0.3732],
        [-0.061 ],
        [-0.1829],
        [-0.0287],
        [-0.0253],
        [ 0.0156],
        [-0.1482],
        [ 0.0044]]),
 array([[1.8947]]))

In [105]:
x3_shap.sum() + x3_base

array([[1.0955]])

In [104]:
combined_shap = np.zeros((8,1), dtype=np.float32)
# We want 10 shap values for every 784 pixel, i.e. of shape: (10, 784).

# sum(s1) = (y1 - E(x))
# sum(s2) = (y2 - E(y1i))
# sum(s3) = (y2 - E(y2i))

# y1 = sum(s1) + E(y1i)
# y2 = sum(s2) + E(y2i)
# y2 = sum(s3) + E(y2i)

for feature in range(8):
    for latent in range(5):
        x1_shap_norm = np.max(np.abs(x1_shap[feature]))
        shap1 = x1_shap[feature,latent]
        for cls in range(1):
            shap2 = x2_shap[latent,cls]
            x2_shap_norm = np.max(np.abs(x2_shap[latent]))
            combined_shap[feature,cls] += (shap1/x1_shap_norm)*(shap2/x2_shap_norm)
# Done!

combined_shap

array([[-1.8384],
       [-1.3209],
       [-0.5419],
       [-1.5663],
       [-1.5446],
       [ 2.0041],
       [-0.7215],
       [ 1.2848]], dtype=float32)