In [None]:
import torch
import pandas as pd
import numpy as np

#setups = [0, 1, 2]
setups = [0]
explainers = ["sg/highest", "sg/lowest"]

# explainers = ["eg", "ig", "ixg", "sal", "tsgh", "tsgl"]
metrics = ["comp", "inf", "loc"]


results_df = pd.DataFrame([])

for setup in setups:
    for explainer in explainers:
        # print( f"setup: {setup}, explainer: {explainer}" )
        metric_df = pd.DataFrame()

        for metric in metrics:
            path = f"../output/model/setup{setup}/metrics/{explainer}/{metric}.pt"
            
            data = torch.load(path)
            
            if torch.isnan(data).any():
                nan_percentage = torch.isnan(data).float().mean().item() * 100
                # print( f"nan percentage: {nan_percentage:.5f}%" )
                # if there are nan values in data print the percentage of nan values and remove them by replacing them with the mean
                tmp = data[~torch.isnan(data)]            
                data[ torch.isnan(data) ] = tmp.mean()  
            
            # create a dataframe with the metric and the data
            metric_df[ metric ] = data.numpy()
            #print( f"metric: {metric}, data size: {data.size()}, val: {data.mean()} +/- {data.std()}" )
            
        # append the dataframe to the results dataframe
        metric_df["setup"] = setup
        metric_df["explainer"] = explainer
        results_df = pd.concat([results_df, metric_df], ignore_index=True)
            
            
results_df.head()





Unnamed: 0,comp,inf,loc,lle,setup,explainer
0,5.945839,0.682777,1.791071,0.048755,0,eg
1,5.49688,0.122622,5.909897,0.037516,0,eg
2,5.663014,0.681049,5.070124,0.418933,0,eg
3,5.726206,0.288789,2.256105,0.057358,0,eg
4,5.898883,0.20186,3.321312,0.249442,0,eg


In [None]:
# for each setup plot a boxplot with each metric comparing the different explainers

import seaborn as sns
import matplotlib.pyplot as plt

# create the subplots : one row for each setup, one column for each metric
fig, axs = plt.subplots( len( setups ), len( metrics ), figsize=(15, 10))
axs = axs.flatten()
# create a boxplot for each setup and metric

for i, setup in enumerate(setups):
    for j, metric in enumerate(metrics):
        metric_df = results_df[ results_df["setup"] == setup ]        
        sns.boxplot( x="explainer", y=metric, data=metric_df, ax=axs[i + j], palette="Set3" )
        axs[i + j].set_title( f"Setup {setup}, Metric {metric}" )      
        
        # in case of the inf metric set the y limit between 0 and 1
        if metric == "inf" or metric == "lle":
            axs[i + j].set_ylim(0, 1)


In [None]:
# create a latex table with the mean value for each setup and metric to compare each explainer
latex_table = results_df.groupby(by=["setup", "explainer"]).mean().reset_index()
# round to the 3 decimal
latex_table = latex_table.round(2)

# sort by setup and explainer
latex_table = latex_table.sort_values(by=["setup", "explainer"])

latex_table.to_latex( "latex_table.tex", index=False, float_format="%.2f" )

latex_table.head( n =20 )


Unnamed: 0,setup,explainer,comp,inf,loc
0,0,eg,5.91,0.31,2.76
1,0,ig,5.04,0.31,3.41
2,0,ixg,5.53,0.33,2.95
3,0,sal,5.77,0.34,2.79
4,0,tsgh,4.93,0.31,4.08
5,0,tsgl,4.9,0.31,4.08
6,1,eg,5.93,0.29,2.8
7,1,ig,5.14,0.29,3.36
8,1,ixg,5.81,0.35,2.44
9,1,sal,6.03,0.36,2.28
