In [1]:
import shap
import numpy as np
import torch
from tqdm import tqdm
np.set_printoptions(formatter={'float': '{: 0.4f}'.format})

In [2]:
encoder = torch.nn.Sequential(
    torch.nn.Linear(3,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,1),
)

decoder = torch.nn.Sequential(
    torch.nn.Linear(1,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,5),
    torch.nn.ReLU(),
    torch.nn.Linear(5,3)
)

optim = torch.optim.Adam((encoder + decoder).parameters())

In [3]:
X = torch.zeros((10_000,3), dtype=torch.float32, device="cpu")
X[:,0] = torch.rand((10_000,))
X[:,1] = X[:,0]*2
X[:,2] = X[:,1]*4

In [4]:
bar = tqdm(range(10000))
for epoch in bar:
    optim.zero_grad()
    x = X[torch.randperm(len(X), device="cpu")[:1000]]
    x_hat: torch.Tensor = (encoder + decoder)(x)
    loss: torch.Tensor = torch.mean((x_hat - x)**2)
    bar.set_description(f"Loss: {loss:.5f}")
    loss.backward()
    optim.step()

Loss: 8.44404:   0%|          | 0/10000 [00:00<?, ?it/s]CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
Loss: 0.00000: 100%|██████████| 10000/10000 [00:28<00:00, 348.92it/s]


In [5]:
idx = torch.randint(0,len(X),(1,))
X[idx], encoder(X[idx]), decoder(encoder(X[idx]))

(tensor([[0.5643, 1.1285, 4.5141]]),
 tensor([[-7.3975]], grad_fn=<AddmmBackward0>),
 tensor([[0.5641, 1.1282, 4.5137]], grad_fn=<AddmmBackward0>))

In [6]:
import shap.maskers


def f(x):
    with torch.no_grad():
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
            return encoder(x).numpy(force=True)
        else:
            return encoder(x).numpy(force=True)
        
explainer = shap.ExactExplainer(f, shap.maskers.Independent(X.numpy(force=True), 1000))

def f_inv(x):
    with torch.no_grad():
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
            return decoder(x).numpy(force=True)
        else:
            return decoder(x).numpy(force=True)
        
explainer_inv = shap.ExactExplainer(f_inv, shap.maskers.Independent(encoder(X).numpy(force=True), 1000))

In [7]:
exp: shap.Explanation = explainer(X[idx:idx+1].numpy())[0]
exp_values = exp.values
exp_base_values = exp.base_values
exp

.values =
array([-0.0691, -0.0207, -0.8062])

.base_values =
array([-6.5015])

.data =
array([ 0.5643,  1.1285,  4.5141], dtype=float32)

In [8]:
inv_exp: shap.Explanation = explainer_inv(encoder(X[idx:idx+1]).numpy(force=True))[0][0]
inv_exp_values = inv_exp.values
inv_exp_base_values = inv_exp.base_values
inv_exp

.values =
array([ 0.0644,  0.1287,  0.5150])

.base_values =
array([ 0.4997,  0.9994,  3.9987])

.data =
-7.397464

In [9]:
exp_values + exp_base_values

array([-6.5706, -6.5222, -7.3077])

In [10]:
inv_exp_values + inv_exp_base_values

array([ 0.5641,  1.1282,  4.5137])

In [13]:
abs(exp_values + exp_base_values)/sum(abs(exp_values + exp_base_values))

array([ 0.3221,  0.3197,  0.3582])

In [14]:
abs(inv_exp_values + inv_exp_base_values)/sum(abs(inv_exp_values + inv_exp_base_values))

array([ 0.0909,  0.1818,  0.7273])