In [335]:
from xai import *
import numpy as np
np.set_printoptions(precision=8, suppress=True)
import matplotlib.pyplot as plt
import random
import numpy as np
import torch
import shap
import pandas as pd
import plotly.express as px
from dateutil import parser as dateutil
from tqdm import tqdm
from sklearn.datasets import make_regression

In [462]:
n_features = 15
latent_dim = 10
n_targets = 3

X, Y = make_regression(
    n_samples=10_000, n_features=n_features, n_informative=7, n_targets=n_targets
)
X = X.astype(np.float32).reshape((-1,n_features))
X = (X - X.min())/(X.max() - X.min())

Y = Y.astype(np.float32).reshape((-1,n_targets))

N_train = int(len(X)*0.8)

X_train, Y_train = X[:N_train], Y[:N_train]
X_val, Y_val = X[N_train:], Y[N_train:]

In [463]:
X.shape, Y.shape

((10000, 15), (10000, 3))

In [464]:
autoencoder = AutoEncoder(data_shape=(n_features,), latent_shape=(latent_dim,), hidden_layers=3)
predictor = Network.dense(input_dim=autoencoder.latent_shape, output_dim=(n_targets,), hidden_layers=3)
autoencoder, predictor

(Sequential(
   (0): Reshape((15,))
   (1): Linear(in_features=15, out_features=13, bias=True)
   (2): ReLU()
   (3): Linear(in_features=13, out_features=12, bias=True)
   (4): ReLU()
   (5): Linear(in_features=12, out_features=11, bias=True)
   (6): ReLU()
   (7): Linear(in_features=11, out_features=10, bias=True)
   (8): Reshape((10,))
   (9): Reshape((10,))
   (10): Linear(in_features=10, out_features=11, bias=True)
   (11): ReLU()
   (12): Linear(in_features=11, out_features=12, bias=True)
   (13): ReLU()
   (14): Linear(in_features=12, out_features=13, bias=True)
   (15): ReLU()
   (16): Linear(in_features=13, out_features=15, bias=True)
   (17): Reshape((15,))
 ),
 Sequential(
   (0): Reshape((10,))
   (1): Linear(in_features=10, out_features=8, bias=True)
   (2): ReLU()
   (3): Linear(in_features=8, out_features=6, bias=True)
   (4): ReLU()
   (5): Linear(in_features=6, out_features=4, bias=True)
   (6): ReLU()
   (7): Linear(in_features=4, out_features=3, bias=True)
   (8): Res

In [465]:
autoencoder.adam().fit(
    X_train=X_train,
    Y_train=X_train,
    epochs=100_000,
    batch_size=256,
    X_val=X_val,
    Y_val=X_val,
    early_stop_count=1500,
    loss_criterion="MSELoss",
    verbose=True
)[70:].plot_loss("California housing autoencoder training")

Early stopping! Train-loss: 0.009688, Val-loss: 0.009387:   7%|▋         | 6936/100000 [01:25<19:01, 81.54it/s]


In [466]:
predictor.adam().fit(
    X_train=autoencoder.encoder(X_train).output(),
    Y_train=Y_train,
    epochs=100_000,
    batch_size=256,
    X_val=autoencoder.encoder(X_val).output(),
    Y_val=Y_val,
    early_stop_count=3500,
    loss_criterion="MSELoss",
    verbose=True
).plot_loss("California housing predictor training")

Early stopping! Train-loss: 12809.958984, Val-loss: 14715.342773:   6%|▌         | 6231/100000 [01:10<17:41, 88.37it/s] 


In [524]:
sample = X_val[np.random.randint(0, len(X_val))]
output = (autoencoder.encoder + predictor)(sample).output()

normal_explanation = (autoencoder.encoder + predictor)(sample).explain("exact", X_val)
base_values = normal_explanation.base_values
normal_explanation = normal_explanation.shap_values

recon_explanation = autoencoder.decoder(autoencoder.encoder(sample)).explain("exact", autoencoder.encoder(X_val).output()).shap_values
prediction_explanation = predictor(autoencoder.encoder(sample)).explain("exact", autoencoder.encoder(X_val).output()).shap_values

normal_explanation.shape, recon_explanation.shape, prediction_explanation.shape

combo_explanation = prediction_explanatio

((3, 15), (15, 10), (3, 10))

In [542]:
normal_explanation[0].sum() - ((autoencoder.encoder + predictor)(sample).output()[0] - base_values[0][0])

tensor(0., device='cuda:0', grad_fn=<RsubBackward1>)

In [550]:
output[0]

tensor(6.4664, device='cuda:0', grad_fn=<SelectBackward0>)

In [546]:
normal_explanation[0]

array([ 1.12637332, -1.42533462,  3.64555311,  6.88121965,  2.20741014,
        0.21469427, -2.61651772, 14.81911403,  2.50657745, -0.82808478,
        5.51275918,  5.10052773,  1.43461775,  0.31553278,  3.28429309])

In [549]:
prediction_explanation[0]

array([-2.49465949,  4.21066291,  0.50123789,  8.23481736,  2.30698916,
       25.48089017, -0.32623369,  4.33500657, -9.66361584,  9.59363467])