In [1]:
from synth_xai.utils import (
    prepare_dutch,
)
from pathlib import Path
from synth_xai.explanations.explanation_utils import (
    load_bb,
)
import copy
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from scipy.io import arff
from sklearn.preprocessing import (
    MinMaxScaler,
)
from synth_xai.explanations.explanation_utils import (
    evaluate_bb,
    find_top_closest_rows,
    get_test_data,
    is_explainer_supported,
    label_synthetic_data,
    load_bb,
    load_synthetic_data,
    make_predictions,
    prepare_neighbours,
    setup_wandb,
    transform_input_data,
)
from synth_xai.explanations.explainer_model import ExplainerModel


  from .autonotebook import tqdm as notebook_tqdm


# Useful Functions

In [2]:
def prepare_dutch(
    sweep: bool,
    seed: int,
    current_path: Path,
    validation_seed: int | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, pd.DataFrame, pd.DataFrame]:
    file_path = current_path
    data = arff.loadarff(file_path)
    dutch_df = pd.DataFrame(data[0]).astype("int32")

    # check the columns with missign values:
    missing_values_columns = dutch_df.columns[dutch_df.isna().any()].tolist()
    for column in missing_values_columns:
        dutch_df[column] = dutch_df[column].fillna(dutch_df[column].mode()[0])

    if len(dutch_df.columns[dutch_df.isna().any()].tolist()) != 0:
        error_message = "There are still missing values in the dataset"
        raise ValueError(error_message)

    dutch_df["sex_binary"] = np.where(dutch_df["sex"] == 1, 1, 0)
    dutch_df["occupation_binary"] = np.where(dutch_df["occupation"] >= 300, 1, 0)

    del dutch_df["sex"]
    del dutch_df["occupation"]

    y = dutch_df["occupation_binary"].astype(int).values
    del dutch_df["occupation_binary"]

    dutch_df = pd.get_dummies(dutch_df, columns=None, drop_first=False)

    x_train, x_test, y_train, y_test = train_test_split(dutch_df, y, test_size=0.2, random_state=seed, stratify=y)

    train_df = copy.copy(x_train)
    train_df["occupation_binary"] = y_train

    test_df = copy.copy(x_test)
    test_df["occupation_binary"] = y_test

    scaler = MinMaxScaler()
    x_train = scaler.fit_transform(x_train)
    x_test = scaler.transform(x_test)

    if sweep:
        x_train, x_val, y_train, y_val = train_test_split(
            x_train,
            y_train,
            test_size=0.2,
            random_state=validation_seed,
            stratify=y_train,
        )
    else:
        x_val = None
        y_val = None

    return x_train, x_val, x_test, y_train, y_val, y_test, train_df, test_df


# Load BB Model and Dataset

In [3]:
bb_path = "./example_bb/dutch_BB.pth"
bb = load_bb(bb_path)

bb.to("cuda" if torch.cuda.is_available() else "cpu")

SimpleModel(
  (layer1): Linear(in_features=11, out_features=32, bias=True)
  (layer2): Linear(in_features=32, out_features=2, bias=True)
)

In [4]:
class Args:
    def __init__(self):
        self.dataset_name = "dutch"
        self.seed = 42
        self.synthetic_dataset_path = "./example_synthetic_data/synthetic_data.csv"
args = Args()

In [5]:
train_data, test_data, outcome_variable = get_test_data(args)
x, y, scaler = transform_input_data(train_data=train_data, test_data=test_data, outcome_variable=outcome_variable)

In [6]:
predictions = evaluate_bb(x, y, bb)

[32m2025-03-13 15:58:01.250[0m | [1mINFO    [0m | [36msynth_xai.explanations.explanation_utils[0m:[36mevaluate_bb[0m:[36m217[0m - [1mAccuracy: 0.832092022509103 - F1: 0.8321849875595729[0m


# Load the Synthetic Dataset

In [7]:
synthetic_data_path = Path(args.synthetic_dataset_path)

synthetic_data = load_synthetic_data(synthetic_data_path)

synthetic_data_labels = label_synthetic_data(
    synthetic_data=synthetic_data, outcome_variable=outcome_variable, bb=bb, scaler=scaler
)
synthetic_data[outcome_variable] = synthetic_data_labels

# Select the sample we want to explain

In [106]:
index = 5
sample = test_data.iloc[[index]]
print(sample)
x_sample = torch.tensor([x[index]])
y_sample = torch.tensor([y[index]])

sample_pred_bb = make_predictions(x_sample, y_sample, bb)

       age  household_position  household_size  prev_residence_place  \
25279    5                1110             112                     1   

       citizenship  country_birth  edu_level  economic_status  \
25279            1              1          2              111   

       cur_eco_activity  Marital_status  sex_binary  occupation_binary  
25279               131               1           1                  1  


# Create the Neighbourhood

In [107]:
top_k = 1000
top_k_samples = find_top_closest_rows(
    synthetic_data=synthetic_data,
    sample=sample,
    k=top_k,
    y_name=outcome_variable,
)

X, Y, old_x = prepare_neighbours(top_k_samples=top_k_samples, y_name=outcome_variable)
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

# Explanation with Decision Tree

In [108]:
explanation_type = "dt"
validation_seed = 1

In [109]:
explainer_model = ExplainerModel(explainer_type=explanation_type)

In [110]:
explainer_model.grid_search(x_train=x_train, y_train=y_train, seed=validation_seed)

sample_pred, explanation, threshold, feature = explainer_model.extract_explanation(
    clf=explainer_model.best_model, y_name=outcome_variable, sample=sample
)

y_pred = explainer_model.predict(x_test)
accuracy_val = accuracy_score(y_test, y_pred)
f1_val = f1_score(y_test, y_pred, average="weighted")

In [111]:
print("Fidelity: ", 1 if (y_pred[0] == sample_pred_bb[0]) else 0)

Fidelity:  1


In [112]:
print(f"Accuracy: {accuracy_val}")
print(f"F1: {f1_val}")
print(f"Explanation: {explanation}")

Accuracy: 1.0
F1: 1.0
Explanation: ['(household_position = 1110) <= 1115.5', '(prev_residence_place = 1) <= 1.5', 'Leaf node 2 reached, prediction: 1']


# Explanation with Logistic Regression

In [113]:
explanation_type = "logistic"
validation_seed = 1

In [114]:
explainer_model = ExplainerModel(explainer_type=explanation_type)

In [115]:
explainer_model.grid_search(x_train=x_train, y_train=y_train, seed=validation_seed)

sample_pred, explanation, threshold, feature = explainer_model.extract_explanation(
    clf=explainer_model.best_model, y_name=outcome_variable, sample=sample
)


y_pred = explainer_model.predict(x_test)
accuracy_val = accuracy_score(y_test, y_pred)
f1_val = f1_score(y_test, y_pred, average="weighted")

In [116]:
print("Fidelity: ", 1 if (y_pred[0] == sample_pred_bb[0]) else 0)
print(f"Accuracy: {accuracy_val}")
print(f"F1: {f1_val}")
print(f"Explanation: {explanation}")

Fidelity:  1
Accuracy: 0.995
F1: 0.9950129672006102
Explanation: ['Marital_status: coefficient=-7.647274793987985, value=1', 'age: coefficient=-5.314648352877113, value=5', 'prev_residence_place: coefficient=-2.4425235761689956, value=1', 'edu_level: coefficient=-1.6403407080407693, value=2', 'household_size: coefficient=0.12280868684771122, value=112', 'economic_status: coefficient=0.12168432610281686, value=111', 'cur_eco_activity: coefficient=0.1039957805705147, value=131', 'household_position: coefficient=0.007978501769350919, value=1110', 'citizenship: coefficient=0.0, value=1', 'country_birth: coefficient=0.0, value=1', 'sex_binary: coefficient=0.0, value=1']


In [117]:
test_data.columns

Index(['age', 'household_position', 'household_size', 'prev_residence_place',
       'citizenship', 'country_birth', 'edu_level', 'economic_status',
       'cur_eco_activity', 'Marital_status', 'sex_binary',
       'occupation_binary'],
      dtype='object')

# Explanation with SVM

In [118]:
explanation_type = "svm"
validation_seed = 1

In [119]:
explainer_model = ExplainerModel(explainer_type=explanation_type)

In [120]:
explainer_model.grid_search(x_train=x_train, y_train=y_train, seed=validation_seed)

sample_pred, explanation, threshold, feature = explainer_model.extract_explanation(
    clf=explainer_model.best_model, y_name=outcome_variable, sample=sample
)


y_pred = explainer_model.predict(x_test)
accuracy_val = accuracy_score(y_test, y_pred)
f1_val = f1_score(y_test, y_pred, average="weighted")

In [121]:
print("Fidelity: ", 1 if (y_pred[0] == sample_pred_bb[0]) else 0)
print(f"Accuracy: {accuracy_val}")
print(f"F1: {f1_val}")
print(f"Explanation: {explanation}")

Fidelity:  1
Accuracy: 1.0
F1: 1.0
Explanation: ['prev_residence_place: coefficient=-1.3115770976584527, value=1', 'age: coefficient=-1.3115600558866127, value=5', 'edu_level: coefficient=-1.311551535022013, value=2', 'Marital_status: coefficient=-0.5302355040847307, value=1', 'household_size: coefficient=-0.15822148192455643, value=112', 'cur_eco_activity: coefficient=-0.15820444018793867, value=131', 'household_position: coefficient=-0.14808844333992965, value=1110', 'sex_binary: coefficient=6.38378239159465e-16, value=1', 'citizenship: coefficient=0.0, value=1', 'country_birth: coefficient=0.0, value=1', 'economic_status: coefficient=0.0, value=111']


# Explanation with KNN

In [122]:
explanation_type = "knn"
validation_seed = 1

In [123]:
explainer_model = ExplainerModel(explainer_type=explanation_type)

In [124]:
explainer_model.grid_search(x_train=x_train, y_train=y_train, seed=validation_seed)

sample_pred, explanation, threshold, feature = explainer_model.extract_explanation(
    clf=explainer_model.best_model, y_name=outcome_variable, sample=sample
)


y_pred = explainer_model.predict(x_test)
accuracy_val = accuracy_score(y_test, y_pred)
f1_val = f1_score(y_test, y_pred, average="weighted")

In [125]:
print("Fidelity: ", 1 if (y_pred[0] == sample_pred_bb[0]) else 0)
print(f"Accuracy: {accuracy_val}")
print(f"F1: {f1_val}")
print(f"Explanation: {explanation}")

Fidelity:  1
Accuracy: 1.0
F1: 1.0
Explanation: ['KNN prediction: 1', 'Nearest neighbors (index, distance, label):', 'Index: 25, distance: 0.0000, label: 1, sample: [   5 1110  112    1    1    1    2  111  131    1    1]', 'Index: 134, distance: 0.0000, label: 1, sample: [   5 1110  112    1    1    1    2  111  131    1    1]', 'Index: 59, distance: 0.0000, label: 1, sample: [   5 1110  112    1    1    1    2  111  131    1    1]']


In [126]:
similar_1 = test_data.iloc[[25]]
print(similar_1)

      age  household_position  household_size  prev_residence_place  \
8720    8                1122             114                     1   

      citizenship  country_birth  edu_level  economic_status  \
8720            1              1          3              111   

      cur_eco_activity  Marital_status  sex_binary  occupation_binary  
8720               132               2           0                  1  


In [127]:
similar_2 = test_data.iloc[[134]]
print(similar_2)

     age  household_position  household_size  prev_residence_place  \
781   11                1121             112                     1   

     citizenship  country_birth  edu_level  economic_status  cur_eco_activity  \
781            1              1          3              111               135   

     Marital_status  sex_binary  occupation_binary  
781               2           0                  1  


In [128]:
similar_3 = test_data.iloc[[59]]
print(similar_3)

       age  household_position  household_size  prev_residence_place  \
13364   12                1121             112                     1   

       citizenship  country_birth  edu_level  economic_status  \
13364            1              1          2              111   

       cur_eco_activity  Marital_status  sex_binary  occupation_binary  
13364               131               2           0                  1  
