In [1]:
import pandas as pd
import numpy as np
import numpy.typing as npt
import seaborn as sns
import matplotlib.pyplot as plt
import os
import json
from experiment_utils import load_data, get_closest_to_optimal_point, get_pareto_optimal_mask, get_ideal_point

dataset = 'german'
dates = ['2023-03-12', '2023-03-14', '2023-03-15']

only_valid = True

#path = os.path.join(os.getcwd(), 'experiments', dates, 'scores')

In [2]:
# Load scores
list_of_scores_df , scores_df_all, scores_test_set_indices = load_data('scores', dates, dataset)
print(f'Gathered scores for {len(list_of_scores_df)} instances')

Gathered scores for 100 instances


In [3]:
# Load valid_scores
list_of_valid_scores_df , valid_scores_df_all, valid_scores_test_set_indices = load_data('valid_scores', dates, dataset)
print(f'Gathered valid_scores for {len(list_of_valid_scores_df)} instances')

Gathered valid_scores for 100 instances


In [4]:
# Load counterfactuals
list_of_counterfactuals_df , counterfactuals_df_all, cf_test_set_indices = load_data('counterfactuals', dates, dataset)
print(f'Gathered counterfactuals for {len(list_of_counterfactuals_df)} instances')

Gathered counterfactuals for 100 instances


In [5]:
# Load valid_counterfactuals
list_of_valid_counterfactuals_df , valid_counterfactuals_df_all, valid_cf_test_set_indices = load_data('valid_counterfactuals', dates, dataset)
print(f'Gathered valid_counterfactuals for {len(list_of_valid_counterfactuals_df)} instances')

Gathered valid_counterfactuals for 100 instances


In [6]:
# Load test data - original x instances
if 'experiments' in os.getcwd():
    test_data_path = os.path.join(os.path.pardir, 'data', f'{dataset}_test.csv')
else:
    test_data_path = os.path.join(os.getcwd(), 'data', f'{dataset}_test.csv')


test_dataset = pd.read_csv(test_data_path).iloc[scores_test_set_indices]
print(f'Loaded test data for {len(test_dataset)} instances')

Loaded test data for 100 instances


In [7]:
# Load constraints for the dataset
with open(os.path.join(os.path.pardir, 'data', f'{dataset}_constraints.json'), 'r') as f:
    constraints = json.load(f)
print(f'Loaded constraints for: {constraints["dataset_shortname"]}')

Loaded constraints for: german


In [8]:
assert scores_test_set_indices == cf_test_set_indices
assert len(list_of_scores_df) == len(list_of_counterfactuals_df) == len(test_dataset)

In [9]:
from typing import List


def get_ranges(test_data: pd.DataFrame, constraints: dict) -> npt.NDArray:
    '''
    Get ranges for continous variables.
    '''
    mins = test_data[constraints['continuous_features_nonsplit']].to_numpy().min(axis=0)
    maxes = test_data[constraints['continuous_features_nonsplit']].to_numpy().max(axis=0)
    feature_ranges = maxes - mins
    return feature_ranges


def heom(x: npt.NDArray, y: npt.NDArray, ranges: npt.NDArray, continous_indices: npt.NDArray, categorical_indices: npt.NDArray) -> float:
    '''
    Calculate HEOM distance between x and y. 
    X and Y should not be normalized. 
    X should be (n, m) dimensional.
    Y should be 1-D array.
    Ranges is max-min on each continous variables (order matters). 
    '''
    distance = np.zeros(x.shape[0])

    # Continous |x-y| / range
    distance += np.sum(np.abs(x[:, continous_indices].astype('float64') - y[continous_indices].astype('float64')) / ranges, axis=1)

    # Categorical - overlap
    distance += np.sum(~np.equal(x[:, categorical_indices], y[categorical_indices]), axis=1)

    return distance

def instability(test_data: pd.DataFrame, 
                 x_index: int, 
                 counterfactual: pd.DataFrame | pd.Series, 
                 list_of_counterfactuals_df: List[pd.DataFrame], 
                 ranges: npt.NDArray, 
                 continous_indices: npt.NDArray | List[float], 
                 categorical_indices: npt.NDArray | List[float]
                 ):
    # Find closest instance to original_x in test_data
    n = len(test_data)
    x = test_data.iloc[0:n+1].to_numpy()
    y = test_data.iloc[x_index].to_numpy()

    all_distances = heom(x, y, ranges, continous_indices, categorical_indices)
    # find closest instance to original_x in test_data
    sorting_indices = np.argsort(all_distances)
    # we do not take 0 because it is the same instance as original_x
    closest_index = np.array(list(zip(range(n), all_distances)))[sorting_indices][1][0].astype(int)
    # counterfactuals of closest x' to x
    closest_counterfactuals = list_of_counterfactuals_df[closest_index].to_numpy()
    
    # x_counterfactuals = list_of_counterfactuals_df[x_index].to_numpy()
    # # calculate all pairs of distances between counterfactuals from x and x'
    # sum_of_distances = .0
    # for x_cf in x_counterfactuals:
    #     mean_distance = np.mean(heom(closest_counterfactuals, x_cf, ranges, continous_indices, categorical_indices))
    #     sum_of_distances += mean_distance
    # return sum_of_distances / len(x_counterfactuals)
    
    instability_score = np.min(heom(closest_counterfactuals, counterfactual.to_numpy(), ranges, continous_indices, categorical_indices))
    return instability_score
    
    
    

continous_indices = [test_dataset.columns.get_loc(c) for c in constraints['continuous_features_nonsplit']]
categorical_indices = [test_dataset.columns.get_loc(c) for c in constraints['categorical_features_nonsplit']]
ranges = get_ranges(test_dataset, constraints)

print(f'Continous indices: {continous_indices}')
print(f'Categorical indices: {categorical_indices}')
print(f'Ranges: {ranges}')

test_plaus = instability(test_dataset, 0, list_of_counterfactuals_df[0].iloc[0], list_of_counterfactuals_df, ranges, continous_indices, categorical_indices)
# Calculate example instability score
print(f'Test instability: {test_plaus:.2f}')

Continous indices: [0, 1, 2, 3, 4, 5, 6]
Categorical indices: [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Ranges: [   56 15052     3     3    47     1     1]
Test instability: 6.12


In [10]:
def sparsity(x_instance: npt.NDArray, cf_instance: npt.NDArray, continous_indices, categorical_indices) -> int:
    _sparsity = 0
    
    # Continous
    _sparsity += np.sum(~np.isclose(x_instance[continous_indices].astype('float64'), cf_instance[continous_indices].astype('float64'), atol=1e-05))
    
    # Categorical
    _sparsity += np.sum(~np.equal(x_instance[categorical_indices].astype('str'), cf_instance[categorical_indices].astype('str')))
    
    return _sparsity

In [11]:
def is_actionable(x_instance: npt.NDArray, cf_instance: npt.NDArray, continous_indices, categorical_indices, freeze_indices) -> bool:
    for freeze_index in freeze_indices:
        if freeze_index in continous_indices \
            and not np.isclose(x_instance[freeze_index:freeze_index+1].astype('float64'), cf_instance[freeze_index:freeze_index+1].astype('float64'), atol=1e-05):
            return False
        if freeze_index in categorical_indices \
            and not np.equal(x_instance.astype('str')[freeze_index], cf_instance.astype('str')[freeze_index]):
            return False
    return True

freeze_indices = [test_dataset.columns.get_loc(c) for c in constraints['non_actionable_features']]

In [12]:
def get_actionable_indices(x_instance: pd.DataFrame | pd.Series, cf_instances: pd.DataFrame, continous_indices, categorical_indices, freeze_indices) -> npt.NDArray:
    actionability = []
    for _, _cf in cf_instances.iterrows():
        actionability.append(is_actionable(x_instance.to_numpy(), _cf.to_numpy(), continous_indices, categorical_indices, freeze_indices))
    return cf_instances[actionability].index

## Combine experiment metrics

In [51]:
all_explainer_names = counterfactuals_df_all['explainer'].unique().tolist() + ['ideal_point_eucli', 'ideal_point_cheby', 'random_choice']

experiment_scores = {
    'proximity': {k: [] for k in all_explainer_names},
    'k_feasibility_3': {k: [] for k in all_explainer_names},
    'discriminative_power_9': {k: [] for k in all_explainer_names},
    'sparsity': {k: [] for k in all_explainer_names},
    'instability': {k: [] for k in all_explainer_names},
    'coverage': {k: 0 for k in all_explainer_names},
    'actionable': {k: 0 for k in all_explainer_names},
}

In [52]:
from tqdm import tqdm

experiment2_list_of_scores = []

# Calculate instability for all counterfactuals
for i in tqdm(range(len(test_dataset))):
    if only_valid:
        i_counterfactuals = list_of_valid_counterfactuals_df[i]
        i_scores = list_of_valid_scores_df[i]
    else:
        i_counterfactuals = list_of_counterfactuals_df[i]
        i_scores = list_of_scores_df[i]
    
    experiment2scores = pd.DataFrame(columns=['Proximity', 'K_Feasibility(3)', 'DiscriminativePower(9)', 'explainer'])
    
    for explainer_name in all_explainer_names:
        
        _i_counterfactuals = i_counterfactuals.copy(deep=True)
        _i_scores = i_scores.copy(deep=True)
        
        if 'ideal_point' in explainer_name:
            
            # Filter counterfactuals to include only actionable
            actionable_indices = get_actionable_indices(test_dataset.iloc[i], _i_counterfactuals, continous_indices, categorical_indices, freeze_indices)
            
            _i_counterfactuals = _i_counterfactuals.iloc[actionable_indices]
            _i_scores = _i_scores.iloc[actionable_indices]
            
            # Get counterfactual closest to ideal point
            iscores = _i_scores[['Proximity', 'K_Feasibility(3)', 'DiscriminativePower(9)']].to_numpy()
            
            # Apply normalization in each feature
            iscores = (iscores - iscores.min(axis=0)) / (iscores.max(axis=0) - iscores.min(axis=0))
            
            pareto_mask = get_pareto_optimal_mask(iscores, ['min', 'min', 'max'])
            ideal_point = get_ideal_point(iscores, ['min', 'min', 'max'], pareto_mask)
            
            distance_metric = 'euclidean' if 'eucli' in explainer_name else 'chebyshev'
            
            closest_idx = get_closest_to_optimal_point(iscores, ['min', 'min', 'max'], pareto_mask, ideal_point, distance_metric)
            #print(closest_idx)
            _index = closest_idx
        elif explainer_name == 'random_choice':
            # Get random counterfactual from all counterfactuals
            _index = np.random.permutation(_i_scores.index)[0]
        elif explainer_name not in _i_scores['explainer'].unique():
            continue
        else:
            #print(explainer_name)
            # Get random counterfactual from particular explainer
            _index = np.random.permutation(_i_scores[_i_counterfactuals['explainer'] == explainer_name].index)[0]
            
        _cf = _i_counterfactuals.iloc[_index]
        _instability = instability(test_dataset, i, _cf, list_of_counterfactuals_df, ranges, continous_indices, categorical_indices)
        experiment_scores['instability'][explainer_name].append(_instability)
        
        _sparsity = sparsity(test_dataset.iloc[i].to_numpy(), _cf.to_numpy(), continous_indices, categorical_indices)
        experiment_scores['sparsity'][explainer_name].append(_sparsity)
        
        _score = _i_scores.iloc[_index]
        experiment_scores['proximity'][explainer_name].append(_score['Proximity'])
        experiment_scores['k_feasibility_3'][explainer_name].append(_score['K_Feasibility(3)'])
        experiment_scores['discriminative_power_9'][explainer_name].append(_score['DiscriminativePower(9)'])
        experiment_scores['coverage'][explainer_name] += 1
        
        actionable = is_actionable(test_dataset.iloc[i].to_numpy(), _cf.to_numpy(), continous_indices, categorical_indices, freeze_indices)
        experiment_scores['actionable'][explainer_name] += int(actionable)
        
        # Create list of scores for experiment 2 by adding scores for idael point and random choice
        new_record = pd.DataFrame({
            'DiscriminativePower(9)': [_score['DiscriminativePower(9)']],
            'K_Feasibility(3)': [_score['K_Feasibility(3)']],
            'Proximity': [_score['Proximity']],
            'explainer': [explainer_name],
        })
        experiment2scores = pd.concat([experiment2scores, new_record], axis=0)
    experiment2_list_of_scores.append(experiment2scores)

100%|██████████| 100/100 [00:05<00:00, 17.43it/s]


In [53]:
# average experiment scores
for metric_name, v in experiment_scores.items():
    for explainer_name, scores in v.items():
        if metric_name in ['coverage', 'actionable']:
            experiment_scores[metric_name][explainer_name] = experiment_scores[metric_name][explainer_name] / len(test_dataset)
        else:
            experiment_scores[metric_name][explainer_name] = np.mean(scores)
#print(f'{metric_name} {explainer_name}: {experiment_scores[metric_name][explainer_name]:.2f}')
print(experiment_scores)

{'proximity': {'dice': 1.845199516254174, 'cadex': 1.3755489424040659, 'fimap': 6.856590921161852, 'wachter': 1.2910160812426654, 'cem': 0.6177871433003198, 'cfproto': 4.290421580626075, 'growing-spheres': 7.407465861055548, 'actionable-recourse': 1.0051174818471176, 'face': 5.143013794986465, 'ideal_point_eucli': 3.2114239048210895, 'ideal_point_cheby': 2.898142832405329, 'random_choice': 3.731790716121377}, 'k_feasibility_3': {'dice': 4.0074806739026645, 'cadex': 3.72579780835402, 'fimap': 3.0809473929286435, 'wachter': 3.7320889239375736, 'cem': 4.184359725554645, 'cfproto': 4.644736414607237, 'growing-spheres': 5.801672900369567, 'actionable-recourse': 3.546955785007771, 'face': 1.8690380365307784, 'ideal_point_eucli': 2.4624324403471083, 'ideal_point_cheby': 2.7018608932513253, 'random_choice': 3.960707430875949}, 'discriminative_power_9': {'dice': 0.43111111111111117, 'cadex': 0.42840778923253153, 'fimap': 0.5521191294387171, 'wachter': 0.32592592592592584, 'cem': 0.3076923076923

In [54]:
# build dataframe from experiment scores
experiment1_df = pd.DataFrame(experiment_scores).round(2)
if only_valid:
    experiment1_df.to_csv(f'results/experiment1_{dataset}_valid.csv')
else:
    experiment1_df.to_csv(f'results/experiment1_{dataset}.csv')
experiment1_df

Unnamed: 0,proximity,k_feasibility_3,discriminative_power_9,sparsity,instability,coverage,actionable
dice,1.85,4.01,0.43,2.16,4.19,1.0,1.0
cadex,1.38,3.73,0.43,2.49,3.89,0.97,0.97
fimap,6.86,3.08,0.55,9.94,3.68,0.97,0.97
wachter,1.29,3.73,0.33,3.93,3.94,0.3,0.3
cem,0.62,4.18,0.31,2.15,3.99,0.13,0.13
cfproto,4.29,4.64,0.48,5.78,4.74,0.99,0.88
growing-spheres,7.41,5.8,0.56,10.5,5.46,1.0,1.0
actionable-recourse,1.01,3.55,0.44,1.39,3.6,0.23,0.23
face,5.14,1.87,0.6,8.22,3.83,1.0,0.98
ideal_point_eucli,3.21,2.46,0.8,4.99,3.68,1.0,1.0


In [55]:
max_metric = ['discriminative_power_9', 'coverage', 'actionable']

def highlight_top3(s):
    #print(s)
    if s.name in max_metric:
        top = sorted(s, reverse=True)[:3]
    else:
        top = sorted(s)[:3]
    return ['font-weight: bold' if v  in top else '' for v in s]

# bold top 3 in each metric
res = experiment1_df.style.apply(highlight_top3, axis=0)
# Round to 2 decimals
res = res.format(precision=2)
res

Unnamed: 0,proximity,k_feasibility_3,discriminative_power_9,sparsity,instability,coverage,actionable
dice,1.85,4.01,0.43,2.16,4.19,1.0,1.0
cadex,1.38,3.73,0.43,2.49,3.89,0.97,0.97
fimap,6.86,3.08,0.55,9.94,3.68,0.97,0.97
wachter,1.29,3.73,0.33,3.93,3.94,0.3,0.3
cem,0.62,4.18,0.31,2.15,3.99,0.13,0.13
cfproto,4.29,4.64,0.48,5.78,4.74,0.99,0.88
growing-spheres,7.41,5.8,0.56,10.5,5.46,1.0,1.0
actionable-recourse,1.01,3.55,0.44,1.39,3.6,0.23,0.23
face,5.14,1.87,0.6,8.22,3.83,1.0,0.98
ideal_point_eucli,3.21,2.46,0.8,4.99,3.68,1.0,1.0


In [56]:
# pandas dataframe to latex table
def pandas_to_latex(df: pd.DataFrame, keep_formatting: bool = True) -> str:
    """Converts a pandas dataframe to a latex table.
    Args:
        df: The dataframe to convert.
        keep_formatting: Whether to keep the formatting of the dataframe.
    Returns:
        The latex table as a string.
    """
    latex = df.to_latex()
    # Replace \font-weightbold with proper latexbf formatting
    latex = latex.replace(r"\font-weightbold", r"\bfseries")
    # Insert \hline after each newline 
    latex = latex.replace(r"\\", r"\\ \hline")
    # Insert \hline at the top after first newline \n
    latex = latex.replace(r"rrr}", r"rrr} \hline")
    # Replace undersores with dashes
    latex = latex.replace("_", "-")
    # Insert bold line before ideal-point-eucli
    latex = latex.replace("ideal-point-", r"\bfseries ideal-point-")
    # Rename columns according to dictionary
    shortnames = {
        'proximity': 'prox',
        'k-feasibility-3': 'feas-3',
        'discriminative-power-9': 'discrpow-9',
        'sparsity': 'spars',
        'instability': 'plausib',
        'coverage': 'cover',
        'actionable': 'actionab',
    }
    uparrow = ['discrpow-9', 'actionab', 'cover']
    for k, v in shortnames.items():
        latex = latex.replace(f'{k} ', rf'{v} $\uparrow$' if v in uparrow else rf'{v} $\downarrow$')
        
    
    
    return latex
 
    
print(pandas_to_latex(res, keep_formatting=True))

\begin{tabular}{lrrrrrrr} \hline
 & prox $\downarrow$& feas-3 $\downarrow$& discrpow-9 $\uparrow$& spars $\downarrow$& plausib $\downarrow$& cover $\uparrow$& actionab $\uparrow$\\ \hline
dice & 1.85 & 4.01 & 0.43 & \bfseries 2.16 & 4.19 & \bfseries 1.00 & \bfseries 1.00 \\ \hline
cadex & 1.38 & 3.73 & 0.43 & 2.49 & 3.89 & 0.97 & 0.97 \\ \hline
fimap & 6.86 & 3.08 & 0.55 & 9.94 & \bfseries 3.68 & 0.97 & 0.97 \\ \hline
wachter & \bfseries 1.29 & 3.73 & 0.33 & 3.93 & 3.94 & 0.30 & 0.30 \\ \hline
cem & \bfseries 0.62 & 4.18 & 0.31 & \bfseries 2.15 & 3.99 & 0.13 & 0.13 \\ \hline
cfproto & 4.29 & 4.64 & 0.48 & 5.78 & 4.74 & 0.99 & 0.88 \\ \hline
growing-spheres & 7.41 & 5.80 & 0.56 & 10.50 & 5.46 & \bfseries 1.00 & \bfseries 1.00 \\ \hline
actionable-recourse & \bfseries 1.01 & 3.55 & 0.44 & \bfseries 1.39 & \bfseries 3.60 & 0.23 & 0.23 \\ \hline
face & 5.14 & \bfseries 1.87 & \bfseries 0.60 & 8.22 & 3.83 & \bfseries 1.00 & 0.98 \\ \hline
\bfseries ideal-point-eucli & 3.21 & \bfseries 2.46 

## Experiment 2 - preferences of different simulated users

In [57]:
experiment2_list_of_scores[0]['explainer'].value_counts()

dice                 1
cadex                1
fimap                1
cfproto              1
growing-spheres      1
face                 1
ideal_point_eucli    1
ideal_point_cheby    1
random_choice        1
Name: explainer, dtype: int64

In [65]:
from collections import defaultdict
resolution = 20
metrics_to_consider = ['DiscriminativePower(9)', 'K_Feasibility(3)', 'Proximity']

combinations = [(i/resolution, (resolution-i-k)/resolution, k/resolution) for i in range(0,resolution+1) for k in range(0,resolution-i+1)]

results = []

for i, j, k in tqdm(combinations):
    counts = defaultdict(lambda: 0)
    
    for scores in experiment2_list_of_scores:
        # normalize scores in columns
        scores = scores.copy().reset_index(drop=True)
        x = scores[metrics_to_consider]
        x = (x - x.min()) / (x.max() - x.min()) 
        scores[metrics_to_consider] = x
        scores['weighted_score'] = np.sum(x * [i, -j, -k], axis=1) #[discr, feas, prox]
        
        #print(x)
        
        #break
    

        idxmax = np.argmax(scores['weighted_score'])
        counts[scores['explainer'].iloc[idxmax]] += 1
    #break
     
    # best explainer
    best = max(counts.items(), key=lambda x: x[1])
    results.append((i, j, k, best[0], best[1] / sum(counts.values())))
    
        
print(results)

100%|██████████| 231/231 [01:20<00:00,  2.86it/s]

[(0.0, 1.0, 0.0, 'face', 0.6), (0.0, 0.95, 0.05, 'face', 0.62), (0.0, 0.9, 0.1, 'face', 0.58), (0.0, 0.85, 0.15, 'face', 0.59), (0.0, 0.8, 0.2, 'face', 0.57), (0.0, 0.75, 0.25, 'face', 0.52), (0.0, 0.7, 0.3, 'face', 0.49), (0.0, 0.65, 0.35, 'face', 0.4), (0.0, 0.6, 0.4, 'ideal_point_eucli', 0.33), (0.0, 0.55, 0.45, 'ideal_point_eucli', 0.33), (0.0, 0.5, 0.5, 'ideal_point_eucli', 0.28), (0.0, 0.45, 0.55, 'cadex', 0.31), (0.0, 0.4, 0.6, 'cadex', 0.33), (0.0, 0.35, 0.65, 'cadex', 0.35), (0.0, 0.3, 0.7, 'cadex', 0.37), (0.0, 0.25, 0.75, 'cadex', 0.4), (0.0, 0.2, 0.8, 'cadex', 0.42), (0.0, 0.15, 0.85, 'cadex', 0.43), (0.0, 0.1, 0.9, 'cadex', 0.46), (0.0, 0.05, 0.95, 'cadex', 0.45), (0.0, 0.0, 1.0, 'cadex', 0.48), (0.05, 0.95, 0.0, 'face', 0.58), (0.05, 0.9, 0.05, 'face', 0.57), (0.05, 0.85, 0.1, 'face', 0.55), (0.05, 0.8, 0.15, 'face', 0.54), (0.05, 0.75, 0.2, 'face', 0.52), (0.05, 0.7, 0.25, 'face', 0.5), (0.05, 0.65, 0.3, 'face', 0.44), (0.05, 0.6, 0.35, 'ideal_point_eucli', 0.4), (0.05, 




In [66]:
experiment2_list_of_scores[0]

Unnamed: 0,Proximity,K_Feasibility(3),DiscriminativePower(9),explainer
0,2.0,5.226987,0.444444,dice
0,0.336688,4.646552,0.333333,cadex
0,6.39131,2.818776,1.0,fimap
0,1.000385,4.777732,0.666667,cfproto
0,6.357379,4.974154,0.444444,growing-spheres
0,5.524585,2.822099,0.555556,face
0,2.0,3.077862,0.777778,ideal_point_eucli
0,2.0,3.077862,0.777778,ideal_point_cheby
0,2.0,5.226987,0.444444,random_choice


In [77]:
import matplotlib.pyplot as plt
import plotly.express as px
# Plot plotly.figure_factory.create_ternary_contour
from plotly.figure_factory import create_ternary_contour
# Display plotly in browser (not in notebook)
import plotly.io as pio
pio.renderers.default = "browser"

x,y,z,explainer,percentage = zip(*results)
x = np.array(x)
y = np.array(y)
z = np.array(z)
percentage = np.array(percentage)

df = pd.DataFrame({
    'DiscriminativePower(9)': x,
    'K_Feasibility(3)': y,
    'Proximity': z,
    'explainer': explainer,
    'percentage': percentage,
})

fig = px.scatter_ternary(df, 
                        a="DiscriminativePower(9)", 
                        b="K_Feasibility(3)", 
                        c="Proximity", 
                        color="explainer", 
                        size="percentage", 
                        size_max=30, 
                        hover_name="explainer", 
                        color_continuous_scale=px.colors.sequential.Plasma,
                        # Title
                        title=f"Best Explainer for {dataset} dataset" + (" (only valid)" if only_valid else ""),
                        # Show axes arrow    
                        )


# Save plot to png
if only_valid:
    fig.write_image(f"results/experiment2_{dataset}_valid.png")
else:
    fig.write_image(f"results/experiment2_{dataset}.png")
fig.show()