# Set up dataset and model

Please run all the cells.

In [None]:
import os
import sys

os.chdir('../')

In [None]:
!gpustat

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="0"

In [None]:
sys.argv=["train_objexplainer.py", "configs/vitbase_imagenette_shapley_objexplainer_newsample_32.json"]

In [None]:
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import copy
import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import ipdb
import numpy as np
import torch
import transformers
from datasets import load_dataset
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from arguments import DataTrainingArguments, ExplainerArguments, SurrogateArguments
from models import (
    ObjExplainerForImageClassification,
    ObjExplainerForImageClassificationConfig,
    SurrogateForImageClassificationConfig,
)
from utils import (
    MaskDataset,
    configure_dataset,
    generate_mask,
    get_checkpoint,
    get_image_transform,
    load_shapley,
    log_dataset,
    read_eval_results,
    setup_dataset,
)

""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.32.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/image-classification/requirements.txt",
)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class OtherArguments:
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )

    train_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    validation_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    test_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    train_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for train",
        },
    )

    validation_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for validation",
        },
    )

    test_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for test",
        },
    )

In [None]:
########################################################
# Parse arguments
#######################################################
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser(
    (
        SurrogateArguments,
        ExplainerArguments,
        DataTrainingArguments,
        TrainingArguments,
        OtherArguments,
    )
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_args_into_dataclasses()

########################################################
# Setup logging
#######################################################
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

########################################################
# Correct cache dir if necessary
########################################################
if not os.path.exists(
    os.sep.join((data_args.dataset_cache_dir).split(os.sep, 2)[:2])
):
    if os.path.exists("/data2"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/data2"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    elif os.path.exists("/sdata"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/sdata"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    else:
        raise ValueError(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found"
        )

########################################################
# Set seed before initializing model.
########################################################
set_seed(training_args.seed)

########################################################
# Initialize our dataset and prepare it for the 'image-classification' task.
########################################################
dataset_original, labels, label2id, id2label = setup_dataset(
    data_args=data_args, other_args=other_args
)

########################################################
# Initialize explainer model
########################################################

explainer_config = AutoConfig.from_pretrained(
    explainer_args.explainer_config_name
    or explainer_args.explainer_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

if os.path.isfile(
    f"{explainer_args.explainer_model_name_or_path}/config.json"
) and (
    json.loads(
        open(f"{explainer_args.explainer_model_name_or_path}/config.json").read()
    )["architectures"][0]
    == "ObjExplainerForImageClassification"
):
    explainer = ObjExplainerForImageClassification.from_pretrained(
        explainer_args.explainer_model_name_or_path,
        from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
        config=explainer_config,
        cache_dir=explainer_args.explainer_cache_dir,
        revision=explainer_args.explainer_model_revision,
        token=other_args.token,
        ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )
else:
    surrogate_config = AutoConfig.from_pretrained(
        surrogate_args.surrogate_config_name
        or surrogate_args.surrogate_model_name_or_path,
        num_labels=len(labels),
        label2id=label2id,
        id2label=id2label,
        finetuning_task="image-classification",
        cache_dir=surrogate_args.surrogate_cache_dir,
        revision=surrogate_args.surrogate_model_revision,
        token=other_args.token,
    )
    surrogate_for_image_classification_config = SurrogateForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    )

    explainer_for_image_classification_config = ObjExplainerForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_for_image_classification_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
        explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
        explainer_config=explainer_config,
        explainer_from_tf=bool(
            ".ckpt" in explainer_args.explainer_model_name_or_path
        ),
        explainer_cache_dir=explainer_args.explainer_cache_dir,
        explainer_revision=explainer_args.explainer_model_revision,
        explainer_token=other_args.token,
        explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )

    explainer = ObjExplainerForImageClassification(
        config=explainer_for_image_classification_config,
    )
explainer_image_processor = AutoImageProcessor.from_pretrained(
    explainer_args.explainer_image_processor_name
    or explainer_args.explainer_model_name_or_path,
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

########################################################
# Configure dataset (set max samples, transforms, etc.)
########################################################
dataset_explainer = copy.deepcopy(dataset_original)
dataset_explainer = configure_dataset(
    dataset=dataset_explainer,
    image_processor=explainer_image_processor,
    training_args=training_args,
    data_args=data_args,
    train_augmentation=False,
    validation_augmentation=False,
    test_augmentation=False,
    logger=logger,
)

In [None]:
device="cuda:0"
model=explainer.surrogate
model.to(device)

dataset=dataset_explainer["test"]

# label info

In [None]:
id2label

# Model inference example

In [None]:
sample=dataset[0]

In [None]:
sample.keys()

In [None]:
sample["image"]

In [None]:
sample["pixel_values"].shape

In [None]:
sample["labels"]

pixel_values: (batch_size, channel, height, weight)

masks: (batch_size, num_mask_samples, num_players)

## get grand value

In [None]:
model.eval()
with torch.no_grad():
    output=model(pixel_values=sample["pixel_values"].unsqueeze(0).to(device), 
          masks=torch.ones((1,1,196), device=device),
          return_loss=False)

In [None]:
print(output.logits)
print(output.logits.softmax(dim=-1)[0, 0].cpu().numpy())

## get grand and null simultaneously.

In [None]:
model.eval()
with torch.no_grad():
    output=model(pixel_values=sample["pixel_values"].unsqueeze(0).to(device),
          masks=torch.concat([torch.ones((1,1,196), device=device), torch.zeros((1,1,196), device=device)], axis=1),
          return_loss=False)

In [None]:
print(output.logits[0].cpu().numpy())
print(output.logits[0].softmax(-1).cpu().numpy())

# Example `Game` class for the model.

In [None]:
class CooperativeGame:
    '''Base class for cooperative games.'''

    def __init__(self):
        raise NotImplementedError

    def __call__(self, S):
        '''Evaluate cooperative game.'''
        raise NotImplementedError

    def grand(self):
        '''Get grand coalition value.'''
        return self.__call__(np.ones((1, self.players), dtype=int))[0]

    def null(self):
        '''Get null coalition value.'''
        return self.__call__(np.zeros((1, self.players), dtype=int))[0]


class PredictionGame(CooperativeGame):
    '''
    Cooperative game for an individual example's prediction.

    Args:
      extension: model extension (see removal.py).
      sample: numpy array representing a single model input.
    '''

    def __init__(self, surrogate, sample):
        # Store sample.
        self.surrogate = surrogate
        self.sample = sample
        self.players = 196

    def __call__(self, S):
        '''
        Evaluate cooperative game.

        Args:
          S: array of player coalitions with size (batch, players).
        '''
        # Evaluate.
        with torch.no_grad():
            output = self.surrogate(
                self.sample["pixel_values"].unsqueeze(0).to(device), 
                torch.Tensor(S).unsqueeze(0).to(device),
                return_loss=False)
            return output.logits[0].softmax(dim=-1).detach().cpu().numpy()

In [None]:
game = PredictionGame(
    surrogate=explainer.surrogate,
    sample=dataset_explainer["test"][0]
)

In [None]:
print(game.grand().shape)
print(game.null().shape)
print(game(np.random.choice([0,1], size=(4,196), replace=True)).shape)

# SGD Shapley implementation

Todo list:
- (Done) Implementing minibatches
- (Done) Make importance sampling optional
- (Done) Make paired sampling optional
- (Done) Change to function rather than class

In [None]:
import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


def projection_step(phi, total):
    return phi - (np.sum(phi, axis=0) - total) / len(phi)


def SGDShapley(game,
               n_iter=100,
               mbsize=32,
               step=0.001,
               step_type="constant",
               sampling="importance",
               averaging="uniform",
               C=1,
               phi_0=False):
    """
    Estimate the Shapley values using projected stochastic gradient descent.
    """
    # Get general information
    assert sampling in ("default", "paired", "importance")
    assert step_type in ("constant", "sqrt", "inverse")
    assert averaging in ("none", "uniform", "tail")
    d = game.players

    # Setup for importance sampling
    dict_w_k = dict()  # weights per size k
    dict_L_k = dict()  # L-smooth constant per size k
    D = C * np.sqrt(d)
    for k in range(1, d):
        w_k = (d - 1) / (ncr(d, k) * k * (d - k))
        L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
        dict_w_k.update({k: w_k})
        dict_L_k.update({k: L_k})

    # Summation of all L per coalition (closed formula)
    sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

    # Subset distributions

    # 1. Importance sampling
    p = [ncr(d, k) for k in range(1, d)]
    p /= np.sum(p)
    p_importance = np.array(list(dict_L_k.values())) * np.array(p)
    p_importance /= np.sum(p_importance)

    # 2. Default distribution or paired sampling
    p_default = 1 / (np.arange(1, d) * (d - np.arange(1, d)))
    p_default /= p_default.sum()

    # Get null/grand and output dimension
    grand = game(np.ones((1, d), dtype=bool))[0]
    null = game(np.zeros((1, d), dtype=bool))[0]
    assert isinstance(grand, np.ndarray)
    out_dim = len(grand)
    total = grand - null

    # Initialize Shapley value estimates
    if phi_0:
        phi = phi_0.copy()
    else:
        phi = np.zeros((d, out_dim))

    # Projection step
    phi = projection_step(phi, total)

    # Store for iterate averaging
    if out_dim is None:
        phi_iterates = np.zeros((n_iter, d))
    else:
        phi_iterates = np.zeros((n_iter, d, out_dim))

    for t in tqdm(range(n_iter)):
        # Sample subset cardinality
        if sampling == "importance":
            k_list = np.random.choice(list(range(1, d)), size=mbsize, p=p_importance)
        else:
            k_list = np.random.choice(list(range(1, d)), size=mbsize, p=p_default)

        # Apply permutations
        indices = [np.random.permutation(d)[:k] for k in k_list]
        x = np.zeros((mbsize, d))
        for i in range(mbsize):
            if (i % 2 == 1) and (sampling == "paired"):
                x[i] = 1 - x[i - 1]
            else:
                x[i, indices[i]] = 1

        # Compute y
        y = game(x.astype(bool)) - null

        # Calculate gradient
        residual = (x.dot(phi) - y)
        grad = x[:, :, None] * residual[:, None, :]
        if sampling == "importance":
            # Get weights w, p for importance sampling
            w = np.array([dict_w_k[k] for k in x.sum(axis=1)])
            p = np.array([dict_L_k[k] / sum_L for k in x.sum(axis=1)])

            # Apply importance sampling weights
            grad *= np.expand_dims(w / p, (1, 2))

        # Average gradient
        grad = np.mean(grad, axis=0)

        # Update phi
        if step_type == "constant":
            phi = phi - step * grad
        elif step_type == "sqrt":
            phi = phi - (step / np.sqrt(t + 1)) * grad
        elif step_type == "inverse":
            phi = phi - (step / (t + 1)) * grad

        # Projection step
        phi = projection_step(phi, total)

        # Update iterate history
        phi_iterates[t] = phi

    # Calculate iterate averages
    if averaging == "none":
        return phi_iterates
    elif averaging == "uniform":
        averaged = np.cumsum(phi_iterates, axis=0) / np.expand_dims(np.arange(n_iter) + 1, (1, 2))
        return averaged
    elif averaging == "tail":
        t = np.expand_dims(np.arange(len(phi_iterates)) + 1, (1, 2))
        averaged = np.cumsum(2 * phi_iterates * t, axis=0) / (t * (t + 1))
        return averaged

# SGD Shapley ground truth tests

@chanwoo

In [None]:
from utils import load_shapley
import matplotlib.pyplot as plt

In [None]:
shapley_loaded_test = load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test/")

In [None]:
# Setup
n_examples = 10
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    # Run SGD estimator
    phi = SGDShapley(
        game,
        mbsize=2,
        n_iter=5000,
        step=0.0005,
        sampling="paired",
        step_type="constant",
        averaging="uniform",
    )

    # Store values
    sgd_results[i] = {
        "estimates": phi,
        "label": sample["labels"]
    }

In [None]:
# Individual curves
dist_list = []

plt.figure()

for i in range(n_examples):
    # Get target-class estimates
    label = sgd_results[i]["label"]
    estimates = sgd_results[i]["estimates"][:, :, label]

    # Get target-class ground truth
    ground_truth = shapley_loaded_test[i]["values"][-1][:, label]
    
    # Plot distance
    dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
    dist_list.append(dist)
    plt.plot(np.arange(len(dist)), dist, color="C0")
    
plt.title("SGD Shapley Convergence")
plt.xlabel("# Steps")
plt.ylabel("L2 Distance")
plt.show()

In [None]:
# Averaged curve
plt.figure()
mean_dist = np.array(dist_list).mean(axis=0)
plt.plot(np.arange(len(mean_dist)), mean_dist)
plt.title("SGD Shapley Convergence")
plt.xlabel("# Steps")
plt.ylabel("L2 Distance")
plt.show()

# Parameter tuning

- Inverse step schedule leads to very slow progress, constant often works better
- Step size = 0.001 worked best for one sample, but leads to divergence for others
- Tail averaging works better when objective improves monotonically, uniform is better in noisy cases
- Default subset distribution seems to work better than importance sampling. Setting mbsize = 2 with paired sampling works even better

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant", "sqrt", "inverse"):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=1,
            n_iter=10000,
            step=0.001,
            sampling="importance",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant", "sqrt", "inverse"):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=1,
            n_iter=10000,
            step=0.0003,
            sampling="importance",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant", "sqrt", "inverse"):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=1,
            n_iter=10000,
            step=0.0003,
            sampling="default",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant", "sqrt"):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=1,
            n_iter=10000,
            step=0.0001,
            sampling="default",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant", "sqrt"):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=1,
            n_iter=10000,
            step=0.001,
            sampling="default",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant", "sqrt"):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=2,
            n_iter=5000,
            step=0.0003,
            sampling="paired",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant",):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=2,
            n_iter=5000,
            step=0.0002,
            sampling="paired",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant",):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=2,
            n_iter=5000,
            step=0.0005,
            sampling="paired",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant",):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=2,
            n_iter=5000,
            step=0.0007,
            sampling="paired",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant",):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=2,
            n_iter=5000,
            step=0.001,
            sampling="paired",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 1
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    for step_type in ("constant",):
        # Run SGD estimator
        phi = SGDShapley(
            game,
            mbsize=2,
            n_iter=5000,
            step=0.0015,
            sampling="paired",
            step_type=step_type,
            averaging="none",
        )

        # Store values
        sgd_results[i] = {
            "estimates": phi,
            "label": sample["labels"]
        }

        ### Plot results ###

        # Get target-class estimates
        label = sgd_results[i]["label"]
        estimates = sgd_results[i]["estimates"][:, :, label]

        # Get target-class ground truth
        ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

        plt.figure()

        # No averaging
        dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

        # Uniform averaging
        averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

        # Tail averaging
        t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
        averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
        dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
        plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

        plt.title(f'SGD Shapley Convergence (index={i} step_type={step_type})')
        plt.ylabel('L2 Distance to Ground Truth')
        plt.xlabel('# Steps')
        plt.legend()
        plt.show()

In [None]:
# Setup
n_examples = 10
sgd_results = {}

for i in range(n_examples):
    # Set up game
    sample = dataset_explainer["test"][i]
    game = PredictionGame(
        surrogate=explainer.surrogate,
        sample=sample
    )

    # Run SGD estimator
    phi = SGDShapley(
        game,
        mbsize=2,
        n_iter=5000,
        step=0.001,
        sampling="paired",
        step_type="constant",
        averaging="none",
    )

    # Store values
    sgd_results[i] = {
        "estimates": phi,
        "label": sample["labels"]
    }

    ### Plot results ###

    # Get target-class estimates
    label = sgd_results[i]["label"]
    estimates = sgd_results[i]["estimates"][:, :, label]

    # Get target-class ground truth
    ground_truth = shapley_loaded_test[i]["values"][-1][:, label]

    plt.figure()

    # No averaging
    dist = np.sqrt(np.sum((estimates - ground_truth) ** 2, axis=1))
    plt.plot(1 + np.arange(len(dist)), dist, label="No Averaging")

    # Uniform averaging
    averaged = np.cumsum(estimates, axis=0) / np.expand_dims(1 + np.arange(len(estimates)), 1)
    dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
    plt.plot(1 + np.arange(len(dist)), dist, label="Uniform Average")

    # Tail averaging
    t = np.expand_dims(np.arange(len(estimates)) + 1, 1)
    averaged = np.cumsum(2 * estimates * t, axis=0) / (t * (t + 1))
    dist = np.sqrt(np.sum((averaged - ground_truth) ** 2, axis=1))
    plt.plot(1 + np.arange(len(dist)), dist, label="Tail Average")

    plt.title(f'SGD Shapley Convergence (index={i})')
    plt.ylabel('L2 Distance to Ground Truth')
    plt.xlabel('# Steps')
    plt.legend()
    plt.show()

In [None]:
# Visual sanity checks
for i in range(n_examples):
    fig, axarr = plt.subplots(1, 2)
    
    axarr[0].imshow(dataset_explainer["test"][i]["image"])
    axarr[0].set_xticks([])
    axarr[0].set_yticks([])
    
    estimates = sgd_results[i]["estimates"][-1, :, sgd_results[i]["label"]]
    max_abs = np.absolute(estimates).max()
    axarr[1].imshow(estimates.reshape(14, 14), vmin=-max_abs, vmax=max_abs, cmap="seismic")
    axarr[1].set_xticks([])
    axarr[1].set_yticks([])
    
    plt.tight_layout()
    plt.show()