In [658]:
import shap
import numpy as np
import torch
from tqdm import tqdm
np.set_printoptions(formatter={'float': '{: 0.4f}'.format})

In [659]:
i = 0
c = np.array([
    [3,8],
    [0,2],
])

def f(x: np.ndarray):
    x = x.reshape((-1,2))
    x0,x1 = x[:,0], x[:,1]

    y0 = c[0,0]*x0 + c[0,1]*x1
    y1 =             c[1,1]*x1

    y = np.column_stack((y0,y1))
    return y

def f_inv(y: np.ndarray):
    y = y.reshape((-1,2))
    y0,y1 = y[:,0], y[:,1]

    #y0 = c[0,0]*x0 + c[0,1]*x1 <=> x0 = (y0 - c[0,1]*x1)/c[0,0]
    #y1 =             c[1,1]*x1 <=> x1 = (y1)/c[1,1]

    x1 = (y1)/c[1,1]
    x0 = (y0 - c[0,1]*x1)/c[0,0]
    

    x = np.column_stack((x0,x1))
    return x

background = np.random.random((1000,2))
explainer = shap.ExactExplainer(f, shap.maskers.Independent(background, len(background)))
explainer

<shap.explainers._exact.ExactExplainer at 0x7f4eab85ffd0>

In [660]:
background[i], f_inv(f(background[i]))

(array([ 0.0522,  0.3675]), array([[ 0.0522,  0.3675]]))

In [661]:
import shap.maskers


inv_background = f(background)
inv_explainer = shap.ExactExplainer(f_inv, shap.maskers.Independent(inv_background, len(inv_background)))
inv_explainer

<shap.explainers._exact.ExactExplainer at 0x7f4eab8abad0>

\begin{align}
    \phi^{-1}_{c,f} = g^{-1}(y_f) - E|g^{-1}(y_{i,f})| &\Leftrightarrow \phi^{-1}_{c,f} - g^{-1}(y_f) + E|g^{-1}(y_{i,f})| = 0 \\
    \phi_{c,f} = g(x_f) - E|g(x_{i,f})| &\Leftrightarrow \phi_{c,f} - g(x_f) + E|g(x_{i,f})| = 0
\end{align}

\begin{align}
    \phi^{-1}_{c,f} - g^{-1}(y_f) + E|g^{-1}(y_{i,f})| &= \phi_{c,f} - g(x_f) + E|g(x_{i,f})|\\
    &\Leftrightarrow\\
    \phi_{c,f} &= \phi^{-1}_{c,f} - g^{-1}(y_f) + E|g^{-1}(y_{i,f})| + g(x_f) - E|g(x_{i,f})|
\end{align}

In [662]:
exp: shap.Explanation = explainer(background[i:i+1])[0]
exp_values = exp.values
exp_base_values = exp.base_values
exp

.values =
array([[-1.2912,  0.0000],
       [-1.1026, -0.2757]])

.base_values =
array([ 5.4906,  1.0107])

.data =
array([ 0.0522,  0.3675])

In [663]:
inv_exp: shap.Explanation = inv_explainer(inv_background[i:i+1])[0]
inv_exp_values = inv_exp.values
inv_exp_base_values = inv_exp.base_values
inv_exp

.values =
array([[-0.7979,  0.0000],
       [ 0.3675, -0.1378]])

.base_values =
array([ 0.4826,  0.5053])

.data =
array([ 3.0967,  0.7350])

In [698]:
exp.values

array([[-1.2912,  0.0000],
       [-1.1026, -0.2757]])

In [700]:
n = (inv_exp_values - f_inv(inv_background[i:i+1]) + inv_exp_base_values + f(background[i:i+1]) - exp_base_values)
n

array([[-2.7614, -0.1378],
       [-1.5959, -0.2757]])

In [703]:
n*0.7

array([[-1.9330, -0.0965],
       [-1.1171, -0.1930]])

In [665]:
exp_values/inv_exp_values

invalid value encountered in divide


array([[ 1.6181,  nan],
       [-3.0000,  2.0000]])

In [666]:
exp_values / np.max(np.abs(exp_values))

array([[-1.0000,  0.0000],
       [-0.8540, -0.2135]])

In [667]:
inv_exp_values / np.max(np.abs(inv_exp_values))

array([[-1.0000,  0.0000],
       [ 0.4606, -0.1727]])

In [668]:
exp_values.sum(0), f(background[i:i+1]) - exp_base_values

(array([-2.3938, -0.2757]), array([[-2.3938, -0.2757]]))

In [669]:
exp_values + exp_base_values

array([[ 4.1994,  1.0107],
       [ 4.3879,  0.7350]])

In [670]:
inv_exp_values + inv_exp_base_values

array([[-0.3153,  0.5053],
       [ 0.8502,  0.3675]])

In [671]:
(exp_values)/np.sum(np.abs(exp_values), axis=0)

array([[-0.5394,  0.0000],
       [-0.4606, -1.0000]])

In [672]:
(inv_exp_values)/np.sum(np.abs(inv_exp_values), axis=0)

array([[-0.6846,  0.0000],
       [ 0.3154, -1.0000]])

In [673]:
inv_exp_values

array([[-0.7979,  0.0000],
       [ 0.3675, -0.1378]])

In [674]:
exp_weights = abs(exp_values)/np.sum(abs(exp_values), axis=0)
exp_weights

array([[ 0.5394,  0.0000],
       [ 0.4606,  1.0000]])

In [675]:
inv_exp_weights = abs(inv_exp_values)/np.sum(abs(inv_exp_values), axis=0)
inv_exp_weights

array([[ 0.6846,  0.0000],
       [ 0.3154,  1.0000]])

In [676]:
exp

.values =
array([[-1.2912,  0.0000],
       [-1.1026, -0.2757]])

.base_values =
array([ 5.4906,  1.0107])

.data =
array([ 0.0522,  0.3675])

In [677]:
efficiency_range = f(background[i:i+1])[0] - np.mean(f(background), axis=0)

exp_weights*efficiency_range.reshape((-1,1))

array([[-1.2912, -0.0000],
       [-0.1270, -0.2757]])