# Tabular data

In [None]:
import copy
import json
import os
import random

import attr_functions
import numpy as np
import pandas as pd
import torch
import wandb
from captum import attr
from lime.lime_tabular import LimeTabularExplainer
from matplotlib import pyplot as plt
from numpy import random as np_rand
from scipy import stats
from sklearn.metrics import accuracy_score
from sklearn.metrics import mean_absolute_error as mae
from sklearn.preprocessing import MinMaxScaler
from torch import nn
from tqdm.auto import tqdm
from xailib.explainers.lore_explainer import LoreTabularExplainer
from xailib.models.pytorch_classifier_wrapper import pytorch_classifier_wrapper

In [None]:
SEED = 42
SUBSAMPLING = True

random.seed(SEED)
np.random.seed(SEED)

## Load data

In [None]:
train_path = "./data/Tabular/daixi_train.csv"
test_path = "./data/Tabular/daixi_test.csv"

In [None]:
X_train = pd.read_csv(train_path, sep=";", header=None).values
X_test = pd.read_csv(test_path, sep=";", header=None)

X_train = np.round(X_train, 4)
X_test = np.round(X_test.values, 4)

if SUBSAMPLING:
    idxs = random.sample(range(0, len(X_test)), int(len(X_test) * 0.01))
    X_test = X_test[idxs, :]

y_train = (
    attr_functions.ssin(X_train[:, 0], X_train[:, 1], X_train[:, 2]) > 0.6
).astype(np.float64)
y_test = (attr_functions.ssin(X_test[:, 0], X_test[:, 1], X_test[:, 2]) > 0.6).astype(
    np.float64
)

In [None]:
len(X_test)

### Balanceig

In [None]:
number, edges = np.histogram(y_train)
mida = min(number)

grups_x = []
grups_y = []

y_train = y_train.flatten()
for idx in range(len(edges) - 1):
    selection = (y_train >= edges[idx]) & (y_train <= edges[idx + 1])

    grup_y = y_train[selection]
    grup_x = X_train[selection]

    sub_select = np.random.choice(
        np.arange(len(grup_y)), min(mida, len(grup_y)), replace=False
    )

    grups_y.append(grup_y[sub_select])
    grups_x.append(grup_x[sub_select])

X_train = np.vstack(grups_x)
y_train = np.vstack(grups_y).flatten()

p = np.random.permutation(len(y_train))

X_train = X_train[p]
y_train = y_train[p]

### To tensor

In [None]:
X_train_df, X_test_df, y_train_df, y_test_df = X_train, X_test, y_train, y_test

X_train_df = pd.DataFrame(X_train_df)
X_train_df.columns = ["x1", "x2", "x3"]
X_train_df["target"] = y_train.flatten()

X_test_df = pd.DataFrame(X_test_df)
X_test_df.columns = ["x1", "x2", "x3"]
X_test_df["target"] = y_test.flatten()

In [None]:
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)

y_train = torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1)
y_test = torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1)

# Model

In [None]:
input_ftrs = X_train.shape[1]

net = nn.Sequential(
    nn.Linear(input_ftrs, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 128),
    nn.ReLU(inplace=True),
    nn.Linear(128, 1),
    # nn.Sigmoid()
)

device = torch.device("cuda:0")
# device = torch.device("cpu")


net = net.to(device)

In [None]:
net

## Train

In [None]:
def regression_metric(a, b):
    if isinstance(a, torch.Tensor):
        a = a.cpu().numpy()

    if isinstance(b, torch.Tensor):
        b = b.cpu().numpy()

    return mae(a, b)


def cls_metric(a, b):
    if isinstance(a, torch.Tensor):
        a = a.cpu().numpy()

    if isinstance(b, torch.Tensor):
        b = b.cpu().numpy()

    return accuracy_score(a, b)

In [None]:
EPOCHS = 5000
LR = 0.0001
GAMMA = 0.85
STEP_SIZE = 150

# criterion = nn.L1Loss()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LR)


pbar = tqdm(range(EPOCHS), desc="Time, he's waiting in the wings")

best_val = 0
best_model = None

wandb.init(
    project="daixi",
    config={
        "epochs": EPOCHS,
        "lr": LR,
        "step": STEP_SIZE,
        "gamma": GAMMA,
    },
)

from torch.optim.lr_scheduler import StepLR

# Create the scheduler after the optimizer
scheduler = StepLR(
    optimizer, step_size=STEP_SIZE, gamma=GAMMA
)  # adjust step_size and gamma as needed
y_train = y_train.to(device)

for epoch in pbar:
    net.train()
    output = net(X_train.to(device))
    loss = criterion(output, y_train)

    loss.backward()
    optimizer.step()

    net.eval()
    output = net(X_test.to(device)).cpu().detach().numpy()
    res_val = cls_metric((output > 0), y_test)

    wandb.log({"Train MAE": loss, "Val MAE": res_val})

    if res_val > best_val:
        best_val = res_val
        best_model = copy.deepcopy(net.state_dict())

    scheduler.step()  # update learning rate
    pbar.set_description(
        f"Epoch {epoch}/{EPOCHS} - Val. Loss {round(loss.item(), 2)} - Val. Perform.: {round(res_val, 2)}"
    )

net.load_state_dict(best_model)

In [None]:
net.load_state_dict(best_model)
net.eval()
print(best_val)

In [None]:
torch.save(net.state_dict(), "./output/ssin_cls_tabular.pt")

# XAI

In [None]:
net.load_state_dict(torch.load("./output/ssin_cls_tabular.pt", weights_only=True));

In [None]:
from sklearn.metrics import classification_report

pred = net(X_test.to(device)) > 0

print(classification_report(y_test.cpu().numpy(), pred.cpu().numpy()))

## Metrics

In [None]:
epsilon = 1e-5


def _to_probability(info):
    """Convert the input to a probability distribution.

    Args:
        info: NumPy array with the input to convert.

    Returns:
        NumPy array with the input converted to a probability distribution
    """
    if isinstance(info, torch.Tensor):
        info = info.cpu().detach().numpy()
    info = np.copy(info)
    info_shape = info.shape
    scaler = MinMaxScaler()

    info = info.reshape(-1, 1)
    info = scaler.fit_transform(info)
    info = info.reshape(info_shape)

    return info / (np.sum(info) + epsilon)


def kl(sal_map_gt, sal_map):
    """Compute the Kullback-Leibler divergence between two saliency maps.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float with the Kullback-Leibler divergence between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    # You may want to instead make copies to avoid changing the np arrays.
    sal_map_gt = sal_map_gt + epsilon
    sal_map = sal_map + epsilon

    divergence = np.sum(sal_map_gt * np.log(sal_map_gt / sal_map))

    return divergence


def emd(sal_map_gt, sal_map):
    """Compute the Earth Mover's Distance between two saliency maps.

    Earth Mover's Distance (EMD) is a measure of the distance between two probability distributions over a region.
    It is defined as the minimum cost of turning one distribution into the other, where the cost is the amount of
    "earth" moved, or the amount of probability mass that must be moved from one point to another.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float between 0 and 1 with the EMD between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    sal_map_gt /= sal_map_gt.max() if sal_map_gt.max() > 0 else 1
    sal_map /= sal_map.max() if sal_map.max() > 0 else 1

    diff = stats.wasserstein_distance(sal_map.flatten(), sal_map_gt.flatten())

    return diff


def _to_zero_one(info):
    return (info - info.min()) / (info.max() - info.min())


def AUC_Borji(sal_map_gt, sal_map, n_rep=100, step_size=0.1, rand_sampler=None):
    """
    This measures how well the saliency map of an image predicts the ground truth human fixations on the image.
    ROC curve created by sweeping through threshold values at fixed step size until the maximum saliency map value.
        - True positive (tp) rate correspond to the ratio of saliency map values above threshold
          at fixation locations to the total number of fixation locations.
        - False positive (fp) rate correspond to the ratio of saliency map values above threshold
          at random locations to the total number of random locations (as many random locations as fixations,
            sampled uniformly from fixation_map ALL IMAGE PIXELS), averaging over n_rep number of selections of random locations.
    Parameters
    ----------
    saliency_map : real-valued matrix
    fixation_map : binary matrix
        Human fixation map.
    n_rep : int, optional
        Number of repeats for random sampling of non-fixated locations.
    step_size : int, optional
        Step size for sweeping through saliency map.
    rand_sampler : callable
        S_rand = rand_sampler(S, F, n_rep, n_fix)
        Sample the saliency map at random locations to estimate false positive.
        Return the sampled saliency values, S_rand.shape=(n_fix,n_rep)
    Returns
    -------
    AUC : float, between [0,1]
    """
    sal_map_gt = _to_zero_one(sal_map_gt)
    sal_map = _to_zero_one(sal_map)

    saliency_map = np.asarray(sal_map)
    fixation_map = np.asarray(sal_map_gt) > 0.5
    # If there are no fixation to predict, return NaN
    if not np.any(fixation_map):
        print("no fixation to predict")
        return np.nan
    # # Normalize saliency map to have values between [0,1]
    # saliency_map = _to_probability(saliency_map)

    S = saliency_map.ravel()
    F = fixation_map.ravel()
    S_fix = S[F]  # Saliency map values at fixation locations
    n_fix = len(S_fix)
    n_pixels = len(S)
    # For each fixation, sample n_rep values from anywhere on the saliency map
    if rand_sampler is None:
        r = np_rand.randint(0, n_pixels, [n_fix, n_rep])
        S_rand = S[r]  # Saliency map values at random locations (including fixated locations!? underestimated)
    else:
        S_rand = rand_sampler(S, F, n_rep, n_fix)
    # Calculate AUC per random split (set of random locations)
    auc = np.zeros(n_rep) * np.nan
    for rep in range(n_rep):
        thresholds = np.r_[0 : np.max(np.r_[S_fix, S_rand[:, rep]]) : step_size][::-1]
        tp = np.zeros(len(thresholds) + 2)
        fp = np.zeros(len(thresholds) + 2)
        tp[0] = 0
        tp[-1] = 1
        fp[0] = 0
        fp[-1] = 1
        for k, thresh in enumerate(thresholds):
            tp[k + 1] = np.sum(S_fix >= thresh) / float(n_fix)
            fp[k + 1] = np.sum(S_rand[:, rep] >= thresh) / float(n_fix)
        auc[rep] = np.trapz(tp, fp)
    return np.mean(auc)  # Average across random splits


def sim(sal_map_gt, sal_map):
    """Compute the sim distance between two saliency maps.

    Args:
        sal_map_gt: NumPy array with the ground truth saliency map.
        sal_map: NumPy array with the saliency map to compare.

    Returns:
        Float with the min distance between the two saliency maps.
    """
    sal_map_gt = _to_probability(sal_map_gt)
    sal_map = _to_probability(sal_map)

    # sal_map_gt /= np.sum(sal_map_gt) + epsilon
    # sal_map /= np.sum(sal_map) + epsilon

    diff = np.min(np.stack([sal_map, sal_map_gt]), axis=0)
    diff = np.sum(diff)

    return diff


metrics = {
    # "emd": emd,
    "kl": kl,
    "auc": AUC_Borji,
    "sim": sim,
}

## Get GT

In [None]:
with torch.no_grad():
    # for x in tqdm(X_test):
    org_output = net(X_test.to(device)).cpu().detach()
    importance = torch.zeros_like(X_test)

    for i in range(3):
        x_prime = X_test.clone()
        x_prime[:, i] = 0
        aux_output = net(x_prime.to(device)).cpu().detach()

        importance[:, i] = (org_output - aux_output)[:, 0]

## XAI Methods

### LIME

In [None]:
def get_lime(explainer, data):
    org_shape = data.shape

    if len(data.shape) > 1:
        data = data.flatten()
    with torch.no_grad():
        if isinstance(data, torch.Tensor):
            data = data.detach().cpu().numpy()
        elif not isinstance(data, np.array):
            raise Exception(
                f"Input data must be either a Pytorch Tensor or a Numpy array, instead {type(data)}."
            )
        exp = explainer.explain_instance(
            data,
            lambda x: net(torch.Tensor(x).to(device)).cpu().numpy(),
            num_features=3,
            top_labels=1,
        )

        expl = sorted(list(exp.as_map().values())[0], key=lambda x: x[0])
        expl = torch.Tensor([float(e[1]) for e in expl])
        expl = expl.reshape(org_shape)

        return expl


explainer = LimeTabularExplainer(
    X_train.cpu().numpy(), discretize_continuous=True, mode="regression"
)

### Grad

In [None]:
def grad_fn(x, xai):
    res = xai.attribute(x.to(device), target=0)

    return res


sal = attr.Saliency(net)

### DeepLift

In [None]:
deep_lift = attr.DeepLift(net)

### IG

In [None]:
ig = attr.IntegratedGradients(net)

### LORE

In [None]:
bbox = pytorch_classifier_wrapper(net, device=device, n_features=3)
explainer_lore = LoreTabularExplainer(bbox)

In [None]:
config = {"neigh_type": "geneticp", "size": 1000, "ocr": 0.1, "ngen": 10}
explainer_lore.fit(X_train_df, "target", config)

In [None]:
def attribute_dt(estimator, instance):
    """Attribute the importance of each feature in the prediction of a decision tree.

    Args:
        estimator: (sklearn.model). Decision tree model.
        instance: (np.array) Instance to explain.

    Returns:
        np.array with the attribution of each feature.
    """
    children_left = estimator.tree_.children_left
    children_right = estimator.tree_.children_right
    feature = estimator.tree_.feature
    threshold = estimator.tree_.threshold
    impurity = estimator.tree_.impurity

    importance = {}

    node_id = 0
    while children_left[node_id] != children_right[node_id]:

        if feature[node_id] not in importance:
            importance[feature[node_id]] = 0

        if instance[feature[node_id]] <= threshold[node_id]:
            children_id = children_left[node_id]
        else:
            children_id = children_right[node_id]

        importance[feature[node_id]] += impurity[node_id] - impurity[children_id]
        node_id = children_id

    attribution = np.zeros_like(instance).astype(np.float64)

    adder = 0
    for feature, value in importance.items():
        adder += value
        attribution[feature] = value

    attribution /= adder

    return attribution

In [None]:
def get_lore(inst):
    if isinstance(inst, torch.Tensor):
        inst = inst.cpu().detach().numpy()

    inst = inst.flatten()
    exp = explainer_lore.explain(inst)

    return attribute_dt(exp.exp.dt, inst)

## SHAP

In [None]:
kernel_shap = attr.KernelShap(net)

# Experimentation

In [None]:
methods = {
    "lime": lambda x: get_lime(explainer, x),
    "grad": lambda x: grad_fn(x, sal),
    "deep_lift": lambda x: grad_fn(x, deep_lift),
    "shap": lambda x: kernel_shap.attribute(x.to(device), target=0, n_samples=200),
    "lore": get_lore,
    "ig": lambda x: grad_fn(x, ig),
}

In [None]:
RESULTS_PATH = "./results_tabular_data.json"

#### Obtain the explanations for each method

In [None]:
for method_name, method in methods.items():
    results_method = {k: [] for k in metrics.keys()}

    raw_results = []
    for x, y, gt in zip(tqdm(X_test, desc=method_name), y_test, importance):
        explanation = method(x.reshape(1, 3))
        if isinstance(explanation, torch.Tensor):
            explanation = explanation.cpu().detach().numpy()

        raw_results.append(explanation)
    raw_results = np.vstack(raw_results)

    with open(os.path.join("results", "tabular", f"{method_name}.npy"), "wb") as f:
        np.save(f, raw_results)

#### Obtain the metrics for each method

In [None]:
results = dict()

for method_name in methods.keys():
    results_method = {k: [] for k in metrics.keys()}

    raw_results = np.load(os.path.join("results", "tabular", f"{method_name}.npy"))

    for explanation, y, gt, gt_xai in zip(
        tqdm(raw_results, desc=method_name), pred, y_test, importance
    ):
        if int(y.detach().cpu().numpy()) != int(gt):
            continue

        for metric_name, metric_fn in metrics.items():
            res = metric_fn(gt_xai.flatten(), explanation.flatten())
            results_method[metric_name].append(float(res))

    results[method_name] = results_method

with open(RESULTS_PATH, "w") as f:
    json.dump(results, f)

In [None]:
with open(RESULTS_PATH) as f:
    results = json.load(f)

for method_name, method_info in results.items():
    print(method_name.upper())
    for k, v in method_info.items():
        print(f"{k}: {np.nanmean(v)} - {np.nanstd(v)}")
    print("-" * 25)