In [None]:
import os
import sys

os.chdir('../')

# Set up explainers

In [None]:
!gpustat

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="7"

In [None]:
sys.argv=["train_regexplainer.py", "configs/vitbase_imagenette_shapley_regexplainer_upfront_512.json"]

In [None]:
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import copy
import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import ipdb
import numpy as np
import torch
import transformers
from datasets import load_dataset
from torch.nn import functional as F
from tqdm import tqdm
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from arguments import DataTrainingArguments, ExplainerArguments, SurrogateArguments
from models import (
    RegExplainerForImageClassification,
    RegExplainerForImageClassificationConfig,
    SurrogateForImageClassificationConfig,
)
from utils import (
    MaskDataset,
    configure_dataset,
    generate_mask,
    get_checkpoint,
    get_image_transform,
    load_attribution,
    load_shapley,
    log_dataset,
    read_eval_results,
    setup_dataset,
)

""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.32.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/image-classification/requirements.txt",
)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class OtherArguments:
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )

    train_shapley_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    validation_shapley_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    test_shapley_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    train_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for train",
        },
    )

    validation_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for validation",
        },
    )

    test_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for test",
        },
    )

In [None]:

########################################################
# Parse arguments
#######################################################
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser(
    (
        SurrogateArguments,
        ExplainerArguments,
        DataTrainingArguments,
        TrainingArguments,
        OtherArguments,
    )
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_args_into_dataclasses()

########################################################
# Setup logging
#######################################################
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

########################################################
# Correct cache dir if necessary
########################################################
if not os.path.exists(
    os.sep.join((data_args.dataset_cache_dir).split(os.sep, 2)[:2])
):
    if os.path.exists("/data2"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/data2"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    elif os.path.exists("/sdata"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/sdata"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    else:
        raise ValueError(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found"
        )

########################################################
# Set seed before initializing model.
########################################################
set_seed(training_args.seed)

########################################################
# Initialize our dataset and prepare it for the 'image-classification' task.
########################################################
dataset_original, labels, label2id, id2label = setup_dataset(
    data_args=data_args, other_args=other_args
)

########################################################
# Initialize explainer model
########################################################
explainer_config = AutoConfig.from_pretrained(
    explainer_args.explainer_config_name
    or explainer_args.explainer_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

if os.path.isfile(
    f"{explainer_args.explainer_model_name_or_path}/config.json"
) and (
    json.loads(
        open(f"{explainer_args.explainer_model_name_or_path}/config.json").read()
    )["architectures"][0]
    == "RegExplainerForImageClassification"
):
    explainer = RegExplainerForImageClassification.from_pretrained(
        explainer_args.explainer_model_name_or_path,
        from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
        config=explainer_config,
        cache_dir=explainer_args.explainer_cache_dir,
        revision=explainer_args.explainer_model_revision,
        token=other_args.token,
        ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )
else:
    surrogate_config = AutoConfig.from_pretrained(
        surrogate_args.surrogate_config_name
        or surrogate_args.surrogate_model_name_or_path,
        num_labels=len(labels),
        label2id=label2id,
        id2label=id2label,
        finetuning_task="image-classification",
        cache_dir=surrogate_args.surrogate_cache_dir,
        revision=surrogate_args.surrogate_model_revision,
        token=other_args.token,
    )
    surrogate_for_image_classification_config = SurrogateForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    )

    explainer_for_image_classification_config = RegExplainerForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_for_image_classification_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
        explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
        explainer_config=explainer_config,
        explainer_from_tf=bool(
            ".ckpt" in explainer_args.explainer_model_name_or_path
        ),
        explainer_cache_dir=explainer_args.explainer_cache_dir,
        explainer_revision=explainer_args.explainer_model_revision,
        explainer_token=other_args.token,
        explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )

    explainer = RegExplainerForImageClassification(
        config=explainer_for_image_classification_config,
    )
explainer_image_processor = AutoImageProcessor.from_pretrained(
    explainer_args.explainer_image_processor_name
    or explainer_args.explainer_model_name_or_path,
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

########################################################
# Configure dataset (set max samples, transforms, etc.)
########################################################
dataset_explainer = copy.deepcopy(dataset_original)
dataset_explainer = configure_dataset(
    dataset=dataset_explainer,
    image_processor=explainer_image_processor,
    training_args=training_args,
    data_args=data_args,
    train_augmentation=False,
    validation_augmentation=False,
    test_augmentation=False,
    logger=logger,
)

In [None]:
regexplainer=explainer

In [None]:
from models import (
    RegExplainerNormalizeForImageClassification,
    RegExplainerNormalizeForImageClassificationConfig
)

In [None]:
regexplainer_normalize_for_image_classification_config = RegExplainerNormalizeForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_for_image_classification_config,
    surrogate_from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
    explainer_config=explainer_config,
    explainer_from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
    explainer_cache_dir=explainer_args.explainer_cache_dir,
    explainer_revision=explainer_args.explainer_model_revision,
    explainer_token=other_args.token,
    explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
)

regexplainer_normalize = RegExplainerNormalizeForImageClassification(
    config=regexplainer_normalize_for_image_classification_config,
)


In [None]:
from models import (
    ObjExplainerForImageClassification,
    ObjExplainerForImageClassificationConfig
)

In [None]:
surrogate_for_image_classification_config

In [None]:
objexplainer_for_image_classification_config = ObjExplainerForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_for_image_classification_config,
    surrogate_from_tf=bool(
        ".ckpt" in surrogate_args.surrogate_model_name_or_path
    ),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
    explainer_config=explainer_config,
    explainer_from_tf=bool(
        ".ckpt" in explainer_args.explainer_model_name_or_path
    ),
    explainer_cache_dir=explainer_args.explainer_cache_dir,
    explainer_revision=explainer_args.explainer_model_revision,
    explainer_token=other_args.token,
    explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
)

objexplainer = ObjExplainerForImageClassification(
    config=objexplainer_for_image_classification_config,
)

# Estimate FLOPS

In [None]:
import deepspeed

In [None]:
from torch.utils.data import DataLoader

In [None]:
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler

In [None]:
profile_step=3
device="cpu"

In [None]:
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str


if __name__ == '__main__':
    model = ViTB32()
    model.eval()

    input = cv2.imread("../input.jpg")
    input = cv2.resize(input, (224, 224))
    input = torch.from_numpy(input).permute(2, 0, 1)
    input = input[None,:,:,:].float()

    flop = FlopCountAnalysis(model, input)
    print(flop_count_table(flop, max_depth=4))
    print(flop_count_str(flop))
    print(flop.total())

In [None]:
flops/8/1e+9, macs/8/1e+9

In [None]:
class surrogate_warpper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        print(type(x))
        print(type(masks))
        return self.model(x, masks, return_loss=False)
    
surrogate_warpped=surrogate_warpper(explainer.surrogate)

In [None]:
flop = FlopCountAnalysis(surrogate_warpped, pixel_values.to(device))
print(flop_count_table(flop, max_depth=4))
print(flop_count_str(flop))

In [None]:
batch_size=1

items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        print(step)
        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        masks=torch.Tensor(
                    np.random.choice([0,1], 
                         size=(batch_size, 256, 196), 
                         replace=True)
        )        
        
        
        if step == profile_step:
            flop = FlopCountAnalysis(surrogate_warpped, pixel_values.to(device))
            print(flop_count_table(flop, max_depth=4))
            print(flop_count_str(flop))
            print(flop.total())            
            break
        else:
            loss = surrogate_warpped(pixel_values.to(device)
                           )

### FLOPs-surrogate

In [None]:
profile_step=3

In [None]:
from ptflops import get_model_complexity_info

In [None]:
class surrogate_warpper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return self.model(x, masks, return_loss=False)

In [None]:
batch_size=2

items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        print(step)
        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        masks=torch.Tensor(
                    np.random.choice([0,1], 
                         size=(batch_size, 256, 196), 
                         replace=True)
        )
        
        if step==profile_step:
            macs, params = get_model_complexity_info(surrogate_warpper(explainer.surrogate), 
                                                     (2,3,224,224), 
                                                     as_strings=True,
                                                     print_per_layer_stat=True, verbose=True)            
            break
        else:
            loss = explainer.surrogate(pixel_values=pixel_values,
                                       masks=masks,
                                       return_loss=False)        

In [None]:
try:
    surrogate_prof
except:
    surrogate_prof=FlopsProfiler(explainer.surrogate)

In [None]:
batch_size=1
items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):


        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        masks=torch.Tensor(
                    np.random.choice([0,1], 
                         size=(batch_size, 32, 196), 
                         replace=True)
        )
        
        if step == profile_step:
            surrogate_prof.start_profile()        

        loss = explainer.surrogate(pixel_values=pixel_values.to(device), 
                                   masks=masks,
                                   return_loss=False)
        
        if step == profile_step:
            flops = surrogate_prof.get_total_flops(as_string=False)
            params = surrogate_prof.get_total_params(as_string=False)
            macs = surrogate_prof.get_total_macs(as_string=False)
            duration = surrogate_prof.get_total_duration(as_string=False)
            
            surrogate_prof.print_model_profile(profile_step=profile_step)
            surrogate_prof.end_profile() 
            break

In [None]:
flops/1e+12, macs/1e+12

In [None]:
flops,macs*2

In [None]:
macs/32

In [None]:
try:
    explainer_prof
except:
    explainer_prof=FlopsProfiler(explainer)

In [None]:
batch_size=32

items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        if step == profile_step:
            explainer_prof.start_profile()

        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])

        loss = explainer(pixel_values=pixel_values.to(device), 
                                   return_loss=False)
        
        if step == profile_step:
            flops = explainer_prof.get_total_flops(as_string=False)
            params = explainer_prof.get_total_params(as_string=False)
            macs = explainer_prof.get_total_macs(as_string=False)
            duration = explainer_prof.get_total_duration(as_string=False)
            #explainer_prof.print_model_profile(profile_step=profile_step, output_file='profiler_log.txt')
            explainer_prof.print_model_profile(profile_step=profile_step)
            explainer_prof.end_profile() 
            break

In [None]:
batch_size=32

items=[dataset_explainer["train"][i] for i in range(batch_size)]

optimizer=torch.optim.AdamW(explainer.parameters())

for step in range(5):


    pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
    masks=torch.Tensor(
                np.random.choice([0,1], 
                     size=(batch_size, 32, 196), 
                     replace=True)
    )        
    model_outputs=torch.randn((batch_size, 32, 10))

    optimizer.zero_grad()
    loss = explainer(pixel_values=pixel_values.to(device),
                     masks=masks,
                     model_outputs=model_outputs,
                     return_loss=True)

    if step == profile_step:
        explainer_prof.start_profile() 
        
    loss.loss.backward()
    optimizer.step()        
        
    if step == profile_step:
        flops = explainer_prof.get_total_flops(as_string=False)
        params = explainer_prof.get_total_params(as_string=False)
        #explainer_prof.print_model_profile(profile_step=profile_step, output_file='profiler_log.txt')
        explainer_prof.print_model_profile(profile_step=profile_step)
        explainer_prof.end_profile() 
        break

### FLOPs-reg explainer

In [None]:
device="cuda:0"

In [None]:
regexplainer.to(device)

In [None]:
try:
    regexplainer_prof
except:
    regexplainer_prof=FlopsProfiler(regexplainer)
    

In [None]:
print(batch_size)
items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        
        if step == profile_step:
            regexplainer_prof.start_profile()        

        loss = regexplainer(pixel_values=pixel_values.to(device), return_loss=False)
        
        if step == profile_step:
            flops = regexplainer_prof.get_total_flops(as_string=False)
            params = regexplainer_prof.get_total_params(as_string=False)
            macs = regexplainer_prof.get_total_macs(as_string=False)
            duration = regexplainer_prof.get_total_duration(as_string=False)            
            regexplainer_prof.print_model_profile(profile_step=profile_step)
            regexplainer_prof.end_profile() 
            break

In [None]:
optimizer=torch.optim.AdamW(regexplainer.parameters())

In [None]:
for step in tqdm(range(10)):
    if step == profile_step:
        prof.start_profile()

    pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
    shapley_ground_truth=torch.randn(32,196,10)

    loss = regexplainer(pixel_values=pixel_values.to(device), 
                        shapley_values=shapley_ground_truth.to(device),
                        return_loss=True)
    
   

    if step == profile_step:
        flops = prof.get_total_flops(as_string=True)
        params = prof.get_total_params(as_string=True)
        prof.print_model_profile(profile_step=profile_step)
        prof.end_profile() 
        break
        
    loss.backward()
    optimizer.step()         

In [None]:
with torch.no_grad():
for step in range(100):
    if step == profile_step:
        prof.start_profile()
        
    pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])

    loss = regexplainer(pixel_values=pixel_values)
    
    sdds

    if step == profile_step:
        flops = prof.get_total_flops(as_string=True)
        params = prof.get_total_params(as_string=True)
        prof.print_model_profile(profile_step=profile_step)
        prof.end_profile()

    loss.backward()
    optimizer.step()

In [None]:
prof = FlopsProfiler(regexplainer)

In [None]:
deepspeed.profiling.flops_profiler.profiler.FlopsProfiler(model, ds_engine=None, recompute_fwd_factor=0.0)

# Visualizing results

In [None]:
device="cuda:0"

In [None]:
regexplainer.to(device)

In [None]:
regexplainer_normalize.to(device)

In [None]:
objexplainer.to(device)

In [None]:
import glob
import pandas as pd
import json
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import cm

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=1

In [None]:
from scipy import stats

In [None]:
def get_ground_truth_metric_with_explainer(
    attribution_values,
    explainer,
    dataset,
    iters_ground_truth,
    meta_info,
    check_class_efficiency=True,
    ground_truth_key_select=None,
    transform_mode=None
):
    record_dict_list= []
    ground_truth_list = []
    estimated_list = []
    explainer.eval()
    
    # for sample_idx, tracking_dict in tqdm(attribution_values.items()):
    for sample_idx, tracking_dict in tqdm(
        attribution_values.items() if ground_truth_key_select is None else {key: attribution_values[key] for key in ground_truth_key_select}.items()
    ):
        # Prepare input and make prediction.
        data = dataset[sample_idx]
        with torch.no_grad():
            estimated = explainer(pixel_values=data["pixel_values"].unsqueeze(0).to(explainer.device), return_loss=False)["logits"][0]
            
        # Prepare ground truth values.
        if isinstance(tracking_dict["iters"], np.ndarray):
            tracking_dict["iters"] = tracking_dict["iters"].tolist()
        
        ground_truth = tracking_dict["values"][tracking_dict["iters"].index(iters_ground_truth)]

        if transform_mode is None:
            pass
        elif transform_mode=="global":
            pass
        elif transform_mode=="sqrt":
            pass        
        elif transform_mode=="perinstance":
            ground_truth = ground_truth / np.linalg.norm(ground_truth, axis=(0, 1), keepdims=True)
        elif transform_mode=="perinstanceperclass":
            ground_truth = ground_truth / np.linalg.norm(ground_truth, axis=0, keepdims=True)
        else:
            raise ValueError(transform_mode)
        
        estimated = estimated.T.cpu().detach().numpy()
        ground_truth_list.append(ground_truth)
        estimated_list.append(estimated)
        
        # Calculate MSE.
        diff = estimated - ground_truth
        mse_class= (diff * diff).sum(axis=0)

        # Calculate other metrics.
        record = {
            "sample_idx": sample_idx,
            "mse_all": mse_class.mean(),
            "pearsonr_all": stats.pearsonr(estimated.flatten(), ground_truth.flatten())[0],
            "pearsonr_all_per_class": np.mean([stats.pearsonr(estimated[:, class_idx], ground_truth[:, class_idx])[0] for class_idx in np.arange(ground_truth.shape[1])]),
            "spearmanr_all": stats.spearmanr(estimated.flatten(), ground_truth.flatten())[0],
            "spearmanr_all_per_class": np.mean([stats.spearmanr(estimated[:, class_idx], ground_truth[:, class_idx])[0] for class_idx in np.arange(ground_truth.shape[1])]),
            "sign_agreement_all": ((estimated>0)==(ground_truth>0)).astype(int).mean(),
        }
        
        
        if check_class_efficiency:
            target_class_idx=np.argmax(tracking_dict["values"][0].sum(axis=0))
            assert data["labels"]==target_class_idx
            
            record.update({
                "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
                "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(), 
                "pearsonr_target": stats.pearsonr(estimated[:, target_class_idx], ground_truth[:, target_class_idx])[0],                
                "spearmanr_target": stats.spearmanr(estimated[:, target_class_idx], ground_truth[:, target_class_idx])[0],                
            })        
        
        # Append result for this image.
        record.update(meta_info)
        record_dict_list.append(record)

    # Calculate global metrics.
    ground_truth_all = np.array(ground_truth_list)
    estimated_all = np.array(estimated_list)
    pearson = stats.pearsonr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    spearman = stats.spearmanr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    for record in record_dict_list:
        record['pearson_global'] = pearson
        record['spearman_global'] = spearman
        
    return record_dict_list

In [None]:
def get_ground_truth_metric_with_value(
    attribution_values_ground_truth,
    iters_ground_truth,
    attribution_values_calculated,
    iters_calculated,
    meta_info,
    check_class_efficiency=True,
    ground_truth_key_select=None,
):
    record_dict_list = []
    ground_truth_list = []
    estimated_list = []

    for sample_idx, tracking_dict_ground_truth in tqdm(
        attribution_values_ground_truth.items() if ground_truth_key_select is None else {key: attribution_values_ground_truth[key] for key in ground_truth_key_select}.items()
    ):
        
        tracking_dict_calculated = attribution_values_calculated[sample_idx]

        if isinstance(tracking_dict_ground_truth["iters"], np.ndarray):
            tracking_dict_ground_truth["iters"] = tracking_dict_ground_truth[
                "iters"
            ].tolist()

        if isinstance(tracking_dict_calculated["iters"], np.ndarray):
            tracking_dict_calculated["iters"] = tracking_dict_calculated[
                "iters"
            ].tolist()

        # Prepare ground truth values and model's estimates.
        ground_truth = tracking_dict_ground_truth["values"][
            tracking_dict_ground_truth["iters"].index(iters_ground_truth)
        ]
        estimated = tracking_dict_calculated["values"][
            tracking_dict_calculated["iters"].index(iters_calculated)
        ]
        ground_truth_list.append(ground_truth)
        estimated_list.append(estimated)

        # Calculate MSE.
        diff = estimated - ground_truth
        mse_class = (diff * diff).sum(axis=0)
       
        
        # Calculate other metrics.
        record = {
            "sample_idx": sample_idx,
            "mse_all": mse_class.mean(),

            "pearsonr_all": stats.pearsonr(estimated.flatten(), ground_truth.flatten())[0],
            "pearsonr_all_per_class": np.mean([
                stats.pearsonr(estimated[:, class_idx], ground_truth[:, class_idx])[0]
                for class_idx in np.arange(ground_truth.shape[1])
            ]),

            "spearmanr_all": stats.spearmanr(estimated.flatten(), ground_truth.flatten())[0],
            "spearmanr_all_per_class": np.mean([
                stats.spearmanr(estimated[:, class_idx], ground_truth[:, class_idx])[0]
                for class_idx in np.arange(ground_truth.shape[1])
            ]),
            "sign_agreement_all": ((estimated>0)==(ground_truth>0)).astype(int).mean(),
        }
        
        
        if check_class_efficiency:
            target_class_idx_ground_truth = np.argmax(tracking_dict_ground_truth["values"][0].sum(axis=0))
            target_class_idx_calculated = np.argmax(tracking_dict_calculated["values"][0].sum(axis=0))            
            assert target_class_idx_ground_truth == target_class_idx_calculated        
            
            record.update({                
                "mse_target": mse_class[
                    np.arange(len(mse_class)) == target_class_idx_ground_truth
                ].mean(),
                "mse_nontarget": mse_class[
                    np.arange(len(mse_class)) != target_class_idx_ground_truth
                ].mean(),                
                
                "pearsonr_target": stats.pearsonr(
                    estimated[:, target_class_idx_ground_truth],
                    ground_truth[:, target_class_idx_ground_truth],
                )[0],  
                
                "spearmanr_target": stats.spearmanr(
                    estimated[:, target_class_idx_ground_truth],
                    ground_truth[:, target_class_idx_ground_truth],
                )[0],                
                
            })        

        # Append result for this image.
        record.update(meta_info)
        record_dict_list.append(record)

    # Calculate global metrics.
    ground_truth_all = np.array(ground_truth_list)
    estimated_all = np.array(estimated_list)
    pearson = stats.pearsonr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    spearman = stats.spearmanr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    for record in record_dict_list:
        record['pearson_global'] = pearson
        record['spearman_global'] = spearman

    return record_dict_list

# Load data

In [None]:
shapley_loaded_dict={}

In [None]:
# 1M ground truth (train)
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                  attribution_name="shapley")

In [None]:
# 1M ground truth (test)
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test", attribution_name="shapley")

In [None]:
# Reg-AO targets permutation (train)
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train",
                 attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
# Reg-AO targets KernelSHAP (train)
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
    attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
# Reg-AO targets SGD-Shapley (train)
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train",
             attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
# Reg-AO targets KernelSHAP (antithetical)
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train",
             attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
banzhaf_loaded_dict={}

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                 attribution_name="banzhaf")

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                 attribution_name="banzhaf")

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
                 attribution_name="banzhaf",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train",
                 attribution_name="banzhaf",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
for key in banzhaf_loaded_dict.keys():
    print(key, len(banzhaf_loaded_dict[key])) 

In [None]:
lime_loaded_dict={}

In [None]:
lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                 attribution_name="lime")

In [None]:
lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test"]\
=load_attribution("logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                 attribution_name="lime")

In [None]:
lime_loaded_dict["logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train",
                 attribution_name="lime",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])                  

In [None]:
# lime_loaded_dict["logs/vitbase_imagenette_surrogate_binomial_eval_validation/extract_output/validation"]\
# =load_attribution("logs/vitbase_imagenette_surrogate_binomial_eval_validation/extract_output/validation",
#                  attribution_name="lime")

# Comparison of the distribution of norms between different feature attribution methods (Figure 14)

In [None]:
metric_list_plot_norm=[]
for idx in np.random.RandomState(seed=42).permutation(list(range(9469)))[:100].tolist():
    shapley_sample=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"][idx]
    shapley_sample=shapley_sample["values"][shapley_sample["iters"].index(1000000)] 
    
    metric_list_plot_norm.append({"sample_idx": idx,
                                 "norm": np.linalg.norm(shapley_sample),
                                 "method_type": "KernelSHAP"
                                })
    
for idx in np.random.RandomState(seed=42).permutation(list(range(9469)))[:100].tolist():
    banzhaf_sample=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"][idx]
    banzhaf_sample=banzhaf_sample["values"][banzhaf_sample["iters"].index(1000000)]

    metric_list_plot_norm.append({"sample_idx": idx,
                                 "norm": np.linalg.norm(banzhaf_sample),
                                 "method_type": "BanzhafMSR"
                                })  
    
for idx in np.random.RandomState(seed=42).permutation(list(range(9469)))[:100].tolist():
    lime_sample=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"][idx]
    lime_sample=lime_sample["values"][lime_sample["iters"].index(1000000)]

    metric_list_plot_norm.append({"sample_idx": idx,
                                 "norm": np.linalg.norm(lime_sample),
                                 "method_type": "LIME"
                                })      

In [None]:
metric_list_plot_norm_df=pd.DataFrame(metric_list_plot_norm)

In [None]:
fig = plt.figure(figsize=(3*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 3, hspace=0.3)

axd={}
for idx1, method_type in enumerate(["KernelSHAP", "BanzhafMSR", "LIME"]):
    ax=plt.Subplot(fig, box1[idx1])
    fig.add_subplot(ax)

    plot_key=(method_type)
    axd[plot_key]=ax  
    
for idx1, method_type in enumerate(["KernelSHAP", "BanzhafMSR", "LIME"]):

    plot_key=(method_type)                      
    
    metric_list_plot_norm_df[metric_list_plot_norm_df["method_type"]==method_type]["norm"].hist(ax=axd[plot_key], bins=10)
    
    axd[plot_key].set_title({"KernelSHAP": "Shapley values", "BanzhafMSR": "Banzhaf values", "LIME": "LIME"}[method_type])
    
    axd[plot_key].set_xlabel("L2 norm")
    axd[plot_key].set_ylabel("Count")
    
        

In [None]:
fig.savefig("logs/plots/"+f"feature_attribution_distribution.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"feature_attribution_distribution.pdf", bbox_inches='tight')

# Load training target

In [None]:
metric_list_value_shapley=[]

In [None]:
for num_subsets in [512, 1024, 2048, 3072]:
    metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
                                                 })

for num_subsets in [512, 1024, 2048, 3072]:
    metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train",
                                                 })

for num_subsets in [196, 392, 588, 1176, 3136]:
    metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train",
                                                 })

for num_subsets in [258, 514, 1026, 2050, 4098, 5122, 9986]:
    metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train",
                                                 })   

In [None]:
if False:
    torch.save(metric_list_value_shapley, "logs/experiment_results/metric_list_value_shapley.pt")

In [None]:
metric_list_value_banzhaf=[]

In [None]:
for num_subsets in [5, 10, 20, 30, 40, 50, 60, 70, 80, 90]:
    metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train",
                                                 },
                                       check_class_efficiency=False)

for num_subsets in [100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]:
    metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
                                                 },
                                       check_class_efficiency=False)
#                                        ground_truth_key_select=\
#                                     banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"].keys()
#                                                          )
for num_subsets in list(range(6000, 100000+1000, 10000)):
    metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                 },
                                       check_class_efficiency=False)    

# for num_subsets in [100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]:
#     metric_list_value+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
#                                        iters_ground_truth=1000000, 
#                                        attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train",
#                                                  },
#                                         check_class_efficiency=False,
#                                        ground_truth_key_select=\
#                                     banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"].keys()                                                         
#                                                          )                                                          

In [None]:
# for num_subsets in [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]+list(range(6000, 100000+1000, 10000)):
#     metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"], 
#                                        iters_ground_truth=1000000, 
#                                        attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
#                                                  },
#                                        check_class_efficiency=False)

In [None]:
if False:
    torch.save(metric_list_value_banzhaf, "logs/experiment_results/metric_list_value_banzhaf.pt")

In [None]:
metric_list_value_lime=[]

In [None]:
for num_subsets in list(range(128, 3200, 128)):
    metric_list_value_lime+=get_ground_truth_metric_with_value(attribution_values_ground_truth=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=lime_loaded_dict["logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train",
                                                 },
                                       check_class_efficiency=False)
#                                        ground_truth_key_select=\
#                                     banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"].keys()
#                                                          )


for num_subsets in list(range(100000, 1000000+100000, 100000)):
    metric_list_value_lime+=get_ground_truth_metric_with_value(attribution_values_ground_truth=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                 },
                                       check_class_efficiency=False)

In [None]:
if False:
    torch.save(metric_list_value_lime, "logs/experiment_results/metric_list_value_lime.pt")

# Evaluate explainer

In [None]:
def compare_checkpoint_value(current_checkpoint, best_checkpoint):
    if current_checkpoint<best_checkpoint:
        return "before"
    elif current_checkpoint==best_checkpoint:
        return "best"
    elif current_checkpoint>best_checkpoint:
        return "after"
    else:
        raise ValueError

In [None]:
def get_best_model_checkpoint(model_path):
    if os.path.exists(model_path+"/trainer_state.json"):
        with open(model_path+"/trainer_state.json") as f:
            trainer_state = json.load(f) 
        return trainer_state["best_model_checkpoint"]
    else:
        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        with open(checkpoint_path_list[-1]+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)                
        return checkpoint_trainer_state["best_model_checkpoint"]

In [None]:
from shutil import rmtree
#for model_path in glob.glob("/sdata/chanwkim/xai-amortization/logs_0901/*"):
for model_path in glob.glob("logs/*"):
    print(model_path)
    if len(glob.glob(model_path+"/checkpoint-*"))==0:
        #print("pass", model_path)
        pass
    else:
        best_checkpoint_path=get_best_model_checkpoint(model_path)
        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        
        epoch_step_count=pd.Series([int(i.split('-')[-1]) for i in checkpoint_path_list]).sort_values().diff().min()
        checkpoint_to_delete=[checkpoint_path for checkpoint_path in checkpoint_path_list if int(checkpoint_path.split('-')[-1])-int(best_checkpoint_path.split('-')[-1])>epoch_step_count*10]
        print(len(checkpoint_to_delete))
        if len(checkpoint_to_delete)!=0:
            print(best_checkpoint_path)
            aa
            for checkpoint_path in tqdm(checkpoint_to_delete):
                rmtree(checkpoint_path)
                print(checkpoint_path)
                
            sdsds

## load Shapley

In [None]:
metric_list_shapley=[]

### Reg-AO (upfront, regression)

In [None]:
for num_subsets in [512, 1024, 2048, 3072]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:20]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })        

In [None]:
for num_subsets in [512, 3072]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:20]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })        

In [None]:
for model_path in ["logs/vitbase_imagenette_shapley_regexplainer_upfront_1024_numtrain_4735",
"logs/vitbase_imagenette_shapley_regexplainer_upfront_2048_numtrain_2367",
"logs/vitbase_imagenette_shapley_regexplainer_upfront_3072_numtrain_1578"
]:
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:50]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })        

### Reg-AO (upfront, permutation)

In [None]:
for num_subsets in [196, 392, 588, 1176, 3136]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:20]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })        

### SGD-shapley

In [None]:
for num_subsets in [9986]:
    model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}"

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:20]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })        

In [None]:
if False:
    torch.save(metric_list_shapley, "logs/experiment_results/metric_list_shapley.pt")
    metric_list_shapley=torch.load("logs/experiment_results/metric_list_shapley.pt")

In [None]:
metric_list_shapley_obj=[]
for model_path in ["logs/vitbase_imagenette_shapley_objexplainer_newsample_32"]:
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        objexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley_obj+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=objexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })


        metric_list_shapley_obj+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=objexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })        

In [None]:
if False:
    torch.save(metric_list_shapley_obj, "logs/experiment_results/metric_list_shapley_obj.pt")
    metric_list_shapley_obj=torch.load("logs/experiment_results/metric_list_shapley_obj.pt")

## load Banzhaf

In [None]:
banzhaf_loaded_dict.keys()

In [None]:
metric_list_banzhaf=[]

In [None]:
for target_transform_mode, num_subsets_list in zip(["global", "perinstanceperclass"],
                                  [[10,100,500], [10, 100,500]]):
    regexplainer_normalize.config.target_transform_mode=target_transform_mode
    for num_subsets in num_subsets_list:
        model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_{target_transform_mode}_{num_subsets}"

        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        for checkpoint_path in tqdm(checkpoint_path_list[:]):
            checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
            with open(checkpoint_path+"/trainer_state.json") as f:
                checkpoint_trainer_state = json.load(f)

            regexplainer_normalize.load_state_dict(checkpoint_state_dict)

            metric_list_banzhaf+=get_ground_truth_metric_with_explainer(attribution_values=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["test"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },                                   
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)


            metric_list_banzhaf+=get_ground_truth_metric_with_explainer(attribution_values=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["train"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)        

In [None]:
if False:
    torch.save(metric_list_banzhaf, "logs/experiment_results/metric_list_banzhaf.pt")
    metric_list_banzhaf=torch.load("logs/experiment_results/metric_list_banzhaf.pt")

## load LIME

In [None]:
metric_list_lime=[]

In [None]:
for target_transform_mode, num_subsets_list in zip(["global", "perinstanceperclass"],
                                  [[256,512], [256,512]]):    
    regexplainer_normalize.config.target_transform_mode=target_transform_mode
    for num_subsets in num_subsets_list:
        model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_{target_transform_mode}_{num_subsets}"

        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        for checkpoint_path in tqdm(checkpoint_path_list[:]):
            checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
            with open(checkpoint_path+"/trainer_state.json") as f:
                checkpoint_trainer_state = json.load(f)

            regexplainer_normalize.load_state_dict(checkpoint_state_dict)

            metric_list_lime+=get_ground_truth_metric_with_explainer(attribution_values=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["test"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },                                   
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)


            metric_list_lime+=get_ground_truth_metric_with_explainer(attribution_values=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["train"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)        

In [None]:
df_temp=pd.DataFrame(metric_list_banzhaf)
df_temp[(df_temp["is_best_checkpoint"]=="best")&
        (df_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train")
].groupby(["model_path"])["mse_all"].agg(["mean", "std"])

In [None]:
if False:
    torch.save(metric_list_lime, "logs/experiment_results/metric_list_lime.pt")
    metric_list_lime=torch.load("logs/experiment_results/metric_list_lime.pt")

# Comparison of the estimation error between different per-example estimators for Shapley value feature attributions across varying numbers of samples.

In [None]:
import copy

In [None]:
metric_list_plot=[]
for metric in metric_list_value_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]})',
             "method_type": 'KernelSHAP',
             "antithetical": False
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        continue
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]}, antithetical)',
             "method_type": 'KernelSHAP',
             "antithetical": True
            }
        )    
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]})',
             "method_type": 'Permutation',
             "antithetical": False
            }
        )       
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]}, antithetical)',
             "method_type": 'Permutation',
             "antithetical": True
            }
        )  
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]}, newsample, {metric_temp["nth"]})',
             "method_type": 'Permutation',
             "antithetical": False
            }
        )  
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'SGD-Shapley ({metric_temp["num_subsets"]}, antithetical)',
             "method_type": 'SGD-Shapley',
             "antithetical": True
            }
        )          
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
metric_list_plot_df.groupby(["method_name","num_subsets",
                            "true_name", "estimated_name",
                            ])[["pearsonr_all_per_class", "spearmanr_all_per_class",
                               "pearsonr_all", "spearmanr_all"
                               ]].mean().reset_index().sort_values("num_subsets")

In [None]:
(metric_list_plot_df[["method_type","antithetical"]]).value_counts()

In [None]:
fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    ax=plt.Subplot(fig, box1[idx1])
    fig.add_subplot(ax)

    plot_key=(idx1)
    axd[plot_key]=ax   
      
    
    
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):

    plot_key=(idx1)


    if metric=="each":
        sns.barplot(
            x="mse_all",
            y="method_name",
        #     hue="method",
        #     style="AO type",
        #     style_order=["Reg-AO", "Obj-AO"],
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_ylabel("Method")#, fontsize=20)
        axd[plot_key].set_xlabel("MSE (all classes)")#, fontsize=20)

        # xaxis
        # axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
        # axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))            
        # axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        # axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        # axd[plot_key].set_xlim(0, 40)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
#         axd[plot_key].set_ylim(0, 0.1)

        axd[plot_key].tick_params(axis='x', which='major', rotation=-90, labelsize=20, labelright=True)
        axd[plot_key].tick_params(axis='y', which='major', rotation=0, labelsize=10)  
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

#         leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  
            
    elif metric=="MSE_all":
        sns.lineplot(
            y="mse_all",
            x="num_subsets",
#             style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_title("Error")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Error")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3100)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 0.45)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

#         for line in leg.get_lines():
#             line.set_linewidth(3.0) 
            
    elif metric=="pearsonr_all":
        sns.lineplot(
            y="pearsonr_all",
            x="num_subsets",
#             style="antithetical",
            hue="method_type",
            #palette="tab10",
            errorbar=None, 
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Correlation")
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Pearson Correlation")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3200)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        leg.remove()
#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  
            
    elif metric=="spearmanr_all":
        sns.lineplot(
            y="spearmanr_all",
            x="num_subsets",
#             style="antithetical",
            hue="method_type",
            errorbar=None,
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Rank correlation")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Spearman Correlation")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3200)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
        leg.remove()
#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  

        
    elif metric=="sign_agreement_all":
        sns.lineplot(
            y="sign_agreement_all",
            x="num_subsets",
#             style="antithetical",
            hue="method_type",
            errorbar=None,
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Sign Agreement")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Sign Agreement")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3200)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
        leg.remove()            

#         for line in leg.get_lines():
#             line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"training_target_quality_shapley.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_quality_shapley.pdf", bbox_inches='tight')

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"][0]["iters"]

In [None]:
def prettify_metric_name(metric_name):
    if metric_name=="mse_all":
        return "Error"
    elif metric_name=="pearsonr_all":
        return "Correlation"
    elif metric_name=="spearmanr_all":
        return "Rank Correlation"
    elif metric_name=="sign_agreement_all":
        return "Sign Agreement"    
    elif metric_name=="pearsonr_all_per_class":
        return "Pearson corr. (Per classes)"
    elif metric_name=="spearmanr_all_per_class":
        return "Spearman corr. (Per classes)"
    else:
        raise ValueError(metric_name)
         
        
def prettify_method_type(method_type):
    if method_type=="KernelSHAP":
        return "KernelSHAP"
    elif method_type=="Permutation":
        return "Permutation Sampling"
    elif method_type=="SGD-Shapley":
        return "SGD-Shapley"  
    elif method_type=="BanzhafMSR":
        return "Banzhaf (MSR)"  
    elif method_type=="LIME":
        return "LIME"      
    else:
        raise ValueError(method_type)

In [None]:
def prettify_transform_mode(transform_mode):
    if transform_mode=="global":
        return ""
    elif transform_mode=="sqrt":
        return "Sqrt"       
    elif transform_mode=="perinstance":
        return "Per-sample Norm."
    elif transform_mode=="perinstanceperclass":
        #return "Per-sample/Per-class Norm."
        return "Per-Label Scaling"
    else:
        raise ValueError(transform_mode)    

In [None]:
# KernelSHAP reference line

metric_list_ground_truth_shapley=[]

for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]+list(range(512*20, 512*20*20, 512*20)):
    metric_list_ground_truth_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()
                                      
                                      ))
    
    
metric_list_plot_reference=[]
for metric in metric_list_ground_truth_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        continue
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
            }
        )        
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_reference.append(metric_temp)    

In [None]:
metric_list_plot_target=[]
for metric in metric_list_value_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train"
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        continue
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]})',
             "method_type": 'Permutation',
             "antithetical": False,
             "split": "train"         
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        continue
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        continue
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'SGD-Shapley ({metric_temp["num_subsets"]})',
             "method_type": 'SGD-Shapley',
             "antithetical": True,
             "split": "train"         
            }
        )        
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)
    
metric_list_plot_explainer=[]    
for metric in metric_list_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (KernelSHAP, {num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )    
    
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            

        metric_temp.update(
            {"method_name": f'Reg-AO (Permutation, {num_subsets})',
             "method_type": 'Permutation',
             "antithetical": False,
             "num_subsets":num_subsets,
             "split": split,
            }
        )   
        
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue
    #f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            

        metric_temp.update(
            {"method_name": f'Reg-AO (SGD-Shapley, {num_subsets})',
             "method_type": 'SGD-Shapley',
             "antithetical": True,
             "num_subsets":num_subsets,
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in ["logs/vitbase_imagenette_shapley_regexplainer_upfront_1024_numtrain_4735", 
                                       "logs/vitbase_imagenette_shapley_regexplainer_upfront_2048_numtrain_2367",
                                       "logs/vitbase_imagenette_shapley_regexplainer_upfront_3072_numtrain_1578"] and \
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]: 
        continue        
        
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]: 
        continue
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue 
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue         
        
    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)


metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df["is_best_checkpoint"][(metric_list_plot_explainer_df['method_type']=="SGD-Shapley")&
                              (metric_list_plot_explainer_df['epoch']==20)
                             ]="best"
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                                                           ]


metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
metric_list_plot_target_df_=metric_list_plot_target_df.copy()
metric_list_plot_target_df_["split"]="test"
idx_mapping=dict(zip(np.random.RandomState(seed=42).permutation(list(range(9469)))[:100],
list(range(100))))
metric_list_plot_target_df_["sample_idx"]=metric_list_plot_target_df_["sample_idx"].map(lambda x: idx_mapping[x])
metric_list_plot_target_df=pd.concat([metric_list_plot_target_df, metric_list_plot_target_df_])

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
# sdsd
# metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
# [['sample_idx',  "num_subsets", 
# 'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
# 'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
[['sample_idx', "num_subsets", 
'mse_target_explainer', 'mse_target_target',
'mse_nontarget_explainer', 'mse_nontarget_target',
'mse_all_explainer', 'mse_all_target',
'pearsonr_target_explainer', 'pearsonr_target_target', 
'pearsonr_all_explainer', 'pearsonr_all_target',
'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
'spearmanr_target_explainer', 'spearmanr_target_target',
'spearmanr_all_explainer', 'spearmanr_all_target',
'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target', 
"sign_agreement_all_explainer", "sign_agreement_all_target"]].mean()#.reset_index()

In [None]:
df_temp=pd.DataFrame(metric_list_plot_target)

In [None]:
metric_list_plot_df[metric_list_plot_df["method_type"]==method_type]

In [None]:
metric_list_plot_df[metric_list_plot_df["method_type"]==method_type]#["method_name"].value_counts()

In [None]:
pd.DataFrame(metric_list_plot_reference)\
.groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    

In [None]:
Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136)(22)

In [None]:
import matplotlib as mpl
import matplotlib.colors as mcolors
def Blue_scalar_color_mapping(value, color_map, data_min, data_max):    
    norm = mcolors.Normalize(vmin=data_min, vmax=data_max)
    scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=color_map)
    return scalar_map.to_rgba(value)

# Comparison of the estimation accuracy between KernelSHAP and amortized predictions

In [None]:
# fig = plt.figure(figsize=(4*(4.3), 3)
#                 )

# box1 = gridspec.GridSpec(1, 4, hspace=0.3)
# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
#                               ]):
#     ax=plt.Subplot(fig, box1[idx1])
#     fig.add_subplot(ax)

#     plot_key=(idx1)
#     axd[plot_key]=ax  

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)


axd={}
for idx1, metric in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):

    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax     



# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
#                               ]):

#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax    


for idx1, metric in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):

        plot_key=(metric, method_type)


        if metric=="each":
            sns.barplot(
                x="mse_all",
                y="method_name",
            #     hue="method",
            #     style="AO type",
            #     style_order=["Reg-AO", "Obj-AO"],
                #palette="tab10",
                #marker='o',    
                markeredgecolor=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
                ax=axd[plot_key]
            )


            axd[plot_key].set_ylabel("Method")#, fontsize=20)
            axd[plot_key].set_xlabel("MSE (all classes)")#, fontsize=20)

            # xaxis
            # axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
            # axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))            
            # axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            # axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
    #         axd[plot_key].set_ylim(0, 0.1)

            axd[plot_key].tick_params(axis='x', which='major', rotation=-90, labelsize=20, labelright=True)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0, labelsize=10)  
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

    #         leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)  

        elif metric=="mse_all":
            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #palette="tab10",
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                





            axd[plot_key].set_title(f"Error")# ({prettify_method_type(method_type)})")


            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Error")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))  
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.5)#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))              
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0.0001, 0.015)
            axd[plot_key].set_yscale("log")


            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            


            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0) 

        elif metric=="pearsonr_all":

            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                  



            axd[plot_key].set_title(f"Correlation")# ({prettify_method_type(method_type)})")

            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Pearson Correlation")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100)) 
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            # yaxis
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            leg.remove()
    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)  

        elif metric=="spearmanr_all":

            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                 


            axd[plot_key].set_title(f"Rank correlation")# ({prettify_method_type(method_type)})")


            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Spearman Correlation")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100)) 
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            # yaxis
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)  


        elif metric=="sign_agreement_all":

            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                 


            axd[plot_key].set_title(f"Sign Agreement")# ({prettify_method_type(method_type)})")#


            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Sign Agreement")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100)) 
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            # yaxis
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()            

    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)             

# Comparison of the estimation error between noisy labels and amortized predictions for Shapley value feature attributions.

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8


fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):

        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_target_explainer', 'mse_target_target',
                                    'mse_nontarget_explainer', 'mse_nontarget_target',
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_target_explainer', 'pearsonr_target_target', 
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_target_explainer', 'spearmanr_target_target',
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()          

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]

    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-3, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-3, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            
            
            leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #leg.set_title("# Samples / Point")   
            
            for legend_text in leg.get_texts():
                try:
                    int(legend_text.get_text())
                except:
                    legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                else:
                    legend_text.set_text(f"{int(legend_text.get_text())}")
                        

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
                       
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                       
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                          
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_appendix.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_appendix.pdf", bbox_inches='tight')

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):

        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_target_explainer', 'mse_target_target',
                                    'mse_nontarget_explainer', 'mse_nontarget_target',
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_target_explainer', 'pearsonr_target_target', 
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_target_explainer', 'spearmanr_target_target',
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()          

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]              

    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-3, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-3, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            
            
            leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #leg.set_title("# Samples / Point")   
            
            for legend_text in leg.get_texts():
                try:
                    int(legend_text.get_text())
                except:
                    legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                else:
                    legend_text.set_text(f"{int(legend_text.get_text())}")
                        

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
            
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                           
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                          
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_external_appendix.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_external_appendix.pdf", bbox_inches='tight')

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8


fig = plt.figure(figsize=(4*4 + 3*0.2, 1*3)
                )

box1 = gridspec.GridSpec(1, 4, wspace=0.3)

axd={}
for idx1, method_type in enumerate(["KernelSHAP", "Permutation"]):
     
    for idx2, metric_name in enumerate(["mse_all", "pearsonr_all"]):
        #box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1])#, wspace=0.2, hspace=0.1)    

        ax=plt.Subplot(fig, box1[idx1*2+idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all",
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation"]):

        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_target_explainer', 'mse_target_target',
                                    'mse_nontarget_explainer', 'mse_nontarget_target',
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_target_explainer', 'pearsonr_target_target', 
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_target_explainer', 'spearmanr_target_target',
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()          

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                #palette=[sns.color_palette("Blues")[i] for i in [1,2,3,4,5]],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]               

    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-3, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-3, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', 
                                     bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            
            leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #leg.set_title("# Samples / Point")   
            
            for legend_text in leg.get_texts():
                try:
                    int(legend_text.get_text())
                except:
                    legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                else:
                    legend_text.set_text(f"{int(legend_text.get_text())}")
                        

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            axd[plot_key].set_title(f"{prettify_method_type(method_type)}")#, fontsize=20)
                        


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                #palette=[sns.color_palette("Blues")[i] for i in [1,2,3,4,5]],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]                             
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(f"{prettify_method_type(method_type)}")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                alpha=0.9,
                ax=axd[plot_key],
            )
                           
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Correlation with target")#, fontsize=20)
            axd[plot_key].set_xlabel("Correlation with prediction")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                alpha=0.9,
                ax=axd[plot_key],
            )
                           
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Correlation with target")#, fontsize=20)
            axd[plot_key].set_xlabel("Correlation with prediction")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_two.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_two.pdf", bbox_inches='tight')

# Comparison of the estimation error between noisy labels and amortized predictions for Banzhaf value feature attributions. 

In [None]:
metric_list_ground_truth_banzahf=[]

for num_subsets in [500*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
    metric_list_ground_truth_banzahf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                 },
                                      )    

In [None]:
metric_list_plot_reference=[]
for metric in metric_list_ground_truth_banzahf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'BanzhafMSR',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "train",
            }
        )

    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_reference.append(metric_temp)

    

metric_list_plot_target=[]
for metric in metric_list_value_banzhaf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'BanzhafMSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "train"
            }
        )  
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'BanzhafMSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "train"
            }
        )          
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
        metric_temp.update(
            {"method_name": f'BanzhafMSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "test"
            }
        )          
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        continue        
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)
    
metric_list_plot_explainer=[]    
for metric in metric_list_banzhaf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_global_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, global, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "global",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_sqrt_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])

        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, sqrt, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "sqrt",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_perinstance_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, perinstance, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "perinstance",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )   
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_perinstanceperclass_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, perinstanceperclass, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "perinstanceperclass",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )           

    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)


metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                                                           ]


metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
metric_list_plot_target_df_=metric_list_plot_target_df.copy()
metric_list_plot_target_df_["split"]="test"
idx_mapping=dict(zip(np.random.RandomState(seed=42).permutation(list(range(9469)))[:100],
list(range(100))))
metric_list_plot_target_df_["sample_idx"]=metric_list_plot_target_df_["sample_idx"].map(lambda x: idx_mapping[x])
metric_list_plot_target_df=pd.concat([metric_list_plot_target_df, metric_list_plot_target_df_])

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
# sdsd
# metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
# [['sample_idx',  "num_subsets", 
# 'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
# 'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "transform_mode" , "num_subsets"])\
[['sample_idx', "num_subsets", 
'mse_all_explainer', 'mse_all_target',
'pearsonr_all_explainer', 'pearsonr_all_target',
'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
'spearmanr_all_explainer', 'spearmanr_all_target',
'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target']].mean()#.reset_index()

In [None]:
pd.DataFrame(metric_list_plot_explainer_df)["split"].value_counts()

In [None]:
metric_list_plot_target_df["split"].value_counts()

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]              

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=10.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 10.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()

                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                          
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                      
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.set_title("# Samples / Point")   
            leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf.pdf", bbox_inches='tight')

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]               

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=10.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 10.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()

                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                           
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
                            
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.set_title("# Samples / Point")   
            leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf_external.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf_external.pdf", bbox_inches='tight')

# Comparison of the estimation error between noisy labels and amortized predictions for LIME feature attributions. 

In [None]:
metric_list_ground_truth_lime=[]

for num_subsets in [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000]:
    metric_list_ground_truth_lime+=get_ground_truth_metric_with_value(attribution_values_ground_truth=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                 },
                                      )    

In [None]:
metric_list_plot_reference=[]
for metric in metric_list_ground_truth_lime:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'LIME',
             "method_type": 'LIME',
             "antithetical": False,
             "split": "train",
            }
        )

    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_reference.append(metric_temp)

    

metric_list_plot_target=[]
for metric in metric_list_value_lime:
    metric_temp=copy.copy(metric)
    
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'LIME ({metric_temp["num_subsets"]})',
             "method_type": 'LIME',
             "antithetical": False,
             "split": "train"
            }
        )   
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'LIME ({metric_temp["num_subsets"]})',
             "method_type": 'LIME',
             "antithetical": False,
             "split": "train"
            }
        )           
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)
    
metric_list_plot_explainer=[]    
for metric in metric_list_lime:
    metric_temp=copy.copy(metric)

    
    if metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_global_{num_subsets}" for num_subsets in [128, 256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, global, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "global",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_sqrt_{num_subsets}" for num_subsets in [128, 256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, sqrt, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "sqrt",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )   
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_perinstance_{num_subsets}" for num_subsets in [256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, perinstance, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "perinstance",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )     
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_perinstanceperclass_{num_subsets}" for num_subsets in [256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, perinstanceperclass, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "perinstanceperclass",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )          

    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)


metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                                                           ]

# metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)

metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
metric_list_plot_target_df_=metric_list_plot_target_df.copy()
metric_list_plot_target_df_["split"]="test"
idx_mapping=dict(zip(np.random.RandomState(seed=42).permutation(list(range(9469)))[:100],
list(range(100))))
metric_list_plot_target_df_["sample_idx"]=metric_list_plot_target_df_["sample_idx"].map(lambda x: idx_mapping[x])
metric_list_plot_target_df=pd.concat([metric_list_plot_target_df, metric_list_plot_target_df_])

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
# sdsd
# metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
# [['sample_idx',  "num_subsets", 
# 'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
# 'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "transform_mode" , "num_subsets"])\
[['sample_idx', "num_subsets", 
'mse_all_explainer', 'mse_all_target',
'pearsonr_all_explainer', 'pearsonr_all_target',
'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
'spearmanr_all_explainer', 'spearmanr_all_target',
'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target']].mean()#.reset_index()

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]                

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:                
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()                
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]                             
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
                            
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
                          
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_lime.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_lime.pdf", bbox_inches='tight')

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]                

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:                
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()                
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]                            
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
                          
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
                           
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_lime_external.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_lime_external.pdf", bbox_inches='tight')

# Estimation accuracy for amortization and KernelSHAP with different dataset sizes given equivalent compute.

In [None]:
metric_list_ground_truth_flops=[]

In [None]:
for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
    metric_list_ground_truth_flops+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()
                                      
                                      ))

In [None]:
# flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
# flops_forward = 38_898_221_568 * 1 * metric_temp["epoch"] * num_train_sample
# flops_backward = 38_898_221_568 * 2 * metric_temp["epoch"] * num_train_sample    
# flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify

In [None]:
num_train_sample=9469

metric_list_plot=[]
for metric in metric_list_ground_truth_flops:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        
        flops_subset_eval = 17_563_067_904 * metric_temp["num_subsets"] * num_train_sample
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
             "flops": flops_subset_eval,
            }
        )
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

for metric in metric_list_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
            
        flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
        flops_forward = 21_335_153_664 * 1 * metric_temp["epoch"] * num_train_sample
        flops_backward = 21_335_153_664 * 2 * metric_temp["epoch"] * num_train_sample    
        flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify
                

        metric_temp.update(
            {"method_name": f'{num_subsets}',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": split,
             "flops":  flops_subset_eval + flops_forward + flops_backward + flops_parameter_update,
            }
        )    
    
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
        flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
        flops_forward = 21_335_153_664 * 1 * metric_temp["epoch"] * num_train_sample
        flops_backward = 21_335_153_664 * 2 * metric_temp["epoch"] * num_train_sample    
        flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify            

        metric_temp.update(
            {"method_name": f'Reg-AO (Permutation, {num_subsets})',
             "method_type": 'Permutation',
             "antithetical": False,
             "split": split,
             "flops":  flops_subset_eval + flops_forward + flops_backward + flops_parameter_update,
            }
        )               
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_{num_subsets}" for num_subsets in [512, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue                
                
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue   
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue 
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue  
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_1024_numtrain_4735",
                                       f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2048_numtrain_2367",
                                       f"logs/vitbase_imagenette_shapley_regexplainer_upfront_3072_numtrain_1578"
                                      
                                      ] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue          
        

    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)    
    
    

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

metric_list_plot_df_epoch=metric_list_plot_df

In [None]:
shapley_loaded_dict.keys()

In [None]:
metric_list_flops=[]

for num_train in [100, 250, 500, 1000, 2000, 5000]:

    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2257_numtrain_{num_train}"

    checkpoint_path_list=sorted(glob.glob(model_path+f"/checkpoint-{int(get_best_model_checkpoint(model_path).split('-')[-1])}"), key=lambda x: int(x.split('-')[-1]))

    
    for checkpoint_path in tqdm(checkpoint_path_list[:50]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_flops+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })


        metric_list_flops+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })        

In [None]:
# ## temp

# model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2048"

# checkpoint_path_list=sorted(glob.glob(model_path+f"/checkpoint-{int(get_best_model_checkpoint(model_path).split('-')[-1])}"), key=lambda x: int(x.split('-')[-1]))


# for checkpoint_path in tqdm(checkpoint_path_list[:50]):
#     checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
#     with open(checkpoint_path+"/trainer_state.json") as f:
#         checkpoint_trainer_state = json.load(f)

#     regexplainer.load_state_dict(checkpoint_state_dict)

#     metric_list_flops+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
#                             explainer=regexplainer,
#                             dataset=dataset_explainer["test"],
#                             iters_ground_truth=999424,
#                             meta_info={
#                                        "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
#                                        "model_path": model_path,
#                                        "epoch": int(checkpoint_trainer_state["epoch"]),
#                                        "is_best_checkpoint": 
#                                         compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
#                                                                  best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
#                                       })


#     metric_list_flops+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
#                             explainer=regexplainer,
#                             dataset=dataset_explainer["train"],
#                             iters_ground_truth=999424,
#                             meta_info={
#                                        "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
#                                        "model_path": model_path,
#                                        "epoch": int(checkpoint_trainer_state["epoch"]),
#                                        "is_best_checkpoint": 
#                                         compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
#                                                                  best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
#                                       })        

In [None]:
metric_list_plot_explainer=[] 
for metric in metric_list_flops:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2257_numtrain_{num_train}" for num_train in [100, 250, 500, 1000, 2000, 5000]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_train=int(metric_temp["model_path"].split('_')[-1])
        if num_train<250:
            continue
        
        num_subsets=int(metric_temp["model_path"].split('_')[-3])
        num_subsets_per_sample={2257:2440}[num_subsets]
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
                

        metric_temp.update(
            {"method_name": f'Reg-AO (KernelSHAP, {num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "num_train":num_train, 
             "num_subsets": num_subsets,
             "num_subsets_per_sample":num_subsets_per_sample,
             "split": split,
            }
        )    
        
        #metric_list_plot.append(metric_temp)
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2048"] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        
        
        num_train=9469
        
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        num_subsets_per_sample={2257:2440, 2048:2440}[num_subsets]
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
                

        metric_temp.update(
            {"method_name": f'Reg-AO (KernelSHAP, {num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "num_train":num_train, 
             "num_subsets": num_subsets,
             "num_subsets_per_sample":num_subsets_per_sample,
             "split": split,
            }
        )    
        
        #metric_list_plot.append(metric_temp)        
        
    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
             target_subset_size=2440,
                  attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
metric_list_value_flops_matched=[]
num_subsets=2440
for i in range(1):
    shapley_loaded_dict_temp={}
    for sample_idx, tracking_dict in shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440"].items():
        shapley_loaded_dict_temp[sample_idx]=tracking_dict[i]

    metric_list_value_flops_matched+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict_temp,
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440",
                                                 }) 

metric_list_plot_target=[]
for metric in metric_list_value_flops_matched:
    metric_temp=copy.copy(metric)

    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train"
            }
        )    
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)

In [None]:
metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")|
                                                            (metric_list_plot_explainer_df["method_type"]=="KernelSHAP")
                                                           
                                                           ]

metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets_per_sample", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets_per_sample", "num_train"])[[\
        'sample_idx',  "num_subsets_per_sample", 
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

In [None]:
metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples.pdf", bbox_inches='tight')

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 


fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        
        
plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])          

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
        
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )            
            

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #, fontsize=20
            
            
            axd[plot_key].set_xscale("log")
            axd[plot_key].set_yscale("log")
            
            # xaxis
#             axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
#             axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
#             axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, .045)
#             axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0,  labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        
        
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name("mse_all"))#, fontsize=20)
        
            from matplotlib.lines import Line2D

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1, linestyle="-")
                           ]

            axd[plot_key].legend(custom_lines, ['Amortized', "KernelSHAP"], 
                                 #ncols=2,
                                 loc='best', 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0, 1, 0.85)
                                )
            #bbox_to_anchor=(-0.8, -0.32, 3, 0),             


        elif metric_name=="pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()            
        

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )

            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)", ) #fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor',alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=2, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)


        elif metric_name=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Spearman Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
    
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)

        
        elif metric_name=="sign_agreement_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',
                                                    "sign_agreement_all"]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class',
                                    'sign_agreement_all']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Sign Agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor',  alpha=0.1) #linewidth=1,
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        

            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)        

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples_logscale.png", bbox_inches='tight', dpi=600)
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples_logscale.pdf", bbox_inches='tight')

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 


fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        
        
plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
 (0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])          

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )            
            

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #, fontsize=20
            
            axd[plot_key].set_xscale("log")
            axd[plot_key].set_yscale("log")

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3, right=5000)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, .045)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0,  labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        
        
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name("mse_all"))#, fontsize=20)
        
            from matplotlib.lines import Line2D

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1, linestyle="-")]

            axd[plot_key].legend(custom_lines, ['Amortized', 'KernelSHAP'], 
                                 #ncols=2,
                                 loc='best', 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0, 1, 0.85)
                                )
            #bbox_to_anchor=(-0.8, -0.32, 3, 0),             


        elif metric_name=="pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()            
        

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )

            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)", ) #fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor',alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3, right=5000)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=2, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)


        elif metric_name=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Spearman Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3, right=5000)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
    
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)

        
        elif metric_name=="sign_agreement_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',
                                                    "sign_agreement_all"]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class',
                                    'sign_agreement_all']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Sign Agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3, right=5000)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor',  alpha=0.1) #linewidth=1,
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        

            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)        

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples_external_logscale.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples_external_logscale.pdf", bbox_inches='tight')

# Error of amortization and KernelSHAP as a function of FLOPs.

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                                                 for i in [512, 1024, 2048, 3072]]+[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])

fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric, method_type)
        
        if metric=="MSE_all":
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]

            sns.scatterplot(
            x="flops",
            y="mse_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["mse_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            

                        
            
            sns.lineplot(
                x="flops",
                y="mse_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072', "KernelSHAP"],    
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            
            
           



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(top=0.08, bottom=0.00008)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)   #labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Error')

            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)            
            
            leg=axd[plot_key].legend(loc='center', 
                                     bbox_to_anchor=(-3.0, -0.35, 3, 0),
                                     ncols=4,
                                     #bbox_to_anchor=(1.0, 0, 0.5, 1)
                                    )#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
        
            handles, labels = axd[plot_key].get_legend_handles_labels()
            
            axd[plot_key].add_artist(leg)
            
            leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '2048', 'KernelSHAP', '1024','3072']],  #512 2048, KernelSHAP, 1024, 3072
                                     #labels=['512 Samples', '1024 Samples', '2048 Samples', '3072 Samples','KernelSHAP'], 
                                     labels=['512', '2048', 'KernelSHAP', '1024', '3072'], 
                                 loc='upper left', 
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3))
#                                  columnspacing=1,
#                                  #loc='best',
#                                  #bbox_to_anchor=(0, 0, 1, 1),                                     
#                                  ncols=2,)
            leg.remove()
        
        


            
            
            from matplotlib.lines import Line2D
            import matplotlib.lines as mlines
            from matplotlib.legend_handler import HandlerBase

            class CustomHandler(HandlerBase):
                def create_artists(self, legend, orig_handle, x0, y0, width, height, fontsize, trans):
                    if orig_handle.get_color()==(0.8666666666666667, 0.5176470588235295, 0.3215686274509804):
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linestyle='-', color=orig_handle.get_color())                        
                        return [line_o]
                    else:
                        # Create a line with the 'o' marker
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linewidth=0.8,
                                               linestyle='-', color=orig_handle.get_color())
                        # Create a line with the 'X' marker
                        marker_o = mlines.Line2D([x0 + width], [y0 + height/2.], 
                                               linestyle='', color=orig_handle.get_color(), marker='o', markersize=5)
#                         marker_x = mlines.Line2D([x0], [y0 + height/2.], 
#                                                linestyle='', color=orig_handle.get_color(), marker='X', markersize=5)                    
                        return [line_o, marker_o]            

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][2], lw=1),
                            Line2D([0], [0], color=(0.8666666666666667, 0.5176470588235295, 0.3215686274509804), lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][3], lw=1),
                            ]

            axd[plot_key].legend(custom_lines, ['512', '2048','KernelSHAP', '1024', '3072', ], 
                                 ncols=2,
                                 loc='upper left', 
                                 handler_map={mlines.Line2D: CustomHandler()}, 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3),
                                 columnspacing=0.5,
                                )   
        
        
        
        elif metric=="pearsonr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="pearsonr_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
           

            sns.lineplot(
                x="flops",
                y="pearsonr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
            
            

            

        elif metric=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="spearmanr_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["spearmanr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
                      

            sns.lineplot(
                x="flops",
                y="spearmanr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Spearman Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Rank Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
            
            
            
        elif metric=="sign_agreement_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="sign_agreement_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["sign_agreement_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
          

            sns.lineplot(
                x="flops",
                y="sign_agreement_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Sign Agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Sign Agrement')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            



            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

In [None]:
fig.savefig("logs/plots/"+f"flops_matched_appendix.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"flops_matched_appendix.pdf", bbox_inches='tight')

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 
plt.rcParams['axes.prop_cycle']=plt.cycler(color=[sns.color_palette("Blues")[i] for i in [1,2,3,4]]+\
                                           #[(0,0,0)]
                                            [(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
                                          ) 

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                                                 for i in [512, 1024, 2048, 3072]]+[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])

fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["epoch_MSE_all", "epoch_pearsonr_all", "trainsamples_MSE_all", "trainsamples_pearsonr_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["epoch_MSE_all", "epoch_pearsonr_all", "trainsamples_MSE_all", "trainsamples_pearsonr_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric, method_type)
        
        if metric=="epoch_MSE_all":
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]

            sns.scatterplot(
            x="flops",
            y="mse_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["mse_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
                        
            
            sns.lineplot(
                x="flops",
                y="mse_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072', "KernelSHAP"],    
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            
            
           



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(top=0.08, bottom=0.00008)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)   #labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Error')

            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)            
            
            leg=axd[plot_key].legend(loc='center', 
                                     bbox_to_anchor=(-3.0, -0.35, 3, 0),
                                     ncols=4,
                                     #bbox_to_anchor=(1.0, 0, 0.5, 1)
                                    )#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
             
            handles, labels = axd[plot_key].get_legend_handles_labels()
            
#             leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '1024', '2048', '3072']], 
#                                      labels=['512', '1024', '2048', '3072'], 
#                                  loc='upper left', 
#                                  bbox_to_anchor=(0,0,1,1),
#                                  ncols=1,) 
            #leg.set_title("Amortized (# Samples / Point)")
            axd[plot_key].add_artist(leg)
            # Adding the text to the left of the first legend
#             x_offset = -2.6  # Adjust this value as needed to position the text
#             y_offset = -0.25  # Adjust this value as needed for vertical positioning
#             axd[plot_key].text(x_offset, y_offset, "Amortized (# Samples / Point)", transform=axd[plot_key].transAxes, 
#                                verticalalignment='top', horizontalalignment='left')
            
            leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '2048', 'KernelSHAP', '1024','3072']],  #512 2048, KernelSHAP, 1024, 3072
                                     #labels=['512 Samples', '1024 Samples', '2048 Samples', '3072 Samples','KernelSHAP'], 
                                     labels=['512', '2048', 'KernelSHAP', '1024', '3072'], 
                                 loc='upper left', 
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3))
#                                  columnspacing=1,
#                                  #loc='best',
#                                  #bbox_to_anchor=(0, 0, 1, 1),                                     
#                                  ncols=2,)
            leg.remove()
        
        


            
            
            from matplotlib.lines import Line2D
            import matplotlib.lines as mlines
            from matplotlib.legend_handler import HandlerBase

            class CustomHandler(HandlerBase):
                def create_artists(self, legend, orig_handle, x0, y0, width, height, fontsize, trans):
                    if orig_handle.get_color()==(0.8666666666666667, 0.5176470588235295, 0.3215686274509804):
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linestyle='-', color=orig_handle.get_color())                        
                        return [line_o]
                    else:
                        # Create a line with the 'o' marker
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linewidth=0.8,
                                               linestyle='-', color=orig_handle.get_color())
                        # Create a line with the 'X' marker
                        marker_o = mlines.Line2D([x0 + width], [y0 + height/2.], 
                                               linestyle='', color=orig_handle.get_color(), marker='o', markersize=5)
#                         marker_x = mlines.Line2D([x0], [y0 + height/2.], 
#                                                linestyle='', color=orig_handle.get_color(), marker='X', markersize=5)                    
                        return [line_o, marker_o]            

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][2], lw=1),
                            Line2D([0], [0], color=(0.8666666666666667, 0.5176470588235295, 0.3215686274509804), lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][3], lw=1),
                            ]

            axd[plot_key].legend(custom_lines, ['512', '2048','KernelSHAP', '1024', '3072', ], 
                                 ncols=2,
                                 loc='upper left', 
                                 handler_map={mlines.Line2D: CustomHandler()}, 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3),
                                )            

        elif metric=="epoch_pearsonr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="pearsonr_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
                   

            sns.lineplot(
                x="flops",
                y="pearsonr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()



        elif metric=="trainsamples_MSE_all":
#             plt.rcParams['axes.prop_cycle']=plt.cycler(color=\
#                 [(0,0,1),
#                  (0,0,0)])        
            plt.rcParams['axes.prop_cycle']=plt.cycler(color=[sns.color_palette("Blues")[i] for i in [4]]+\
                                                       #[(0,0,0)]
                                                       [(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
                                                      ) 
    
            plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(2257, color_map=plt.cm.Blues, data_min=-500, data_max=3136)]+\
                                                       [(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
                                                      )     
    
    
            
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',]
                                                    ].mean().reset_index()
            sns.lineplot(
                x="num_train", 
                y="mse_all", 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0],
                #linewidth=3,
                #palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc["mse_all"], 
                                 linestyle="-",
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )            
            

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #, fontsize=20
            
            axd[plot_key].set_xscale("log")
            axd[plot_key].set_yscale("log")

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, .045)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0,  labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name("mse_all"))#, fontsize=20)
        
            from matplotlib.lines import Line2D

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1, linestyle="-")]

            axd[plot_key].legend(custom_lines, ['Amortized', 'KernelSHAP'], 
                                 #ncols=2,
                                 loc='best', 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0, 1, 0.85)
                                )
            #bbox_to_anchor=(-0.8, -0.32, 3, 0),           


        elif metric=="trainsamples_pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()            
        

            sns.lineplot(
                x="num_train", 
                y="pearsonr_all", 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                #palette="Set2",
                color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0],
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )

            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc["pearsonr_all"], 
                                 linestyle="-",                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor',alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=2, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            axd[plot_key].set_title(prettify_metric_name("pearsonr_all"))#, fontsize=20)
        
         #
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_compute.pdf", bbox_inches='tight')

#  Comparison between stochastic amortization, FastSHAP and KernelSHAP as a function of total FLOPs.

In [None]:
metric_list_ground_truth_flops=[]
for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
    metric_list_ground_truth_flops+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()
                                      
                                      ))

In [None]:
# flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
# flops_forward = 38_898_221_568 * 1 * metric_temp["epoch"] * num_train_sample
# flops_backward = 38_898_221_568 * 2 * metric_temp["epoch"] * num_train_sample    
# flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify

In [None]:
num_train_sample=9469

metric_list_plot=[]

for metric in metric_list_ground_truth_flops:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        
        flops_subset_eval = 17_563_067_904 * metric_temp["num_subsets"] * num_train_sample
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
             "flops": flops_subset_eval,
            }
        )
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

for metric in metric_list_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
        flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
        flops_forward = 21_335_153_664 * 1 * metric_temp["epoch"] * num_train_sample
        flops_backward = 21_335_153_664 * 2 * metric_temp["epoch"] * num_train_sample    
        flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify            

        metric_temp.update(
            {"method_name": f'Reg-AO ({num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": split,
             "flops":  flops_subset_eval + flops_forward + flops_backward + flops_parameter_update,
            }
        )  
        
    
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
        flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
        flops_forward = 21_335_153_664 * 1 * metric_temp["epoch"] * num_train_sample
        flops_backward = 21_335_153_664 * 2 * metric_temp["epoch"] * num_train_sample    
        flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify            

        metric_temp.update(
            {"method_name": f'Amortized ({num_subsets})',
             "method_type": 'Permutation',
             "antithetical": False,
             "split": split,
             "flops":  flops_subset_eval + flops_forward + flops_backward + flops_parameter_update,
            }
        )               
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_{num_subsets}" for num_subsets in [512, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue                
                
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue   
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue 
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue  
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_1024_numtrain_4735",
                                       f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2048_numtrain_2367",
                                       f"logs/vitbase_imagenette_shapley_regexplainer_upfront_3072_numtrain_1578",
                                      
                                      ] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue          
        

    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)    
    
    
for metric in metric_list_shapley_obj:
    metric_temp=copy.copy(metric)
    
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_32"] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
        flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets * metric_temp["epoch"]
        flops_forward = 38_898_221_568 * 1 * metric_temp["epoch"] * num_train_sample
        flops_backward = 38_898_221_568 * 2 * metric_temp["epoch"] * num_train_sample    
        flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify            

        metric_temp.update(
            {"method_name": f'Obj-AO (KernelSHAP)',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": split,
             "flops":  flops_subset_eval + flops_forward + flops_backward + flops_parameter_update,
            }
        )               
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_{num_subsets}" for num_subsets in [512, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue                
                
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue   
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue 
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue  
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_1024_numtrain_4735",
                                       f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2048_numtrain_2367",
                                       f"logs/vitbase_imagenette_shapley_regexplainer_upfront_3072_numtrain_1578"
                                      ] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue          
        

    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)       
    

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df_epoch_reg_obj=metric_list_plot_df

In [None]:
metric_list_plot_df["method_name"].value_counts()

In [None]:
metric_list_plot_df_filtered=metric_list_plot_df[metric_list_plot_df["method_name"]=="Obj-AO (KernelSHAP)"]


metric_list_plot_df_filtered[metric_list_plot_df_filtered["is_best_checkpoint"]=="best"]

In [None]:
Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3500) 
                                                 for i in [196, 392, 588, 1176, 3136]]+[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])

fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["Permutation"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["Permutation"]):

        plot_key=(metric, method_type)
        
        if metric=="MSE_all":
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="best")\
                                                          ]
            
                    

            sns.scatterplot(
            x="flops",
            y="mse_all",
            hue="method_name",
            hue_order=[
                     'Amortized (196)',
                     'Amortized (392)',
                     'Amortized (588)',
                     'Amortized (1176)',
                    'Amortized (3136)'],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["mse_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
         
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_name"]=="KernelSHAP") \
                                                           #(metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")"before")\
                                                          ] 
            
            
            
            sns.lineplot(
                x="flops",
                y="mse_all",
#                 hue="method_name",
#                 hue_order=["KernelSHAP"],  
                color=(0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select[metric_list_plot_df_select["flops"]<1e+18],
                ax=axd[plot_key]
            )    
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_name"]=="Obj-AO (KernelSHAP)")
                                                           #(metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="before")\
                                                          ] 
            
            
            sns.lineplot(
                x="flops",
                y="mse_all",
#                 hue="method_name",
#                 hue_order=["Obj-AO (KernelSHAP)"],    
                color=(0.7,0.7, 0.7),
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )                
            

            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            #axd[plot_key].set_xlim(1e+16, 1e+18)
            axd[plot_key].set_xlim(1e+16, 2e+18)
            axd[plot_key].set_xscale('log')

#             axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
#             axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            #axd[plot_key].set_ylim(top=0.01, bottom=0.0005)
            #axd[plot_key].set_ylim(top=0.08, bottom=0.00008)
            #axd[plot_key].set_ylim(top=0.45, bottom=0.00008)
            axd[plot_key].set_ylim(top=0.45, bottom=0.000035)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)   #labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Error')

            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)              
            
            
            from matplotlib.lines import Line2D
            import matplotlib.lines as mlines
            from matplotlib.legend_handler import HandlerBase

            class CustomHandler(HandlerBase):
                def create_artists(self, legend, orig_handle, x0, y0, width, height, fontsize, trans):
                    print(orig_handle.get_color())
                    
                    if orig_handle.get_color()==(0.8666666666666667, 0.5176470588235295, 0.3215686274509804) or\
                    orig_handle.get_color()==(0.7, 0.7, 0.7):
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linestyle='-', color=orig_handle.get_color())                        
                        return [line_o]
                    else:
                        print('marker')
                        marker_o = mlines.Line2D([x0 + width], [y0 + height/2.], 
                                               linestyle='', color=orig_handle.get_color(), marker='o', markersize=5)
#                         marker_x = mlines.Line2D([x0], [y0 + height/2.], 
#                                                linestyle='', color=orig_handle.get_color(), marker='X', markersize=5)                    
                        return [marker_o]
                        


#             plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3500) 
#                                                              for i in [196, 392, 588, 1176, 3136]]+[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])



#                 [Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3500) 
#                                                                              for i in [196, 392, 588, 1176, 3136]]

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][2], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][3], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][4], lw=1),                            
                            Line2D([0], [0], color=(0.8666666666666667, 0.5176470588235295, 0.3215686274509804), lw=1),
                            Line2D([0], [0], color=(0.7,0.7,0.7), lw=1),
                            ]
    
#             196 KernelSHAP 392 FastSHAP 588 1176 3136
#             np.array(custom_lines)[[0,5,1,6,2,3,4]].tolist()
            axd[plot_key].legend(np.array(custom_lines)[
                #[0,5,1,6,2,3,4]
                [0,1,2,3,4,5,6]
                                                       ].tolist(), 
                                 np.array(['Amortized (196)', 
                                                'Amortized (392)',
                                                'Amortized (588)',
                                                'Amortized (1176)', 
                                                'Amortized (3136)', 
                                                "KernelSHAP",
                                                "FastSHAP"])[
                                     
                                     #[0,5,1,6,2,3,4]
                                 [0,1,2,3,4,5,6]
                                 ].tolist(), 
                                 ncols=2,
                                 loc='upper left', 
                                 handler_map={mlines.Line2D: CustomHandler()}, 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3),
                                 fontsize="x-small",
                                 columnspacing=0.5,
                                )   
        
        
        
        elif metric=="pearsonr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_type"]==method_type) &
                                                           (metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="best")\
                                                          ]
            
                    

            sns.scatterplot(
            x="flops",
            y="pearsonr_all",
            hue="method_name",
            hue_order=[
                     'Amortized (196)',
                     'Amortized (392)',
                     'Amortized (588)',
                     'Amortized (1176)',
                    'Amortized (3136)',
            "KernelSHAP"],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            

   
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_name"]=="KernelSHAP")
                                                           #(metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="before")\
                                                          ] 
            
            
            
            sns.lineplot(
                x="flops",
                y="pearsonr_all",
#                 hue="method_name",
#                 hue_order=["KernelSHAP"],  
                color=(0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select[metric_list_plot_df_select["flops"]<1e+18],
                ax=axd[plot_key]
            )    
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_name"]=="Obj-AO (KernelSHAP)")
                                                           #(metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="before")\
                                                          ] 
            
            
            sns.lineplot(
                x="flops",
                y="pearsonr_all",
#                 hue="method_name",
#                 hue_order=["Obj-AO (KernelSHAP)"],    
                color=(0.7,0.7, 0.7),
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )                



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 1e+18)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
            
            

            

        elif metric=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]
            
                    

            sns.scatterplot(
            x="flops",
            y="spearmanr_all",
            hue="method_name",
            hue_order=[
                     'Reg-AO (196)',
                     'Reg-AO (392)',
                     'Reg-AO (588)',
                     'Reg-AO (1176)',
                    'Reg-AO (3136)'],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["spearmanr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_name"]=="Obj-AO (KernelSHAP)")&\
                                                           (metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]            
            sns.lineplot(
            x="flops",
            y="spearmanr_all",
                alpha=0.8,     
                linewidth=1.5,
                                color=(0.7,0.7, 0.7),
            data=metric_list_plot_df_select.groupby(["epoch"]).apply(lambda x: (x.groupby("flops")["spearmanr_all"].mean().sort_index().reset_index().iloc[-1])).reset_index(),
            ax=axd[plot_key],
            ) 

            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Spearman Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 1e+18)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Rank Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
            
            
            
        elif metric=="sign_agreement_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]
            
                    

            sns.scatterplot(
            x="flops",
            y="sign_agreement_all",
            hue="method_name",
            hue_order=[
                     'Reg-AO (196)',
                     'Reg-AO (392)',
                     'Reg-AO (588)',
                     'Reg-AO (1176)',
                    'Reg-AO (3136)'],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["sign_agreement_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
            
            metric_list_plot_df_select=metric_list_plot_df_epoch_reg_obj[(metric_list_plot_df_epoch_reg_obj["split"]=="train")&\
                                                           (metric_list_plot_df_epoch_reg_obj["method_name"]=="Obj-AO (KernelSHAP)")&\
                                                           (metric_list_plot_df_epoch_reg_obj["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]            
            sns.lineplot(
            x="flops",
            y="sign_agreement_all",
                alpha=0.8,     
                linewidth=1.5,
                                color=(0.7,0.7, 0.7),
            data=metric_list_plot_df_select.groupby(["epoch"]).apply(lambda x: (x.groupby("flops")["sign_agreement_all"].mean().sort_index().reset_index().iloc[-1])).reset_index(),
            ax=axd[plot_key],
            ) 
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Sign Agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 1e+18)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Sign Agrement')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            



            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

In [None]:
fig.savefig("logs/plots/"+f"reg_vs_obj_compute.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"reg_vs_obj_compute.pdf", bbox_inches='tight')

In [None]:
np.array(custom_lines)[[0,5,1,6,2,3,4]].tolist()

# qualitative evaluation plot

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import matplotlib as mpl

def plot_figure(explainer, dataset, sample_idx_list):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    
    explainer_output=[]
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        explainer.eval()
        with torch.no_grad():
            explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
            explanation=explanation["logits"][0]
        explainer_output.append(explanation.detach().cpu().numpy())

        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                explainer.eval()
                with torch.no_grad():
                    explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
                    explanation=explanation["logits"][0]
                #explainer_output.append(explanation.detach().cpu().numpy())
                #print(explanation.shape)
                if len(explanation.shape)==2:
                    explanation_class=explanation[plot_type].detach().cpu().numpy()
                else:
                    explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=mpl.colormaps['Greys'](1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
    return fig, explainer_output

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib import cm

def plot_figure_attribution(dataset, sample_idx_list, attribution_value, attribution_value_key):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{id2label[str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key][:,plot_type]
                
                #print(explanation_class.shape)

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{id2label[str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
    return fig           

In [None]:
import matplotlib as mpl
mpl.rc('text', usetex=False)
mpl.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib import cm

def plot_target_prediction_groundtruth(dataset, 
                                       sample_idx_list,
                                       attribution_value,
                                       attribution_value_key,
                                       explainer,
                                       attribution_value_groundtruth,
                                       attribution_value_key_groundtruth                                 
                                      
                                      ):
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image", "target", "prediction", "groundtruth"])+0.05*len(["empty"])), 
                              1.53*len(sample_idx_list) + 0.06* (len(sample_idx_list)-1)  ))
    
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+ ["target", "prediction", "groundtruth"]  ), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.05]+[1, 1, 1])  
    
    
    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+ ["target", "prediction", "groundtruth"]):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.0)
        
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
                
                
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+["target", "prediction", "groundtruth"]):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                class_name_={'0': 'Tench',
                 '1': 'English springer',
                 '2': 'Cassette player',
                 '3': 'Chain saw',
                 '4': 'Church',
                 '5': 'French horn',
                 '6': 'Garbage truck',
                 '7': 'Gas pump',
                 '8': 'Golf ball',
                 '9': 'Parachute'}[str(label)]
                if idx1==0:
                    axd[plot_key].set_title("Context\n$b$", pad=7, zorder=10)
                axd[plot_key].set_ylabel(class_name_)
            elif plot_type=="empty":
                pass
            elif plot_type=="target":  
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                #print({n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])})
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key][:,label]
                

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                #axd[plot_key].set_title(f"{id2label[str(plot_type)]}")
                #axd[plot_key].set_title(r'$\tilde{a}(b)$')
                if idx1==0:
                    axd[plot_key].set_title("Noisy label\n"+r'$\tilde{a}(b)$')                                
                
            elif plot_type=="prediction":  
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                #print({n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])})
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key][:,label]
                    
                explainer.eval()
                with torch.no_grad():
                    explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
                    explanation=explanation["logits"][0]
                if len(explanation.shape)==2:
                    explanation_class=explanation[label].detach().cpu().numpy()
                else:
                    explanation_class=explanation.detach().cpu().numpy()                    
                    
                print(explanation.shape)

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                #axd[plot_key].set_title(f"{id2label[str(plot_type)]}")                
                if idx1==0:
                    axd[plot_key].set_title("Prediction\n"+r'$a(b;\theta)$')                
                
            elif plot_type=="groundtruth":  
                
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                #print({n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])})
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value_groundtruth[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key_groundtruth][:,label]

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                #axd[plot_key].set_title(f"{id2label[str(plot_type)]}")
                if idx1==0:
                    axd[plot_key].set_title("Ground Truth\n"+r'$a(b)$')

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)     
    return fig

In [None]:
fig=plot_target_prediction_groundtruth(
    dataset=dataset_explainer["train"],
    sample_idx_list=[774, 3772, 3418, 7132, 2183, 683, 6310], # 6310
    attribution_value=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train'],
    attribution_value_key=512,
    explainer=regexplainer,
    attribution_value_groundtruth=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train'],
    attribution_value_key_groundtruth=1000000,
)

In [None]:
fig.savefig("logs/plots/"+f"shapley_qualitative.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_qualitative.pdf", bbox_inches='tight')