In [16]:
import shap
import numpy as np
import torch
from tqdm import tqdm
np.set_printoptions(formatter={'float': '{: 0.4f}'.format})

In [17]:
features = 3
classes = 1

encoder = torch.nn.Sequential(
    torch.nn.Linear(features,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,classes),
)

decoder = torch.nn.Sequential(
    torch.nn.Linear(classes,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,features)
)

optim = torch.optim.Adam((encoder + decoder).parameters())

In [18]:
X = torch.zeros((10_000,3), dtype=torch.float32, device="cpu")
X[:,0] = torch.rand((10_000,))
X[:,1] = X[:,0]*2
X[:,2] = X[:,1]*2

In [19]:
bar = tqdm(range(10_000))
for epoch in bar:
    optim.zero_grad()
    x = X[torch.randperm(len(X), device="cpu")[:1000]]
    x_hat: torch.Tensor = (encoder + decoder)(x)
    loss: torch.Tensor = torch.mean((x_hat - x)**2)
    bar.set_description(f"Loss: {loss:.5f}")
    loss.backward()
    optim.step()

Loss: 0.00000: 100%|██████████| 10000/10000 [01:10<00:00, 141.47it/s]


In [20]:
import shap.maskers


def f(x):
    with torch.no_grad():
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()

        return encoder(x).numpy(force=True)

        
explainer = shap.ExactExplainer(f, shap.maskers.Independent(X.numpy(force=True), 1000))

def f_inv(x):
    with torch.no_grad():
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
            
        return decoder(x).numpy(force=True)
        
explainer_inv = shap.ExactExplainer(f_inv, shap.maskers.Independent(encoder(X).numpy(force=True), 1000))

In [21]:
idx = torch.randint(0,len(X),(1,))
X[idx], encoder(X[idx]), decoder(encoder(X[idx]))

(tensor([[0.3834, 0.7668, 1.5335]]),
 tensor([[3.7061]], grad_fn=<AddmmBackward0>),
 tensor([[0.3837, 0.7671, 1.5340]], grad_fn=<AddmmBackward0>))

In [65]:
exp: shap.Explanation = explainer(X[idx:idx+1].numpy(force=True))
exp_values = exp.values.reshape((features,classes))
exp_base_values = exp.base_values.reshape((classes))
exp

.values =
array([[-0.0674, -0.4815, -0.5104]])

.base_values =
array([[ 4.7655]])

.data =
array([[ 0.3834,  0.7668,  1.5335]], dtype=float32)

In [60]:
inv_exp: shap.Explanation = explainer_inv(encoder(X[idx:idx+1]).numpy(force=True))
inv_exp_values = inv_exp.values.reshape((features,classes))
inv_exp_base_values = inv_exp.base_values.reshape((features))
inv_exp

.values =
array([[[-0.1059, -0.2115, -0.4230]]])

.base_values =
array([[ 0.4896,  0.9786,  1.9570]])

.data =
array([[ 3.7061]], dtype=float32)

In [61]:
exp_values.sum(0)

array([-1.0594])

In [128]:
n = ((inv_exp_values.T - X[idx].numpy(force=True) + inv_exp_base_values) + f(X[idx:idx+1]) - exp_base_values).sum(0)

In [127]:
n

array([-1.0591, -1.0591, -1.0589])

In [129]:
encoder_importance = (exp.values/np.sum(abs(exp.values), axis=1))[0]
encoder_importance

array([-0.0636, -0.4545, -0.4818])

In [130]:
decoder_importance = (inv_exp.values/np.sum(abs(inv_exp.values), axis=2))[0]
decoder_importance

array([[-0.1430, -0.2857, -0.5713]])

In [31]:
encoder_importance[0,0], decoder_importance[0,0]

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed