In [None]:
import os
import sys

os.chdir('../')

In [None]:
sys.argv=["train_surrogate.py", "configs/vitbase_imagenette_surrogate_eval.json"]

In [None]:
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import ipdb
import numpy as np
import torch
import tqdm
import transformers
from datasets import load_dataset
from PIL import Image
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.nn import functional as F
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from models import (
    SurrogateForImageClassification,
    SurrogateForImageClassificationConfig,
)
from utils import generate_mask, get_image_transform

""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.32.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/image-classification/requirements.txt",
)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
print(MODEL_CONFIG_CLASSES)
print(MODEL_TYPES)


def pil_loader(path: str):
    with open(path, "rb") as f:
        im = Image.open(f)
        return im.convert("RGB")


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
    them on the command line.
    """

    dataset_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
        },
    )
    dataset_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The configuration name of the dataset to use (via the datasets library)."
        },
    )

    dataset_cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the downloaded dataset."},
    )

    train_dir: Optional[str] = field(
        default=None, metadata={"help": "A folder containing the training data."}
    )
    validation_dir: Optional[str] = field(
        default=None, metadata={"help": "A folder containing the validation data."}
    )

    test_dir: Optional[str] = field(
        default=None, metadata={"help": "A folder containing the test data."}
    )

    train_validation_split: Optional[float] = field(
        default=0.15, metadata={"help": "Percent to split off of train for validation."}
    )

    validation_test_split: Optional[float] = field(
        default=0.5, metadata={"help": "Percent to split off of val for test."}
    )

    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_val_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of validation examples to this "
                "value if set."
            )
        },
    )
    max_test_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of test examples to this "
                "value if set."
            )
        },
    )

    def __post_init__(self):
        if self.dataset_name is None and (
            self.train_dir is None and self.validation_dir is None
        ):
            raise ValueError(
                "You must specify either a dataset name from the hub or a train and/or validation directory."
            )


@dataclass
class OtherArguments:
    extract_output: Optional[str] = field(
        default=None,
        metadata={
            "help": "Extract output from the model. If None, will not extract output with N masks."
        },
    )

    num_mask_samples: Optional[int] = field(
        default=1,
        metadata={"help": "Number of masks to use for extracting output."},
    )

    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
        },
    )

    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )


@dataclass
class ClassifierArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    classifier_model_name_or_path: str = field(
        default="google/vit-base-patch16-224-in21k",
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        },
    )
    classifier_model_type: Optional[str] = field(
        default=None,
        metadata={
            "help": "If training from scratch, pass a model type from the list: "
            + ", ".join(MODEL_TYPES)
        },
    )

    classifier_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    classifier_cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from s3"
        },
    )
    classifier_model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    classifier_image_processor_name: str = field(
        default=None, metadata={"help": "Name or path of preprocessor config."}
    )

    classifier_ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={
            "help": "Will enable to load a pretrained model whose head dimensions are different."
        },
    )


@dataclass
class SurrogateArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    surrogate_model_name_or_path: str = field(
        default="google/vit-base-patch16-224-in21k",
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        },
    )
    surrogate_model_type: Optional[str] = field(
        default=None,
        metadata={
            "help": "If training from scratch, pass a model type from the list: "
            + ", ".join(MODEL_TYPES)
        },
    )

    surrogate_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    surrogate_cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from s3"
        },
    )
    surrogate_model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    surrogate_image_processor_name: str = field(
        default=None, metadata={"help": "Name or path of preprocessor config."}
    )
    surrogate_ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={
            "help": "Will enable to load a pretrained model whose head dimensions are different."
        },
    )

In [None]:
########################################################
# Parse arguments
#######################################################
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser(
    (
        ClassifierArguments,
        SurrogateArguments,
        DataTrainingArguments,
        TrainingArguments,
        OtherArguments,
    )
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    (
        classifier_args,
        surrogate_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    (
        classifier_args,
        surrogate_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_args_into_dataclasses()

if other_args.use_auth_token is not None:
    warnings.warn(
        "The `use_auth_token` argument is deprecated and will be removed in v4.34.",
        FutureWarning,
    )
    if other_args.token is not None:
        raise ValueError(
            "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
        )
    other_args.token = other_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_image_classification", surrogate_args, data_args)

########################################################
# Setup logging
#######################################################
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

########################################################
# Set seed before initializing model.
########################################################
set_seed(training_args.seed)

########################################################
# Initialize our dataset and prepare it for the 'image-classification' task.
########################################################
if data_args.dataset_name is not None:
    if data_args.dataset_name == "frgfm/imagenette":
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=data_args.dataset_cache_dir,
            task=None,
            token=other_args.token,
        )

        for split in dataset.keys():
            if "label" in dataset[split].features:
                dataset[split] = dataset[split].rename_column("label", "labels")

    else:
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=data_args.dataset_cache_dir,
            task="image-classification",
            token=other_args.token,
        )
else:
    data_files = {}
    if data_args.train_dir is not None:
        data_files["train"] = os.path.join(data_args.train_dir, "**")
    if data_args.validation_dir is not None:
        data_files["validation"] = os.path.join(data_args.validation_dir, "**")
    if data_args.test_dir is not None:
        data_files["test"] = os.path.join(data_args.test_dir, "**")
    dataset = load_dataset(
        "imagefolder",
        data_files=data_files,
        cache_dir=surrogate_args.surrogate_cache_dir,
        task="image-classification",
    )

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_validation_split = (
    None if "validation" in dataset.keys() else data_args.train_validation_split
)
if (
    isinstance(data_args.train_validation_split, float)
    and data_args.train_validation_split > 0.0
):
    split = dataset["train"].train_test_split(data_args.train_validation_split)
    dataset["train"] = split["train"]
    dataset["validation"] = split["test"]

data_args.validation_test_split = (
    None if "test" in dataset.keys() else data_args.validation_test_split
)

if (
    isinstance(data_args.validation_test_split, float)
    and data_args.validation_test_split > 0.0
):
    split = dataset["validation"].train_test_split(data_args.validation_test_split)
    dataset["validation"] = split["train"]
    dataset["test"] = split["test"]

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

########################################################
# Initialize classifier model
########################################################
classifier_config = AutoConfig.from_pretrained(
    classifier_args.classifier_config_name
    or classifier_args.classifier_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=classifier_args.classifier_cache_dir,
    revision=classifier_args.classifier_model_revision,
    token=other_args.token,
)
classifier = AutoModelForImageClassification.from_pretrained(
    classifier_args.classifier_model_name_or_path,
    from_tf=bool(".ckpt" in classifier_args.classifier_model_name_or_path),
    config=classifier_config,
    cache_dir=classifier_args.classifier_cache_dir,
    revision=classifier_args.classifier_model_revision,
    token=other_args.token,
    ignore_mismatched_sizes=classifier_args.classifier_ignore_mismatched_sizes,
)
classifier_image_processor = AutoImageProcessor.from_pretrained(
    classifier_args.classifier_image_processor_name
    or classifier_args.classifier_model_name_or_path,
    cache_dir=classifier_args.classifier_cache_dir,
    revision=classifier_args.classifier_model_revision,
    token=other_args.token,
)

########################################################
# Initialize surrogate model
########################################################
surrogate_config = AutoConfig.from_pretrained(
    surrogate_args.surrogate_config_name
    or surrogate_args.surrogate_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=surrogate_args.surrogate_cache_dir,
    revision=surrogate_args.surrogate_model_revision,
    token=other_args.token,
)

if os.path.isfile(
    f"{surrogate_args.surrogate_model_name_or_path}/config.json"
) and (
    json.loads(
        open(f"{surrogate_args.surrogate_model_name_or_path}/config.json").read()
    )["architectures"][0]
    == "SurrogateForImageClassification"
):
    surrogate = SurrogateForImageClassification.from_pretrained(
        surrogate_args.surrogate_model_name_or_path,
        from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
        config=surrogate_config,
        cache_dir=surrogate_args.surrogate_cache_dir,
        revision=surrogate_args.surrogate_model_revision,
        token=other_args.token,
        ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    )
else:
    surrogate_for_image_classification_config = SurrogateForImageClassificationConfig(
        classifier_pretrained_model_name_or_path=classifier_args.classifier_model_name_or_path,
        classifier_config=classifier_config,
        classifier_from_tf=bool(
            ".ckpt" in classifier_args.classifier_model_name_or_path
        ),
        classifier_cache_dir=classifier_args.classifier_cache_dir,
        classifier_revision=classifier_args.classifier_model_revision,
        classifier_token=other_args.token,
        classifier_ignore_mismatched_sizes=classifier_args.classifier_ignore_mismatched_sizes,
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    )

    surrogate = SurrogateForImageClassification(
        config=surrogate_for_image_classification_config,
    )

surrogate_image_processor = AutoImageProcessor.from_pretrained(
    surrogate_args.surrogate_image_processor_name
    or surrogate_args.surrogate_model_name_or_path,
    cache_dir=surrogate_args.surrogate_cache_dir,
    revision=surrogate_args.surrogate_model_revision,
    token=other_args.token,
)

########################################################
# Align dataset to model settings
########################################################

if training_args.do_train:
    if "train" not in dataset:
        raise ValueError("--do_train requires a train dataset")
    if "validation" not in dataset:
        raise ValueError("--do_train requires a validation dataset")

if data_args.max_train_samples is not None:
    dataset["train"] = (
        dataset["train"]
        .shuffle(seed=training_args.seed)
        .select(range(data_args.max_train_samples))
    )
if data_args.max_val_samples is not None:
    dataset["validation"] = (
        dataset["validation"]
        .shuffle(seed=training_args.seed)
        .select(range(data_args.max_val_samples))
    )

if training_args.do_eval:
    if "test" not in dataset:
        raise ValueError("--do_eval requires a test dataset")

if data_args.max_test_samples is not None:
    dataset["test"] = (
        dataset["test"]
        .shuffle(seed=training_args.seed)
        .select(range(data_args.max_test_samples))
    )

########################################################
# Align dataset to model settings
########################################################
# Set the training transforms
# if training_args.do_train:
#     dataset["train_classifier"] = dataset["train"]
#     dataset["train_classifier"].set_transform(
#         get_image_transform(classifier_image_processor)["train_transform"]
#     )
#     dataset["validation_classifier"] = dataset["validation"]
#     dataset["validation_classifier"].set_transform(
#         get_image_transform(classifier_image_processor)["eval_transform"]
#     )

# # Set the validation transforms
# if training_args.do_eval:
#     dataset["test_classifier"] = dataset["test"]
#     dataset["test_classifier"].set_transform(
#         get_image_transform(classifier_image_processor)["eval_transform"]
#     )

########################################################
# Evaluate the original model
########################################################

# def collate_fn(examples):
#     pixel_values = torch.stack([example["pixel_values"] for example in examples])
#     labels = torch.tensor([example["labels"] for example in examples])
#     return {"pixel_values": pixel_values, "labels": labels}

# classifier_trainer = Trainer(
#     model=classifier,
#     args=training_args,
#     train_dataset=None,
#     eval_dataset=None,
#     compute_metrics=None,
#     tokenizer=classifier_image_processor,
#     data_collator=collate_fn,
# )
# print("classifier_trainer.label_names", classifier_trainer.label_names)
# print(classifier_trainer.evaluate(dataset["validation_classifier"]))
########################################################
# Add random generator
########################################################

def transform_mask(example_batch):
    """Add mask to example_batch"""
    if "mask_random_seed" in example_batch:
        example_batch["masks"] = [
            generate_mask(
                num_features=14 * 14,
                num_mask_samples=other_args.num_mask_samples,
                paired_mask_samples=False,
                mode="uniform",
                random_state=np.random.RandomState(
                    example_batch["mask_random_seed"][idx]
                ),
            )
            for idx in range(len(example_batch["labels"]))
        ]
    else:
        example_batch["masks"] = [
            generate_mask(
                num_features=14 * 14,
                num_mask_samples=other_args.num_mask_samples,
                paired_mask_samples=False,
                mode="uniform",
                random_state=None,
            )
            for idx in range(len(example_batch["labels"]))
        ]
    return example_batch

dataset_surrogate = dataset.copy()
dataset_surrogate["validation"] = dataset_surrogate["validation"].add_column(
    "mask_random_seed",
    iter(
        np.random.RandomState(training_args.seed).randint(
            0,
            len(dataset_surrogate["validation"]),
            size=len(dataset_surrogate["validation"]),
        )
    ),
)

dataset_surrogate["train"].set_transform(
    lambda x: transform_mask(
        get_image_transform(surrogate_image_processor)["train_transform"](x)
    )
)

dataset_surrogate["validation"].set_transform(
    lambda x: transform_mask(
        get_image_transform(surrogate_image_processor)["eval_transform"](x)
    )
)

########################################################
# Initalize the surrogate trainer
########################################################
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p):
    """Computes accuracy on a batch of predictions"""
    # import ipdb

    # ipdb.set_trace()
    # print(p.predictions.shape, p.label_ids.shape)
    # print(p)
    return metric.compute(
        predictions=np.argmax(p.predictions[0][:, 0, :], axis=1),
        references=p.label_ids,
    )

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    masks = torch.tensor(np.array([example["masks"] for example in examples]))

    return {
        "pixel_values": pixel_values,
        "labels": labels,
        "masks": masks,
    }

surrogate_trainer = Trainer(
    model=surrogate,
    args=training_args,
    train_dataset=dataset_surrogate["train"] if training_args.do_train else None,
    eval_dataset=dataset_surrogate["validation"] if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=surrogate_image_processor,
    data_collator=collate_fn,
)

# ipdb.set_trace()
# print("surrogate_trainer.label_names", surrogate_trainer.label_names)
# print(surrogate_trainer.evaluate(dataset["validation_surrogate"]))

In [None]:
# from shapreg import removal, games, shapley

In [None]:
class ShapleyValues:
    '''For storing and plotting Shapley values.'''
    def __init__(self, values, std):
        self.values = values
        self.std = std

    def plot(self,
             feature_names=None,
             sort_features=True,
             max_features=np.inf,
             orientation='horizontal',
             error_bars=True,
             color='C0',
             title='Feature Importance',
             title_size=20,
             tick_size=16,
             tick_rotation=None,
             axis_label='',
             label_size=16,
             figsize=(10, 7),
             return_fig=False):
        '''
        Plot Shapley values.

        Args:
          feature_names: list of feature names.
          sort_features: whether to sort features by their Shapley values.
          max_features: number of features to display.
          orientation: horizontal (default) or vertical.
          error_bars: whether to include standard deviation error bars.
          color: bar chart color.
          title: plot title.
          title_size: font size for title.
          tick_size: font size for feature names and numerical values.
          tick_rotation: tick rotation for feature names (vertical plots only).
          label_size: font size for label.
          figsize: figure size (if fig is None).
          return_fig: whether to return matplotlib figure object.
        '''
        return plotting.plot(
            self, feature_names, sort_features, max_features, orientation,
            error_bars, color, title, title_size, tick_size, tick_rotation,
            axis_label, label_size, figsize, return_fig)

def default_min_variance_samples(game):
    '''Determine min_variance_samples.'''
    return 5

def default_variance_batches(num_players, batch_size):
    '''
    Determine variance_batches.

    This value tries to ensure that enough samples are included to make A
    approximation non-singular.
    '''

    return int(np.ceil(10 * num_players / batch_size))

def calculate_result(A, b, total):
    '''Calculate the regression coefficients.'''
    num_players = A.shape[1]
    try:
        if len(b.shape) == 2:
            A_inv_one = np.linalg.solve(A, np.ones((num_players, 1)))
        else:
            A_inv_one = np.linalg.solve(A, np.ones(num_players))
        A_inv_vec = np.linalg.solve(A, b)
        values = (
            A_inv_vec -
            A_inv_one * (np.sum(A_inv_vec, axis=0, keepdims=True) - total)
            / np.sum(A_inv_one))
    except np.linalg.LinAlgError:
        raise ValueError('singular matrix inversion. Consider using larger '
                         'variance_batches')

    return values

def ShapleyRegressionPrecomputed(
                      grand_value,
                      null_value,
                      model_outputs,
                      masks, 
                      num_players,
                      batch_size=512,
                      detect_convergence=True,
                      thresh=0.01,
                      n_samples=None,
                      paired_sampling=True,
                      return_all=False,
                      min_variance_samples=None,
                      variance_batches=None,
                      bar=True,
                      verbose=False):
    # Verify arguments.
    from tqdm.auto import tqdm

    if min_variance_samples is None:
        min_variance_samples = 5
    else:
        assert isinstance(min_variance_samples, int)
        assert min_variance_samples > 1

    if variance_batches is None:
        variance_batches = default_variance_batches(num_players, batch_size)
    else:
        assert isinstance(variance_batches, int)
        assert variance_batches >= 1

    # Possibly force convergence detection.
    if n_samples is None:
        n_samples = 1e20
        if not detect_convergence:
            detect_convergence = True
            if verbose:
                print('Turning convergence detection on')

    if detect_convergence:
        assert 0 < thresh < 1

    # Weighting kernel (probability of each subset size).
    weights = np.arange(1, num_players)
    weights = 1 / (weights * (num_players - weights))
    weights = weights / np.sum(weights)

    # Calculate null and grand coalitions for constraints.
    null = null_value
    grand = grand_value

    # Calculate difference between grand and null coalitions.
    total = grand - null

    # Set up bar.
    n_loops = int(np.ceil(n_samples / batch_size))
    if bar:
        if detect_convergence:
            bar = tqdm(total=1)
        else:
            bar = tqdm(total=n_loops * batch_size)

    # Setup.
    n = 0
    b = 0
    A = 0
    estimate_list = []

    # For variance estimation.
    A_sample_list = []
    b_sample_list = []

    # For tracking progress.
    var = np.nan * np.ones(num_players)
    if return_all:
        N_list = []
        std_list = []
        val_list = []

    # Begin sampling.
    for it in range(n_loops):
        # Sample subsets.
        S = np.zeros((batch_size, num_players), dtype=bool)
        num_included = np.random.choice(num_players - 1, size=batch_size,
                                        p=weights) + 1
        for row, num in zip(S, num_included):
            inds = np.random.choice(num_players, size=num, replace=False)
            row[inds] = 1
            
        S=masks[batch_size*it:batch_size*(it+1)]
        game_S=model_outputs[batch_size*it:batch_size*(it+1)]
#         print("S", S, S.sum(axis=1))
#         print("game(s)", game_S)
#         print("game(s)-null", game_S-null)


        A_sample = np.matmul(S[:, :, np.newaxis].astype(float),
                             S[:, np.newaxis, :].astype(float))


        b_sample = (S.astype(float).T
                    * (game_S - null)[:, np.newaxis].T).T
        
#         print("b", b_sample)
#         print("variance_batches", variance_batches)

        # Welford's algorithm.
        n += batch_size
        b += np.sum(b_sample - b, axis=0) / n
        A += np.sum(A_sample - A, axis=0) / n

        # Calculate progress.
        values = calculate_result(A, b, total)
        A_sample_list.append(A_sample)
        b_sample_list.append(b_sample)
        if len(A_sample_list) == variance_batches:
            # Aggregate samples for intermediate estimate.
            A_sample = np.concatenate(A_sample_list, axis=0).mean(axis=0)
            b_sample = np.concatenate(b_sample_list, axis=0).mean(axis=0)
            A_sample_list = []
            b_sample_list = []

            # Add new estimate.
            estimate_list.append(calculate_result(A_sample, b_sample, total))

            # Estimate current var.
            # print(len(estimate_list), min_variance_samples)
            if len(estimate_list) >= min_variance_samples:
                var = np.array(estimate_list).var(axis=0)

        # Convergence ratio.
        std = np.sqrt(var * variance_batches / (it + 1))
        ratio = np.max(
            np.max(std, axis=0) / (values.max(axis=0) - values.min(axis=0)))
        # print("std", var)
        # Print progress message.
        if verbose:
            if detect_convergence:
                print(f'StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})')
            else:
                print(f'StdDev Ratio = {ratio:.4f}')

        # Check for convergence.
        if detect_convergence:
            if ratio < thresh:
                if verbose:
                    print('Detected convergence')

                # Skip bar ahead.
                if bar:
                    bar.n = bar.total
                    bar.refresh()
                break

        # Forecast number of iterations required.
        if detect_convergence:
            N_est = (it + 1) * (ratio / thresh) ** 2
            if bar and not np.isnan(N_est):
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()
        elif bar:
            bar.update(batch_size)

        # Save intermediate quantities.
        if return_all:
            val_list.append(values)
            std_list.append(std)
            if detect_convergence:
                N_list.append(N_est)
        
        # print("size", batch_size*it, len(masks))
        if batch_size*(it+1)>=len(masks):
            break
    print(ratio)
    # Return results.
    if return_all:
        # Dictionary for progress tracking.
        iters = (
            (np.arange(it + 1) + 1) * batch_size *
            (1 + int(paired_sampling)))
        tracking_dict = {
            'values': val_list,
            'std': std_list,
            'iters': iters}
        if detect_convergence:
            tracking_dict['N_est'] = N_list

        return ShapleyValues(values, std), tracking_dict
    else:
        return ShapleyValues(values, std)

In [None]:
# game = games.PredictionGame_torchimagetensor(surrogate_SHAP_wrapped, x)
# explanation = shapley.ShapleyRegression(game, batch_size=batch_size, thresh=thresh, variance_batches=variance_batches)
# return explanation

In [None]:
train_val_eval=torch.load("logs/vitbase_imagenette_surrogate_eval/extract_output_concat.pt", map_location="cpu")
grandnull=torch.load("logs/vitbase_imagenette_surrogate_eval3/extract_output_concat.pt", map_location="cpu")

In [None]:
test_eval=torch.load("logs/vitbase_imagenette_surrogate_eval2/extract_output_concat.pt", map_location="cpu")

In [None]:
surrogate_eval_dict={}

In [None]:
print(train_val_eval.keys())
print(test_eval.keys())
print(grandnull.keys())

In [None]:
surrogate_eval_dict.update(train_val_eval)
surrogate_eval_dict.update(test_eval)
surrogate_eval_dict.update(grandnull)

In [None]:
surrogate_eval_dict.keys()

In [None]:
torch.save(surrogate_eval_dict, "logs/vitbase_imagenette_surrogate_eval/extract_output_all.pt")

In [None]:
def prepare_input(phase, idx):
    grand=surrogate_eval_dict[f"{phase}_grand_null_logits"][idx][1]
    null=surrogate_eval_dict[f"{phase}_grand_null_logits"][idx][0]
    if phase=="train":
        logits=surrogate_eval_dict[f"{phase}_logits"][idx]
        masks=surrogate_eval_dict[f"{phase}_masks"][idx]
    elif phase=='validation':
        logits=surrogate_eval_dict[f"{phase}_logits"][idx]
        masks=surrogate_eval_dict[f"{phase}_masks"][idx]
    elif phase=='test':
        logits=surrogate_eval_dict[f"{phase}_logits"][idx]
        masks=surrogate_eval_dict[f"{phase}_masks"][idx]        
        
    return {"grand": grand,
            "null": null,
            "logits": logits,
            "masks": masks,
           }

In [None]:
from scipy.special import softmax
batch_size=512

shapley_values_dict={}

for phase in ["test"]:
    shapley_values_dict.setdefault(phase,[])
    for idx in tqdm.tqdm(range((len(surrogate_eval_dict[f"{phase}_logits"])))):
        surrogate_eval=prepare_input(phase, idx)

        explanation=ShapleyRegressionPrecomputed(grand_value=softmax(surrogate_eval["grand"], axis=0),
                      null_value=softmax(surrogate_eval["null"], axis=0),
                      model_outputs=softmax(surrogate_eval["logits"], axis=1),
                      masks=surrogate_eval["masks"],
                      batch_size=batch_size,
                      num_players=196,
                      variance_batches=2,
                      min_variance_samples=2,
                      return_all=True,
                      bar=False
                      )  
        num_mask_shapley={}
        for it in range(len(explanation[1]["values"])):
            num_mask_shapley[batch_size*it]=explanation[1]["values"][it]            
        shapley_values_dict[phase].append(num_mask_shapley)

In [None]:
# 0.5 vs 0.2

In [None]:
num_eval_ground_truth=99840

record_dict_list=[]

for sample_idx, num_eval_shapley_values in enumerate(shapley_values_dict["test"]):
    target_class_idx=np.argmax(num_eval_shapley_values[99840].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[99840])
        mse_class=(diff*diff).sum(axis=0)
        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
        })

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import font_manager
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_target",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.5))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)                   

In [None]:
num_eval_ground_truth=99840

mse_dict={
    "target": {},
    "non_target": {},
    "all": {}
}

for num_eval_shapley_values in shapley_values_dict["test"]:
    target_class_idx=np.argmax(num_eval_shapley_values[99840].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[99840])
        mse_class=(diff*diff).mean(axis=0)
        
        mse_dict["target"].setdefault(num_eval, []).append(mse_class[np.arange(len(mse_class))==target_class_idx].mean())
        mse_dict["non_target"].setdefault(num_eval, []).append(mse_class[np.arange(len(mse_class))!=target_class_idx].mean())
        mse_dict["all"].setdefault(num_eval, []).append(mse_class[:].mean())

In [None]:
pd.DataFrame(mse_dict["target"])

In [None]:
import pandas as pd

In [None]:
mse_dict

In [None]:
torch.save(shapley_values_dict, "logs/vitbase_imagenette_surrogate_eval/shapley.pt")

In [None]:
torch.save(shapley_values_dict, "logs/vitbase_imagenette_surrogate_eval/shapley.pt")

In [None]:
explanation[1]["values"][1]

In [None]:
explanation[0].values

In [None]:
train_eval.keys()

In [None]:
output0=torch.load("logs/vitbase_imagenette_surrogate_eval/extract_output_concat.pt", map_location="cpu")

In [None]:
output2=torch.load("logs/vitbase_imagenette_surrogate_eval2/extract_output.pt", map_location="cpu")

In [None]:
output3=torch.load("logs/vitbase_imagenette_surrogate_eval3/extract_output.pt", map_location="cpu")

In [None]:
for key,value in output3.items():
    output3[key]=np.concatenate(value, axis=1)
    output3[key]=[v for v in output3[key]]
    print(key, len(output3[key]))

In [None]:
torch.save(output3, "logs/vitbase_imagenette_surrogate_eval3/extract_output_concat.pt")

In [None]:
32*100

In [None]:
len(value),value[0].shape

In [None]:
output0

In [None]:
torch.save(output0, "logs/vitbase_imagenette_surrogate_eval/extract_output_concat.pt")

In [None]:
print(1)

In [None]:
pickle.dump(output, open("logs/vitbase_imagenette_surrogate_eval/extract_output_all.pt", 'w'), protocol=4)

In [None]:
torch.save(output, "logs/vitbase_imagenette_surrogate_eval/extract_output_all.pt", pickle_protocol=4)

In [None]:
output.keys()

In [None]:
output2.keys()

In [None]:
for key,value in output0.items():
    print(key, np.concatenate(value, axis=1).shape)

In [None]:
for key,value in output2.items():
    print(key, np.concatenate(value, axis=1).shape)

In [None]:
for key,value in output3.items():
    print(key, np.concatenate(value, axis=1).shape)

In [None]:
########################################################
# Detecting last checkpoint
#######################################################
last_checkpoint = None
if (
    os.path.isdir(training_args.output_dir)
    and training_args.do_train
    and not training_args.overwrite_output_dir
):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif (
        last_checkpoint is not None and training_args.resume_from_checkpoint is None
    ):
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )

if other_args.extract_output:
    if (
        isinstance(other_args.extract_output, str)
        and "," in other_args.extract_output
    ):
        extract_output_key = {
            "train": int(other_args.extract_output.split(",")[0]),
            "validation": int(other_args.extract_output.split(",")[1]),
            "test": int(other_args.extract_output.split(",")[2]),
        }
    else:
        extract_output_key = {
            "train": int(other_args.extract_output),
            "validation": int(other_args.extract_output),
            "test": int(other_args.extract_output),
        }

    def transform_mask_grand_null(example_batch):
        """Add mask to example_batch"""
        example_batch["masks"] = [
            generate_mask(
                num_features=14 * 14,
                num_mask_samples=2,
                paired_mask_samples=True,
                mode="empty",
                random_state=None,
            )
            for idx in range(len(example_batch["labels"]))
        ]
        return example_batch

    def transform_mask(example_batch):
        """Add mask to example_batch"""
        if "mask_random_seed" in example_batch:
            example_batch["masks"] = [
                generate_mask(
                    num_features=14 * 14,
                    num_mask_samples=other_args.num_mask_samples,
                    paired_mask_samples=False,
                    mode="shapley",
                    random_state=np.random.RandomState(
                        example_batch["mask_random_seed"][idx]
                    ),
                )
                for idx in range(len(example_batch["labels"]))
            ]
        else:
            example_batch["masks"] = [
                generate_mask(
                    num_features=14 * 14,
                    num_mask_samples=other_args.num_mask_samples,
                    paired_mask_samples=False,
                    mode="shapley",
                    random_state=None,
                )
                for idx in range(len(example_batch["labels"]))
            ]
        return example_batch

    # for dataset.keys()
    import copy

    dataset_extract = copy.deepcopy(dataset)
    # dataset_extract = dataset.copy()
    # for key in dataset_extract.keys():
    save_dict = {}
    for key in dataset_extract.keys():
        dataset_extract[key].set_transform(
            lambda x: transform_mask_grand_null(
                get_image_transform(surrogate_image_processor)["eval_transform"](x)
            )
        )
        predict_output = surrogate_trainer.predict(dataset_extract[key])
        assert all(
            predict_output.label_ids
            == dataset_extract[key].with_transform(lambda x: x)["labels"]
        )
        save_dict.setdefault(key + "_grand_null_logits", []).append(
            predict_output.predictions[0]
        )
        save_dict.setdefault(key + "_grand_null_masks", []).append(
            predict_output.predictions[1]
        )
        # continue

        dataset_extract[key].set_transform(
            lambda x: transform_mask(
                get_image_transform(surrogate_image_processor)["eval_transform"](x)
            )
        )
        for idx in tqdm.tqdm(
            range(
                (extract_output_key[key] + other_args.num_mask_samples - 1)
                // other_args.num_mask_samples
            )
        ):
            predict_output = surrogate_trainer.predict(dataset_extract[key])
            assert all(
                predict_output.label_ids
                == dataset_extract[key].with_transform(lambda x: x)["labels"]
            )
            save_dict.setdefault(key + "_logits", []).append(
                predict_output.predictions[0]
            )
            save_dict.setdefault(key + "_masks", []).append(
                predict_output.predictions[1]
            )
    ipdb.set_trace()
    # save to file
    # np.concatenate(save_dict["train_logits"])
    torch.save(
        save_dict, os.path.join(training_args.output_dir, "extract_output.pt")
    )
    # torch.save(dataset_extract, ("logs/extract_output.dataset.pt"))

    # torch.save(save_dict, os.path.join(training_args.output_dir, "extract_output.pt"))
    # torch.save(save_dict, os.path.join(training_args.output_dir, "extract_output.pt"))
    # # save dataset_extract
    # dataset_extract.save_to_disk(os.path.join(training_args.output_dir, "dataset_extract"))

########################################################
# Training
#######################################################
if training_args.do_train:
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = surrogate_trainer.train(resume_from_checkpoint=checkpoint)
    surrogate_trainer.save_model()
    surrogate_trainer.log_metrics("train", train_result.metrics)
    surrogate_trainer.save_metrics("train", train_result.metrics)
    surrogate_trainer.save_state()

########################################################
# Evaluation
#######################################################
if training_args.do_eval:
    metrics = surrogate_trainer.evaluate()
    surrogate_trainer.log_metrics("eval", metrics)
    surrogate_trainer.save_metrics("eval", metrics)

########################################################
# Write model card and (optionally) push to hub
#######################################################
kwargs = {
    "finetuned_from": surrogate_args.surrogate_model_name_or_path,
    "tasks": "image-classification",
    "dataset": data_args.dataset_name,
    "tags": ["image-classification", "vision"],
}
if training_args.push_to_hub:
    surrogate_trainer.push_to_hub(**kwargs)
else:
    surrogate_trainer.create_model_card(**kwargs)


if __name__ == "__main__":
main()

In [None]:
from utils.shapreg import removal, games, shapley

In [None]:
explanation = shapley.ShapleyRegression(game, batch_size=batch_size, thresh=thresh, variance_batches=variance_batches)
    return explanation