In [1]:
import warnings
warnings.filterwarnings("ignore")

from xai_agg import *

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.ensemble import RandomForestClassifier

import pandas as pd
import numpy as np

import dill

2025-01-25 08:52:11.367376: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-25 08:52:11.650755: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
raw_data = pd.read_csv("../data/Student_Depression_Dataset.csv")
display(raw_data)
display(raw_data["Profession"].value_counts())

Unnamed: 0,id,Gender,Age,City,Profession,Academic Pressure,Work Pressure,CGPA,Study Satisfaction,Job Satisfaction,Sleep Duration,Dietary Habits,Degree,Have you ever had suicidal thoughts ?,Work/Study Hours,Financial Stress,Family History of Mental Illness,Depression
0,2,Male,33.0,Visakhapatnam,Student,5.0,0.0,8.97,2.0,0.0,5-6 hours,Healthy,B.Pharm,Yes,3.0,1.0,No,1
1,8,Female,24.0,Bangalore,Student,2.0,0.0,5.90,5.0,0.0,5-6 hours,Moderate,BSc,No,3.0,2.0,Yes,0
2,26,Male,31.0,Srinagar,Student,3.0,0.0,7.03,5.0,0.0,Less than 5 hours,Healthy,BA,No,9.0,1.0,Yes,0
3,30,Female,28.0,Varanasi,Student,3.0,0.0,5.59,2.0,0.0,7-8 hours,Moderate,BCA,Yes,4.0,5.0,Yes,1
4,32,Female,25.0,Jaipur,Student,4.0,0.0,8.13,3.0,0.0,5-6 hours,Moderate,M.Tech,Yes,1.0,1.0,No,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
27896,140685,Female,27.0,Surat,Student,5.0,0.0,5.75,5.0,0.0,5-6 hours,Unhealthy,Class 12,Yes,7.0,1.0,Yes,0
27897,140686,Male,27.0,Ludhiana,Student,2.0,0.0,9.40,3.0,0.0,Less than 5 hours,Healthy,MSc,No,0.0,3.0,Yes,0
27898,140689,Male,31.0,Faridabad,Student,3.0,0.0,6.61,4.0,0.0,5-6 hours,Unhealthy,MD,No,12.0,2.0,No,0
27899,140690,Female,18.0,Ludhiana,Student,5.0,0.0,6.88,2.0,0.0,Less than 5 hours,Healthy,Class 12,Yes,10.0,5.0,No,1


Profession
Student                   27870
Architect                     8
Teacher                       6
Digital Marketer              3
Content Writer                2
Chef                          2
Doctor                        2
Pharmacist                    2
Civil Engineer                1
UX/UI Designer                1
Educational Consultant        1
Manager                       1
Lawyer                        1
Entrepreneur                  1
Name: count, dtype: int64

In [3]:
preprocessed_data = raw_data.drop(columns=["id", "City", "Profession", "Degree"])

# Treat column names
preprocessed_data.rename(columns={
    'Have you ever had suicidal thoughts ?': 'Suicidal',
    "Family History of Mental Illness": "Family_History",
    "Gender": "GenderMale"
}, inplace=True)
preprocessed_data.columns = preprocessed_data.columns.str.replace(' ', '_')
preprocessed_data.columns = preprocessed_data.columns.str.replace('/', '_')

# Encoding categorical variables
categorical_features = []
preprocessed_data["GenderMale"] = preprocessed_data["GenderMale"].map({"Male": 1, "Female": 0})
preprocessed_data["Dietary_Habits"] = preprocessed_data["Dietary_Habits"].map({"Others": 0, "Unhealthy": 1, "Moderate": 2, "Healthy": 3})
preprocessed_data["Suicidal"] = preprocessed_data["Suicidal"].map({"Yes": 1, "No": 0})
preprocessed_data["Family_History"] = preprocessed_data["Family_History"].map({"Yes": 1, "No": 0})
preprocessed_data["Sleep_Duration"] = preprocessed_data["Sleep_Duration"].map({"Others": 0, "Less than 5 hours": 1, "5-6 hours": 2, "7-8 hours": 3, "More than 8 hours": 4})

preprocessed_data, _ = train_test_split(preprocessed_data, test_size=0.875, stratify=preprocessed_data['Depression'], random_state=42)
preprocessed_data.dropna(inplace=True)

display(preprocessed_data)

Unnamed: 0,GenderMale,Age,Academic_Pressure,Work_Pressure,CGPA,Study_Satisfaction,Job_Satisfaction,Sleep_Duration,Dietary_Habits,Suicidal,Work_Study_Hours,Financial_Stress,Family_History,Depression
22551,0,33.0,3.0,0.0,5.86,5.0,0.0,2,1,1,8.0,1.0,1,0
15512,1,32.0,3.0,0.0,6.16,4.0,0.0,4,2,0,0.0,3.0,0,0
21944,1,19.0,3.0,0.0,7.80,4.0,0.0,4,2,0,2.0,1.0,0,0
15462,1,32.0,4.0,0.0,6.89,3.0,0.0,2,2,1,5.0,1.0,0,1
21760,0,33.0,2.0,0.0,8.17,2.0,0.0,3,3,0,6.0,4.0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10927,1,22.0,5.0,0.0,5.27,1.0,0.0,4,1,0,9.0,5.0,0,1
12611,0,23.0,5.0,0.0,6.27,1.0,0.0,1,1,0,1.0,4.0,1,1
14848,1,21.0,1.0,0.0,8.59,1.0,0.0,3,2,0,3.0,2.0,1,0
13476,0,29.0,1.0,0.0,9.71,3.0,0.0,4,2,0,6.0,4.0,0,0


In [4]:
y = preprocessed_data["Depression"]
X = preprocessed_data.drop(columns=["Depression"])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [5]:
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
print(f"ROC AUC: {roc_auc_score(y_test, y_pred)}")

Accuracy: 0.8151862464183381
ROC AUC: 0.8017150225660865


# Experiments

In [6]:
results, metadata = evaluate_aggregate_explainer(
    clf, X_train, X_test, categorical_features,
    aggregation_algs=["wsum", "w_bordafuse", "w_condorcet"],
    n_instances=5
)

metadata["description"] = "compares wsum, w_bordafuse, w_condorcet aggregation algorithms"
metadata["dataset"] = "student_depression"

with open('pickles/student_depression/COMPARE_wsum-w_bordafuse-w_condorcet-allrank.pkl', 'wb') as f:
    dill.dump(ExperimentRun(metadata, results), f)

Selected indexes: [ 8273 27289 11440 22644 13791]
Epoch 1/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 1.1068 - val_loss: 1.0731
Epoch 2/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 581us/step - loss: 1.0482 - val_loss: 1.0184
Epoch 3/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 622us/step - loss: 0.9943 - val_loss: 0.9642
Epoch 4/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 619us/step - loss: 0.9457 - val_loss: 0.9123
Epoch 5/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 701us/step - loss: 0.8904 - val_loss: 0.8666
Epoch 6/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 655us/step - loss: 0.8523 - val_loss: 0.8298
Epoch 7/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.8136 - val_loss: 0.8017
Epoch 8/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 567us/step - loss: 0.7940 - val

Could not find an anchor satisfying the 0.95 precision constraint. Now returning the best non-eligible result. The desired precision threshold might not be achieved due to the quantile-based discretisation of the numerical features. The resolution of the bins may be too large to find an anchor of required precision. Consider increasing the number of bins in `disc_perc`, but note that for some numerical distribution (e.g. skewed distribution) it may not help.
Could not find an anchor satisfying the 0.95 precision constraint. Now returning the best non-eligible result. The desired precision threshold might not be achieved due to the quantile-based discretisation of the numerical features. The resolution of the bins may be too large to find an anchor of required precision. Consider increasing the number of bins in `disc_perc`, but note that for some numerical distribution (e.g. skewed distribution) it may not help.


	 Running instance 13791
Running evaluation for settings 2/3
Explainer components: [<class 'xai_agg.explainers.LimeWrapper'>, <class 'xai_agg.explainers.ShapTabularTreeWrapper'>, <class 'xai_agg.explainers.AnchorWrapper'>], Metrics: ['nrc', 'sensitivity_spearman', 'rb_faithfulness_corr'], MCDM algorithm: <pymcdm.methods.topsis.TOPSIS object at 0x70fbef3cc1f0>, Aggregation algorithm: w_bordafuse
	 Running instance 8273
	 Running instance 27289
	 Running instance 11440
	 Running instance 22644


Could not find an anchor satisfying the 0.95 precision constraint. Now returning the best non-eligible result. The desired precision threshold might not be achieved due to the quantile-based discretisation of the numerical features. The resolution of the bins may be too large to find an anchor of required precision. Consider increasing the number of bins in `disc_perc`, but note that for some numerical distribution (e.g. skewed distribution) it may not help.


	 Running instance 13791
Running evaluation for settings 3/3
Explainer components: [<class 'xai_agg.explainers.LimeWrapper'>, <class 'xai_agg.explainers.ShapTabularTreeWrapper'>, <class 'xai_agg.explainers.AnchorWrapper'>], Metrics: ['nrc', 'sensitivity_spearman', 'rb_faithfulness_corr'], MCDM algorithm: <pymcdm.methods.topsis.TOPSIS object at 0x70fbef3cc1f0>, Aggregation algorithm: w_condorcet
	 Running instance 8273
	 Running instance 27289
	 Running instance 11440
	 Running instance 22644


Could not find an anchor satisfying the 0.95 precision constraint. Now returning the best non-eligible result. The desired precision threshold might not be achieved due to the quantile-based discretisation of the numerical features. The resolution of the bins may be too large to find an anchor of required precision. Consider increasing the number of bins in `disc_perc`, but note that for some numerical distribution (e.g. skewed distribution) it may not help.
Could not find an anchor satisfying the 0.95 precision constraint. Now returning the best non-eligible result. The desired precision threshold might not be achieved due to the quantile-based discretisation of the numerical features. The resolution of the bins may be too large to find an anchor of required precision. Consider increasing the number of bins in `disc_perc`, but note that for some numerical distribution (e.g. skewed distribution) it may not help.
Could not find an anchor satisfying the 0.95 precision constraint. Now ret

	 Running instance 13791


In [9]:
with open('pickles/student_depression/COMPARE_wsum-w_bordafuse-w_condorcet-allrank.pkl', 'rb') as f:
    exp = dill.load(f)

In [12]:
methods = ["wsum", "w_bordafuse", "w_condorcet"]
present_experiment_run(exp, labels=methods)

wsum:



[                              nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.582348              0.978022              0.271817
 ShapTabularTreeWrapper  21.582348              1.000000              0.062507
 AnchorWrapper           16.750557              0.539413              0.386522
 AggregateExplainer      17.631316              0.947802              0.257086,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.582348              0.971978              0.936640
 ShapTabularTreeWrapper  21.474080              0.968595              0.514708
 AnchorWrapper           18.161386              0.664333              0.032407
 AggregateExplainer      18.838809              0.953846              0.836047,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.074611              0.976923              0.714512
 ShapTabularTreeWrapper  20.061509              0.

Worst case avoidances:
	- for all metrics: 4
	- for 2/3 metrics: 4
AVG:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,18.127712,0.949451,0.541257
AnchorWrapper,17.420596,0.765694,0.401127
LimeWrapper,20.975416,0.971758,0.545214
ShapTabularTreeWrapper,21.019081,0.988981,0.405099




Avg rank:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,1.8,3.2,2.6
AnchorWrapper,1.2,3.4,2.4
LimeWrapper,3.5,2.0,2.2
ShapTabularTreeWrapper,3.5,1.4,2.8


w_bordafuse:



[                              nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.713983              0.962088              0.099972
 ShapTabularTreeWrapper  21.582348              0.996694              0.206998
 AnchorWrapper           16.750557              0.747684              0.552542
 AggregateExplainer      24.093589              0.921978              0.185343,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             22.416107              0.973626              0.767939
 ShapTabularTreeWrapper  21.474080              0.964187              0.424649
 AnchorWrapper           16.960760              0.640947              0.009282
 AggregateExplainer      21.950258              0.941758              0.610274,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             20.008172              0.962637              0.711779
 ShapTabularTreeWrapper  20.061509              0.

Worst case avoidances:
	- for all metrics: 1
	- for 2/3 metrics: 3
AVG:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,23.129082,0.94044,0.3408
AnchorWrapper,17.494821,0.783573,0.415339
LimeWrapper,20.99906,0.967473,0.473984
ShapTabularTreeWrapper,21.019081,0.988099,0.37128




Avg rank:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,3.8,3.0,3.2
AnchorWrapper,1.0,4.0,2.0
LimeWrapper,2.6,1.8,2.6
ShapTabularTreeWrapper,2.6,1.2,2.2


w_condorcet:



[                              nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             20.487456              0.970330              0.013180
 ShapTabularTreeWrapper  21.582348              0.995041              0.056476
 AnchorWrapper           16.750557              0.891042              0.588887
 AggregateExplainer      24.093589              0.983516              0.354785,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.038750              0.964835              0.712946
 ShapTabularTreeWrapper  21.474080              0.971901              0.328682
 AnchorWrapper           16.960760              0.780590              0.202681
 AggregateExplainer      24.093589              0.925824              0.650700,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             20.629599              0.974725              0.710572
 ShapTabularTreeWrapper  20.061509              0.

Worst case avoidances:
	- for all metrics: 0
	- for 2/3 metrics: 4
AVG:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,24.093589,0.868791,0.395318
AnchorWrapper,17.375055,0.868906,0.427223
LimeWrapper,20.82249,0.971429,0.433554
ShapTabularTreeWrapper,21.019081,0.989421,0.350756




Avg rank:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,4.0,3.0,2.4
AnchorWrapper,1.0,3.8,2.2
LimeWrapper,2.4,2.2,2.6
ShapTabularTreeWrapper,2.6,1.0,2.8


# Evaluating Rank aggregation algorithms
### Execution

# RAE-T vs. RAE-E | 10 samples
### Execution

In [6]:
results, metadata = evaluate_aggregate_explainer(
    clf, X_train, X_test, categorical_features,
    metrics_sets=[['nrc', 'sensitivity_spearman', 'rb_faithfulness_corr']],
    mcdm_algs=[pymcdm.methods.TOPSIS(), pymcdm.methods.EDAS()],
    n_instances=10
)

metadata["description"] = "RAE-T vs RAE-S, 10 samples"

with open('pickles/student_depression/RAE-T_vs_RAE-S_10-allrank.pkl', 'wb') as f:
    dill.dump(ExperimentRun(metadata, results), f)

Selected indexes: [ 2374  5094 14399   354  3879 14730  9360 17960  7977 19307]
Epoch 1/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 1.1097 - val_loss: 1.0697
Epoch 2/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 591us/step - loss: 1.0333 - val_loss: 1.0140
Epoch 3/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 667us/step - loss: 0.9945 - val_loss: 0.9659
Epoch 4/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 681us/step - loss: 0.9447 - val_loss: 0.9231
Epoch 5/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - loss: 0.9064 - val_loss: 0.8854
Epoch 6/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 815us/step - loss: 0.8609 - val_loss: 0.8522
Epoch 7/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 884us/step - loss: 0.8304 - val_loss: 0.8239
Epoch 8/500
[1m70/70[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7

### Analysis

In [7]:
with open('pickles/student_depression/RAE-T_vs_RAE-S_10-allrank.pkl', 'rb') as f:
    exp = dill.load(f)

In [8]:
methods = ["RAE-T", "RAE-E"]
present_experiment_run(exp, labels=methods)

RAE-T:



[                              nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.211467              0.983516              0.880819
 ShapTabularTreeWrapper  20.741480              0.994490              0.920241
 AnchorWrapper           17.559590              1.000000              0.747207
 AggregateExplainer      17.886925              0.976374              0.743415,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             20.629599              0.956044              0.220449
 ShapTabularTreeWrapper  20.192111              0.965840              0.232545
 AnchorWrapper           16.960760              0.835722              0.231393
 AggregateExplainer      18.905932              0.826374              0.290000,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.688702              0.974725              0.403972
 ShapTabularTreeWrapper  21.454425              0.

Worst case avoidances:
	- for all metrics: 5
	- for 2/3 metrics: 9
AVG:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,18.214384,0.946429,0.559827
AnchorWrapper,17.197114,0.870992,0.567401
LimeWrapper,20.443514,0.960989,0.441468
ShapTabularTreeWrapper,20.97076,0.990193,0.55668




Avg rank:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,2.0,3.0,2.6
AnchorWrapper,1.0,3.3,2.0
LimeWrapper,3.5,2.6,3.1
ShapTabularTreeWrapper,3.5,1.1,2.3


RAE-E:



[                              nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             19.249676              0.939560              0.879739
 ShapTabularTreeWrapper  20.741480              0.998347              0.887028
 AnchorWrapper           17.559590              0.970131              0.691835
 AggregateExplainer      17.762886              0.935165              0.741337,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             21.211467              0.973626              0.178879
 ShapTabularTreeWrapper  20.192111              0.972452              0.302388
 AnchorWrapper           18.161386              0.843967              0.013994
 AggregateExplainer      18.745550              0.861538              0.276032,
                               nrc  sensitivity_spearman  rb_faithfulness_corr
 LimeWrapper             22.938190              0.966484              0.208088
 ShapTabularTreeWrapper  21.454425              0.

Worst case avoidances:
	- for all metrics: 7
	- for 2/3 metrics: 9
AVG:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,18.000209,0.921154,0.575966
AnchorWrapper,18.385744,0.845828,0.477188
LimeWrapper,20.802944,0.97022,0.419113
ShapTabularTreeWrapper,20.97076,0.989917,0.565473




Avg rank:


Unnamed: 0,nrc,sensitivity_spearman,rb_faithfulness_corr
AggregateExplainer,1.8,3.3,2.1
AnchorWrapper,1.4,3.5,2.6
LimeWrapper,3.3,2.1,3.1
ShapTabularTreeWrapper,3.5,1.1,2.2
