In [2]:
import shap
import numpy as np
import torch
from tqdm import tqdm
np.set_printoptions(formatter={'float': '{: 0.4f}'.format})

In [101]:
i = 0
c = np.array([
    [2,0,0],
    [0,80,0],
    [0,0,70]
])

def f(x: np.ndarray):
    x = x.reshape((-1,3))
    x0,x1,x2 = x[:,0], x[:,1], x[:,2]

    y0 = c[0,0]*x0 + c[0,1]*x1 + c[0,2]*x2
    y1 =             c[1,1]*x1 + c[1,2]*x2
    y2 =                         c[2,2]*x2
    return np.column_stack((y0,y1,y2))

def f_inv(y: np.ndarray):
    y = y.reshape((-1,3))
    y0,y1,y2 = y[:,0], y[:,1], y[:,2]

    #y0 = c[0,0]*x0 + c[0,1]*x1 + c[0,2]*x2 <=> x0 = (y0 - c[0,1]*x1 - c[0,2]*x2)/c[0,0]
    #y1 =             c[1,1]*x1 + c[1,2]*x2 <=> x1 = (y1 - c[1,2]*x2)/c[1,1]
    #y2 =                         c[2,2]*x2 <=> x2 = (y2)/c[2,2]

    x2 = (y2)/c[2,2]
    x1 = (y1 - c[1,2]*x2)/c[1,1]
    x0 = (y0 - c[0,1]*x1 - c[0,2]*x2)/c[0,0]

    return np.column_stack((x0,x1,x2))

background = np.random.random((1000,3))
explainer = shap.ExactExplainer(f, shap.maskers.Independent(background, len(background)))
explainer

<shap.explainers._exact.ExactExplainer at 0x7e2d6a3833d0>

In [102]:
background[i], f_inv(f(background[i]))

(array([ 0.2677,  0.2396,  0.5502]), array([[ 0.2677,  0.2396,  0.5502]]))

In [103]:
import shap.maskers


inv_background = np.random.random((1000,3))
inv_explainer = shap.ExactExplainer(f_inv, shap.maskers.Independent(inv_background, len(inv_background)))
inv_explainer

<shap.explainers._exact.ExactExplainer at 0x7e2d6a3827a0>

In [104]:
exp: shap.Explanation = explainer(background[i:i+1])[0]
exp_values = exp.values
exp_base_values = exp.base_values
exp

.values =
array([[-0.4564,  0.0000, -0.0000],
       [ 0.0000, -20.5554,  0.0000],
       [ 0.0000,  0.0000,  3.4384]])

.base_values =
array([ 0.9918,  39.7222,  35.0742])

.data =
array([ 0.2677,  0.2396,  0.5502])

In [105]:
inv_exp: shap.Explanation = inv_explainer(inv_background[i:i+1])[0]
inv_exp_values = inv_exp.values
inv_exp_base_values = inv_exp.base_values
inv_exp

.values =
array([[ 0.1495, -0.0000,  0.0000],
       [-0.0000,  0.0058,  0.0000],
       [ 0.0000,  0.0000,  0.0012]])

.base_values =
array([ 0.2452,  0.0064,  0.0071])

.data =
array([ 0.7894,  0.9746,  0.5758])

In [111]:
inv_exp_values

array([[ 0.1495, -0.0000,  0.0000],
       [-0.0000,  0.0058,  0.0000],
       [ 0.0000,  0.0000,  0.0012]])

In [112]:
exp_weights = abs(exp_values)/np.sum(abs(exp_values + exp_base_values), axis=1)
exp_weights

array([[ 0.0061,  0.0000,  0.0000],
       [ 0.0000,  0.3722,  0.0000],
       [ 0.0000,  0.0000,  0.0434]])

In [113]:
inv_exp_weights = abs(inv_exp_values)/np.sum(abs(inv_exp_values + inv_exp_base_values), axis=1)
inv_exp_weights

array([[ 0.3662,  0.0000,  0.0000],
       [ 0.0000,  0.0218,  0.0000],
       [ 0.0000,  0.0000,  0.0045]])

In [114]:
exp

.values =
array([[-0.4564,  0.0000, -0.0000],
       [ 0.0000, -20.5554,  0.0000],
       [ 0.0000,  0.0000,  3.4384]])

.base_values =
array([ 0.9918,  39.7222,  35.0742])

.data =
array([ 0.2677,  0.2396,  0.5502])

In [115]:
efficiency_range = f(background[i:i+1])[0] - np.mean(f(background), axis=0)

exp_weights*efficiency_range.reshape((-1,1))

array([[-0.0028, -0.0000, -0.0000],
       [-0.0000, -7.6499, -0.0000],
       [ 0.0000,  0.0000,  0.1492]])