In [None]:
import os
import sys

os.chdir('../')

In [None]:
!gpustat

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="3"

In [None]:
sys.argv=["train_objexplainer.py", "configs/vitbase_imagenette_shapley_objexplainer_newsample_32.json"]

In [None]:
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import copy
import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import ipdb
import numpy as np
import torch
import transformers
from datasets import load_dataset
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from arguments import DataTrainingArguments, ExplainerArguments, SurrogateArguments
from models import (
    ObjExplainerForImageClassification,
    ObjExplainerForImageClassificationConfig,
    SurrogateForImageClassificationConfig,
)
from utils import (
    MaskDataset,
    configure_dataset,
    generate_mask,
    get_checkpoint,
    get_image_transform,
    load_shapley,
    log_dataset,
    read_eval_results,
    setup_dataset,
)

""" Fine-tuning a ðŸ¤— Transformers model for image classification"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.32.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/image-classification/requirements.txt",
)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class OtherArguments:
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )

    train_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    validation_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    test_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    train_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for train",
        },
    )

    validation_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for validation",
        },
    )

    test_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for test",
        },
    )

In [None]:
########################################################
# Parse arguments
#######################################################
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser(
    (
        SurrogateArguments,
        ExplainerArguments,
        DataTrainingArguments,
        TrainingArguments,
        OtherArguments,
    )
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_args_into_dataclasses()

########################################################
# Setup logging
#######################################################
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

########################################################
# Correct cache dir if necessary
########################################################
if not os.path.exists(
    os.sep.join((data_args.dataset_cache_dir).split(os.sep, 2)[:2])
):
    if os.path.exists("/data2"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/data2"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    elif os.path.exists("/sdata"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/sdata"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    else:
        raise ValueError(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found"
        )

########################################################
# Set seed before initializing model.
########################################################
set_seed(training_args.seed)

########################################################
# Initialize our dataset and prepare it for the 'image-classification' task.
########################################################
dataset_original, labels, label2id, id2label = setup_dataset(
    data_args=data_args, other_args=other_args
)

########################################################
# Initialize explainer model
########################################################

explainer_config = AutoConfig.from_pretrained(
    explainer_args.explainer_config_name
    or explainer_args.explainer_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

if os.path.isfile(
    f"{explainer_args.explainer_model_name_or_path}/config.json"
) and (
    json.loads(
        open(f"{explainer_args.explainer_model_name_or_path}/config.json").read()
    )["architectures"][0]
    == "ObjExplainerForImageClassification"
):
    explainer = ObjExplainerForImageClassification.from_pretrained(
        explainer_args.explainer_model_name_or_path,
        from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
        config=explainer_config,
        cache_dir=explainer_args.explainer_cache_dir,
        revision=explainer_args.explainer_model_revision,
        token=other_args.token,
        ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )
else:
    surrogate_config = AutoConfig.from_pretrained(
        surrogate_args.surrogate_config_name
        or surrogate_args.surrogate_model_name_or_path,
        num_labels=len(labels),
        label2id=label2id,
        id2label=id2label,
        finetuning_task="image-classification",
        cache_dir=surrogate_args.surrogate_cache_dir,
        revision=surrogate_args.surrogate_model_revision,
        token=other_args.token,
    )
    surrogate_for_image_classification_config = SurrogateForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    )

    explainer_for_image_classification_config = ObjExplainerForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_for_image_classification_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
        explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
        explainer_config=explainer_config,
        explainer_from_tf=bool(
            ".ckpt" in explainer_args.explainer_model_name_or_path
        ),
        explainer_cache_dir=explainer_args.explainer_cache_dir,
        explainer_revision=explainer_args.explainer_model_revision,
        explainer_token=other_args.token,
        explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )

    explainer = ObjExplainerForImageClassification(
        config=explainer_for_image_classification_config,
    )
explainer_image_processor = AutoImageProcessor.from_pretrained(
    explainer_args.explainer_image_processor_name
    or explainer_args.explainer_model_name_or_path,
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

########################################################
# Configure dataset (set max samples, transforms, etc.)
########################################################
dataset_explainer = copy.deepcopy(dataset_original)
dataset_explainer = configure_dataset(
    dataset=dataset_explainer,
    image_processor=explainer_image_processor,
    training_args=training_args,
    data_args=data_args,
    train_augmentation=False,
    validation_augmentation=False,
    test_augmentation=False,
    logger=logger,
)

In [None]:
from models import (
    RegExplainerForImageClassification,
    RegExplainerForImageClassificationConfig
)

In [None]:
regexplainer_for_image_classification_config = RegExplainerForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_for_image_classification_config,
    surrogate_from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
    explainer_config=explainer_config,
    explainer_from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
    explainer_cache_dir=explainer_args.explainer_cache_dir,
    explainer_revision=explainer_args.explainer_model_revision,
    explainer_token=other_args.token,
    explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
)

regexplainer = RegExplainerForImageClassification(
    config=regexplainer_for_image_classification_config,
)


# move to GPU

In [None]:
device="cuda:0"

In [None]:
explainer.to(device)

In [None]:
regexplainer.to(device)

# visualizing 

In [None]:
import glob
import pandas as pd
import json
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import font_manager
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import matplotlib as mpl

def plot_figure(explainer, dataset, sample_idx_list):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                explainer.eval()
                with torch.no_grad():
                    explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
                    explanation=explanation["logits"][0]
                #print(explanation.shape)
                if len(explanation.shape)==2:
                    explanation_class=explanation[plot_type].detach().cpu().numpy()
                else:
                    explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=mpl.colormaps['Greys'](1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
    return fig

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib import cm

def plot_figure_shapley(dataset, sample_idx_list, shapley_value, shapley_value_key):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{id2label[str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(shapley_value[sample_idx].keys()))
                #print(plot_type, shapley_value[sample_idx][shapley_value_key].shape)
                explanation_class={n_samples:values for n_samples, values in zip(shapley_value[sample_idx]["iters"], shapley_value[sample_idx]["values"])}[shapley_value_key][:,plot_type]
                
                #print(explanation_class.shape)
#                 print(explanation_class.shape, plot_type)
#                 explainer.eval()
#                 with torch.no_grad():
#                     explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
#                     explanation=explanation["logits"][0]
#                 if len(explanation.shape)==2:
#                     explanation_class=explanation[plot_type].detach().cpu().numpy()
#                 else:
#                     explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{id2label[str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
    return fig           

In [None]:
def get_ground_truth_metric_with_explainer(shapley_values, explainer, dataset, iters_ground_truth, meta_info):
    record_dict_list=[]
    
    for sample_idx, tracking_dict in tqdm(shapley_values.items()):
        data=dataset[sample_idx]
        
        target_class_idx=np.argmax(tracking_dict["values"][0].sum(axis=0))
        assert data["labels"]==target_class_idx
        
        explainer.eval()
        with torch.no_grad():
            estimated=explainer(pixel_values=data["pixel_values"].unsqueeze(0).to(explainer.device), return_loss=False)["logits"][0]
            
        if isinstance(tracking_dict["iters"], np.ndarray):
            tracking_dict["iters"]=tracking_dict["iters"].tolist()
        
        ground_truth=tracking_dict["values"][tracking_dict["iters"].index(iters_ground_truth)]
            
        diff=(estimated.T.cpu().detach().numpy()-ground_truth)
        
        mse_class=(diff*diff).sum(axis=0)
        
        record={
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
        }
        record.update(meta_info)
        
        record_dict_list.append(record)
        
    return record_dict_list

In [None]:
def get_ground_truth_metric_with_value(shapley_values_ground_truth, iters_ground_truth, 
                                       shapley_values_calculated, iters_calculated, 
                                       meta_info):
#     print(shapley_values_ground_truth.keys())
    
    record_dict_list=[]
    
    for sample_idx, tracking_dict_ground_truth in tqdm(shapley_values_ground_truth.items()):
        target_class_idx_ground_truth=np.argmax(tracking_dict_ground_truth["values"][0].sum(axis=0))
        
        tracking_dict_calculated=shapley_values_calculated[sample_idx]
        target_class_idx_calculated=np.argmax(tracking_dict_calculated["values"][0].sum(axis=0))         
        
        assert target_class_idx_ground_truth==target_class_idx_calculated
        
        if isinstance(tracking_dict_ground_truth["iters"], np.ndarray):
            tracking_dict_ground_truth["iters"]=tracking_dict_ground_truth["iters"].tolist()
            
        if isinstance(tracking_dict_calculated["iters"], np.ndarray):
            tracking_dict_calculated["iters"]=tracking_dict_calculated["iters"].tolist()            
        
        ground_truth=tracking_dict_ground_truth["values"][tracking_dict_ground_truth["iters"].index(iters_ground_truth)]
        
        estimated=tracking_dict_calculated["values"][tracking_dict_calculated["iters"].index(iters_calculated)]
#         estimated=tracking_dict_ground_truth["values"][tracking_dict_ground_truth["iters"].index(iters_calculated)]
        
        diff=(estimated-ground_truth)
        
        mse_class=(diff*diff).sum(axis=0)
        
#         print(sample_idx, target_class_idx_ground_truth)
        record={
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx_ground_truth].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx_ground_truth].mean(),
            "mse_all": mse_class[:].mean(),
        }
        record.update(meta_info)
        
        record_dict_list.append(record)        
    return record_dict_list

In [None]:
# def get_ground_truth_metric(shapley_values, explainer, dataset, iters_ground_truth, meta_info):
#     record_dict_list=[]
    
#     for sample_idx, tracking_dict in tqdm(shapley_values.items()):
#         data=dataset[sample_idx]
        
#         target_class_idx=np.argmax(tracking_dict["values"][0].sum(axis=0))
#         assert data["labels"]==target_class_idx
        
#         explainer.eval()
#         with torch.no_grad():
#             estimated=explainer(pixel_values=data["pixel_values"].unsqueeze(0).to(explainer.device), return_loss=False)["logits"][0]
            
#         if isinstance(tracking_dict["iters"], np.ndarray):
#             tracking_dict["iters"]=tracking_dict["iters"].tolist()
        
#         ground_truth=tracking_dict["values"][tracking_dict["iters"].index(iters_ground_truth)]
            
#         diff=(estimated.T.cpu().detach().numpy()-ground_truth)
        
#         mse_class=(diff*diff).sum(axis=0)
        
#         record={
#             "sample_idx": sample_idx,
#             "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
#             "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
#             "mse_all": mse_class[:].mean(),
#         }
#         record.update(meta_info)
        
#         record_dict_list.append(record)
        
#     return record_dict_list

In [None]:
dataset_explainer

In [None]:
dataset_explainer["validation"][-1]

In [None]:
np.random.RandomState(seed=42).permutation(list(range(9469))).tolist()[:100][::-1].index(7707)

In [None]:
np.random.RandomState(seed=42).permutation(list(range(9469))).tolist()[:100][::-1].index(6051)

In [None]:
for i in [
    "774", "8336", "8367", "7065", "2362", "3146", "3945", "3577", "7615", 
    "6553", "5204", "6673", "4925", "8285", "7724", "683", "6578", "7001", 
    "2183", "7758", "9234", "1650", "7593", "4838", "8294", "7290", "3995", 
    "6051", "2526", "3798", "7923", "483", "1087", "3019", "1217", "5014", 
    "1076", "8250", "5327", "6909", "908", "106", "315", "6177", "7854", 
    "4354", "6310", "457", "8606", "7689", "7707"
]:
#     if i not in np.random.RandomState(seed=42).permutation(list(range(9469))).tolist():
#         print("not found", i)
    print(np.random.RandomState(seed=42).permutation(list(range(9469))).tolist().index(int(i)))
    


In [None]:
np.random.RandomState(seed=42).permutation(list(range(9469))).tolist()[:100][::-1].index(8606)

In [None]:
!ls logs/ -trl

# Training target quality

In [None]:
!ls logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train/0

In [None]:
shapley_loaded_dict={}

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train",
             target_subset_size=196,
             )

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train"]\
=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train")

In [None]:
metric_list_value=[]

In [None]:
for num_subsets in [512, 1024, 2048, 3072]:
    metric_list_value+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"method":f"KernelSHAP ({num_subsets})",
                                                  "num_subsets": num_subsets,
                                                  "estimation_method": "KernelSHAP",
                                                 })
    
    metric_list_value+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"method":f"KernelSHAP ({num_subsets}, antithetical)",
                                                  "num_subsets": num_subsets,
                                                  "estimation_method": "KernelSHAP",
                                                 })    

In [None]:
for num_subsets in [196, 392, 588, 1176, 3136]:
    metric_list_value+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"method":f"Permutation ({num_subsets})",
                                                  "num_subsets": num_subsets,
                                                  "estimation_method": "Permutation",                                                  
                                                 })
    metric_list_value+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"method":f"Permutation ({num_subsets}, antithetical)",
                                                  "num_subsets": num_subsets,
                                                  "estimation_method": "Permutation",                                                  
                                                 })    

In [None]:
for i in range(16):
    shapley_loaded_dict_temp={}
    for sample_idx, tracking_dict in shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train"].items():
        shapley_loaded_dict_temp[sample_idx]=tracking_dict[i]

    metric_list_value+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       shapley_values_calculated=shapley_loaded_dict_temp,
                                       iters_calculated=196,
                                       meta_info={"method":f"Permutation (196, newsample, {i+1})",
                                                  "num_subsets": 196,
                                                  "estimation_method": "Permutation",                                                  
                                                 })


In [None]:
load_shapley??

In [None]:
load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train")

In [None]:
def load_shapley(path, target_subset_size=None):
    file_list = glob.glob(str(Path(path) / "[0-9]*"))
    output_dict = {}
    if target_subset_size is None:
        for file in tqdm(file_list):
            loaded = torch.load(Path(file) / "shapley_output.pt")

            output_dict[int(file.split("/")[-1])] = loaded
    else:
        for file in tqdm(file_list):
            subset_file_list = glob.glob(
                str(Path(file) / f"shapley_output_{target_subset_size}_*.pt")
            )
            loaded_list = []
            for subset_file in sorted(
                subset_file_list,
                key=lambda x: int(x.split("_")[-1].split(".")[0]),
            ):
                loaded = torch.load(subset_file)
                loaded_list.append(loaded)
            output_dict[int(file.split("/")[-1])] = loaded_list

    return output_dict

In [None]:
/homes/gws/chanwkim/vit-shapley/results/3_explanation_generate/ImageNette/vit_base_patch16_224_kernelshap_test.pickle

In [None]:
!ls logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train/*/shapley_output.pt | wc -l

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,10))

axd={"main":ax}

plot_key="main"


def get_reg_type(x):
    if "KernelSHAP" in x:
        return "KernelSHAP"
    elif "Permutation" in x:
        return "Permutation"
    else:
        return "none"
    

metric_df=pd.DataFrame(metric_list_value)
# metric_df["method"]=metric_df["method"].str.pad(40, side="right", fillchar='-')

# metric_df["explainer"]=metric_df["explainer"].str.replace(
#     "Reg-AO (upfront, regression, 512, antithetical)",
#     "Reg-AO (upfront, regression, antithetical, 512)")\
#     .str.replace(
#     "Obj-AO (newsample, 32, antithetical)",
#     "Obj-AO (newsample, antithetical, 32)",)

# print(metric_df["explainer"].value_counts())


# metric_df["AO type"]=metric_df["explainer"].map(lambda x: x.split('(')[0].strip())
# metric_df["num_subsets"]=metric_df["explainer"].map(lambda x: int(x.split(',')[-1][:-1].strip()))
# metric_df["reg type"]=metric_df["explainer"].map(get_reg_type)

# metric_df=metric_df.sort_values(["AO type", "reg type", "num_subsets"], ascending=True)
# metric_df=metric_df[metric_df["explainer"].str.contains("Obj-AO")]

sns.barplot(
    y="method",
    x="mse_target",
#     hue="method",
#     style="AO type",
#     style_order=["Reg-AO", "Obj-AO"],
    palette="tab10",
    linewidth=3,
    data=metric_df,
    ax=axd[plot_key]
)


axd[plot_key].set_ylabel("Method", fontsize=20)
axd[plot_key].set_xlabel("MSE", fontsize=20)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))    
# axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
# axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
# axd[plot_key].set_ylim(0, 0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', rotation=-90, labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 


# yax = axd[plot_key].get_yaxis()
# # find the maximum width of the label on the major ticks
# pad = max(T.label1.get_window_extent().width for T in yax.majorTicks)
# yax.set_tick_params(pad=pad)

# axd[plot_key].axis["left"].major_ticklabels.set_ha("left")

# for label in axd[plot_key].get_yticklabels():
#     label.set_horizontalalignment('right')
#     import matplotlib.transforms as mtrans
#     # Shifting the label by -15 points on the x-axis
#     trans = mtrans.Affine2D().translate(-100, 0)
#     t = axd[plot_key].transData + trans
#     label.set_transform(t)    



# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
# axd[plot_key].set_xlim(0, 40)
# axd[plot_key].set_ylim(0, 0.030)

leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

for line in leg.get_lines():
    line.set_linewidth(3.0)

# Training curve

### Obj-AO (newsample)

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_objexplainer_newsample_32/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:100]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    explainer.load_state_dict(state_dict)
    
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=explainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Obj-AO (newsample, 32)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_32/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[20:20+20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    explainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=explainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Obj-AO (newsample, antithetical, 32)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

### Reg-AO (upfront, regression)

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_upfront_512/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Reg-AO (upfront, regression, 512)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_upfront_1024/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Reg-AO (upfront, regression, 1024)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_upfront_2048/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Reg-AO (upfront, regression, 2048)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_upfront_3072/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Reg-AO (upfront, regression, 3072)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_512/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
metric_list_=[]
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list_+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Reg-AO (upfront, regression, antithetical, 512)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

### Reg-AO (upfront, permutation)

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_196/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999936,
                            meta_info={"explainer": "Reg-AO (upfront, permutation, 196)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_392/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999936,
                            meta_info={"explainer": "Reg-AO (upfront, permutation, 392)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_588/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999936,
                            meta_info={"explainer": "Reg-AO (upfront, permutation, 588)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_1176/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999936,
                            meta_info={"explainer": "Reg-AO (upfront, permutation, 1176)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_3136/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))

for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999936,
                            meta_info={"explainer": "Reg-AO (upfront, permutation, 3136)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
metric_list=([metric for metric in metric_list if metric["explainer"]!='Reg-AO (upfront, regression, 512, antithetical)'])




In [None]:
checkpoint_path_list

### Reg-AO (newsample, permutation)

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_196/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))

for checkpoint_path in tqdm(checkpoint_path_list[:20]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    regexplainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
                            explainer=regexplainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=999424,
                            meta_info={"explainer": "Reg-AO (newsample, permutation, 196)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

In [None]:
# torch.save(metric_list, "logs/experiment_results/metric_list.pt")

### plotting

In [None]:
fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main":ax}

plot_key="main"


def get_reg_type(x):
    if "regression" in x:
        return "regression"
    elif "permutation" in x:
        return "permutation"
    else:
        return "none"
    

metric_df=pd.DataFrame(metric_list+metric_list_)
metric_df["explainer"]=metric_df["explainer"].str.replace(
    "Reg-AO (upfront, regression, 512, antithetical)",
    "Reg-AO (upfront, regression, antithetical, 512)")\
    .str.replace(
    "Obj-AO (newsample, 32, antithetical)",
    "Obj-AO (newsample, antithetical, 32)",)

print(metric_df["explainer"].value_counts())


metric_df["AO type"]=metric_df["explainer"].map(lambda x: x.split('(')[0].strip())
metric_df["num_subsets"]=metric_df["explainer"].map(lambda x: int(x.split(',')[-1][:-1].strip()))
metric_df["reg type"]=metric_df["explainer"].map(get_reg_type)

metric_df=metric_df.sort_values(["AO type", "reg type", "num_subsets"], ascending=True)
# metric_df=metric_df[metric_df["explainer"].str.contains("Obj-AO")]
metric_df=metric_df[metric_df["explainer"].str.contains("permutation")]

sns.lineplot(
    x="epoch",
    y="mse_target",
    hue="explainer",
    style="AO type",
    style_order=["Reg-AO", "Obj-AO"],
    palette="tab10",
    linewidth=3,
    data=metric_df,
    ax=axd[plot_key]
)


axd[plot_key].set_ylabel("MSE", fontsize=20)
axd[plot_key].set_xlabel("Epoch", fontsize=20)

          
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))    
# axd[plot_key].set_ylim(0, 0.1)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
axd[plot_key].set_xlim(0, 40)
axd[plot_key].set_ylim(0, 0.030)

leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

for line in leg.get_lines():
    line.set_linewidth(3.0)

# Error from prediction vs Error from targets

In [None]:
if "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test" not in shapley_loaded_dict.keys():
    shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"]=\
    load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test")

In [None]:
if "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train" not in shapley_loaded_dict.keys():
    shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"]=\
    load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train")

In [None]:
if "logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test" not in shapley_loaded_dict.keys():
    shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test"]=\
    load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test")

In [None]:
if "logs/vitbase_imagenette_surrogate_shapley_eval_test_permutation/extract_output/test" not in shapley_loaded_dict.keys():
    shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_permutation/extract_output/test"]=\
    load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_test_permutation/extract_output/test")

In [None]:
"logs/vitbase_imagenette_surrogate_shapley_eval_test_permutation/extract_output/test

In [None]:
'logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test'

In [None]:
shapley_loaded_dict={
    "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train": load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"),
    "logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test": load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test")
    "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train": load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train")
    
}

In [None]:
shapley_loaded_dict.keys()

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"]=load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train")

In [None]:
metric_list_explainer=[]
metric_list_target=[]

In [None]:
for num_subsets, checkpoint_path in {
    512: "logs/vitbase_imagenette_shapley_regexplainer_upfront_512/checkpoint-888",
    1024: "logs/vitbase_imagenette_shapley_regexplainer_upfront_1024/checkpoint-1036",
    1536: "logs/vitbase_imagenette_shapley_regexplainer_upfront_1536/checkpoint-1480",
    2048: "logs/vitbase_imagenette_shapley_regexplainer_upfront_2048/checkpoint-1480",
    3072: "logs/vitbase_imagenette_shapley_regexplainer_upfront_3072/checkpoint-1924",
}.items():
    
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")

    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
    num_epoch=int(trainer_state["epoch"])

    regexplainer.load_state_dict(state_dict)
    metric_list_explainer+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                                        explainer=regexplainer,
                                                        dataset=dataset_explainer["train"],
                                                        iters_ground_truth=999424,
                                                        meta_info={"explainer": f"Reg-AO (upfront, regression, {num_subsets}) (epoch={num_epoch})",
                                                                   "num_subsets": num_subsets,
                                                                   "repeat": "upfront",
                                                                   "estimation_method": "KernelSHAP",
                                                                   "split": "train",
                                                                  })

    metric_list_target+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                                    iters_ground_truth=999424,
                                                    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"], 
                                                    iters_calculated=num_subsets, 
                                                    meta_info={"explainer": f"Reg-AO (upfront, regression, {num_subsets}) (epoch={num_epoch})",
                                                               "num_subsets": num_subsets,
                                                               "estimation_method": "KernelSHAP",
                                                                "split": "train",
                                                               }
    )    
    
#     metric_list_explainer+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
#                                                         explainer=regexplainer,
#                                                         dataset=dataset_explainer["test"],
#                                                         iters_ground_truth=999424,
#                                                         meta_info={"explainer": f"Reg-AO (upfront, regression, {num_subsets}) (epoch={num_epoch})",
#                                                                    "num_subsets": num_subsets,
#                                                                    "repeat": "upfront",
#                                                                    "estimation_method": "KernelSHAP",
#                                                                    "split": "test",
#                                                                   })

#     metric_list_target+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
#                                                     iters_ground_truth=999424,
#                                                     shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test"], 
#                                                     iters_calculated=num_subsets, 
#                                                     meta_info={"explainer": f"Reg-AO (upfront, regression, {num_subsets}) (epoch={num_epoch})",
#                                                                "num_subsets": num_subsets,
#                                                                "estimation_method": "KernelSHAP",
#                                                                 "split": "test",
#                                                                }
#     )        

In [None]:
for num_subsets, checkpoint_path in {
    196: "logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_196/checkpoint-1036",
    392: "logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_392/checkpoint-1332",
    588: "logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_588/checkpoint-1332",
    1176: "logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_1176/checkpoint-1480",
    3136: "logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_3136/checkpoint-2812"
}.items():
    
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")

    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
    num_epoch=int(trainer_state["epoch"])

    regexplainer.load_state_dict(state_dict)
    metric_list_explainer+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                                        explainer=regexplainer,
                                                        dataset=dataset_explainer["train"],
                                                        iters_ground_truth=999424,
                                                        meta_info={"explainer": f"Reg-AO (upfront, permutation, {num_subsets}) (epoch={num_epoch})",
                                                                   "num_subsets": num_subsets,
                                                                   "repeat": "upfront",
                                                                   "estimation_method": "Permutation",
                                                                   "split": "train",
                                                                   "target": False,
                                                                  })

    metric_list_target+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                                    iters_ground_truth=999424,
                                                    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"], 
                                                    iters_calculated=num_subsets, 
                                                    meta_info={"explainer": f"Reg-AO (upfront, permutation, {num_subsets}) (epoch={num_epoch})",
                                                               "num_subsets": num_subsets,
                                                               "estimation_method": "Permutation",
                                                                "split": "train",
                                                               "target": True,
                                                               }
    )        
    
    
#     metric_list_explainer+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
#                                                         explainer=regexplainer,
#                                                         dataset=dataset_explainer["test"],
#                                                         iters_ground_truth=999424,
#                                                         meta_info={"explainer": f"Reg-AO (upfront, permutation, {num_subsets}) (epoch={num_epoch})",
#                                                                    "num_subsets": num_subsets,
#                                                                    "repeat": "upfront",
#                                                                    "estimation_method": "Permutation",
#                                                                    "split": "test",
#                                                                    "target": False,
#                                                                   })

#     metric_list_target+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
#                                                     iters_ground_truth=999424,
#                                                     shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_permutation/extract_output/test"], 
#                                                     iters_calculated=num_subsets, 
#                                                     meta_info={"explainer": f"Reg-AO (upfront, permutation, {num_subsets}) (epoch={num_epoch})",
#                                                                "num_subsets": num_subsets,
#                                                                "estimation_method": "Permutation",
#                                                                 "split": "test",
#                                                                "target": True,
#                                                                }
#     )             
    

In [None]:
metric_list_value_variable=[]

In [None]:
for num_subsets in [
    1000,
    2000,
    3000,
    4000,
    5000,
    6000,
    7000,
    8000,
    9000,
    10000,
    20000,
    30000,
    40000,
    50000,
    60000,
    70000,
    80000,
    90000,
    100000,
    200000,
    300000,
    400000,
    500000,
    600000,
    700000,
    700000,
    800000,
    900000,    
                   ]:
    
    num_subsets=(int(num_subsets//512)+1)*512
    
    metric_list_value_variable+=get_ground_truth_metric_with_value(shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"method":f"KernelSHAP ({num_subsets})",
                                                  "num_subsets": num_subsets,
                                                  "estimation_method": "KernelSHAP",
                                                 })

In [None]:
metric_df

In [None]:
metric_target_df

In [None]:
metric_df_plot=metric_df.groupby("explainer")[['sample_idx',"num_subsets",
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().reset_index()
metric_df_plot["num_subsets"]=metric_df_plot["num_subsets"].astype(int)
metric_df_plot=metric_df_plot[metric_df_plot["explainer"].str.contains("permutation")]

In [None]:
metric_df_plot.sort_values("num_subsets", ascending=True)

In [None]:
metric_df_.groupby(['explainer', 'epoch']).mean().loc["Reg-AO (upfront, permutation, 196)"].loc[7]

In [None]:
metric_df_.groupby(['explainer', 'epoch']).mean().loc["Reg-AO (upfront, permutation, 392)"].loc[9]

In [None]:
metric_df_.groupby(['explainer', 'epoch']).mean().loc["Reg-AO (upfront, permutation, 588)"].loc[12]

In [None]:
metric_df_.groupby(['explainer', 'epoch']).mean().loc["Reg-AO (upfront, permutation, 588)"]

In [None]:
metric_df_.groupby(['explainer', 'epoch']).mean().loc["Reg-AO (upfront, permutation, 1176)"].loc[10]

In [None]:
metric_df_.groupby(['explainer', 'epoch']).mean().loc["Reg-AO (upfront, permutation, 3332)"].loc[19]

In [None]:
metric_df_=pd.DataFrame(metric_list+metric_list_)
metric_df_["explainer"]=metric_df_["explainer"].str.replace(
    "Reg-AO (upfront, regression, 512, antithetical)",
    "Reg-AO (upfront, regression, antithetical, 512)")\
    .str.replace(
    "Obj-AO (newsample, 32, antithetical)",
    "Obj-AO (newsample, antithetical, 32)",)

print(metric_df_["explainer"].value_counts())

In [None]:
metric_explainer_df=pd.DataFrame(metric_list_explainer)

metric_target_df=pd.DataFrame(metric_list_target)

metric_df=metric_explainer_df.merge(right=metric_target_df, 
                          left_on=["explainer", "sample_idx", "num_subsets", "split"],
                          right_on=["explainer", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
metric_df[metric_df["split"]=="train"].groupby("explainer")[['sample_idx',  "num_subsets", 
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

# train permutation

In [None]:
pd.DataFrame(metric_list_value_variable).columns

In [None]:
pd.DataFrame(metric_list_value_variable).groupby(["estimation_method","num_subsets"])\
[['mse_target', 'mse_nontarget', 'mse_all']].mean().loc["KernelSHAP"]

In [None]:
metric_df_plot["mse_target_explainer"].min()

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,8))

axd={"main":ax}

plot_key="main"

metric_df_plot=metric_df[metric_df["split"]=="train"].groupby("explainer")[['sample_idx',"num_subsets",
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().reset_index()
metric_df_plot["num_subsets"]=metric_df_plot["num_subsets"].astype(int)
metric_df_plot=metric_df_plot[metric_df_plot["explainer"].str.contains("permutation")]

sns.scatterplot(
    x="mse_all_explainer", 
    y="mse_all_target", 
    hue="num_subsets",
    data=metric_df_plot,
    s=200,
    palette="Set2",
    ax=axd[plot_key],
)

count=0
for idx, row in pd.DataFrame(metric_list_value_variable).groupby(["estimation_method","num_subsets"])\
[['mse_target', 'mse_nontarget', 'mse_all']].mean().loc["KernelSHAP"].iterrows():
    print(metric_df_plot["mse_all_explainer"].min(), row["mse_all"])
    if row["mse_all"]>metric_df_plot["mse_all_explainer"].min() and row["mse_all"]<metric_df_plot["mse_all_explainer"].max():
#     if True:
        axd[plot_key].vlines(ymin=0, ymax=1, 
                             x=row["mse_all"], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
                             label=f'KernelSHAP {idx}'
                            )
        count+=1

# axd[plot_key].set_xlim(0,3)
# axd[plot_key].set_ylim(0,3)

axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100))

axd[plot_key].set_ylabel("Error from target", fontsize=20)
axd[plot_key].set_xlabel("Error from prediction", fontsize=20)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
# axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)


axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 

axd[plot_key].set_xscale("log")
axd[plot_key].set_yscale("log")

axd[plot_key].legend(loc="right", fontsize=15)

axd[plot_key].set_xlim(left=1e-3)

In [None]:
metric_df_plot=metric_df[metric_df["split"]=="train"].groupby("explainer")[['sample_idx',"num_subsets",
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().reset_index()
metric_df_plot["num_subsets"]=metric_df_plot["num_subsets"].astype(int)
metric_df_plot=metric_df_plot[metric_df_plot["explainer"].str.contains("regression")]

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,8))

axd={"main":ax}

plot_key="main"

metric_df_plot=metric_df[metric_df["split"]=="train"].groupby("explainer")[['sample_idx',"num_subsets",
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().reset_index()
metric_df_plot["num_subsets"]=metric_df_plot["num_subsets"].astype(int)
metric_df_plot=metric_df_plot[metric_df_plot["explainer"].str.contains("regression")]
# sds

count=0
for idx, row in pd.DataFrame(metric_list_value_variable).groupby(["estimation_method","num_subsets"])\
[['mse_target', 'mse_nontarget', 'mse_all']].mean().loc["KernelSHAP"].iterrows():
    print(metric_df_plot["mse_all_explainer"].min(), row["mse_all"])
    if row["mse_all"]>metric_df_plot["mse_all_explainer"].min() and row["mse_all"]<metric_df_plot["mse_all_explainer"].max():
#     if True:
        axd[plot_key].vlines(ymin=0, ymax=1, 
                             x=row["mse_all"], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
                             label=f'KernelSHAP {idx}'
                            )
        count+=1

sns.scatterplot(
    x="mse_all_explainer", 
    y="mse_all_target", 
    hue="num_subsets",
    data=metric_df_plot,
    s=200,
    palette="Set2",
    ax=axd[plot_key],
)

# axd[plot_key].set_xlim(0,3)
# axd[plot_key].set_ylim(0,3)

axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100))

axd[plot_key].set_ylabel("Error from target", fontsize=20)
axd[plot_key].set_xlabel("Error from prediction", fontsize=20)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
# axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)


axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 

axd[plot_key].set_xscale("log")
axd[plot_key].set_yscale("log")

axd[plot_key].legend(loc="right", fontsize=15)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,8))

axd={"main":ax}

plot_key="main"

metric_df_plot=metric_df[metric_df["split"]=="test"].groupby("explainer")[['sample_idx',"num_subsets",
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().reset_index()
metric_df_plot["num_subsets"]=metric_df_plot["num_subsets"].astype(int)
metric_df_plot=metric_df_plot[metric_df_plot["explainer"].str.contains("permutation")]

sns.scatterplot(
    x="mse_target_explainer", 
    y="mse_target_target", 
    hue="num_subsets",
    data=metric_df_plot,
    s=200,
    palette="Set2",
    ax=axd[plot_key],
)

# axd[plot_key].set_xlim(0,3)
# axd[plot_key].set_ylim(0,3)

axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100))

axd[plot_key].set_ylabel("Error from target", fontsize=20)
axd[plot_key].set_xlabel("Error from prediction", fontsize=20)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
# axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)


axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 

axd[plot_key].set_xscale("log")
axd[plot_key].set_yscale("log")

axd[plot_key].legend(loc="right", fontsize=15)

# train regression

# test regression

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,8))

axd={"main":ax}

plot_key="main"

metric_df_plot=metric_df[metric_df["split"]=="test"].groupby("explainer")[['sample_idx',"num_subsets",
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().reset_index()
metric_df_plot["num_subsets"]=metric_df_plot["num_subsets"].astype(int)
metric_df_plot=metric_df_plot[metric_df_plot["explainer"].str.contains("regression")]

sns.scatterplot(
    x="mse_target_explainer", 
    y="mse_target_target", 
    hue="num_subsets",
    data=metric_df_plot,
    s=200,
    palette="Set2",
    ax=axd[plot_key],
)

# axd[plot_key].set_xlim(0,3)
# axd[plot_key].set_ylim(0,3)

axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100))

axd[plot_key].set_ylabel("Error from target", fontsize=20)
axd[plot_key].set_xlabel("Error from prediction", fontsize=20)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
# axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)


axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 

axd[plot_key].set_xscale("log")
axd[plot_key].set_yscale("log")

axd[plot_key].legend(loc="right", fontsize=15)

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean()[["mse_target_explainer",
                                      "mse_target_calculated",
                                      "mse_nontarget_explainer",
                                      "mse_nontarget_calculated"
                                     
                                     ]]\
.loc[sorted(metric_df["explainer"].unique(), key=lambda x: int(x.split(',')[-1].strip().replace(')','')))]

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,8))

axd={"main":ax}

plot_key="main"

sns.scatterplot(
    x="mse_target_explainer", 
    y="mse_target_target", 
    hue="explainer",
    hue_order=sorted(metric_df["explainer"][metric_df["explainer"].str.contains("regression")].unique(), key=lambda x: int(x.split(',')[-1].strip().replace(')',''))),
    data=metric_df[metric_df["explainer"].str.contains("regression")].groupby("explainer")
    [['sample_idx',  
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().reset_index(),
    palette=sns.color_palette("rocket", 3),
    s=200,

    ax=axd[plot_key],
)

# axd[plot_key].set_xlim(0,3)
# axd[plot_key].set_ylim(0,3)

axd[plot_key].plot(np.linspace(0,3,100), np.linspace(0,3,100))

axd[plot_key].set_ylabel("Error from target", fontsize=20)
axd[plot_key].set_xlabel("Error from prediction", fontsize=20)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
# axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)


axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 

# axd[plot_key].set_xscale("log")
# axd[plot_key].set_yscale("log")

axd[plot_key].legend(loc="best", fontsize=15)

In [None]:
fig, ax = plt.subplots(1,1, figsize=(8,8))

axd={"main":ax}

plot_key="main"

sns.scatterplot(
    x="mse_target_explainer", 
    y="mse_target_calculated", 
    hue="explainer",
    hue_order=sorted(metric_df["explainer"][metric_df["explainer"].str.contains("permutation")].unique(), key=lambda x: int(x.split(',')[-1].strip().replace(')',''))),
    data=metric_df[metric_df["explainer"].str.contains("permutation")].groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean().reset_index(),
    palette=sns.color_palette("rocket", 4),
    s=200,
    ax=axd[plot_key],
)

# axd[plot_key].set_xlim(0,3)
# axd[plot_key].set_ylim(0,3)

axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100))

axd[plot_key].set_ylabel("Error from target", fontsize=20)
axd[plot_key].set_xlabel("Error from prediction", fontsize=20)

# axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
# axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)


axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False) 

axd[plot_key].set_xscale("log")
axd[plot_key].set_yscale("log")

axd[plot_key].legend(loc="best", fontsize=15)

# end

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean().reset_index()

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean()

In [None]:
sns.scatterplot(x="mse_target_explainer", 
                y="mse_target_calculated", 
                hue="explainer",
                data=metric_df)

In [None]:
metric_df

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].apply(lambda x: len(x))

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean().T

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean().T

In [None]:
metric_df.fillna("None").groupby(["explainer", "sample_idx"]).apply(lambda x: print(x))

In [None]:
pd.DataFrame(get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_calculated=512*3, 
    meta_info={}
)).mean()

In [None]:
pd.DataFrame(get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"], 
    iters_calculated=512*1, 
    meta_info={}
)).mean()

In [None]:
pd.DataFrame(get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_calculated=512*2, 
    meta_info={}
)).mean()

In [None]:
199680/512

In [None]:
pd.DataFrame(metric).mean()

In [None]:
pd.DataFrame(metric).mean()

In [None]:
get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"], 
    iters_calculated=512, 
    meta_info={}
)    

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"]=\
load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"][9213]["iters"]

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"][9213]["iters"]

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"]

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test"]

In [None]:
metric_list+=get_ground_truth_metric(shapley_values=shapley_loaded_test, 
                        explainer=regexplainer,
                        dataset=dataset_explainer["test"],
                        iters_ground_truth=200192,
                        meta_info={"explainer": "Reg-AO (upfront, 512)",
                                   "epoch":int(trainer_state["epoch"])
                                  })

In [None]:
metric_list

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_objexplainer_newsample_32/"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:100]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    explainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric(shapley_values=shapley_loaded_test, 
                            explainer=explainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=200192,
                            meta_info={"explainer": "Obj-AO (newsample, 32)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

# per_sample

In [None]:
num_eval_ground_truth=99840

record_dict_list=[]

for sample_idx, num_eval_shapley_values in enumerate(shapley_values_dict["test"]):
    target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)
        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "per-sample",
        })

In [None]:
record_dict_list=[]

for sample_idx, num_eval_shapley_values in enumerate(shapley_loaded_test["test"]):
    target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)
        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "per-sample",
        })

In [None]:
shapley_values_loaded

In [None]:
shapley_values.shape

In [None]:
99840/512

In [None]:
shapley_estimated.shape

# regression

In [None]:
import glob
for model_path_reg in ['logs/vitbase_imagenette_explainer_regression_0',
                       'logs/vitbase_imagenette_explainer_regression_512',                       
                       'logs/vitbase_imagenette_explainer_regression_1024',
                       'logs/vitbase_imagenette_explainer_regression_1536']:
    num_eval=int(model_path_reg.split('_')[-1])+512
    state_dict = torch.load(f"{model_path_reg}/pytorch_model.bin", map_location="cpu")
    regexplainer.load_state_dict(state_dict)
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset["test_explainer"], 
                sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
    fig.suptitle(f"Reg-AO {num_eval}")    

In [None]:
import glob
for model_path_reg in ['logs/vitbase_imagenette_explainer_regression_0',
                       'logs/vitbase_imagenette_explainer_regression_1024',
                       'logs/vitbase_imagenette_explainer_regression_512',
                       'logs/vitbase_imagenette_explainer_regression_1536']:
    num_eval=int(model_path_reg.split('_')[-1])+512
    state_dict = torch.load(f"{model_path_reg}/checkpoint-1480/pytorch_model.bin", map_location="cpu")
    regexplainer.load_state_dict(state_dict)
#     fig=plot_figure(explainer=regexplainer, 
#                 dataset=dataset["test_explainer"], 
#                 sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
#     fig.suptitle(f"Reg-AO {num_eval}")    
    for sample_idx, (num_eval_shapley_values, data) in enumerate(zip(shapley_values_dict["test"], dataset["test_explainer"])):
        target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
        regexplainer.eval()
        with torch.no_grad():
            shapley_estimated=regexplainer(pixel_values=data["pixel_values"].unsqueeze(0), return_loss=False)["logits"][0]
        diff=(shapley_estimated.T.detach().numpy()-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "regression_AO",            
        })
        

In [None]:
import tqdm

In [None]:
import glob
for model_path_obj in sorted(glob.glob("logs/vitbase_imagenette_explainer_objective/checkpoint-*"), key=lambda x: int(x.split('-')[-1])):
    num_eval=int(model_path_obj.split('-')[-1])/148*32
    if int(int(model_path_obj.split('-')[-1])/148)%10!=1:
        continue
    state_dict = torch.load(f"{model_path_obj}/pytorch_model.bin", map_location="cpu")
    explainer.load_state_dict(state_dict)
#     fig=plot_figure(explainer=regexplainer, 
#                 dataset=dataset["test_explainer"], 
#                 sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
#     fig.suptitle(f"Reg-AO {num_model_eval}")    
    for sample_idx, (num_eval_shapley_values, data) in enumerate(zip(tqdm.tqdm(shapley_values_dict["test"]), dataset["test_explainer"])):        
        target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
        explainer.eval()
        with torch.no_grad():
            shapley_estimated=explainer(pixel_values=data["pixel_values"].unsqueeze(0).to(explainer.device), return_loss=False)["logits"][0]
        diff=(shapley_estimated.T.cpu().detach().numpy()-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "objective_AO",            
        })

In [None]:
int(int(model_path_obj.split('-')[-1])/148)%10

In [None]:
model_path_obj

In [None]:
444/148

In [None]:
explainer.device

In [None]:
import pandas as pd

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import font_manager
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_target",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.5))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
import pandas as pd

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df,
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_target",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(1000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(500))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)                   

axd[plot_key].set_xlim(0,3500)
axd[plot_key].set_ylim(0, 0.1)

axd[plot_key].set_title("Target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(1000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(500))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)                   

axd[plot_key].set_xlim(0,3500)
axd[plot_key].set_ylim(0, 0.01)

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
500/32

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_explainer_regression_0/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_explainer_regression/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)

In [None]:
explainer_out=explainer.forward(pixel_values=dataset["validation_explainer"][0]['pixel_values'].unsqueeze(0),
                               return_loss=False)

In [None]:
import tqdm

In [None]:
!gpustat

In [None]:
device="cuda:7"

In [None]:
explainer.to(device)
explainer.surrogate_null=explainer.surrogate_null.to(device)

In [None]:
plot_figure(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
plot_figure(explainer, dataset["validation_explainer"], [0,  250, 500,1000])

In [None]:
dataset["validation"][0]["image"]

In [None]:
dataset["test"][0]["image"]

In [None]:
dataset["test"] = (

In [None]:
data_args.max_test_samples

In [None]:
from utils import load_shapley

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output],
                'std': [],
                'iters': [0]}}, shapley_value_key=0)

In [None]:
shapley_loaded_train=load_shapley("logs/vitbase_imagenette_surrogate_eval_train/extract_output/train/")
shapley_loaded_train_permutation=load_shapley("logs/vitbase_imagenette_surrogate_train_validation_permutation/extract_output/train/")

In [None]:
shapley_loaded_validation=load_shapley("logs/vitbase_imagenette_surrogate_eval_validation/extract_output/validation/")
shapley_loaded_validation_permutation=load_shapley("logs/vitbase_imagenette_surrogate_eval_validation_permutation/extract_output/validation/")

In [None]:
shapley_loaded_test=load_shapley("logs/vitbase_imagenette_surrogate_eval_test/extract_output/test/")
shapley_loaded_test_permutation=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded_test_permutation=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded1=load_shapley("logs/vitbase_imagenette_surrogate_eval_validation/extract_output/validation/")

In [None]:
shapley_loaded1[0]["iters"]

In [None]:
shapley_loaded2[0]["iters"]

In [None]:
len(shapley_loaded1), len(shapley_loaded2)

In [None]:
shapley_loaded=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded.keys()

In [None]:
!ls logs/vitbase_imagenette_surrogate_eval_train_permutation/extract_output/train/3 -l

In [None]:
shapley_loaded[0]["values"]

In [None]:
load_shapley??

In [None]:
shapley_loaded[0]["values"][-1].sum(axis=0)

In [None]:
device="cuda:0"
explainer.surrogate.to(device)

In [None]:
import numpy as np
from tqdm.auto import tqdm
from scipy.special import softmax


def ShapleySampling(game,
                    batch_size=512,
                    detect_convergence=True,
                    thresh=0.01,
                    n_samples=None,
                    antithetical=False,
                    return_all=False,
                    bar=True,
                    verbose=False):
    # Verify arguments.
    stochastic = False
#     if isinstance(game, CooperativeGame):
#         stochastic = False
#     elif isinstance(game, StochasticCooperativeGame):
#         stochastic = True
#     else:
#         raise ValueError('game must be CooperativeGame or '
#                          'StochasticCooperativeGame')

    # Possibly force convergence detection.
    if n_samples is None:
        n_samples = 1e20
        if not detect_convergence:
            detect_convergence = True
            if verbose:
                print('Turning convergence detection on')

    if detect_convergence:
        assert 0 < thresh < 1

    # Calculate null coalition value.
    if stochastic:
        null = game.null(batch_size=batch_size)
    else:
        null = game.null()

    # Set up bar.
    n_loops = int(np.ceil(n_samples / batch_size))
    if bar:
        if detect_convergence:
            bar = tqdm(total=1)
        else:
            bar = tqdm(total=n_loops * batch_size)

    # Setup.
    num_players = game.players
    if isinstance(null, np.ndarray):
        values = np.zeros((num_players, len(null)))
        sum_squares = np.zeros((num_players, len(null)))
        deltas = np.zeros((batch_size, num_players, len(null)))
    else:
        values = np.zeros((num_players))
        sum_squares = np.zeros((num_players))
        deltas = np.zeros((batch_size, num_players))
    permutations = np.tile(np.arange(game.players), (batch_size, 1))
    arange = np.arange(batch_size)
    n = 0

    # For tracking progress.
    if return_all:
        N_list = []
        std_list = []
        val_list = []

    # Begin sampling.
    for it in range(n_loops):
        for i in range(batch_size):
            if antithetical and i % 2 == 1:
                permutations[i] = permutations[i - 1][::-1]
            else:
                np.random.shuffle(permutations[i])
        S = np.zeros((batch_size, game.players), dtype=int)

        # Sample exogenous (if applicable).
        if stochastic:
            U = game.sample(batch_size)

        # Unroll permutations.
        prev_value = null
        for i in tqdm(range(num_players)):
            S[arange, permutations[:, i]] = 1
            if stochastic:
                next_value = game(S, U)
            else:
                next_value = game(S)
            deltas[arange, permutations[:, i]] = next_value - prev_value
            prev_value = next_value

        # Welford's algorithm.
        n += batch_size
        diff = deltas - values
        values += np.sum(diff, axis=0) / n
        diff2 = deltas - values
        sum_squares += np.sum(diff * diff2, axis=0)

        # Calculate progress.
        var = sum_squares / (n ** 2)
        std = np.sqrt(var)
        ratio = np.max(
            np.max(std, axis=0) / (values.max(axis=0) - values.min(axis=0)))

        # Print progress message.
        if verbose:
            if detect_convergence:
                print(f'StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})')
            else:
                print(f'StdDev Ratio = {ratio:.4f}')

        # Check for convergence.
        if detect_convergence:
            if ratio < thresh:
                if verbose:
                    print('Detected convergence')

                # Skip bar ahead.
                if bar:
                    bar.n = bar.total
                    bar.refresh()
                break

        # Forecast number of iterations required.
        if detect_convergence:
            N_est = (it + 1) * (ratio / thresh) ** 2
            if bar and not np.isnan(N_est):
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()
        elif bar:
            bar.update(batch_size)

        # Save intermediate quantities.
        if return_all:
            val_list.append(np.copy(values))
            std_list.append(np.copy(std))
            if detect_convergence:
                N_list.append(N_est)

    # Return results.
    if return_all:
        # Dictionary for progress tracking.
        iters = (np.arange(it + 1) + 1) * batch_size * num_players
        tracking_dict = {
            'values': val_list,
            'std': std_list,
            'iters': iters}
        if detect_convergence:
            tracking_dict['N_est'] = N_list

        return tracking_dict
    else:
        return (values, std)
    
class CooperativeGame:
    '''Base class for cooperative games.'''

    def __init__(self):
        raise NotImplementedError

    def __call__(self, S):
        '''Evaluate cooperative game.'''
        raise NotImplementedError

    def grand(self):
        '''Get grand coalition value.'''
        return self.__call__(np.ones((1, self.players), dtype=int))[0]

    def null(self):
        '''Get null coalition value.'''
        return self.__call__(np.zeros((1, self.players), dtype=int))[0]


class PredictionGame(CooperativeGame):
    '''
    Cooperative game for an individual example's prediction.

    Args:
      extension: model extension (see removal.py).
      sample: numpy array representing a single model input.
    '''

    def __init__(self, surrogate, sample):
        # Add batch dimension to sample.

        self.surrogate = surrogate
        self.sample = sample

        # Store feature groups.

        self.players = 196
        self.groups_matrix = None

        # Caching.
        self.sample_repeat = sample

    def __call__(self, S):
        '''
        Evaluate cooperative game.

        Args:
          S: array of player coalitions with size (batch, players).
        '''
        # Try to use caching for repeated data.
        input_data = self.sample_repeat

        # Evaluate.
        with torch.no_grad():
            output=self.surrogate(input_data["pixel_values"].unsqueeze(0).to(device), 
                                  torch.Tensor(S).unsqueeze(0).to(device), return_loss=False)
            logits=output.logits
            return softmax(logits[0].detach().cpu().numpy(), axis=1)
    


# shapley_sampling=ShapleySampling(game,
#                     batch_size=128,
#                     detect_convergence=True,
#                     thresh=0.01,
#                     antithetical=False,
#                     return_all=True,
#                     bar=True,
#                     verbose=False)

In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1

            # Compute y_i
            y_i = game(x_i.astype(bool)[np.newaxis])[0] - null

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        phi = np.mean(phi_iterates, axis=0)

        return phi


In [None]:
from collections import OrderedDict
class SGDshapley():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        """

        # Store in a dictionary for each size k of coalitions
        dict_Ï‰_k = OrderedDict() # weights per size k
        dict_L_k = OrderedDict() # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            Ï‰_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = Ï‰_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_Ï‰_k.update({k: Ï‰_k})
            dict_L_k.update({k: L_k})
        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])
        # Probability distributions for sampling new instance
        # Classic SGD
        p = [ncr(d,k) for k in range(1,d)]
        p /= np.sum(p)
        # Importance Sampling proposal q
#         print(dict_L_k.keys(), dict_L_k.values())
#         sdsd
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.n = 2**d - 2
        self.dict_Ï‰_k = dict_Ï‰_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _F_i(self, Î¦, x_i, y_i, Ï‰_i):
        """Function value per instance i"""
        res = .5 * self.n * Ï‰_i * (np.dot(x_i, Î¦) - y_i)**2
        return res

    def _grad_F_i(self, Î¦, x_i, y_i, Ï‰_i):
        """Gradient vector per instance i"""
        res = Ï‰_i * x_i[:,None].dot(x_i[None,:]).dot(Î¦) - Ï‰_i * y_i * x_i
        return res

    def _Î _1(self, x, b):
        """Projection Î  on convex set K_1"""
        if np.abs((np.sum(x) - b)) <= 1e-6:
            return x
        else:
            return x - (np.sum(x) - b)/len(x)

    def _Î _2(self, x, D):
        """Projection Î  on convex set K_2"""
        if np.linalg.norm(x) > D:
            return x * D / np.linalg.norm(x)
        else:
            return x

    def _Dykstra_proj(self, x, D, b, iter_proj=100, epsilon=1e-6):
        """
        Dykstra's algorithm to find orthogonal projection
        onto intersection of convex sets
        """
        xk = x.copy()
        d = len(x)
        pk, qk = np.zeros(d), np.zeros(d)
        for k in range(iter_proj):
            yk = self._Î _2(xk + pk, D)
            pk = xk + pk - yk
            if np.linalg.norm(self._Î _1(yk + qk, b) - xk, 2) <= epsilon:
                break
            else:
                xk = self._Î _1(yk + qk, b)
                qk = yk + qk - xk
        return xk

    def sgd(self, game, dimension_select, n_iter=100, step=.1, step_type="sqrt",
            callback=None, Î¦_0=False):
        """
        Stochastic gradient descent algorithm
        The game is defined for an element x, a reference r and function fc

        """

        # Get general information
        
        f_x = game(np.ones((1, self.d), dtype=int))[0][dimension_select]
        f_r = game(np.zeros((1, self.d), dtype=int))[0][dimension_select]

        
        v_M = f_x - f_r

        d = self.d
        n = 2**d - 2
        p = self.p
        dict_Ï‰_k = self.dict_Ï‰_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # Store Shapley Values in a pandas Series
        if Î¦_0:
            Î¦ = Î¦_0.copy()
        else:
            Î¦ = np.zeros(d)
        Î¦_storage = np.zeros((n_iter,d))

        # projection onto convex set K by using a simple algorithm
        # Î¦ = self._Dykstra_proj(Î¦, D, v_M, iter_proj, epsilon=1e-6)
        Î¦ = Î¦ - (np.sum(Î¦) - v_M) / d

        # Sample in advance coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        for t in tqdm(range(1, n_iter+1)):
            # build x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            # Compute y_i
            #z_S = np.array([x.values[j] if x_i[j] == 1 else ref.values[j] for j in range(d)])            
            f_S = game(x_i[np.newaxis])[0][dimension_select]
            y_i = f_S - f_r
            # get weight Ï‰_i
            Ï‰_i = dict_Ï‰_k[k]
            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(Î¦, x_i, y_i, Ï‰_i)
            # update Î¦
            if step_type == "constant":
                Î¦ = Î¦ - step * grad_i
            elif step_type == "sqrt":
                Î¦ = Î¦ - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                Î¦ = Î¦ - (step/(t)) * grad_i

            # projection onto convex set K
            # Î¦ = self._Dykstra_proj(Î¦, D, v_M, iter_proj, epsilon=1e-6)
            Î¦ = Î¦ - (Î¦.sum() - v_M) / d

            # update storage of Î¦
            Î¦_storage[t-1,:] = Î¦


        # Average all Î¦
        Î¦ = np.mean(Î¦_storage,axis=0)

        return Î¦
    
    def sgd_minibatch(self, game, batch_size, dimension_select, n_iter=100, step=.1, step_type="sqrt",
            callback=None, Î¦_0=False):
        """
        Stochastic gradient descent algorithm
        The game is defined for an element x, a reference r and function fc

        """

        # Get general information
        
        f_x = game(np.ones((1, self.d), dtype=int))[0][dimension_select]
        f_r = game(np.zeros((1, self.d), dtype=int))[0][dimension_select]

        
        v_M = f_x - f_r

        d = self.d
        n = 2**d - 2
        p = self.p
        dict_Ï‰_k = self.dict_Ï‰_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # Store Shapley Values in a pandas Series
        if Î¦_0:
            Î¦ = Î¦_0.copy()
        else:
            Î¦ = np.zeros(d)
        Î¦_storage = []

        # projection onto convex set K by using a simple algorithm
        # Î¦ = self._Dykstra_proj(Î¦, D, v_M, iter_proj, epsilon=1e-6)
        Î¦ = Î¦ - (np.sum(Î¦) - v_M) / d

        # Sample in advance coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        grad_i_accum=[]
        
        for t in tqdm(range(1, n_iter+1)):
            # build x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            # Compute y_i
            #z_S = np.array([x.values[j] if x_i[j] == 1 else ref.values[j] for j in range(d)])            
            f_S = game(x_i[np.newaxis])[0][dimension_select]
            y_i = f_S - f_r
            # get weight Ï‰_i
            Ï‰_i = dict_Ï‰_k[k]
            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(Î¦, x_i, y_i, Ï‰_i)
            grad_i_accum.append(grad_i)
            
            if t%batch_size==0:
                # update Î¦
                if step_type == "constant":
                    Î¦ = Î¦ - step * np.array(grad_i_accum).mean(axis=0)
                elif step_type == "sqrt":
                    Î¦ = Î¦ - (step/np.sqrt(t)) * np.array(grad_i_accum).mean(axis=0)
                elif step_type == "inverse":
                    Î¦ = Î¦ - (step/(t)) * np.array(grad_i_accum).mean(axis=0)

                # projection onto convex set K
                # Î¦ = self._Dykstra_proj(Î¦, D, v_M, iter_proj, epsilon=1e-6)
                Î¦ = Î¦ - (Î¦.sum() - v_M) / d

                # update storage of Î¦
                Î¦_storage.append(Î¦)
                grad_i_accum=[]

        # Average all Î¦
        Î¦ = np.mean(np.array(Î¦_storage), axis=0)  
        
        return Î¦

In [None]:
fig=plot_figure_shapley(dataset_explainer["validation"], sample_idx_list=[20], 
                    shapley_value=shapley_loaded1, shapley_value_key=3584)

In [None]:
fig=plot_figure_shapley(dataset_explainer["validation"], sample_idx_list=[20], 
                    shapley_value=shapley_loaded2, shapley_value_key=3332)

In [None]:
shapley_loaded_test[0]['iters']

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[5], 
                    shapley_value=shapley_loaded_test, shapley_value_key=5120)

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[5], 
                    shapley_value=shapley_loaded_test_permutation, shapley_value_key=3332)

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[1], 
                    shapley_value=shapley_loaded_test_permutation, shapley_value_key=3332)

In [None]:
shapley_loaded_test_permutation[0]["iters"]

In [None]:
3332/196

In [None]:
shapley_loaded2[0]['iters']

In [None]:
fig=plot_figure_shapley(dataset_explainer["validation"], sample_idx_list=[110], 
                    shapley_value=shapley_loaded2, shapley_value_key=3332)

In [None]:
shapley_loaded1[0]["iters"]

In [None]:
shapley_loaded1

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd(game,
            n_iter=5000,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
game=PredictionGame(surrogate=explainer.surrogate,
                    sample=dataset_explainer["test"][0]
                    )

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=32,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=32,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=64,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=64,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=512,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
sgd_shapley_old_output.sum()

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old_output

In [None]:
game(np.ones((1,196)))[0][8]-game(np.zeros((1,196)))[0][8]

In [None]:
sgd_shapley_old_output.sum()

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd(game,
            n_iter=200000,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
sgd_shapley_old_output

In [None]:
np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))

In [None]:
sgd_shapley_old_output==np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))[:,0]

In [None]:
sgd_shapley_old_output

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
        f_x = game(np.ones((1, self.d), dtype=bool))[0]
        f_r = game(np.zeros((1, self.d), dtype=bool))[0]
        import ipdb
        ipdb.set_trace()
        
        v_M = f_x - f_r


In [None]:
game=PredictionGame(surrogate=explainer.surrogate,
                    sample=dataset_explainer["test"][0]
                    )

In [None]:
shapley_sampling=ShapleySampling(game,
                    batch_size=32,
                    n_samples=32*32,
                    detect_convergence=False,
                    thresh=0.01,
                    antithetical=False,
                    return_all=True,
                    bar=True,
                    verbose=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=200000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)
        
        
        k_record=[]
        x_i_record=[]
        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            
            k_record.append(k)
            x_i_record.append(x_i)
            
            
        
        x_i_record=np.array(x_i_record)
        y_i_record=[]
        
        for i in tqdm(range(int(np.ceil(len(x_i_record)/128)))):

            y_i = game(x_i_record[128*i:128*(i+1)].astype(int)) - null
            y_i_record.append(y_i)

        y_i_record=np.vstack(y_i_record)   
                
            
        for t in tqdm(range(1, n_iter+1)):
            # Compute y_i
#             print(x_i.astype(bool).shape)
            k=k_record[t-1]
            x_i=x_i_record[t-1]
            y_i = y_i_record[t-1]
            #print(game(x_i.astype(bool)[np.newaxis])[0], y_i)

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        return np.cumsum(phi_iterates, axis=0)/(np.arange(len(phi_iterates))+1).reshape(-1,1,1)
#         return phi_iterates
        #phi = np.mean(phi_iterates, axis=0)
        
        return phi


In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
#             res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
            res = x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
#         list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=self.p)
        
        
        k_record=[]
        x_i_record=[]
        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            
            k_record.append(k)
            x_i_record.append(x_i)
            
            
        
        x_i_record=np.array(x_i_record)
        y_i_record=[]
        
        for i in tqdm(range(int(np.ceil(len(x_i_record)/128)))):

            y_i = game(x_i_record[128*i:128*(i+1)].astype(int)) - null
            y_i_record.append(y_i)

        y_i_record=np.vstack(y_i_record)   
                
            
        for t in tqdm(range(1, n_iter+1)):
            # Compute y_i
#             print(x_i.astype(bool).shape)
            k=k_record[t-1]
            x_i=x_i_record[t-1]
            y_i = y_i_record[t-1]
            #print(game(x_i.astype(bool)[np.newaxis])[0], y_i)

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
#             grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)
            grad_i = self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        return np.cumsum(phi_iterates, axis=0)/(np.arange(len(phi_iterates))+1).reshape(-1,1,1)
#         return phi_iterates
        #phi = np.mean(phi_iterates, axis=0)
        
        return phi


In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1

            # Compute y_i
#             print(x_i.astype(bool).shape)
            y_i = game(x_i.astype(bool)[np.newaxis])[0] - null
            #print(game(x_i.astype(bool)[np.newaxis])[0], y_i)

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        return np.cumsum(phi_iterates, axis=0)/(np.arange(len(phi_iterates))+1).reshape(-1,1,1)
#         return phi_iterates
        #phi = np.mean(phi_iterates, axis=0)
        
        return phi


In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=500000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
1/np.sqrt(50000)

In [None]:
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
            import ipdb
            ipdb.set_trace()
            res = x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.01,
            step_type="sqrt",
            phi_0=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=10,
            step_type="sqrt",
            phi_0=False)

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=1,
            step_type="sqrt",
            phi_0=False)

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.01,
            step_type="sqrt",
            phi_0=False)

In [None]:
sgd_shapley_output[-1]

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
 
# grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)
grad_i = self._grad_F_i(phi, x_i, y_i, w_i)

# res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
res = x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output[50000-1]],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=5000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
sgd_shapley_output.shape

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output[10000-1]],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output[5000].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_output.shape

In [None]:
sgd_shapley_output__[:500].mean(axis=0)

In [None]:
(np.array([sgd_shapley_output__[:i+1].mean(axis=0) for i in range(len(sgd_shapley_output__[:100]))])==\
np.cumsum(sgd_shapley_output__[:100], axis=0)/(np.arange(len(sgd_shapley_output__[:100]))+1).reshape(-1,1,1)).all()

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output__[:50].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output__[:500].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output__[:50].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=100,
            step=0.1,
            step_type="inverse",
            phi_0=False)

In [None]:
game=PredictionGame(surrogate=explainer.surrogate,
                    sample=dataset_explainer["test"][0]
                    )

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.1,
            step_type="inverse",
            phi_0=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=200000,
            step=0.01,
            step_type="inverse",
            phi_0=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output_=sgd_shapley.sgd(game,
            n_iter=200000,
            step=0.1,
            step_type="inverse",
            phi_0=False)

In [None]:
shapley_sampling=ShapleySampling(game,
                    batch_size=128,
                    n_samples=8*128,
                    detect_convergence=False,
                    thresh=0.01,
                    antithetical=False,
                    return_all=True,
                    bar=True,
                    verbose=False)

In [None]:
4*128*196

In [None]:
sgd_shapley_output.shape

In [None]:
shapley_loaded[0]['values'][0].shape

In [None]:
len(shapley_sampling["values"])

In [None]:
sgd_shapley_output.shape

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': sgd_shapley_output,
            'std': [],
            'iters': list(range(1, len(sgd_shapley_output)+1))}}, shapley_value_key=int(200000))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output],
            'std': [],
            'iters': [50000]}}, shapley_value_key=int(50000))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output_],
            'std': [],
            'iters': [50000]}}, shapley_value_key=int(50000))

In [None]:
shapley_sampling["iters"]

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0: shapley_sampling}, shapley_value_key=int(50176))

In [None]:
shapley_loaded[0]["iters"]

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value=shapley_loaded, shapley_value_key=int(3332))

In [None]:
fig=plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:{
            'values': [sgd_shapley_output],
            'std': [],
            'iters': [50000]}}, shapley_value_key=int(50000))

In [None]:
fig.savefig("aaaa.png")

In [None]:
!pwd

In [None]:
        tracking_dict = 

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(100352))

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(100352))

In [None]:
4*196*128=10k

In [None]:
128*196

In [None]:
shapley_sampling.keys()

In [None]:
shapley_sampling["values"][0]

In [None]:
shapley_sampling["values"][0].sum(axis=0)

In [None]:
shapley_sampling["iters"]

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(25088))

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(100352))

In [None]:
plot_figure_shapley?

In [None]:
int(np.ceil(100 / 1))

In [None]:
shapley_sampling.keys()

In [None]:
shapley_sampling["values"]

In [None]:
shapley_sampling["iters"]

In [None]:
!gpustat

In [None]:
shapley_loaded[0]["iters"]

In [None]:
10000*0.2/60

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(320068))

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(32144))

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(32144))

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(32144))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(512))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(3072))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(1536))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(2048))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(3072))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(5120))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(100352))

In [None]:
shapley_loaded=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded=load_shapley("logs/vitbase_imagenette_surrogate_eval_test/extract_output/test/")

In [None]:
shapley_loaded[40]

In [None]:
shapley_loaded[0]["iters"]

In [None]:
196*17

In [None]:
shapley_loaded[0]["iters"]

In [None]:
shapley_loaded[0]["iters"]

In [None]:
200116/196

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 31], 
                    shapley_loaded, int(200116))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 31], 
                    shapley_loaded, int(3332))

In [None]:
1036/148

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_regexplainer_permutation_upfront_196/checkpoint-888/pytorch_model.bin", map_location="cpu")
regexplainer.load_state_dict(state_dict)
plot_figure(regexplainer, dataset_explainer["test"],  [0,  10, 20, 30, 31])
# epoch 6

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_regexplainer_permutation_upfront_3332/checkpoint-8732/pytorch_model.bin", map_location="cpu")
regexplainer.load_state_dict(state_dict)
plot_figure(regexplainer, dataset_explainer["test"],  [0,  10, 20, 30, 31])
#19

In [None]:
32*5

In [None]:
148*5

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_newsample_32/checkpoint-740/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 30, 40])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_newsample/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_newsample/checkpoint-14800/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_upfront_3200/checkpoint-14800/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_upfront_3200/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
dataset_explainer["test"][0]

In [None]:
plot_figure?

In [None]:
shapley_loaded[0]["iters"]

In [None]:
shapley_loaded[1]["values"][0].sum(axis=0)

In [None]:
shapley_loaded[1]["values"][-1].sum(axis=0)

In [None]:
dict(shapley_loaded[0], )

In [None]:
shapley_loaded[0]["iters"]

In [None]:
512*10

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 5120)

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 1536)

In [None]:
max(shapley_values_test["test"][0].keys())

In [None]:
99840+512

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 99840)

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 1536)

In [None]:
plot_figure_shapley(explainer, dataset["validation_explainer"], [0,  250, 500,1000],
                    shapley_values["validation"], 1536)

In [None]:
plot_figure(explainer, dataset["test_explainer"], [0,  10,20,30])

In [None]:
plot_figure(explainer, dataset["test"], [0,  250, 500,1000])

In [None]:
shapley_values = torch.load(
    "logs/vitbase_imagenette_surrogate_eval/shapley_train_val.pt",
    map_location="cpu",
)

In [None]:
shapley_values_test = torch.load(
    "logs/vitbase_imagenette_surrogate_eval/shapley.pt",
    map_location="cpu",
)

In [None]:
shapley_values_test.keys()

In [None]:
########################################################
# Initalize the explainer trainer
########################################################
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p):
    """Computes accuracy on a batch of predictions"""
    # import ipdb

    # ipdb.set_trace()
    # print(p.predictions.shape, p.label_ids.shape)
    # return metric.compute(
    #     predictions=np.argmax(p.predictions[:, 0, :], axis=1),
    #     references=p.label_ids,
    # )
    return {}

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    masks = torch.tensor(np.array([example["masks"] for example in examples]))

    return {
        "pixel_values": pixel_values,
        "labels": labels,
        "masks": masks,
    }

explainer_trainer = Trainer(
    model=explainer,
    args=training_args,
    train_dataset=dataset["train_explainer"] if training_args.do_train else None,
    eval_dataset=dataset["validation_explainer"] if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=explainer_image_processor,
    data_collator=collate_fn,
)

# ipdb.set_trace()
# print("explainer_trainer.label_names", explainer_trainer.label_names)
# print(explainer_trainer.evaluate(dataset["validation_explainer"]))

########################################################
# Detecting last checkpoint
#######################################################
last_checkpoint = None
if (
    os.path.isdir(training_args.output_dir)
    and training_args.do_train
    and not training_args.overwrite_output_dir
):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif (
        last_checkpoint is not None and training_args.resume_from_checkpoint is None
    ):
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )

########################################################
# Training
#######################################################
if training_args.do_train:
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = explainer_trainer.train(resume_from_checkpoint=checkpoint)
    explainer_trainer.save_model()
    explainer_trainer.log_metrics("train", train_result.metrics)
    explainer_trainer.save_metrics("train", train_result.metrics)
    explainer_trainer.save_state()

########################################################
# Evaluation
#######################################################
if training_args.do_eval:
    metrics = explainer_trainer.evaluate()
    explainer_trainer.log_metrics("eval", metrics)
    explainer_trainer.save_metrics("eval", metrics)

########################################################
# Write model card and (optionally) push to hub
#######################################################
kwargs = {
    "finetuned_from": explainer_args.explainer_model_name_or_path,
    "tasks": "image-classification",
    "dataset": data_args.dataset_name,
    "tags": ["image-classification", "vision"],
}
if training_args.push_to_hub:
    explainer_trainer.push_to_hub(**kwargs)
else:
    explainer_trainer.create_model_card(**kwargs)

In [None]:
    
    
    ########################################################
    # Initalize the explainer trainer
    ########################################################
    # Load the accuracy metric from the datasets package
    metric = evaluate.load("accuracy")

    # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p):
        """Computes accuracy on a batch of predictions"""
        # import ipdb

        # ipdb.set_trace()
        # print(p.predictions.shape, p.label_ids.shape)
        # return metric.compute(
        #     predictions=np.argmax(p.predictions[:, 0, :], axis=1),
        #     references=p.label_ids,
        # )
        return {}

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["labels"] for example in examples])
        masks = torch.tensor(np.array([example["masks"] for example in examples]))

        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "masks": masks,
        }

    explainer_trainer = Trainer(
        model=explainer,
        args=training_args,
        train_dataset=dataset["train_explainer"] if training_args.do_train else None,
        eval_dataset=dataset["validation_explainer"] if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=explainer_image_processor,
        data_collator=collate_fn,
    )

    # ipdb.set_trace()
    # print("explainer_trainer.label_names", explainer_trainer.label_names)
    # print(explainer_trainer.evaluate(dataset["validation_explainer"]))

    ########################################################
    # Detecting last checkpoint
    #######################################################
    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif (
            last_checkpoint is not None and training_args.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    ########################################################
    # Training
    #######################################################
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = explainer_trainer.train(resume_from_checkpoint=checkpoint)
        explainer_trainer.save_model()
        explainer_trainer.log_metrics("train", train_result.metrics)
        explainer_trainer.save_metrics("train", train_result.metrics)
        explainer_trainer.save_state()

    ########################################################
    # Evaluation
    #######################################################
    if training_args.do_eval:
        metrics = explainer_trainer.evaluate()
        explainer_trainer.log_metrics("eval", metrics)
        explainer_trainer.save_metrics("eval", metrics)

    ########################################################
    # Write model card and (optionally) push to hub
    #######################################################
    kwargs = {
        "finetuned_from": explainer_args.explainer_model_name_or_path,
        "tasks": "image-classification",
        "dataset": data_args.dataset_name,
        "tags": ["image-classification", "vision"],
    }
    if training_args.push_to_hub:
        explainer_trainer.push_to_hub(**kwargs)
    else:
        explainer_trainer.create_model_card(**kwargs)


if __name__ == "__main__":
    main()