![quanda_benchmarks_demo.png](attachment:quanda_benchmarks_demo.png)

In this notebook, we will go through the different ways of creating an evaluation benchmark and comparing different attributors with it. First, we take the most straightforward route and go over the process of downloading a precomputed quanda benchmark for data attribution evaluation. This way, you can quickly write a quanda wrapper for your explainer and evaluate it against the existing explainers in the controlled setups we have prepared for you.

Afterwards, we will go through the steps of assembling a benchmark from existing components. This option allows you to create your own controlled setup, and use quanda benchmarks for evaluation of different data attributors.

Finally, we will summarize how to create your setup using quanda benchmarks, which includes managing datasets, training models and running evaluations.

Throughout this tutorial, we will be using a LeNet model trained on the MNIST dataset.

We first handle our include statements.

In [23]:
import os
import sys

import torch
import torchvision
from quanda.benchmarks.downstream_eval import ShortcutDetection, MislabelingDetection, SubclassDetection
from quanda.explainers.wrappers import (
    TRAK,
    CaptumArnoldi,
    CaptumSimilarity,
    CaptumTracInCPFast,
    RepresenterPoints,
)

In [None]:
torch.set_float32_matmul_precision("medium")
to_img = torchvision.transforms.Compose([
    torchvision.transforms.Normalize(mean=0.0, std=2.),
    torchvision.transforms.Normalize(mean=-0.5, std=1.),
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize((224, 224)),])

# Downloading Precomputed Benchmarks
In this part of the tutorial, we will use the Shortcut Detection metric.

We will use the benchmark corresponding to this metric to evaluate all data attributors currently included in quanda in terms of their ability to detect when the model is using a shortcut.

We will download the precomputed MNIST benchmark. This includes an MNIST dataset which has shortcut features (an 8-by-8 white box on a specific location) on a subset of its samples from the class 0, and a model trained on this dataset. This model has learned to classify images with these features to the class 0, and we will measure the extent to which this is reflected in the attributions of different methods.

In [None]:
cache_dir = str(os.path.join(os.getcwd(), "quanda_benchmark_tutorial_cache"))
device="cpu"
benchmark = ShortcutDetection.download(
    name="mnist_shortcut_detection",
    cache_dir=cache_dir,
    device=device,
)

The benchmark object contains all information about the controlled evaluation setup. Let's see some samples with the shortcut features, using benchmark.feature_dataset and benchmark.shortcut_indices. 

In [None]:
shortcut_img = benchmark.shortcut_dataset[benchmark.shortcut_indices[15]][0]
tensor_img=torch.concat([shortcut_img,shortcut_img,shortcut_img],dim=0)
img=to_img(tensor_img)
img

In [None]:
predictions = []
for i in benchmark.shortcut_indices:
    x,y = benchmark.shortcut_dataset[i]
    x=x.to(device)
    benchmark.model(x[None])
    predictions.append(benchmark.model(x[None]).argmax().item())
predictions=torch.tensor(predictions)
shortcut_rate=torch.mean((predictions==benchmark.shortcut_cls)*1.0)
shortcut_rate

## Prepare initialization parameters for TDA methods

We now prepare the initialization parameters of attributors: hyperparameters, and components from the benchmark as needed. Note that we do not provide the model and dataset to use for attribution, since those components will be supplied by the benchmark objects, while initializing the attributor during evaluation.

### Similarity Influence

In [None]:
captum_similarity_args = {
    "model_id": "mnist_shortcut_detection_tutorial",
    "layers": "model.fc_2",
    "cache_dir": os.path.join(cache_dir, "captum_similarity"),
}

### Arnoldi Influence Functions

Notice that the trained checkpoints have been saved to the `cache_dir` while downloading the benchmark. We can reach the paths of these checkpoints with `benchmark.checkpoint_paths`

In [None]:
hessian_num_samples=500 # number of samples to use for hessian estimation
hessian_ds=torch.utils.data.Subset(benchmark.shortcut_dataset, torch.randint(0, len(benchmark.shortcut_dataset), (hessian_num_samples,)))

captum_influence_args = {
        "checkpoint": benchmark.checkpoint_paths[-1],
        "layers": ["model.fc_3"],
        "batch_size": 8,
        "hessian_dataset": hessian_ds,
        "projection_dim": 5,
}

### TracInCP

In [None]:
captum_tracin_args = {
    "final_fc_layer": "model.fc_3",
    "loss_fn": torch.nn.CrossEntropyLoss(reduction="mean"),
    "checkpoints": benchmark.checkpoint_paths,
    "batch_size": 8,
}

### TRAK

In [None]:
trak_args = {
    "model_id": "mnist_shortcut_detection",
    "cache_dir": os.path.join(cache_dir, "trak"),
    "batch_size": 8,
    "proj_dim": 5,
}

### Representer Points Selection

In [None]:
representer_points_args = {
    "model_id": "mnist_shortcut_detection",
    "cache_dir": os.path.join(cache_dir, "representer_points"),
    "batch_size": 8,
    "features_layer": "model.relu_4",
    "classifier_layer": "model.fc_3",
}

## Run the benchmark evaluation on the attributors
Note that some attributors take a long time to initialize or compute attributions. For a proof of concept, we recommend using `CaptumSimilarity` or `RepresenterPoints`, or lowering the parameter values given above (i.e. using low `proj_dim` for TRAK or a low Hessian dataset size for Arnoldi Influence)

In [None]:
attributors={
    # please comment out the explainers you are not interested in
    "captum_similarity": (CaptumSimilarity, captum_similarity_args),
    #"captum_arnoldi" : (CaptumArnoldi, captum_influence_args),
    #"captum_tracin" : (CaptumTracInCPFast, captum_tracin_args),
    #"trak" : (TRAK, trak_args),
    #"representer" : (RepresenterPoints, representer_points_args),
}

In [None]:
results=dict()
for name, (cls, kwargs) in attributors.items():
    results[name] = benchmark.evaluate(
        explainer_cls=cls,
        expl_kwargs=kwargs,
        batch_size=128
    )["score"]

The `results` dictionary contains the results of the evaluation. The keys are the names of the explainers and the values are dictionaries containing the results.

In [None]:
results

# Assembling a benchmark from existing components

You may want to handle the creation of each component differently, using different datasets, architectures, training paradigms or a higher/lower percentage of manipulated samples. We now showcase how to create and use a quanda `Benchmark` object to use these components in the evaluation process.

To showcase different benchmarks, we will now switch to the `MislabelingDetection` benchmark. This benchmark evaluates the ability of data atttribution methods to identify mislabeled samples in the training dataset. This is done by training a model on a dataset which has a substantial amount of mislabeled samples. We then use the local data attribution methods to rank the training data. Original papers propose either using self-influence (i.e. the attribution of training samples on themselves) or some special methodology for each explainer (i.e. the global coefficients of the surrogate model in Representer Points). Quanda includes efficient implementation of self-influence or other strategies proposed in the original papers, whenever possible.

This ranking is then used to go through the dataset to check mislabelings. Quanda computes the cumulative mislabeling detection curve and returns the AUC score with respect to this curve.

Instead of creating the components from scratch, we will again download the benchmark and use collect the prepared components. We will then use the `MislabelingDetection.assemble` method to create the benchmark. Note that this is exactly what is happening when we are creating a benchmark using the `download` method.

In [None]:
temp_benchmark = MislabelingDetection.download(
    name="mnist_mislabeling_detection",
    cache_dir=cache_dir,
    device=device,
)

## Required Components

In order to assemble a `MislabelingDetection` benchmark, we require the following components:
- A base training dataset with correct labels
- A dictionary containing mislabeling information: integer keys are the indices of samples to change labels, and the values correspond to the new (wrong) labels that were used to train the model
- A model trained on the mislabeled dataset
- Number of classes in the dataset
- Dataset transform that was used during training, applied to samples before feeding them to the model

Let's collect these components from the downloaded benchmark. We then assemble the benchmark and evaluate the `CaptumSimilarity` attributor with it.

In [None]:
model=temp_benchmark.model
base_dataset=temp_benchmark.base_dataset
mislabeling_labels=temp_benchmark.mislabeling_labels
dataset_transform=temp_benchmark.dataset_transform


In [None]:
benchmark=MislabelingDetection.assemble(
    model=model,
    base_dataset=base_dataset,
    n_classes=10,
    mislabeling_labels=mislabeling_labels,
    dataset_transform=dataset_transform,
    device=device,
)
benchmark.evaluate(
    CaptumSimilarity,
    captum_similarity_args
)