In [None]:
from encoder_decoder_correlations import EncoderDecoderComparison
import torch

from captum.attr import GradientShap, IntegratedGradients, Saliency
import pandas as pd
import numpy as np

In [None]:
# Define globals
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Gradshap

In [None]:
gradient_shap_experiment = EncoderDecoderComparison(
    model_name="",
    attributer_factory=GradientShap,
    dataset='MNIST',
    device = DEVICE,
)

In [None]:
gs_pearsons = gradient_shap_experiment.get_all_model_pearsons(True)

In [None]:
gs_pearsons_mean = gs_pearsons.mean(axis=1)

## Integrated Gradients

In [None]:
ig_experiment = EncoderDecoderComparison(
    model_name="",
    attributer_factory=IntegratedGradients,
    dataset='MNIST',
    device = DEVICE,
)

In [None]:
ig_pearsons = ig_experiment.get_all_model_pearsons(True)
ig_pearsons_mean = ig_pearsons.mean(axis=1)
ig_pearsons_mean

## Saliency

In [None]:
sal_experiment = EncoderDecoderComparison(
    model_name="",
    attributer_factory=Saliency,
    dataset='MNIST',
    device = DEVICE,
)

In [None]:
sal_pearsons = sal_experiment.get_all_model_pearsons(True)
sal_pearsons_mean = sal_pearsons.mean(axis=1)
sal_pearsons_mean

### Plot results for paper

In [None]:
# # Note this is done to load the results without rerunning the experiment - leave commented out if
# # Done from scatch
ig_pearsons_mean = np.array([0.39079937, 0.47364313, 0.45926487, 0.42462469, 0.41964185])
sal_pearsons_mean = np.array([0.21530681, 0.13549975, 0.16179986, 0.16258512, 0.20697079])
gs_pearsons_mean = np.array([0.3503069, 0.32580192, 0.32778726, 0.36071626, 0.32253383])

In [None]:
# One can either plot R-squared (proportion of variance explained) or the raw pearson correelation coefficient
PLOT_R_SQUARED = False

data = {'Integrated Gradients': ig_pearsons_mean,
        'Saliency': sal_pearsons_mean,
        'Gradient Shap': gs_pearsons_mean}

df = pd.DataFrame.from_dict(data)

if PLOT_R_SQUARED:
    df = df**2 # Convert to R2 values
df.index.name = 'Run'

In [None]:
# Compute means and error bars
agg_df = df.agg(['mean', 'sem']).T
agg_df['ci_width'] =+1.96*agg_df['sem']
agg_df['upper_ci'] = agg_df['mean'] + agg_df['ci_width']
agg_df['lower_ci'] = agg_df['mean'] - agg_df['ci_width']

# agg_df = agg_df.T

In [None]:
ylabel = 'R-squared' if PLOT_R_SQUARED else 'Pearson Correlation Coefficient'
agg_df.plot(y='mean',
            kind='bar',
            rot=0,
            ylabel=ylabel,
            yerr=agg_df['ci_width'],
            legend=False,
            # ylim=(0, 1),
            )

In [None]:
agg_df