In [None]:
import os
import sys

os.chdir('../')

In [None]:
!gpustat

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="4"

In [None]:
sys.argv=["train_explainer_regression.py", "configs/vitbase_imagenette_explainer_objective.json"]

In [None]:
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import ipdb
import numpy as np
import torch
import transformers
from datasets import load_dataset
from PIL import Image
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.nn import functional as F
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from models import (
    ExplainerForImageClassification,
    ExplainerForImageClassificationConfig,
    SurrogateForImageClassificationConfig,
)
from utils import generate_mask, get_image_transform

""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.32.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/image-classification/requirements.txt",
)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
print(MODEL_CONFIG_CLASSES)
print(MODEL_TYPES)


def pil_loader(path: str):
    with open(path, "rb") as f:
        im = Image.open(f)
        return im.convert("RGB")


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
    them on the command line.
    """

    dataset_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
        },
    )
    dataset_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The configuration name of the dataset to use (via the datasets library)."
        },
    )

    dataset_cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where to store the downloaded dataset."},
    )
    train_dir: Optional[str] = field(
        default=None, metadata={"help": "A folder containing the training data."}
    )
    validation_dir: Optional[str] = field(
        default=None, metadata={"help": "A folder containing the validation data."}
    )
    test_dir: Optional[str] = field(
        default=None, metadata={"help": "A folder containing the test data."}
    )
    train_validation_split: Optional[float] = field(
        default=0.15, metadata={"help": "Percent to split off of train for validation."}
    )

    validation_test_split: Optional[float] = field(
        default=0.5, metadata={"help": "Percent to split off of val for test."}
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )

    max_val_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of validation examples to this "
                "value if set."
            )
        },
    )
    max_test_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of test examples to this "
                "value if set."
            )
        },
    )

    def __post_init__(self):
        if self.dataset_name is None and (
            self.train_dir is None and self.validation_dir is None
        ):
            raise ValueError(
                "You must specify either a dataset name from the hub or a train and/or validation directory."
            )


@dataclass
class OtherArguments:
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
        },
    )

    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )


@dataclass
class SurrogateArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    surrogate_model_name_or_path: str = field(
        default="google/vit-base-patch16-224-in21k",
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        },
    )
    surrogate_model_type: Optional[str] = field(
        default=None,
        metadata={
            "help": "If training from scratch, pass a model type from the list: "
            + ", ".join(MODEL_TYPES)
        },
    )

    surrogate_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    surrogate_cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from s3"
        },
    )
    surrogate_model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    surrogate_image_processor_name: str = field(
        default=None, metadata={"help": "Name or path of preprocessor config."}
    )

    surrogate_ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={
            "help": "Will enable to load a pretrained model whose head dimensions are different."
        },
    )


@dataclass
class ExplainerArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    explainer_model_name_or_path: str = field(
        default="google/vit-base-patch16-224-in21k",
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        },
    )
    explainer_model_type: Optional[str] = field(
        default=None,
        metadata={
            "help": "If training from scratch, pass a model type from the list: "
            + ", ".join(MODEL_TYPES)
        },
    )

    explainer_config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if not the same as model_name"
        },
    )
    explainer_cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Where do you want to store the pretrained models downloaded from s3"
        },
    )
    explainer_model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific model version to use (can be a branch name, tag name or commit id)."
        },
    )
    explainer_image_processor_name: str = field(
        default=None, metadata={"help": "Name or path of preprocessor config."}
    )
    explainer_ignore_mismatched_sizes: bool = field(
        default=False,
        metadata={
            "help": "Will enable to load a pretrained model whose head dimensions are different."
        },
    )

In [None]:
########################################################
# Parse arguments
#######################################################
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser(
    (
        SurrogateArguments,
        ExplainerArguments,
        DataTrainingArguments,
        TrainingArguments,
        OtherArguments,
    )
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_args_into_dataclasses()

if other_args.use_auth_token is not None:
    warnings.warn(
        "The `use_auth_token` argument is deprecated and will be removed in v4.34.",
        FutureWarning,
    )
    if other_args.token is not None:
        raise ValueError(
            "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
        )
    other_args.token = other_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_image_classification", explainer_args, data_args)

########################################################
# Setup logging
#######################################################
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

########################################################
# Set seed before initializing model.
########################################################
set_seed(training_args.seed)

########################################################
# Initialize our dataset and prepare it for the 'image-classification' task.
########################################################
if data_args.dataset_name is not None:
    if data_args.dataset_name == "frgfm/imagenette":
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=data_args.dataset_cache_dir,
            task=None,
            token=other_args.token,
        )

        for split in dataset.keys():
            if "label" in dataset[split].features:
                dataset[split] = dataset[split].rename_column("label", "labels")

    else:
        dataset = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=data_args.dataset_cache_dir,
            task="image-classification",
            token=other_args.token,
        )
else:
    data_files = {}
    if data_args.train_dir is not None:
        data_files["train"] = os.path.join(data_args.train_dir, "**")
    if data_args.validation_dir is not None:
        data_files["validation"] = os.path.join(data_args.validation_dir, "**")
    if data_args.test_dir is not None:
        data_files["test"] = os.path.join(data_args.test_dir, "**")
    dataset = load_dataset(
        "imagefolder",
        data_files=data_files,
        cache_dir=surrogate_args.surrogate_cache_dir,
        task="image-classification",
    )

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_validation_split = (
    None if "validation" in dataset.keys() else data_args.train_validation_split
)
if (
    isinstance(data_args.train_validation_split, float)
    and data_args.train_validation_split > 0.0
):
    split = dataset["train"].train_test_split(data_args.train_validation_split)
    dataset["train"] = split["train"]
    dataset["validation"] = split["test"]

data_args.validation_test_split = (
    None if "test" in dataset.keys() else data_args.validation_test_split
)

if (
    isinstance(data_args.validation_test_split, float)
    and data_args.validation_test_split > 0.0
):
    split = dataset["validation"].train_test_split(data_args.validation_test_split)
    dataset["validation"] = split["train"]
    dataset["test"] = split["test"]

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

########################################################
# Initialize explainer model
########################################################
surrogate_config = AutoConfig.from_pretrained(
    surrogate_args.surrogate_config_name
    or surrogate_args.surrogate_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=surrogate_args.surrogate_cache_dir,
    revision=surrogate_args.surrogate_model_revision,
    token=other_args.token,
)
surrogate_for_image_classification_config = SurrogateForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_config,
    surrogate_from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
)
explainer_config = AutoConfig.from_pretrained(
    explainer_args.explainer_config_name
    or explainer_args.explainer_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)
explainer_for_image_classification_config = ExplainerForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_for_image_classification_config,
    surrogate_from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
    explainer_config=explainer_config,
    explainer_from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
    explainer_cache_dir=explainer_args.explainer_cache_dir,
    explainer_revision=explainer_args.explainer_model_revision,
    explainer_token=other_args.token,
    explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
)

explainer = ExplainerForImageClassification(
    config=explainer_for_image_classification_config,
)
explainer_image_processor = AutoImageProcessor.from_pretrained(
    explainer_args.explainer_image_processor_name
    or explainer_args.explainer_model_name_or_path,
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

########################################################
# Align dataset to model settings
########################################################

if training_args.do_train:
    if "train" not in dataset:
        raise ValueError("--do_train requires a train dataset")
    if "validation" not in dataset:
        raise ValueError("--do_train requires a validation dataset")

if data_args.max_train_samples is not None:
    dataset["train"] = (
        dataset["train"]
        .shuffle(seed=training_args.seed)
        .select(range(data_args.max_train_samples))
    )
if data_args.max_val_samples is not None:
    dataset["validation"] = (
        dataset["validation"]
        .shuffle(seed=training_args.seed)
        .select(range(data_args.max_val_samples))
    )

if training_args.do_eval:
    if "test" not in dataset:
        raise ValueError("--do_eval requires a test dataset")

if data_args.max_test_samples is not None:
    dataset["test"] = (
        dataset["test"]
        .shuffle(seed=training_args.seed)
        .select(range(data_args.max_test_samples))
    )

########################################################
# Add random generator
########################################################
dataset["train_explainer"] = dataset["train"]
dataset["validation_explainer"] = dataset["validation"]
dataset["validation_explainer"] = dataset["validation_explainer"].add_column(
    "mask_random_seed",
    iter(
        np.random.RandomState(training_args.seed).randint(
            0,
            len(dataset["validation_explainer"]),
            size=len(dataset["validation_explainer"]),
        )
    ),
)

# loaded = torch.load(
#     "logs/vitbase_imagenette_surrogate_eval/extract_output_all.pt",
#     map_location="cpu",
# )

def tranform_mask(example_batch):
    """Add mask to example_batch"""

    mask_full_null = generate_mask(
        num_features=14 * 14,
        num_mask_samples=2,
        paired_mask_samples=True,
        mode="full",
        random_state=None,
    )
    if "mask_random_seed" in example_batch:
        example_batch["masks"] = [
            np.vstack(
                [
                    generate_mask(
                        num_features=14 * 14,
                        num_mask_samples=32,
                        paired_mask_samples=False,
                        mode="shapley",
                        random_state=np.random.RandomState(
                            example_batch["mask_random_seed"][idx]
                        ),
                    ),
                ]
            )
            for idx in range(len(example_batch["labels"]))
        ]
    else:
        example_batch["masks"] = [
            np.vstack(
                [
                    generate_mask(
                        num_features=14 * 14,
                        num_mask_samples=32,
                        paired_mask_samples=False,
                        mode="shapley",
                        random_state=None,
                    ),
                ]
            )
            for idx in range(len(example_batch["labels"]))
        ]
    return example_batch

dataset["train_explainer"].set_transform(
    lambda x: tranform_mask(
        get_image_transform(explainer_image_processor)["eval_transform"](x)
    )
)

dataset["validation_explainer"].set_transform(
    lambda x: tranform_mask(
        get_image_transform(explainer_image_processor)["eval_transform"](x)
    )
)
# import ipdb

# ipdb.set_trace()

########################################################
# Initalize the explainer trainer
########################################################
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p):
    """Computes accuracy on a batch of predictions"""
    # import ipdb

    # ipdb.set_trace()
    # print(p.predictions.shape, p.label_ids.shape)
    # return metric.compute(
    #     predictions=np.argmax(p.predictions[:, 0, :], axis=1),
    #     references=p.label_ids,
    # )
    return {}

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    masks = torch.tensor(np.array([example["masks"] for example in examples]))

    return {
        "pixel_values": pixel_values,
        "labels": labels,
        "masks": masks,
    }

explainer_trainer = Trainer(
    model=explainer,
    args=training_args,
    train_dataset=dataset["train_explainer"] if training_args.do_train else None,
    eval_dataset=dataset["validation_explainer"] if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=explainer_image_processor,
    data_collator=collate_fn,
)

In [None]:
from models import (
    RegExplainerForImageClassification,
    RegExplainerForImageClassificationConfig
)

In [None]:
regexplainer_for_image_classification_config = RegExplainerForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_for_image_classification_config,
    surrogate_from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
    explainer_config=explainer_config,
    explainer_from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
    explainer_cache_dir=explainer_args.explainer_cache_dir,
    explainer_revision=explainer_args.explainer_model_revision,
    explainer_token=other_args.token,
    explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
)

regexplainer = RegExplainerForImageClassification(
    config=regexplainer_for_image_classification_config,
)


In [None]:
ionWarning: The get_cmap function was deprecated in Matplotlib 3.7 
    
    and will be removed two minor releases later. Use 
    
    ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
    
    
    

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import matplotlib as mpl

def plot_figure(explainer, dataset, sample_idx_list):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                explainer.eval()
                with torch.no_grad():
                    explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
                    explanation=explanation["logits"][0]
                #print(explanation.shape)
                if len(explanation.shape)==2:
                    explanation_class=explanation[plot_type].detach().cpu().numpy()
                else:
                    explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=mpl.colormaps['Greys'](1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
    return fig

In [None]:
dataset["test_explainer"]=dataset["test"]
dataset["test_explainer"].set_transform(
    lambda x: 
        get_image_transform(explainer_image_processor)["eval_transform"](x)
    
)

In [None]:
dataset["test_explainer"][0]

In [None]:
regexplainer.device

In [None]:
regexplainer(pixel_values=data["pixel_values"].unsqueeze(0), return_loss=False)["logits"][0]

In [None]:
for data in dataset["test_explainer"]:
    regexplainer.eval()
    with torch.no_grad():
        regexplainer(pixel_values=data["pixel_values"].unsqueeze(0), return_loss=False)
        

In [None]:
shapley_values_dict=torch.load("logs/vitbase_imagenette_surrogate_eval/shapley.pt")

In [None]:
len(shapley_values_dict["test"])

In [None]:
num_eval_ground_truth=99840

mse_dict={
    "target": {},
    "non_target": {},
    "all": {}
}

for num_eval_shapley_values in shapley_values_dict["test"]:
    target_class_idx=np.argmax(num_eval_shapley_values[99840].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[99840])
        mse_class=(diff*diff).mean(axis=0)
        
        mse_dict["target"].setdefault(num_eval, []).append(mse_class[np.arange(len(mse_class))==target_class_idx].mean())
        mse_dict["non_target"].setdefault(num_eval, []).append(mse_class[np.arange(len(mse_class))!=target_class_idx].mean())
        mse_dict["all"].setdefault(num_eval, []).append(mse_class[:].mean())

# per_sample

In [None]:
num_eval_ground_truth=99840

record_dict_list=[]

for sample_idx, num_eval_shapley_values in enumerate(shapley_values_dict["test"]):
    target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)
        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "per-sample",
        })

In [None]:
shapley_values.shape

In [None]:
99840/512

In [None]:
shapley_estimated.shape

# regression

In [None]:
import glob
for model_path_reg in ['logs/vitbase_imagenette_explainer_regression_0',
                       'logs/vitbase_imagenette_explainer_regression_512',                       
                       'logs/vitbase_imagenette_explainer_regression_1024',
                       'logs/vitbase_imagenette_explainer_regression_1536']:
    num_eval=int(model_path_reg.split('_')[-1])+512
    state_dict = torch.load(f"{model_path_reg}/pytorch_model.bin", map_location="cpu")
    regexplainer.load_state_dict(state_dict)
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset["test_explainer"], 
                sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
    fig.suptitle(f"Reg-AO {num_eval}")    

In [None]:
import glob
for model_path_reg in ['logs/vitbase_imagenette_explainer_regression_0',
                       'logs/vitbase_imagenette_explainer_regression_1024',
                       'logs/vitbase_imagenette_explainer_regression_512',
                       'logs/vitbase_imagenette_explainer_regression_1536']:
    num_eval=int(model_path_reg.split('_')[-1])+512
    state_dict = torch.load(f"{model_path_reg}/checkpoint-1480/pytorch_model.bin", map_location="cpu")
    regexplainer.load_state_dict(state_dict)
#     fig=plot_figure(explainer=regexplainer, 
#                 dataset=dataset["test_explainer"], 
#                 sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
#     fig.suptitle(f"Reg-AO {num_eval}")    
    for sample_idx, (num_eval_shapley_values, data) in enumerate(zip(shapley_values_dict["test"], dataset["test_explainer"])):
        target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
        regexplainer.eval()
        with torch.no_grad():
            shapley_estimated=regexplainer(pixel_values=data["pixel_values"].unsqueeze(0), return_loss=False)["logits"][0]
        diff=(shapley_estimated.T.detach().numpy()-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "regression_AO",            
        })
        

In [None]:
import tqdm

In [None]:
import glob
for model_path_obj in sorted(glob.glob("logs/vitbase_imagenette_explainer_objective/checkpoint-*"), key=lambda x: int(x.split('-')[-1])):
    num_eval=int(model_path_obj.split('-')[-1])/148*32
    if int(int(model_path_obj.split('-')[-1])/148)%10!=1:
        continue
    state_dict = torch.load(f"{model_path_obj}/pytorch_model.bin", map_location="cpu")
    explainer.load_state_dict(state_dict)
#     fig=plot_figure(explainer=regexplainer, 
#                 dataset=dataset["test_explainer"], 
#                 sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
#     fig.suptitle(f"Reg-AO {num_model_eval}")    
    for sample_idx, (num_eval_shapley_values, data) in enumerate(zip(tqdm.tqdm(shapley_values_dict["test"]), dataset["test_explainer"])):        
        target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
        explainer.eval()
        with torch.no_grad():
            shapley_estimated=explainer(pixel_values=data["pixel_values"].unsqueeze(0).to(explainer.device), return_loss=False)["logits"][0]
        diff=(shapley_estimated.T.cpu().detach().numpy()-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "objective_AO",            
        })

In [None]:
int(int(model_path_obj.split('-')[-1])/148)%10

In [None]:
model_path_obj

In [None]:
444/148

In [None]:
explainer.device

In [None]:
import pandas as pd

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import font_manager
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_target",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.5))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
import pandas as pd

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df,
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_target",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(1000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(500))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)                   

axd[plot_key].set_xlim(0,3500)
axd[plot_key].set_ylim(0, 0.1)

axd[plot_key].set_title("Target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(1000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(500))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)                   

axd[plot_key].set_xlim(0,3500)
axd[plot_key].set_ylim(0, 0.01)

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
500/32

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_explainer_regression_0/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_explainer_regression/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)

In [None]:
explainer_out=explainer.forward(pixel_values=dataset["validation_explainer"][0]['pixel_values'].unsqueeze(0),
                               return_loss=False)

In [None]:
import tqdm

In [None]:
!gpustat

In [None]:
device="cuda:7"

In [None]:
explainer.to(device)
explainer.surrogate_null=explainer.surrogate_null.to(device)

In [None]:
plot_figure(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
plot_figure(explainer, dataset["validation_explainer"], [0,  250, 500,1000])

In [None]:
dataset["validation"][0]["image"]

In [None]:
dataset["test"][0]["image"]

In [None]:
dataset["test"] = (

In [None]:
data_args.max_test_samples

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib import cm

def plot_figure_shapley(explainer, dataset, sample_idx_list, shapley_value, shapley_value_key):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(shapley_value[sample_idx].keys()))
                #print(plot_type, shapley_value[sample_idx][shapley_value_key].shape)
                explanation_class=shapley_value[sample_idx][shapley_value_key][:,plot_type]
                #print(explanation_class.shape)
#                 print(explanation_class.shape, plot_type)
#                 explainer.eval()
#                 with torch.no_grad():
#                     explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
#                     explanation=explanation["logits"][0]
#                 if len(explanation.shape)==2:
#                     explanation_class=explanation[plot_type].detach().cpu().numpy()
#                 else:
#                     explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
                    

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 99840)

In [None]:
512*10

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 5120)

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 1536)

In [None]:
max(shapley_values_test["test"][0].keys())

In [None]:
99840+512

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 99840)

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 1536)

In [None]:
plot_figure_shapley(explainer, dataset["validation_explainer"], [0,  250, 500,1000],
                    shapley_values["validation"], 1536)

In [None]:
plot_figure(explainer, dataset["test_explainer"], [0,  10,20,30])

In [None]:
plot_figure(explainer, dataset["test"], [0,  250, 500,1000])

In [None]:
shapley_values = torch.load(
    "logs/vitbase_imagenette_surrogate_eval/shapley_train_val.pt",
    map_location="cpu",
)

In [None]:
shapley_values_test = torch.load(
    "logs/vitbase_imagenette_surrogate_eval/shapley.pt",
    map_location="cpu",
)

In [None]:
shapley_values_test.keys()

In [None]:
########################################################
# Initalize the explainer trainer
########################################################
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p):
    """Computes accuracy on a batch of predictions"""
    # import ipdb

    # ipdb.set_trace()
    # print(p.predictions.shape, p.label_ids.shape)
    # return metric.compute(
    #     predictions=np.argmax(p.predictions[:, 0, :], axis=1),
    #     references=p.label_ids,
    # )
    return {}

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    masks = torch.tensor(np.array([example["masks"] for example in examples]))

    return {
        "pixel_values": pixel_values,
        "labels": labels,
        "masks": masks,
    }

explainer_trainer = Trainer(
    model=explainer,
    args=training_args,
    train_dataset=dataset["train_explainer"] if training_args.do_train else None,
    eval_dataset=dataset["validation_explainer"] if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=explainer_image_processor,
    data_collator=collate_fn,
)

# ipdb.set_trace()
# print("explainer_trainer.label_names", explainer_trainer.label_names)
# print(explainer_trainer.evaluate(dataset["validation_explainer"]))

########################################################
# Detecting last checkpoint
#######################################################
last_checkpoint = None
if (
    os.path.isdir(training_args.output_dir)
    and training_args.do_train
    and not training_args.overwrite_output_dir
):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif (
        last_checkpoint is not None and training_args.resume_from_checkpoint is None
    ):
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )

########################################################
# Training
#######################################################
if training_args.do_train:
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = explainer_trainer.train(resume_from_checkpoint=checkpoint)
    explainer_trainer.save_model()
    explainer_trainer.log_metrics("train", train_result.metrics)
    explainer_trainer.save_metrics("train", train_result.metrics)
    explainer_trainer.save_state()

########################################################
# Evaluation
#######################################################
if training_args.do_eval:
    metrics = explainer_trainer.evaluate()
    explainer_trainer.log_metrics("eval", metrics)
    explainer_trainer.save_metrics("eval", metrics)

########################################################
# Write model card and (optionally) push to hub
#######################################################
kwargs = {
    "finetuned_from": explainer_args.explainer_model_name_or_path,
    "tasks": "image-classification",
    "dataset": data_args.dataset_name,
    "tags": ["image-classification", "vision"],
}
if training_args.push_to_hub:
    explainer_trainer.push_to_hub(**kwargs)
else:
    explainer_trainer.create_model_card(**kwargs)

In [None]:
    
    
    ########################################################
    # Initalize the explainer trainer
    ########################################################
    # Load the accuracy metric from the datasets package
    metric = evaluate.load("accuracy")

    # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p):
        """Computes accuracy on a batch of predictions"""
        # import ipdb

        # ipdb.set_trace()
        # print(p.predictions.shape, p.label_ids.shape)
        # return metric.compute(
        #     predictions=np.argmax(p.predictions[:, 0, :], axis=1),
        #     references=p.label_ids,
        # )
        return {}

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["labels"] for example in examples])
        masks = torch.tensor(np.array([example["masks"] for example in examples]))

        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "masks": masks,
        }

    explainer_trainer = Trainer(
        model=explainer,
        args=training_args,
        train_dataset=dataset["train_explainer"] if training_args.do_train else None,
        eval_dataset=dataset["validation_explainer"] if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=explainer_image_processor,
        data_collator=collate_fn,
    )

    # ipdb.set_trace()
    # print("explainer_trainer.label_names", explainer_trainer.label_names)
    # print(explainer_trainer.evaluate(dataset["validation_explainer"]))

    ########################################################
    # Detecting last checkpoint
    #######################################################
    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif (
            last_checkpoint is not None and training_args.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    ########################################################
    # Training
    #######################################################
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = explainer_trainer.train(resume_from_checkpoint=checkpoint)
        explainer_trainer.save_model()
        explainer_trainer.log_metrics("train", train_result.metrics)
        explainer_trainer.save_metrics("train", train_result.metrics)
        explainer_trainer.save_state()

    ########################################################
    # Evaluation
    #######################################################
    if training_args.do_eval:
        metrics = explainer_trainer.evaluate()
        explainer_trainer.log_metrics("eval", metrics)
        explainer_trainer.save_metrics("eval", metrics)

    ########################################################
    # Write model card and (optionally) push to hub
    #######################################################
    kwargs = {
        "finetuned_from": explainer_args.explainer_model_name_or_path,
        "tasks": "image-classification",
        "dataset": data_args.dataset_name,
        "tags": ["image-classification", "vision"],
    }
    if training_args.push_to_hub:
        explainer_trainer.push_to_hub(**kwargs)
    else:
        explainer_trainer.create_model_card(**kwargs)


if __name__ == "__main__":
    main()