In [9]:
import wandb 
import dill
import pandas as pd
import os 

In [10]:
def download_runs(project_name):
    if not os.path.exists(
        f"./results_data/data_{project_name}.pkl"
    ):
        project_details = wandb.Api().runs(f"lucacorbucci/{project_name}")
        project_data = {}
        for run in project_details:
            print("Downloading run ", run.id)
            run_df = pd.DataFrame(
                wandb.Api().run(f"lucacorbucci/{project_name}/{run.id}").scan_history()
            )
            if run.name not in project_data:
                project_data[run.name] = []
            project_data[run.name].append(run_df)
        with open(
            f"./results_data/data_{project_name}.pkl", "wb"
        ) as f:
            dill.dump(project_data, f)
    else:
        with open(
            f"./results_data/data_{project_name}.pkl", "rb"
        ) as f:
            project_data = dill.load(f)
    return project_data

# Explanation Metrics

In [11]:
project_data = download_runs(project_name="tango_explanation_metrics")
project_name = "tango_explanation_metrics"

In [17]:
methods = ["dt", "svm", "logistic"]
datasets = ["dutch", "adult", "letter"]

In [36]:
project_data["logistic_dutch"][0]["faithfulness"]

0    0.04187
Name: faithfulness, dtype: float64

In [73]:
metrics = {}

for dataset in datasets:
    metrics[dataset] = {}   
    for method in methods: 
        metrics[dataset][method] = {}
        results = project_data[f"{method}_{dataset}"][0]
        if "faithfulness" in results.columns:
            faithfulness = round(float(results["faithfulness"]), 2)
            faithfulness_std = round(float(results["faithfulness_std"]), 2)
            metrics[dataset][method]["Faithfulness"] = f"{faithfulness} $\pm$ {faithfulness_std}"
        robustness = round(float(results["robustness"]), 2)
        robustness_std = round(float(results["robustness_std"]), 2)
        stability = round(float(results["stability"]), 2)
        stability_std = round(float(results["stability_std"]), 2)
        metrics[dataset][method]["robustness"] = f"{robustness} $\pm$ {robustness_std}"
        metrics[dataset][method]["stability"] = f"{stability} $\pm$ {stability_std}"

  robustness = round(float(results["robustness"]), 2)
  robustness_std = round(float(results["robustness_std"]), 2)
  stability = round(float(results["stability"]), 2)
  stability_std = round(float(results["stability_std"]), 2)
  faithfulness = round(float(results["faithfulness"]), 2)
  faithfulness_std = round(float(results["faithfulness_std"]), 2)


In [74]:
import pandas as pd

# Initialize an empty list to store the rows
rows = []

# Iterate over the datasets and methods to extract the metrics
for dataset in datasets:
    for method in methods:
        row = {
            'Dataset': dataset,
            'Method': method,
            'Stability': metrics[dataset][method].get('stability', '-'),
            'Robustness': metrics[dataset][method].get('robustness', '-'),
            'Faithfulness': metrics[dataset][method].get('Faithfulness', '-')
        }
        rows.append(row)

# Create a dataframe from the rows
df_metrics = pd.DataFrame(rows)
df_metrics.head(9)

Unnamed: 0,Dataset,Method,Stability,Robustness,Faithfulness
0,dutch,dt,0.97 $\pm$ 0.1,0.42 $\pm$ 0.2,-
1,dutch,svm,0.84 $\pm$ 0.26,0.17 $\pm$ 0.08,0.02 $\pm$ 0.34
2,dutch,logistic,0.91 $\pm$ 0.2,0.17 $\pm$ 0.09,0.04 $\pm$ 0.31
3,adult,dt,0.64 $\pm$ 0.45,0.27 $\pm$ 0.26,-
4,adult,svm,0.85 $\pm$ 0.26,0.26 $\pm$ 0.11,0.01 $\pm$ 0.12
5,adult,logistic,0.41 $\pm$ 0.29,0.11 $\pm$ 0.07,-0.03 $\pm$ 0.21
6,letter,dt,0.93 $\pm$ 0.15,0.39 $\pm$ 0.15,-
7,letter,svm,0.84 $\pm$ 0.31,0.07 $\pm$ 0.04,0.03 $\pm$ 0.22
8,letter,logistic,0.65 $\pm$ 0.32,0.07 $\pm$ 0.04,0.01 $\pm$ 0.19


In [76]:
print(df_metrics.to_latex())

\begin{tabular}{llllll}
\toprule
 & Dataset & Method & Stability & Robustness & Faithfulness \\
\midrule
0 & dutch & dt & 0.97 $\pm$ 0.1 & 0.42 $\pm$ 0.2 & - \\
1 & dutch & svm & 0.84 $\pm$ 0.26 & 0.17 $\pm$ 0.08 & 0.02 $\pm$ 0.34 \\
2 & dutch & logistic & 0.91 $\pm$ 0.2 & 0.17 $\pm$ 0.09 & 0.04 $\pm$ 0.31 \\
3 & adult & dt & 0.64 $\pm$ 0.45 & 0.27 $\pm$ 0.26 & - \\
4 & adult & svm & 0.85 $\pm$ 0.26 & 0.26 $\pm$ 0.11 & 0.01 $\pm$ 0.12 \\
5 & adult & logistic & 0.41 $\pm$ 0.29 & 0.11 $\pm$ 0.07 & -0.03 $\pm$ 0.21 \\
6 & letter & dt & 0.93 $\pm$ 0.15 & 0.39 $\pm$ 0.15 & - \\
7 & letter & svm & 0.84 $\pm$ 0.31 & 0.07 $\pm$ 0.04 & 0.03 $\pm$ 0.22 \\
8 & letter & logistic & 0.65 $\pm$ 0.32 & 0.07 $\pm$ 0.04 & 0.01 $\pm$ 0.19 \\
\bottomrule
\end{tabular}

