In [None]:
import os
import sys

os.chdir('../')

In [None]:
!gpustat

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES']="3"

In [None]:
sys.argv=["train_objexplainer.py", "configs/vitbase_imagenette_shapley_objexplainer_newsample_32.json"]

In [None]:
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import copy
import json
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

import evaluate
import ipdb
import numpy as np
import torch
import transformers
from datasets import load_dataset
from torch.nn import functional as F
from tqdm import tqdm
from transformers import (
    MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoImageProcessor,
    AutoModelForImageClassification,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from arguments import DataTrainingArguments, ExplainerArguments, SurrogateArguments
from models import (
    ObjExplainerForImageClassification,
    ObjExplainerForImageClassificationConfig,
    SurrogateForImageClassificationConfig,
)
from utils import (
    MaskDataset,
    configure_dataset,
    generate_mask,
    get_checkpoint,
    get_image_transform,
    load_attribution,
    log_dataset,
    read_eval_results,
    setup_dataset,
)

""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.32.0.dev0")

require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/pytorch/image-classification/requirements.txt",
)

MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class OtherArguments:
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )

    train_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    validation_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    test_subsets_cache_path: str = field(
        default=None,
        metadata={
            "help": "Where to load the downloaded dataset.",
        },
    )
    train_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for train",
        },
    )

    validation_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for validation",
        },
    )

    test_mask_mode: str = field(
        default="incremental,1",
        metadata={
            "help": "mask mode for test",
        },
    )

In [None]:
########################################################
# Parse arguments
#######################################################
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser(
    (
        SurrogateArguments,
        ExplainerArguments,
        DataTrainingArguments,
        TrainingArguments,
        OtherArguments,
    )
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
    (
        surrogate_args,
        explainer_args,
        data_args,
        training_args,
        other_args,
    ) = parser.parse_args_into_dataclasses()

########################################################
# Setup logging
#######################################################
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

########################################################
# Correct cache dir if necessary
########################################################
if not os.path.exists(
    os.sep.join((data_args.dataset_cache_dir).split(os.sep, 2)[:2])
):
    if os.path.exists("/data2"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/data2"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    elif os.path.exists("/sdata"):
        data_args.dataset_cache_dir = os.sep.join(
            ["/sdata"] + (data_args.dataset_cache_dir).split(os.sep, 2)[2:]
        )
        logger.info(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found, using {data_args.dataset_cache_dir}"
        )
    else:
        raise ValueError(
            f"dataset_cache_dir {data_args.dataset_cache_dir} not found"
        )

########################################################
# Set seed before initializing model.
########################################################
set_seed(training_args.seed)

########################################################
# Initialize our dataset and prepare it for the 'image-classification' task.
########################################################
dataset_original, labels, label2id, id2label = setup_dataset(
    data_args=data_args, other_args=other_args
)

########################################################
# Initialize explainer model
########################################################

explainer_config = AutoConfig.from_pretrained(
    explainer_args.explainer_config_name
    or explainer_args.explainer_model_name_or_path,
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
    finetuning_task="image-classification",
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

if os.path.isfile(
    f"{explainer_args.explainer_model_name_or_path}/config.json"
) and (
    json.loads(
        open(f"{explainer_args.explainer_model_name_or_path}/config.json").read()
    )["architectures"][0]
    == "ObjExplainerForImageClassification"
):
    explainer = ObjExplainerForImageClassification.from_pretrained(
        explainer_args.explainer_model_name_or_path,
        from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
        config=explainer_config,
        cache_dir=explainer_args.explainer_cache_dir,
        revision=explainer_args.explainer_model_revision,
        token=other_args.token,
        ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )
else:
    surrogate_config = AutoConfig.from_pretrained(
        surrogate_args.surrogate_config_name
        or surrogate_args.surrogate_model_name_or_path,
        num_labels=len(labels),
        label2id=label2id,
        id2label=id2label,
        finetuning_task="image-classification",
        cache_dir=surrogate_args.surrogate_cache_dir,
        revision=surrogate_args.surrogate_model_revision,
        token=other_args.token,
    )
    surrogate_for_image_classification_config = SurrogateForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    )

    explainer_for_image_classification_config = ObjExplainerForImageClassificationConfig(
        surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
        surrogate_config=surrogate_for_image_classification_config,
        surrogate_from_tf=bool(
            ".ckpt" in surrogate_args.surrogate_model_name_or_path
        ),
        surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
        surrogate_revision=surrogate_args.surrogate_model_revision,
        surrogate_token=other_args.token,
        surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
        explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
        explainer_config=explainer_config,
        explainer_from_tf=bool(
            ".ckpt" in explainer_args.explainer_model_name_or_path
        ),
        explainer_cache_dir=explainer_args.explainer_cache_dir,
        explainer_revision=explainer_args.explainer_model_revision,
        explainer_token=other_args.token,
        explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
    )

    explainer = ObjExplainerForImageClassification(
        config=explainer_for_image_classification_config,
    )
explainer_image_processor = AutoImageProcessor.from_pretrained(
    explainer_args.explainer_image_processor_name
    or explainer_args.explainer_model_name_or_path,
    cache_dir=explainer_args.explainer_cache_dir,
    revision=explainer_args.explainer_model_revision,
    token=other_args.token,
)

########################################################
# Configure dataset (set max samples, transforms, etc.)
########################################################
dataset_explainer = copy.deepcopy(dataset_original)
dataset_explainer = configure_dataset(
    dataset=dataset_explainer,
    image_processor=explainer_image_processor,
    training_args=training_args,
    data_args=data_args,
    train_augmentation=False,
    validation_augmentation=False,
    test_augmentation=False,
    logger=logger,
)

In [None]:
from models import (
    RegExplainerForImageClassification,
    RegExplainerForImageClassificationConfig
)

In [None]:
regexplainer_for_image_classification_config = RegExplainerForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_for_image_classification_config,
    surrogate_from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
    explainer_config=explainer_config,
    explainer_from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
    explainer_cache_dir=explainer_args.explainer_cache_dir,
    explainer_revision=explainer_args.explainer_model_revision,
    explainer_token=other_args.token,
    explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
)

regexplainer = RegExplainerForImageClassification(
    config=regexplainer_for_image_classification_config,
)


In [None]:
from models import (
    RegExplainerNormalizeForImageClassification,
    RegExplainerNormalizeForImageClassificationConfig
)

In [None]:
regexplainer_normalize_for_image_classification_config = RegExplainerNormalizeForImageClassificationConfig(
    surrogate_pretrained_model_name_or_path=surrogate_args.surrogate_model_name_or_path,
    surrogate_config=surrogate_for_image_classification_config,
    surrogate_from_tf=bool(".ckpt" in surrogate_args.surrogate_model_name_or_path),
    surrogate_cache_dir=surrogate_args.surrogate_cache_dir,
    surrogate_revision=surrogate_args.surrogate_model_revision,
    surrogate_token=other_args.token,
    surrogate_ignore_mismatched_sizes=surrogate_args.surrogate_ignore_mismatched_sizes,
    explainer_pretrained_model_name_or_path=explainer_args.explainer_model_name_or_path,
    explainer_config=explainer_config,
    explainer_from_tf=bool(".ckpt" in explainer_args.explainer_model_name_or_path),
    explainer_cache_dir=explainer_args.explainer_cache_dir,
    explainer_revision=explainer_args.explainer_model_revision,
    explainer_token=other_args.token,
    explainer_ignore_mismatched_sizes=explainer_args.explainer_ignore_mismatched_sizes,
)

regexplainer_normalize = RegExplainerNormalizeForImageClassification(
    config=regexplainer_normalize_for_image_classification_config,
)


# FLOPS calculation

In [None]:
import deepspeed

In [None]:
from torch.utils.data import DataLoader

In [None]:
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler

In [None]:
profile_step=3
device="cpu"

In [None]:
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str


if __name__ == '__main__':
    model = ViTB32()
    model.eval()

    input = cv2.imread("../input.jpg")
    input = cv2.resize(input, (224, 224))
    input = torch.from_numpy(input).permute(2, 0, 1)
    input = input[None,:,:,:].float()

    flop = FlopCountAnalysis(model, input)
    print(flop_count_table(flop, max_depth=4))
    print(flop_count_str(flop))
    print(flop.total())

In [None]:
flops/8/1e+9, macs/8/1e+9

In [None]:
class surrogate_warpper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        print(type(x))
        print(type(masks))
        return self.model(x, masks, return_loss=False)
    
surrogate_warpped=surrogate_warpper(explainer.surrogate)

In [None]:
flop = FlopCountAnalysis(surrogate_warpped, pixel_values.to(device))
print(flop_count_table(flop, max_depth=4))
print(flop_count_str(flop))

In [None]:
batch_size=1

items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        print(step)
        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        masks=torch.Tensor(
                    np.random.choice([0,1], 
                         size=(batch_size, 256, 196), 
                         replace=True)
        )        
        
        
        if step == profile_step:
            flop = FlopCountAnalysis(surrogate_warpped, pixel_values.to(device))
            print(flop_count_table(flop, max_depth=4))
            print(flop_count_str(flop))
            print(flop.total())            
            break
        else:
            loss = surrogate_warpped(pixel_values.to(device)
                           )

# surrogate

In [None]:
profile_step=3

In [None]:
from ptflops import get_model_complexity_info

In [None]:
class surrogate_warpper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return self.model(x, masks, return_loss=False)

In [None]:
batch_size=2

items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        print(step)
        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        masks=torch.Tensor(
                    np.random.choice([0,1], 
                         size=(batch_size, 256, 196), 
                         replace=True)
        )
        
        if step==profile_step:
            macs, params = get_model_complexity_info(surrogate_warpper(explainer.surrogate), 
                                                     (2,3,224,224), 
                                                     as_strings=True,
                                                     print_per_layer_stat=True, verbose=True)            
            break
        else:
            loss = explainer.surrogate(pixel_values=pixel_values,
                                       masks=masks,
                                       return_loss=False)        

In [None]:
from thop import profile

In [None]:
batch_size=256

items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        print(step)
        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        masks=torch.Tensor(
                    np.random.choice([0,1], 
                         size=(batch_size, 128, 196), 
                         replace=True)
        )
        if step==profile_step:
            macs, params = profile(explainer.surrogate, (pixel_values, masks, None, False))            
            break
        else:
            loss = explainer.surrogate(pixel_values=pixel_values,
                                       masks=masks,
                                       return_loss=False)        

In [None]:
try:
    surrogate_prof
except:
    surrogate_prof=FlopsProfiler(explainer.surrogate)

In [None]:
batch_size=1
items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):


        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        masks=torch.Tensor(
                    np.random.choice([0,1], 
                         size=(batch_size, 32, 196), 
                         replace=True)
        )
        
        if step == profile_step:
            surrogate_prof.start_profile()        

        loss = explainer.surrogate(pixel_values=pixel_values.to(device), 
                                   masks=masks,
                                   return_loss=False)
        
        if step == profile_step:
            flops = surrogate_prof.get_total_flops(as_string=False)
            params = surrogate_prof.get_total_params(as_string=False)
            macs = surrogate_prof.get_total_macs(as_string=False)
            duration = surrogate_prof.get_total_duration(as_string=False)
            
            surrogate_prof.print_model_profile(profile_step=profile_step)
            surrogate_prof.end_profile() 
            break

In [None]:
params

In [None]:
flops/1e+12, macs/1e+12

In [None]:
flops,macs*2

In [None]:
macs/32

In [None]:
try:
    explainer_prof
except:
    explainer_prof=FlopsProfiler(explainer)

In [None]:
explainer.forward??

In [None]:
with torch.no_grad():
    explainer

In [None]:
explainer(pixel_values=torch.ones((64,3,224,224))).shape

In [None]:
batch_size=32

items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        if step == profile_step:
            explainer_prof.start_profile()

        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])

        loss = explainer(pixel_values=pixel_values.to(device), 
                                   return_loss=False)
        
        if step == profile_step:
            flops = explainer_prof.get_total_flops(as_string=False)
            params = explainer_prof.get_total_params(as_string=False)
            macs = explainer_prof.get_total_macs(as_string=False)
            duration = explainer_prof.get_total_duration(as_string=False)
            #explainer_prof.print_model_profile(profile_step=profile_step, output_file='profiler_log.txt')
            explainer_prof.print_model_profile(profile_step=profile_step)
            explainer_prof.end_profile() 
            break

In [None]:
batch_size=32

items=[dataset_explainer["train"][i] for i in range(batch_size)]

optimizer=torch.optim.AdamW(explainer.parameters())

for step in range(5):


    pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
    masks=torch.Tensor(
                np.random.choice([0,1], 
                     size=(batch_size, 32, 196), 
                     replace=True)
    )        
    model_outputs=torch.randn((batch_size, 32, 10))

    optimizer.zero_grad()
    loss = explainer(pixel_values=pixel_values.to(device),
                     masks=masks,
                     model_outputs=model_outputs,
                     return_loss=True)

    if step == profile_step:
        explainer_prof.start_profile() 
        
    loss.loss.backward()
    optimizer.step()        
        
    if step == profile_step:
        flops = explainer_prof.get_total_flops(as_string=False)
        params = explainer_prof.get_total_params(as_string=False)
        #explainer_prof.print_model_profile(profile_step=profile_step, output_file='profiler_log.txt')
        explainer_prof.print_model_profile(profile_step=profile_step)
        explainer_prof.end_profile() 
        break

# reg explainer

In [None]:
device="cuda:0"

In [None]:
regexplainer.to(device)

In [None]:
try:
    regexplainer_prof
except:
    regexplainer_prof=FlopsProfiler(regexplainer)
    

In [None]:
print(batch_size)
items=[dataset_explainer["train"][i] for i in range(batch_size)]

with torch.no_grad():
    for step in range(5):
        pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
        
        if step == profile_step:
            regexplainer_prof.start_profile()        

        loss = regexplainer(pixel_values=pixel_values.to(device), return_loss=False)
        
        if step == profile_step:
            flops = regexplainer_prof.get_total_flops(as_string=False)
            params = regexplainer_prof.get_total_params(as_string=False)
            macs = regexplainer_prof.get_total_macs(as_string=False)
            duration = regexplainer_prof.get_total_duration(as_string=False)            
            regexplainer_prof.print_model_profile(profile_step=profile_step)
            regexplainer_prof.end_profile() 
            break

In [None]:
(macs/duration)/1e+12

In [None]:
batch size 
8 170.68 GMACs 0.341 T
16 341.36 GMACs 0.683 T
32 682.72 GMACs 1.37 T
64 1.37 TMACs 2.73 T 
128 2.73 TMACs 5.47 T
256 5.46 TMACs 10.93 T

In [None]:
batch size 
8 170.68 GMACs 0.341 T
16 341.36 GMACs 0.683 T
32 682.72 GMACs 1.37 T
64 1.37 TMACs 5.47 T 
128 2.73 TMACs 5.47 T
256 5.46 TMACs 10.93 T

In [None]:
Notations:
data parallel size (dp_size), model parallel size(mp_size),
number of parameters (params), number of multiply-accumulate operations(MACs),
number of floating-point operations (flops), floating-point operations per second (FLOPS),
fwd latency (forward propagation latency), bwd latency (backward propagation latency),
step (weights update latency), iter latency (sum of fwd, bwd and step latency)

params per GPU:                                                         171.61 M
params of model = params per GPU * mp_size:                             0       
fwd MACs per GPU:                                                       22.02 TMACs
fwd flops per GPU:                                                      44.08 T 
fwd flops of model = fwd flops per GPU * mp_size:                       44.08 T 
fwd latency:                                                            1.92 s  
fwd FLOPS per GPU = fwd flops per GPU / fwd latency:                    22.97 TFLOPS

In [None]:
3. The fwd latency listed in the top module's profile is directly captured at the module forward function in PyTorch, thus it's less than the fwd latency shown above which is captured in DeepSpeed.

SurrogateForImageClassification(
  171.61 M = 100% Params, 22.02 TMACs = 100% MACs, 1.92 s = 100% latency, 22.97 TFLOPS
  (surrogate): ViTForImageClassification(
    85.81 M = 50% Params, 17.6 TMACs = 79.9% MACs, 1.43 s = 74.78% latency, 24.55 TFLOPS
    (vit): ViTModel(
      85.8 M = 50% Params, 17.6 TMACs = 79.9% MACs, 1.43 s = 74.69% latency, 24.58 TFLOPS
      (embeddings): ViTEmbeddings(
        742.66 K = 0.43% Params, 115.84 GMACs = 0.53% MACs, 40.14 ms = 2.09% latency, 5.78 TFLOPS
        (patch_embeddings): ViTPatchEmbeddings(
          590.59 K = 0.34% Params, 115.84 GMACs = 0.53% MACs, 34.19 ms = 1.78% latency, 6.78 TFLOPS
          (projection): Conv2d(590.59 K = 0.34% Params, 115.84 GMACs = 0.53% MACs, 33.59 ms = 1.75% latency, 6.9 TFLOPS, 3, 768, kernel_size=(16, 16), stride=(16, 16))

In [None]:
# surrogate 22.04


In [None]:
prof.print_model_profile?

In [None]:
prof.print_model_profile??

In [None]:
loss["logits"].shape

In [None]:
optimizer=torch.optim.AdamW(regexplainer.parameters())

In [None]:
shapley_ground_truth.shape

In [None]:
for step in tqdm(range(10)):
    if step == profile_step:
        prof.start_profile()

    pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])
    shapley_ground_truth=torch.randn(32,196,10)

    loss = regexplainer(pixel_values=pixel_values.to(device), 
                        shapley_values=shapley_ground_truth.to(device),
                        return_loss=True)
    
   

    if step == profile_step:
        flops = prof.get_total_flops(as_string=True)
        params = prof.get_total_params(as_string=True)
        prof.print_model_profile(profile_step=profile_step)
        prof.end_profile() 
        break
        
    loss.backward()
    optimizer.step()         

In [None]:
with torch.no_grad():
for step in range(100):
    if step == profile_step:
        prof.start_profile()
        
    pixel_values=torch.cat([item["pixel_values"].unsqueeze(0) for item in items])

    loss = regexplainer(pixel_values=pixel_values)
    
    sdds

    if step == profile_step:
        flops = prof.get_total_flops(as_string=True)
        params = prof.get_total_params(as_string=True)
        prof.print_model_profile(profile_step=profile_step)
        prof.end_profile()

    loss.backward()
    optimizer.step()

In [None]:
prof = FlopsProfiler(regexplainer)

In [None]:
deepspeed.profiling.flops_profiler.profiler.FlopsProfiler(model, ds_engine=None, recompute_fwd_factor=0.0)

# move to GPU

In [None]:
device="cuda:0"

In [None]:
explainer.to(device)

In [None]:
regexplainer.to(device)

In [None]:
regexplainer_normalize.to(device)

# visualizing 

In [None]:
import glob
import pandas as pd
import json
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import cm

In [None]:
# !cp /System/Library/Fonts/Supplemental ~/.local/share/fonts/
# rm -fr ~/.cache/matplotlib
from matplotlib import font_manager
from matplotlib.lines import Line2D
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

# https://github.com/dsc/colorbrewer-python/blob/master/colorbrewer.py

Set1 = {
    3: [[228,26,28], [55,126,184], [77,175,74]],
    4: [[228,26,28], [55,126,184], [77,175,74], [152,78,163]],
    5: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0]],
    6: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51]],
    7: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40]],
    8: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191]],
    9: [[228,26,28], [55,126,184], [77,175,74], [152,78,163], [255,127,0], [255,255,51], [166,86,40], [247,129,191], [153,153,153]],
}

Paired = {
    3: [(166,206,227), [31,120,180], [178,223,138]],
    4: [[166,206,227], [31,120,180], [178,223,138], [51,160,44]],
    5: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153]],
    6: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28]],
    7: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111]],
    8: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0]],
    9: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214]],
    10: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154]],
    11: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153]],
    12: [[166,206,227], [31,120,180], [178,223,138], [51,160,44], [251,154,153], [227,26,28], [253,191,111], [255,127,0], [202,178,214], [106,61,154], [255,255,153], [177,89,40]]
}

color_qual_7=['#F53345',
            '#87D303',
            '#04CBCC',
            '#8650CD',
            (160/256, 95/256, 0),
            '#F5A637',              
            '#DBD783',            
             ]

pd.set_option('display.max_rows', 500)

In [None]:
plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.edgecolor']='1.0'
plt.rcParams['legend.framealpha']=1

In [None]:
from scipy import stats

In [None]:
def get_ground_truth_metric_with_explainer(
    attribution_values,
    explainer,
    dataset,
    iters_ground_truth,
    meta_info,
    check_class_efficiency=True,
    ground_truth_key_select=None,
    transform_mode=None
):
    record_dict_list= []
    ground_truth_list = []
    estimated_list = []
    explainer.eval()
    
    # for sample_idx, tracking_dict in tqdm(attribution_values.items()):
    for sample_idx, tracking_dict in tqdm(
        attribution_values.items() if ground_truth_key_select is None else {key: attribution_values[key] for key in ground_truth_key_select}.items()
    ):
        # Prepare input and make prediction.
        data = dataset[sample_idx]
        with torch.no_grad():
            estimated = explainer(pixel_values=data["pixel_values"].unsqueeze(0).to(explainer.device), return_loss=False)["logits"][0]
            
        # Prepare ground truth values.
        if isinstance(tracking_dict["iters"], np.ndarray):
            tracking_dict["iters"] = tracking_dict["iters"].tolist()
        
        ground_truth = tracking_dict["values"][tracking_dict["iters"].index(iters_ground_truth)]

        if transform_mode is None:
            pass
        elif transform_mode=="global":
            pass
        elif transform_mode=="sqrt":
            pass        
        elif transform_mode=="perinstance":
            ground_truth = ground_truth / np.linalg.norm(ground_truth, axis=(0, 1), keepdims=True)
        elif transform_mode=="perinstanceperclass":
            ground_truth = ground_truth / np.linalg.norm(ground_truth, axis=0, keepdims=True)
        else:
            raise ValueError(transform_mode)
        
        estimated = estimated.T.cpu().detach().numpy()
        ground_truth_list.append(ground_truth)
        estimated_list.append(estimated)
        
        # Calculate MSE.
        diff = estimated - ground_truth
        mse_class= (diff * diff).sum(axis=0)

        # Calculate other metrics.
        record = {
            "sample_idx": sample_idx,
            "mse_all": mse_class.mean(),
            "pearsonr_all": stats.pearsonr(estimated.flatten(), ground_truth.flatten())[0],
            "pearsonr_all_per_class": np.mean([stats.pearsonr(estimated[:, class_idx], ground_truth[:, class_idx])[0] for class_idx in np.arange(ground_truth.shape[1])]),
            "spearmanr_all": stats.spearmanr(estimated.flatten(), ground_truth.flatten())[0],
            "spearmanr_all_per_class": np.mean([stats.spearmanr(estimated[:, class_idx], ground_truth[:, class_idx])[0] for class_idx in np.arange(ground_truth.shape[1])]),
            "sign_agreement_all": ((estimated>0)==(ground_truth>0)).astype(int).mean(),
        }
        
        
        if check_class_efficiency:
            target_class_idx=np.argmax(tracking_dict["values"][0].sum(axis=0))
            assert data["labels"]==target_class_idx
            
            record.update({
                "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
                "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(), 
                "pearsonr_target": stats.pearsonr(estimated[:, target_class_idx], ground_truth[:, target_class_idx])[0],                
                "spearmanr_target": stats.spearmanr(estimated[:, target_class_idx], ground_truth[:, target_class_idx])[0],                
            })        
        
        # Append result for this image.
        record.update(meta_info)
        record_dict_list.append(record)

    # Calculate global metrics.
    ground_truth_all = np.array(ground_truth_list)
    estimated_all = np.array(estimated_list)
    pearson = stats.pearsonr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    spearman = stats.spearmanr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    for record in record_dict_list:
        record['pearson_global'] = pearson
        record['spearman_global'] = spearman
        
    return record_dict_list

In [None]:
def get_ground_truth_metric_with_value(
    attribution_values_ground_truth,
    iters_ground_truth,
    attribution_values_calculated,
    iters_calculated,
    meta_info,
    check_class_efficiency=True,
    ground_truth_key_select=None,
):
    record_dict_list = []
    ground_truth_list = []
    estimated_list = []

    for sample_idx, tracking_dict_ground_truth in tqdm(
        attribution_values_ground_truth.items() if ground_truth_key_select is None else {key: attribution_values_ground_truth[key] for key in ground_truth_key_select}.items()
    ):
        
        tracking_dict_calculated = attribution_values_calculated[sample_idx]

        if isinstance(tracking_dict_ground_truth["iters"], np.ndarray):
            tracking_dict_ground_truth["iters"] = tracking_dict_ground_truth[
                "iters"
            ].tolist()

        if isinstance(tracking_dict_calculated["iters"], np.ndarray):
            tracking_dict_calculated["iters"] = tracking_dict_calculated[
                "iters"
            ].tolist()

        # Prepare ground truth values and model's estimates.
        ground_truth = tracking_dict_ground_truth["values"][
            tracking_dict_ground_truth["iters"].index(iters_ground_truth)
        ]
        estimated = tracking_dict_calculated["values"][
            tracking_dict_calculated["iters"].index(iters_calculated)
        ]
        ground_truth_list.append(ground_truth)
        estimated_list.append(estimated)

        # Calculate MSE.
        diff = estimated - ground_truth
        mse_class = (diff * diff).sum(axis=0)
       
        
        # Calculate other metrics.
        record = {
            "sample_idx": sample_idx,
            "mse_all": mse_class.mean(),

            "pearsonr_all": stats.pearsonr(estimated.flatten(), ground_truth.flatten())[0],
            "pearsonr_all_per_class": np.mean([
                stats.pearsonr(estimated[:, class_idx], ground_truth[:, class_idx])[0]
                for class_idx in np.arange(ground_truth.shape[1])
            ]),

            "spearmanr_all": stats.spearmanr(estimated.flatten(), ground_truth.flatten())[0],
            "spearmanr_all_per_class": np.mean([
                stats.spearmanr(estimated[:, class_idx], ground_truth[:, class_idx])[0]
                for class_idx in np.arange(ground_truth.shape[1])
            ]),
            "sign_agreement_all": ((estimated>0)==(ground_truth>0)).astype(int).mean(),
        }
        
        
        if check_class_efficiency:
            target_class_idx_ground_truth = np.argmax(tracking_dict_ground_truth["values"][0].sum(axis=0))
            target_class_idx_calculated = np.argmax(tracking_dict_calculated["values"][0].sum(axis=0))            
            assert target_class_idx_ground_truth == target_class_idx_calculated        
            
            record.update({                
                "mse_target": mse_class[
                    np.arange(len(mse_class)) == target_class_idx_ground_truth
                ].mean(),
                "mse_nontarget": mse_class[
                    np.arange(len(mse_class)) != target_class_idx_ground_truth
                ].mean(),                
                
                "pearsonr_target": stats.pearsonr(
                    estimated[:, target_class_idx_ground_truth],
                    ground_truth[:, target_class_idx_ground_truth],
                )[0],  
                
                "spearmanr_target": stats.spearmanr(
                    estimated[:, target_class_idx_ground_truth],
                    ground_truth[:, target_class_idx_ground_truth],
                )[0],                
                
            })        

        # Append result for this image.
        record.update(meta_info)
        record_dict_list.append(record)

    # Calculate global metrics.
    ground_truth_all = np.array(ground_truth_list)
    estimated_all = np.array(estimated_list)
    pearson = stats.pearsonr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    spearman = stats.spearmanr(estimated_all.flatten(), ground_truth_all.flatten())[0]
    for record in record_dict_list:
        record['pearson_global'] = pearson
        record['spearman_global'] = spearman

    return record_dict_list

# Load data

In [None]:
dataset_explainer

In [None]:
shapley_loaded_dict={}

In [None]:
# shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"]\
# =load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test", attribution_name="shapley")
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test", attribution_name="shapley")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                  attribution_name="shapley")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train",
                 attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
# shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train"]\
# =load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train",
#              target_subset_size=196, attribution_name="shapley",
# sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])                  

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
    attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
# shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train"]\
# =load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train",
#                      attribution_name="shapley",
# sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])                  

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train",
             attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
# shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"]\
# =load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train",
# attribution_name="shapley",                 
# sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100]                 
# )

In [None]:
for key in shapley_loaded_dict.keys():
    print(len(shapley_loaded_dict[key])) 

In [None]:
banzhaf_loaded_dict={}

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                 attribution_name="banzhaf")

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                 attribution_name="banzhaf")

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
                 attribution_name="banzhaf",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train",
                 attribution_name="banzhaf",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
# from shutil import rmtree
# for path_temp in glob.glob("logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train/*"):
#     if int(path_temp.split('/')[-1]) in np.random.RandomState(seed=42).permutation(list(range(9469)))[:100]:
#         pass
#     else:
#         rmtree(path_temp)

In [None]:
for key in banzhaf_loaded_dict.keys():
    print(key, len(banzhaf_loaded_dict[key])) 

In [None]:
!ls logs/vitbase_imagenette_surrogate_binomial*/

In [None]:
lime_loaded_dict={}

In [None]:
lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                 attribution_name="lime")

In [None]:
lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test"]\
=load_attribution("logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                 attribution_name="lime")

In [None]:
lime_loaded_dict["logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train"]\
=load_attribution("logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train",
                 attribution_name="lime",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])                  

In [None]:
# lime_loaded_dict["logs/vitbase_imagenette_surrogate_binomial_eval_validation/extract_output/validation"]\
# =load_attribution("logs/vitbase_imagenette_surrogate_binomial_eval_validation/extract_output/validation",
#                  attribution_name="lime")

# distribution check

In [None]:
type(i)

In [None]:
dataset_explainer["train"][774]

In [None]:
shapley_sample["iters"][-10:], banzhaf_sample["iters"][-10:]

In [None]:
shapley_sample["values"][shapley_sample["iters"].index(1000000)]

In [None]:
lime_loaded_dict.keys()

In [None]:
# shapley_target=[]
# shapley_nontarget=[]

# banzhaf_target=[]
# banzhaf_nontarget=[]

# for i in range(5000):
#     label_target=dataset_explainer["train"][i]['labels']
#     shapley_target_nontarget=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"][i]["values"][-2]
#     banzhaf_target_nontarget=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"][i]["values"][-5]
    
    
    
    
#     shapley_target+=shapley_target_nontarget[:,label_target].tolist()
#     shapley_nontarget+=shapley_target_nontarget[:,np.arange(10)!=label_target].flatten().tolist()
    
    
#     banzhaf_target+=banzhaf_target_nontarget[:,label_target].tolist()
#     banzhaf_nontarget+=banzhaf_target_nontarget[:,np.arange(10)!=label_target].flatten().tolist()    

In [None]:
shapley_sample_list=[]
banzhaf_sample_list=[]
lime_sample_list=[]


    
    label_target=dataset_explainer["train"][idx]['labels']
    shapley_sample=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"][idx]
    shapley_sample=shapley_sample["values"][shapley_sample["iters"].index(1000000)]
    
    banzhaf_sample=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"][idx]
    banzhaf_sample=banzhaf_sample["values"][banzhaf_sample["iters"].index(1000000)]
    
    lime_sample=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"][idx]
    lime_sample=lime_sample["values"][lime_sample["iters"].index(1000000)]    
    

    shapley_sample_list+=shapley_sample.flatten().tolist()
    banzhaf_sample_list+=banzhaf_sample.flatten().tolist()
    lime_sample_list+=lime_sample.flatten().tolist()
    
#     shapley_target+=shapley_target_nontarget[:,label_target].tolist()
#     shapley_nontarget+=shapley_target_nontarget[:,np.arange(10)!=label_target].flatten().tolist()
    
    
#     banzhaf_target+=banzhaf_target_nontarget[:,label_target].tolist()
#     banzhaf_nontarget+=banzhaf_target_nontarget[:,np.arange(10)!=label_target].flatten().tolist()    

In [None]:
metric_list_plot_norm=[]
for idx in np.random.RandomState(seed=42).permutation(list(range(9469)))[:100].tolist():
    shapley_sample=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"][idx]
    shapley_sample=shapley_sample["values"][shapley_sample["iters"].index(1000000)] 
    
    metric_list_plot_norm.append({"sample_idx": idx,
                                 "norm": np.linalg.norm(shapley_sample),
                                 "method_type": "KernelSHAP"
                                })
    
for idx in np.random.RandomState(seed=42).permutation(list(range(9469)))[:100].tolist():
    banzhaf_sample=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"][idx]
    banzhaf_sample=banzhaf_sample["values"][banzhaf_sample["iters"].index(1000000)]

    metric_list_plot_norm.append({"sample_idx": idx,
                                 "norm": np.linalg.norm(banzhaf_sample),
                                 "method_type": "BanzhafMSR"
                                })  
    
for idx in np.random.RandomState(seed=42).permutation(list(range(9469)))[:100].tolist():
    lime_sample=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"][idx]
    lime_sample=lime_sample["values"][lime_sample["iters"].index(1000000)]

    metric_list_plot_norm.append({"sample_idx": idx,
                                 "norm": np.linalg.norm(lime_sample),
                                 "method_type": "LIME"
                                })      

In [None]:
metric_list_plot_norm_df=pd.DataFrame(metric_list_plot_norm)

In [None]:
fig = plt.figure(figsize=(3*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 3, hspace=0.3)

axd={}
for idx1, method_type in enumerate(["KernelSHAP", "BanzhafMSR", "LIME"]):
    ax=plt.Subplot(fig, box1[idx1])
    fig.add_subplot(ax)

    plot_key=(method_type)
    axd[plot_key]=ax  
    
for idx1, method_type in enumerate(["KernelSHAP", "BanzhafMSR", "LIME"]):

    plot_key=(method_type)                      
    
    metric_list_plot_norm_df[metric_list_plot_norm_df["method_type"]==method_type]["norm"].hist(ax=axd[plot_key], bins=10)
    
    axd[plot_key].set_title({"KernelSHAP": "Shapley values", "BanzhafMSR": "Banzhaf values", "LIME": "LIME"}[method_type])
    
    axd[plot_key].set_xlabel("L2 norm")
    axd[plot_key].set_ylabel("Count")
    
        

In [None]:
fig.savefig("logs/plots/"+f"feature_attribution_distribution.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"feature_attribution_distribution.pdf", bbox_inches='tight')

In [None]:
axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    ax=plt.Subplot(fig, box1[idx1])
    fig.add_subplot(ax)

    plot_key=(metric)
    axd[plot_key]=ax   

In [None]:
fig, axes=plt.subplots(2,3, figsize=(20,10))
axes[0][0].hist(shapley_nontarget+shapley_target, bins=100)
axes[0][0].set_xlim([-0.3,0.3])
axes[0][0].set_title("All classes")

axes[0][1].hist(shapley_target, bins=100)
axes[0][1].set_xlim([-0.3,0.3])
axes[0][1].set_title("Target classes")


axes[0][2].hist(shapley_nontarget, bins=100)
axes[0][2].set_xlim([-0.3,0.3])
axes[0][2].set_title("Non-target classes")

axes[1][0].hist(shapley_nontarget+shapley_target, bins=100)
axes[1][0].set_xlim([-0.3,0.3])
axes[1][0].set_title("All classes–log scale")
axes[1][0].set_yscale("log")

axes[1][1].hist(shapley_target, bins=100)
axes[1][1].set_xlim([-0.3,0.3])
axes[1][1].set_title("Target classes–log scale")
axes[1][1].set_yscale("log")


axes[1][2].hist(shapley_nontarget, bins=100)
axes[1][2].set_xlim([-0.3,0.3])
axes[1][2].set_title("Non-target classes–log scale")
axes[1][2].set_yscale("log")

In [None]:
fig, axes=plt.subplots(2,3, figsize=(20,10))
axes[0][0].hist(banzhaf_nontarget+banzhaf_target, bins=100)
axes[0][0].set_xlim([-0.3,0.3])
axes[0][0].set_title("All classes")

axes[0][1].hist(banzhaf_target, bins=100)
axes[0][1].set_xlim([-0.3,0.3])
axes[0][1].set_title("Target classes")


axes[0][2].hist(banzhaf_nontarget, bins=100)
axes[0][2].set_xlim([-0.3,0.3])
axes[0][2].set_title("Non-target classes")

axes[1][0].hist(banzhaf_nontarget+banzhaf_target, bins=100)
axes[1][0].set_xlim([-0.3,0.3])
axes[1][0].set_title("All classes–log scale")
axes[1][0].set_yscale("log")

axes[1][1].hist(banzhaf_target, bins=100)
axes[1][1].set_xlim([-0.3,0.3])
axes[1][1].set_title("Target classes–log scale")
axes[1][1].set_yscale("log")


axes[1][2].hist(banzhaf_nontarget, bins=100)
axes[1][2].set_xlim([-0.3,0.3])
axes[1][2].set_title("Non-target classes–log scale")
axes[1][2].set_yscale("log")

In [None]:
fig, axes=plt.subplots(2,3, figsize=(20,10))
axes[0][0].hist(list(map(lambda x: np.sign(x) * np.power(np.abs(x), 0.35), banzhaf_nontarget+banzhaf_target)), bins=100)
axes[0][0].set_xlim([-0.6,0.6])
axes[0][0].set_title("All classes")

axes[0][1].hist(list(map(lambda x: np.sign(x) * np.power(np.abs(x), 0.35), banzhaf_target)), bins=100)
axes[0][1].set_xlim([-0.6,0.6])
axes[0][1].set_title("Target classes")


axes[0][2].hist(list(map(lambda x: np.sign(x) * np.power(np.abs(x), 0.35), banzhaf_nontarget)), bins=100)
axes[0][2].set_xlim([-0.6,0.6])
axes[0][2].set_title("Non-target classes")

axes[1][0].hist(list(map(lambda x: np.sign(x) * np.power(np.abs(x), 0.35), banzhaf_nontarget+banzhaf_target)), bins=100)
axes[1][0].set_xlim([-0.6,0.6])
axes[1][0].set_title("All classes–log scale")
axes[1][0].set_yscale("log")

axes[1][1].hist(list(map(lambda x: np.sign(x) * np.power(np.abs(x), 0.35), banzhaf_target)), bins=100)
axes[1][1].set_xlim([-0.6,0.6])
axes[1][1].set_title("Target classes–log scale")
axes[1][1].set_yscale("log")


axes[1][2].hist(list(map(lambda x: np.sign(x) * np.power(np.abs(x), 0.35), banzhaf_nontarget)), bins=100)
axes[1][2].set_xlim([-0.6,0.6])
axes[1][2].set_title("Non-target classes–log scale")
axes[1][2].set_yscale("log")

In [None]:
def summarize_value(vec):
    print("min:", np.min(vec))
    print("mean:", np.mean(vec))
    print("max:", np.max(vec))
    print("std:", np.std(vec))

In [None]:
shapley_target=[]
shapley_nontarget=[]

banzhaf_target=[]
banzhaf_nontarget=[]

for i in range(1000):
    label_target=dataset_explainer["train"][i]['labels']
    shapley_target_nontarget=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"][i]["values"][-2]
    banzhaf_target_nontarget=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"][i]["values"][-5]
    
    shapley_target.append(shapley_target_nontarget[:,label_target].tolist())
    shapley_nontarget.append(shapley_target_nontarget[:,np.arange(10)!=label_target].flatten().tolist())
    
    
    banzhaf_target.append(banzhaf_target_nontarget[:,label_target].tolist())
    banzhaf_nontarget.append(banzhaf_target_nontarget[:,np.arange(10)!=label_target].flatten().tolist())

In [None]:
print("Shapley target")
summarize_value(shapley_target)
print("Banzhaf target")
summarize_value(banzhaf_target) 

In [None]:
print("Shapley nontarget")
summarize_value(shapley_nontarget)
print("Banzhaf nontarget")
summarize_value(banzhaf_nontarget)

In [None]:
plt.hist([np.max(i) for i in shapley_target])
plt.yscale("log")
plt.title("max of each sample (Shapley)")

In [None]:
plt.hist([np.max(i) for i in banzhaf_target])
plt.yscale("log")
plt.title("max of each sample (Banzhaf)")

In [None]:
shapley_target_nontarget[:,label_target].tolist()

In [None]:
# metric_list_value=[]
# for num_subsets in range(500, 100000, 500):
#     metric_list_value+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"], 
#                                        iters_ground_truth=1000000, 
#                                        attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
#                                                  })
    

# for num_subsets in range(500, 100000, 500):
#     metric_list_value+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"], 
#                                        iters_ground_truth=1000000, 
#                                        attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
#                                                  })    

In [None]:
plot_df=pd.DataFrame(metric_list_value)

sns.lineplot(
    x="num_subsets",
    y="mse_all",
    hue="",
    data=plot_df
)




# Training target quality

In [None]:
metric_list_value_shapley=[]

In [None]:
for num_subsets in [512, 1024, 2048, 3072]:
    metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
                                                 })

In [None]:
for num_subsets in [196, 392, 588, 1176, 3136]:
    metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train",
                                                 })

In [None]:
# num_subsets=196
# for i in range(16):
#     shapley_loaded_dict_temp={}
#     for sample_idx, tracking_dict in shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train"].items():
#         shapley_loaded_dict_temp[sample_idx]=tracking_dict[i]

#     metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
#                                        iters_ground_truth=999424, 
#                                        attribution_values_calculated=shapley_loaded_dict_temp,
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
#                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train",
#                                                   "nth": i+1,
#                                                  }) 

In [None]:
for num_subsets in [258, 514, 1026, 2050, 4098, 5122, 9986]:
    metric_list_value_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train",
                                                 })   

In [None]:
if False:
    torch.save(metric_list_value_shapley, "logs/experiment_results/metric_list_value_shapley.pt")

In [None]:
metric_list_value_banzhaf=[]

In [None]:
for num_subsets in [5, 10, 20, 30, 40, 50, 60, 70, 80, 90]:
    metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train",
                                                 },
                                       check_class_efficiency=False)

for num_subsets in [100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]:
    metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train",
                                                 },
                                       check_class_efficiency=False)
#                                        ground_truth_key_select=\
#                                     banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"].keys()
#                                                          )
for num_subsets in list(range(6000, 100000+1000, 10000)):
    metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                 },
                                       check_class_efficiency=False)    

# for num_subsets in [100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]:
#     metric_list_value+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
#                                        iters_ground_truth=1000000, 
#                                        attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train",
#                                                  },
#                                         check_class_efficiency=False,
#                                        ground_truth_key_select=\
#                                     banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"].keys()                                                         
#                                                          )                                                          

In [None]:
# for num_subsets in [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]+list(range(6000, 100000+1000, 10000)):
#     metric_list_value_banzhaf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"], 
#                                        iters_ground_truth=1000000, 
#                                        attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
#                                                  },
#                                        check_class_efficiency=False)

In [None]:
if False:
    torch.save(metric_list_value_banzhaf, "logs/experiment_results/metric_list_value_banzhaf.pt")

In [None]:
metric_list_value_lime=[]

In [None]:
for num_subsets in list(range(128, 3200, 128)):
    metric_list_value_lime+=get_ground_truth_metric_with_value(attribution_values_ground_truth=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=lime_loaded_dict["logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train",
                                                 },
                                       check_class_efficiency=False)
#                                        ground_truth_key_select=\
#                                     banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"].keys()
#                                                          )


for num_subsets in list(range(100000, 1000000+100000, 100000)):
    metric_list_value_lime+=get_ground_truth_metric_with_value(attribution_values_ground_truth=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                 },
                                       check_class_efficiency=False)

In [None]:
if False:
    torch.save(metric_list_value_lime, "logs/experiment_results/metric_list_value_lime.pt")

In [None]:
get_best_model_checkpoint(model_path="logs/vitbase_imagenette_shapley_objexplainer_newsample_32")    

# evaluate explainer

In [None]:
def compare_checkpoint_value(current_checkpoint, best_checkpoint):
    if current_checkpoint<best_checkpoint:
        return "before"
    elif current_checkpoint==best_checkpoint:
        return "best"
    elif current_checkpoint>best_checkpoint:
        return "after"
    else:
        raise ValueError

In [None]:
def get_best_model_checkpoint(model_path):
    if os.path.exists(model_path+"/trainer_state.json"):
        with open(model_path+"/trainer_state.json") as f:
            trainer_state = json.load(f) 
        return trainer_state["best_model_checkpoint"]
    else:
        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        with open(checkpoint_path_list[-1]+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)                
        return checkpoint_trainer_state["best_model_checkpoint"]

In [None]:
from shutil import rmtree
for model_path in glob.glob("/sdata/chanwkim/xai-amortization/logs_0901/*"):
    print(model_path)
    if len(glob.glob(model_path+"/checkpoint-*"))==0:
        #print("pass", model_path)
        pass
    else:
        best_checkpoint_path=get_best_model_checkpoint(model_path)
        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        
        epoch_step_count=pd.Series([int(i.split('-')[-1]) for i in checkpoint_path_list]).sort_values().diff().min()
        checkpoint_to_delete=[checkpoint_path for checkpoint_path in checkpoint_path_list if int(checkpoint_path.split('-')[-1])-int(best_checkpoint_path.split('-')[-1])>epoch_step_count*10]
        print(len(checkpoint_to_delete))
        if len(checkpoint_to_delete)!=0:
            print(best_checkpoint_path)
            for checkpoint_path in tqdm(checkpoint_to_delete):
                rmtree(checkpoint_path)
                print(checkpoint_path)
                
            sdsds

In [None]:
metric_list_shapley=[]

### Reg-AO (upfront, regression)

In [None]:
for num_subsets in [512, 1024, 2048, 3072]:
# for num_subsets in [2048, 3072]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:20]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })        

In [None]:
# for num_subsets in [512]:
# # for num_subsets in [2048, 3072]:
#     model_path=f"logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_{num_subsets}"
#     with open(model_path+"/trainer_state.json") as f:
#         trainer_state = json.load(f)

#     checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
#     for checkpoint_path in tqdm(checkpoint_path_list[:20]):
#         checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
#         with open(checkpoint_path+"/trainer_state.json") as f:
#             checkpoint_trainer_state = json.load(f)

#         regexplainer.load_state_dict(checkpoint_state_dict)

#         metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
#                                 explainer=regexplainer,
#                                 dataset=dataset_explainer["test"],
#                                 iters_ground_truth=999424,
#                                 meta_info={
#                                            "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
#                                            "model_path": model_path,
#                                            "epoch": int(checkpoint_trainer_state["epoch"]),
#                                            "is_best_checkpoint": 
#                                             compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
#                                                                      best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
#                                           })


#         metric_list+=get_ground_truth_metric_with_explainer(shapley_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
#                                 explainer=regexplainer,
#                                 dataset=dataset_explainer["train"],
#                                 iters_ground_truth=999424,
#                                 meta_info={
#                                            "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
#                                            "model_path": model_path,
#                                            "epoch": int(checkpoint_trainer_state["epoch"]),
#                                            "is_best_checkpoint": 
#                                             compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
#                                                                      best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
#                                           })        

### Reg-AO (upfront, permutation)

In [None]:
for num_subsets in [196, 392, 588, 1176, 3136]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:20]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(trainer_state["best_model_checkpoint"].split('-')[-1]))
                                          })        

### Reg-AO (newsample, permutation)

In [None]:
# for num_subsets in [196]:
#     model_path=f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}"
# #     with open(model_path+"/trainer_state.json") as f:
# #         trainer_state = json.load(f)

#     checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
#     for checkpoint_path in tqdm(checkpoint_path_list[:20]):
#         checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
#         with open(checkpoint_path+"/trainer_state.json") as f:
#             checkpoint_trainer_state = json.load(f)

#         regexplainer.load_state_dict(checkpoint_state_dict)

#         metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test"], 
#                                 explainer=regexplainer,
#                                 dataset=dataset_explainer["test"],
#                                 iters_ground_truth=999424,
#                                 meta_info={
#                                            "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
#                                            "model_path": model_path,
#                                            "epoch": int(checkpoint_trainer_state["epoch"]),
#                                            "is_best_checkpoint": 
#                                             compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
#                                                                      best_checkpoint=1184)
#                                           })


#         metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
#                                 explainer=regexplainer,
#                                 dataset=dataset_explainer["train"],
#                                 iters_ground_truth=999424,
#                                 meta_info={
#                                            "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
#                                            "model_path": model_path,
#                                            "epoch": int(checkpoint_trainer_state["epoch"]),
#                                            "is_best_checkpoint": 
#                                             compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
#                                                                      best_checkpoint=1184)
#                                           })        

### SGD-shapley

In [None]:
for num_subsets in [9986]:
    model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}"

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    for checkpoint_path in tqdm(checkpoint_path_list[:20]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })


        metric_list_shapley+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })        

In [None]:
!rm -r /sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_9986/checkpoint-9324/

In [None]:
!ls /sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_9986/checkpoint-148/

## load banzhaf

In [None]:
banzhaf_loaded_dict.keys()

In [None]:
metric_list_banzhaf=[]

In [None]:
# for target_transform_mode, num_subsets_list in zip(["global","sqrt", "perinstance", "perinstanceperclass"],
#                                   [[10,100,500], [10,100,500], [100,500], [10, 100,500]]):
for target_transform_mode, num_subsets_list in zip(["global", "perinstanceperclass"],
                                  [[10,100,500], [10, 100,500]]):
    regexplainer_normalize.config.target_transform_mode=target_transform_mode
    for num_subsets in num_subsets_list:
        model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_{target_transform_mode}_{num_subsets}"

        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        for checkpoint_path in tqdm(checkpoint_path_list[:]):
            checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
            with open(checkpoint_path+"/trainer_state.json") as f:
                checkpoint_trainer_state = json.load(f)

            regexplainer_normalize.load_state_dict(checkpoint_state_dict)

            metric_list_banzhaf+=get_ground_truth_metric_with_explainer(attribution_values=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["test"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },                                   
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)


            metric_list_banzhaf+=get_ground_truth_metric_with_explainer(attribution_values=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["train"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)        

In [None]:
if False:
    torch.save(metric_list_banzhaf, "logs/experiment_results/metric_list_banzhaf.pt")
    metric_list_banzhaf=torch.load("logs/experiment_results/metric_list_banzhaf.pt")

### load lime

In [None]:
metric_list_lime=[]

In [None]:
# for target_transform_mode, num_subsets_list in zip(["global","sqrt", "perinstance", "perinstanceperclass"],
#                                   [[128,256,512], [128,256,512], [256,512], [256,512]]):
for target_transform_mode, num_subsets_list in zip(["global", "perinstanceperclass"],
                                  [[256,512], [256,512]]):    
    regexplainer_normalize.config.target_transform_mode=target_transform_mode
    for num_subsets in num_subsets_list:
        model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_{target_transform_mode}_{num_subsets}"

        checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
        for checkpoint_path in tqdm(checkpoint_path_list[:]):
            checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
            with open(checkpoint_path+"/trainer_state.json") as f:
                checkpoint_trainer_state = json.load(f)

            regexplainer_normalize.load_state_dict(checkpoint_state_dict)

            metric_list_lime+=get_ground_truth_metric_with_explainer(attribution_values=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["test"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },                                   
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)


            metric_list_lime+=get_ground_truth_metric_with_explainer(attribution_values=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                    explainer=regexplainer_normalize,
                                    dataset=dataset_explainer["train"],
                                    iters_ground_truth=1000000,
                                    meta_info={
                                               "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                               "model_path": model_path,
                                               "epoch": int(checkpoint_trainer_state["epoch"]),
                                               "is_best_checkpoint": 
                                                compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                         best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                              },
                                    check_class_efficiency=False,
                                    transform_mode=target_transform_mode)        

In [None]:
df_temp=pd.DataFrame(metric_list_banzhaf)
df_temp[(df_temp["is_best_checkpoint"]=="best")&
        (df_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train")
].groupby(["model_path"])["mse_all"].agg(["mean", "std"])

In [None]:
if False:
    torch.save(metric_list_lime, "logs/experiment_results/metric_list_lime.pt")
    metric_list_lime=torch.load("logs/experiment_results/metric_list_lime.pt")

# Training target quality

In [None]:
import copy

In [None]:
metric_list_plot=[]
for metric in metric_list_value_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]})',
             "method_type": 'KernelSHAP',
             "antithetical": False
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]}, antithetical)',
             "method_type": 'KernelSHAP',
             "antithetical": True
            }
        )    
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]})',
             "method_type": 'Permutation',
             "antithetical": False
            }
        )       
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]}, antithetical)',
             "method_type": 'Permutation',
             "antithetical": True
            }
        )  
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]}, newsample, {metric_temp["nth"]})',
             "method_type": 'Permutation',
             "antithetical": False
            }
        )  
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'SGD-Shapley ({metric_temp["num_subsets"]}, antithetical)',
             "method_type": 'SGD-Shapley',
             "antithetical": True
            }
        )          
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
metric_list_plot_df.groupby(["method_name","num_subsets",
                            "true_name", "estimated_name",
                            ])[["pearsonr_all_per_class", "spearmanr_all_per_class",
                               "pearsonr_all", "spearmanr_all"
                               ]].mean().reset_index().sort_values("num_subsets")

In [None]:
(metric_list_plot_df[["method_type","antithetical"]]).value_counts()

In [None]:
fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    ax=plt.Subplot(fig, box1[idx1])
    fig.add_subplot(ax)

    plot_key=(idx1)
    axd[plot_key]=ax   
    
# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax    
    
    
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):

    plot_key=(idx1)


    if metric=="each":
        sns.barplot(
            x="mse_all",
            y="method_name",
        #     hue="method",
        #     style="AO type",
        #     style_order=["Reg-AO", "Obj-AO"],
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_ylabel("Method")#, fontsize=20)
        axd[plot_key].set_xlabel("MSE (all classes)")#, fontsize=20)

        # xaxis
        # axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
        # axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))            
        # axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        # axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        # axd[plot_key].set_xlim(0, 40)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
#         axd[plot_key].set_ylim(0, 0.1)

        axd[plot_key].tick_params(axis='x', which='major', rotation=-90, labelsize=20, labelright=True)
        axd[plot_key].tick_params(axis='y', which='major', rotation=0, labelsize=10)  
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

#         leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  
            
    elif metric=="MSE_all":
        sns.lineplot(
            y="mse_all",
            x="num_subsets",
            #style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_title("Error")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Error")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3100)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 0.45)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

#         for line in leg.get_lines():
#             line.set_linewidth(3.0) 
            
    elif metric=="pearsonr_all":
        sns.lineplot(
            y="pearsonr_all",
            x="num_subsets",
            #style="antithetical",
            hue="method_type",
            #palette="tab10",
            errorbar=None, 
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Correlation")
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Pearson Correlation")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3200)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        leg.remove()
#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  
            
    elif metric=="spearmanr_all":
        sns.lineplot(
            y="spearmanr_all",
            x="num_subsets",
            #style="antithetical",
            hue="method_type",
            errorbar=None,
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Rank correlation")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Spearman Correlation")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3200)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
        leg.remove()
#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  

        
    elif metric=="sign_agreement_all":
        sns.lineplot(
            y="sign_agreement_all",
            x="num_subsets",
            #style="antithetical",
            hue="method_type",
            errorbar=None,
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Sign Agreement")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Sign Agreement")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3200)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
        leg.remove()            

#         for line in leg.get_lines():
#             line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"training_target_quality_shapley.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_quality_shapley.pdf", bbox_inches='tight')

In [None]:
get_ground_truth_metric_with_value??

In [None]:
len(banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"].keys())

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train"][0]["iters"]

In [None]:
metric_list_plot=[]
for metric in metric_list_value_banzhaf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'Banzhaf MSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'Banzhaf MSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False
            }
        )        
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        continue
        metric_temp.update(
            {"method_name": f'Banzhaf MSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False
            }
        ) 
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
        continue
        metric_temp.update(
            {"method_name": f'Banzhaf MSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False
            }
        )         
        
        
#     elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train" and\
#        metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train":
#         metric_temp.update(
#             {"method_name": f'Banzhaf MSR ({metric_temp["num_subsets"]})',
#              "method_type": 'BanzhafMSR',
#              "antithetical": False
#             }
#         )
        
#     elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train" and\
#        metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train":
#         metric_temp.update(
#             {"method_name": f'Banzhaf MSR ({metric_temp["num_subsets"]}, antithetical)',
#              "method_type": 'BanzhafMSR',
#              "antithetical": True
#             }
#         )            
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
metric_list_plot_df.groupby(["method_name","num_subsets",
                            "true_name", "estimated_name",
                            ])[["pearsonr_all_per_class", "spearmanr_all_per_class",
                               "pearsonr_all", "spearmanr_all"
                               ]].mean().reset_index().sort_values("num_subsets")

In [None]:
fig = plt.figure(figsize=(6, 25)
                )

box1 = gridspec.GridSpec(5, 1, hspace=0.4)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
    ax=plt.Subplot(fig, box1[idx1])
    fig.add_subplot(ax)

    plot_key=(idx1)
    axd[plot_key]=ax   
    
    
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):

    plot_key=(idx1)


    if metric=="each":
        sns.barplot(
            x="mse_all",
            y="method_name",
        #     hue="method",
        #     style="AO type",
        #     style_order=["Reg-AO", "Obj-AO"],
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_ylabel("Method", fontsize=20)
        axd[plot_key].set_xlabel("MSE (all classes)", fontsize=20)

        # xaxis
        # axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
        # axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))            
        # axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        # axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        # axd[plot_key].set_xlim(0, 40)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        # axd[plot_key].set_ylim(0, 0.1)

        axd[plot_key].tick_params(axis='x', which='major', rotation=-90, labelsize=20, labelright=True)
        axd[plot_key].tick_params(axis='y', which='major', rotation=0, labelsize=10)  
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

#         leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  
            
    elif metric=="MSE_all":
        sns.lineplot(
            y="mse_all",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_title("Error")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("MSE (all classes)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 5100)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        #axd[plot_key].set_ylim(0, 1.05)
        axd[plot_key].set_yscale("log")

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0) 
            
    elif metric=="pearsonr_all":
        sns.lineplot(
            y="pearsonr_all",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Correlation")
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Pearson corr. (all classes)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 5100)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)  
            
    elif metric=="spearmanr_all":
        sns.lineplot(
            y="spearmanr_all",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Rank correlation")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Spearman corr. (all classes)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 5100)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)  
            
    elif metric=="pearsonr_all_per_class":
        sns.lineplot(
            y="pearsonr_all_per_class",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Correlation")
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Pearson correlation (Per class)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 5100)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)  
            
            
    elif metric=="spearmanr_all_per_class":
        sns.lineplot(
            y="pearsonr_all_per_class",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Rank correlation")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Spearman corr. (Per class)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 5100)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"training_target_quality_banzhaf.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_quality_banzhaf.pdf", bbox_inches='tight')

In [None]:
metric_list_plot=[]
for metric in metric_list_value_lime:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'LIME ({metric_temp["num_subsets"]})',
             "method_type": 'LIME',
             "antithetical": False
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        continue
        metric_temp.update(
            {"method_name": f'LIME ({metric_temp["num_subsets"]})',
             "method_type": 'LIME',
             "antithetical": False
            }
        )        
                  
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
fig = plt.figure(figsize=(6, 25)
                )

box1 = gridspec.GridSpec(5, 1, hspace=0.4)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
    ax=plt.Subplot(fig, box1[idx1])
    fig.add_subplot(ax)

    plot_key=(idx1)
    axd[plot_key]=ax   
    
    
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):

    plot_key=(idx1)


    if metric=="each":
        sns.barplot(
            x="mse_all",
            y="method_name",
        #     hue="method",
        #     style="AO type",
        #     style_order=["Reg-AO", "Obj-AO"],
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_ylabel("Method", fontsize=20)
        axd[plot_key].set_xlabel("MSE (all classes)", fontsize=20)

        # xaxis
        # axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
        # axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))            
        # axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        # axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        # axd[plot_key].set_xlim(0, 40)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
        axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        # axd[plot_key].set_ylim(0, 0.1)

        axd[plot_key].tick_params(axis='x', which='major', rotation=-90, labelsize=20, labelright=True)
        axd[plot_key].tick_params(axis='y', which='major', rotation=0, labelsize=10)  
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

#         leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

#         for line in leg.get_lines():
#             line.set_linewidth(3.0)  
            
    elif metric=="MSE_all":
        sns.lineplot(
            y="mse_all",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        
        axd[plot_key].set_title("Error")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("MSE (all classes)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3400)

        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.1))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
        #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
        #axd[plot_key].set_ylim(0, 1.05)
        axd[plot_key].set_yscale("log")

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0) 
            
    elif metric=="pearsonr_all":
        sns.lineplot(
            y="pearsonr_all",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Correlation")
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Pearson corr. (all classes)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3400)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)  
            
    elif metric=="spearmanr_all":
        sns.lineplot(
            y="spearmanr_all",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Rank correlation")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Spearman corr. (all classes)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3400)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)  
            
    elif metric=="pearsonr_all_per_class":
        sns.lineplot(
            y="pearsonr_all_per_class",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Correlation")
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Pearson correlation (Per class)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3400)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)  
            
            
    elif metric=="spearmanr_all_per_class":
        sns.lineplot(
            y="pearsonr_all_per_class",
            x="num_subsets",
            style="antithetical",
            hue="method_type",
            #palette="tab10",
            marker='o',    
            markeredgecolor=None,
            #markersize=10,   
            #alpha=0.8,            
            #linewidth=3,
            data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
            ax=axd[plot_key]
        )
        
        axd[plot_key].set_title("Rank correlation")
        
        
        axd[plot_key].set_xlabel("# Samples / Point")#, fontsize=20)
        axd[plot_key].set_ylabel("Spearman corr. (Per class)")#, fontsize=20)

        # xaxis
        axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
        #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
        axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_xlim(0, 3400)

        # yaxis
        axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
        #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
        axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
        axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
        axd[plot_key].set_ylim(0, 1.05)

        axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
        axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
        # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        

        axd[plot_key].spines['right'].set_visible(False)
        axd[plot_key].spines['top'].set_visible(False) 

        leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

        for line in leg.get_lines():
            line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"training_target_quality_LIME.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_quality_LIME.pdf", bbox_inches='tight')

# Training curve

In [None]:
!ls logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_196/

In [None]:
148*16

In [None]:
!cat logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_196/checkpoint-296/trainer_state.json

In [None]:
# torch.save(metric_list, "logs/experiment_results/metric_list.pt")

In [None]:
pd.DataFrame(metric_list)["model_path"].value_counts()

In [None]:
metric_list_plot=[]
for metric in metric_list_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (KernelSHAP, {num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": split
            }
        )    
    
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (Permutation, {num_subsets})',
             "method_type": 'Permutation',
             "antithetical": False,
             "split": split
            }
        )      
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_antithetical_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (KernelSHAP, {num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": True,
             "split": split
            }
        )            
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (SGD-Shapley, {num_subsets})',
             "method_type": 'SGD-Shapley',
             "antithetical": True,
             "split": split
            }
        )             
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_32"] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Obj-AO',
             "method_type": 'Obj',
             "antithetical": False,
             "split": split
            }
        )   
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_32"] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Obj-AO',
             "method_type": 'Obj',
             "antithetical": True,
             "split": split
            }
        )           
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
metric_list_plot_df[metric_list_plot_df["antithetical"]]["model_path"].value_counts()

In [None]:
fig = plt.figure(figsize=(27, 30)
                )

box1 = gridspec.GridSpec(5, 1, hspace=0.4)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):

        plot_key=(metric, method_type)
        
        if metric=="MSE_all":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="mse_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("MSE (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) # labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0) 


        elif metric=="pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="pearsonr_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson corr. (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)              
            


        elif metric=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="spearmanr_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Spearman corr. (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)   
                
                
                  

        elif metric=="pearsonr_all_per_class":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="pearsonr_all_per_class",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson corr. (Per classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)   


        elif metric=="spearmanr_all_per_class":
            
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="spearmanr_all_per_class",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch")#, fontsize=20)
            axd[plot_key].set_ylabel("Spearman corr. (Per classes)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"training_curve_shapley_train.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_curve_shapley_train.pdf", bbox_inches='tight')

In [None]:
metric_list_plot=[]
for metric in metric_list_banzhaf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_global_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (BanzhafMSR, global {num_subsets})',
             "method_type": 'BanzhafMSR (global)',
             "antithetical": False,
             "split": split
            }
        )    
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_sqrt_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (BanzhafMSR, sqrt {num_subsets})',
             "method_type": 'BanzhafMSR (sqrt)',
             "antithetical": False,
             "split": split
            }
        )  
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_perinstance_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (BanzhafMSR, perinstance {num_subsets})',
             "method_type": 'BanzhafMSR (perinstance)',
             "antithetical": False,
             "split": split
            }
        ) 
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_perinstanceperclass_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (BanzhafMSR, perinstanceperclass {num_subsets})',
             "method_type": 'BanzhafMSR (perinstanceperclass)',
             "antithetical": False,
             "split": split
            }
        )         

    
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
metric_list_plot_df[["true_name", "model_path", "epoch", "sample_idx"]].duplicated().sum()

In [None]:
metric_list_plot_df["pearsonr_all"].max()

In [None]:
metric_list_plot_df["method_type"].value_counts()

In [None]:
#(4*9, 6*5)
#(4*7, 6*4)

fig = plt.figure(figsize=(35, 24)
                )

box1 = gridspec.GridSpec(5, 1, hspace=0.4)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
     
    for idx2, method_type in enumerate(["BanzhafMSR (sqrt)", "BanzhafMSR (global)", "BanzhafMSR (perinstance)", "BanzhafMSR (perinstanceperclass)"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
    for idx2, method_type in enumerate(["BanzhafMSR (sqrt)", "BanzhafMSR (global)", "BanzhafMSR (perinstance)", "BanzhafMSR (perinstanceperclass)"]):

        plot_key=(metric, method_type)
        
        if metric=="MSE_all":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="mse_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("MSE (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) # labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0) 


        elif metric=="pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="pearsonr_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson corr. (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)              
            


        elif metric=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="spearmanr_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Spearman corr. (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)   
                
                
                  

        elif metric=="pearsonr_all_per_class":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="pearsonr_all_per_class",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson corr. (Per classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)   


        elif metric=="spearmanr_all_per_class":
            
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="spearmanr_all_per_class",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch")#, fontsize=20)
            axd[plot_key].set_ylabel("Spearman corr. (Per classes)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"training_curve_banzhaf_train.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_curve_banzhaf_train.pdf", bbox_inches='tight')

In [None]:
metric_list_plot=[]
for metric in metric_list_lime:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_global_{num_subsets}" for num_subsets in [128, 256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (LIME, global {num_subsets})',
             "method_type": 'LIME (global)',
             "antithetical": False,
             "split": split
            }
        )    
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_sqrt_{num_subsets}" for num_subsets in [128, 256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (LIME, sqrt {num_subsets})',
             "method_type": 'LIME (sqrt)',
             "antithetical": False,
             "split": split
            }
        )    
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_perinstance_{num_subsets}" for num_subsets in [256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (LIME, perinstance {num_subsets})',
             "method_type": 'LIME (perinstance)',
             "antithetical": False,
             "split": split
            }
        )            

        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_perinstanceperclass_{num_subsets}" for num_subsets in [256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:

        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"model_name": f'Reg-AO (LIME, perinstanceperclass {num_subsets})',
             "method_type": 'LIME (perinstanceperclass)',
             "antithetical": False,
             "split": split
            }
        )            
        
     

    
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
#(4*9, 6*5)
#(4*7, 6*4)

fig = plt.figure(figsize=(35, 24)
                )

box1 = gridspec.GridSpec(5, 1, hspace=0.4)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
     
    for idx2, method_type in enumerate(["LIME (sqrt)", "LIME (global)", "LIME (perinstance)", "LIME (perinstanceperclass)"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 4, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "pearsonr_all_per_class", "spearmanr_all_per_class"
                              ]):
    for idx2, method_type in enumerate(["LIME (sqrt)", "LIME (global)", "LIME (perinstance)", "LIME (perinstanceperclass)"]):

        plot_key=(metric, method_type)
        
        if metric=="MSE_all":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="mse_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("MSE (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) # labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0) 


        elif metric=="pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="pearsonr_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson corr. (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)              
            


        elif metric=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="spearmanr_all",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Spearman corr. (all classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)   
                
                
                  

        elif metric=="pearsonr_all_per_class":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="pearsonr_all_per_class",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson corr. (Per classes)") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)   


        elif metric=="spearmanr_all_per_class":
            
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)]
            
            sns.lineplot(
                x="epoch",
                y="spearmanr_all_per_class",
                hue="model_name",
#                 style="antithetical",
#                 palette="tab10",
#                 alpha=0.8,            
#                 linewidth=3,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("Epoch")#, fontsize=20)
            axd[plot_key].set_ylabel("Spearman corr. (Per classes)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            for line in leg.get_lines():
                line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"training_curve_lime_train.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_curve_lime_train.pdf", bbox_inches='tight')

In [None]:
metric_list_plot_df_select

In [None]:
# fig, ax = plt.subplots(1,1, figsize=(10,5))

# axd={"main":ax}

# plot_key="main"


# def get_reg_type(x):
#     if "regression" in x:
#         return "regression"
#     elif "permutation" in x:
#         return "permutation"
#     else:
#         return "none"
    

# metric_df=pd.DataFrame(metric_list+metric_list_)
# metric_df["explainer"]=metric_df["explainer"].str.replace(
#     "Reg-AO (upfront, regression, 512, antithetical)",
#     "Reg-AO (upfront, regression, antithetical, 512)")\
#     .str.replace(
#     "Obj-AO (newsample, 32, antithetical)",
#     "Obj-AO (newsample, antithetical, 32)",)

# print(metric_df["explainer"].value_counts())


# metric_df["AO type"]=metric_df["explainer"].map(lambda x: x.split('(')[0].strip())
# metric_df["num_subsets"]=metric_df["explainer"].map(lambda x: int(x.split(',')[-1][:-1].strip()))
# metric_df["reg type"]=metric_df["explainer"].map(get_reg_type)

# metric_df=metric_df.sort_values(["AO type", "reg type", "num_subsets"], ascending=True)
# # metric_df=metric_df[metric_df["explainer"].str.contains("Obj-AO")]
# metric_df=metric_df[metric_df["explainer"].str.contains("permutation")]

# sns.lineplot(
#     x="epoch",
#     y="mse_target",
#     hue="explainer",
#     style="AO type",
#     style_order=["Reg-AO", "Obj-AO"],
#     palette="tab10",
#     linewidth=3,
#     data=metric_df,
#     ax=axd[plot_key]
# )


# axd[plot_key].set_ylabel("MSE", fontsize=20)
# axd[plot_key].set_xlabel("Epoch", fontsize=20)

          
# axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
# axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].xaxis.set_major_locator(MultipleLocator(10))
# axd[plot_key].xaxis.set_minor_locator(MultipleLocator(5))            
# axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
# axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

# axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
# axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

# axd[plot_key].spines['right'].set_visible(False)
# axd[plot_key].spines['top'].set_visible(False) 

# # axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
# # axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))    
# # axd[plot_key].set_ylim(0, 0.1)

# # axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.01))
# # axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
# axd[plot_key].set_xlim(0, 40)
# axd[plot_key].set_ylim(0, 0.030)

# leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

# for line in leg.get_lines():
#     line.set_linewidth(3.0)

# FLOPs

In [None]:
metric_list_value[0]

In [None]:
metric_list_ground_truth_flops=[]

In [None]:
for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
    metric_list_ground_truth_flops+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()
                                      
                                      ))

In [None]:
# flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
# flops_forward = 38_898_221_568 * 1 * metric_temp["epoch"] * num_train_sample
# flops_backward = 38_898_221_568 * 2 * metric_temp["epoch"] * num_train_sample    
# flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify

In [None]:
num_train_sample=9469

metric_list_plot=[]
for metric in metric_list_ground_truth_flops:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        
        flops_subset_eval = 17_563_067_904 * metric_temp["num_subsets"] * num_train_sample
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
             "flops": flops_subset_eval,
            }
        )
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)

for metric in metric_list_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
            
        flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
        flops_forward = 21_335_153_664 * 1 * metric_temp["epoch"] * num_train_sample
        flops_backward = 21_335_153_664 * 2 * metric_temp["epoch"] * num_train_sample    
        flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify
                

        metric_temp.update(
            {"method_name": f'{num_subsets}',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": split,
             "flops":  flops_subset_eval + flops_forward + flops_backward + flops_parameter_update,
            }
        )    
    
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
        flops_subset_eval = 17_563_067_904 * num_train_sample * num_subsets
        flops_forward = 21_335_153_664 * 1 * metric_temp["epoch"] * num_train_sample
        flops_backward = 21_335_153_664 * 2 * metric_temp["epoch"] * num_train_sample    
        flops_parameter_update = 104_730_000 * metric_temp["epoch"] * (num_train_sample//64) * (2+3+4+3+3+4) # need to verify            

        metric_temp.update(
            {"method_name": f'Reg-AO (Permutation, {num_subsets})',
             "method_type": 'Permutation',
             "antithetical": False,
             "split": split,
             "flops":  flops_subset_eval + flops_forward + flops_backward + flops_parameter_update,
            }
        )               
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue   
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue 
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        continue         
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot.append(metric_temp)    
    
    

metric_list_plot_df=pd.DataFrame(metric_list_plot)
metric_list_plot_df

In [None]:
pd.DataFrame(metric_list_shapley)["model_path"].value_counts()

In [None]:
metric_list_plot_df[["model_path", "antithetical"]].value_counts()

In [None]:
metric_list_plot_df_select

In [None]:
plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 

fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric, method_type)
        
        if metric=="MSE_all":
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&\
                                                           (metric_list_plot_df["method_type"]==method_type)&\
                                                           (metric_list_plot_df["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]

            

                        
            
            sns.lineplot(
                x="flops",
                y="mse_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],    
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            
            
           



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)   #labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Error')

            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)



        elif metric=="pearsonr_all":
            
            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&\
                                                           (metric_list_plot_df["method_type"]==method_type)&\
                                                           (metric_list_plot_df["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]

            sns.lineplot(
                x="flops",
                y="pearsonr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson corr.") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)            
            
            
         
            


        elif metric=="spearmanr_all":

            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&\
                                                           (metric_list_plot_df["method_type"]==method_type)&\
                                                           (metric_list_plot_df["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]

            sns.lineplot(
                x="flops",
                y="spearmanr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("FLOPs") #fontsize=20
            axd[plot_key].set_ylabel("Spearman corr.") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)#, labelsize=20)  
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')


            axd[plot_key].set_title('Rank Correlation')
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            axd[plot_key].get_legend().remove()
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)             

            
        elif metric=="sign_agreement_all":

            metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&\
                                                           (metric_list_plot_df["method_type"]==method_type)&\
                                                           (metric_list_plot_df["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]

            sns.lineplot(
                x="flops",
                y="sign_agreement_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )



            axd[plot_key].set_xlabel("FLOPs") #fontsize=20
            axd[plot_key].set_ylabel("Sign agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)#, labelsize=20)  
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')


            axd[plot_key].set_title('Sign agreement')
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='center', 
                                     bbox_to_anchor=(-3.0, -0.35, 3, 0),
                                     ncols=4,
                                     #bbox_to_anchor=(1.0, 0, 0.5, 1)
                                    )#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             import matplotlib.patches as mpatches
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             empty_handle = mpatches.Patch(color='none', label='Empty Label')
#             labels.append('')
#             leg=axd[plot_key].legend(handles=[empty_handle]+handles, labels=[""]+labels, 
#                                  loc='center', 
#                                  bbox_to_anchor=(-3.0, -0.35, 3, 0),
#                                  ncols=6,)              
            handles, labels = axd[plot_key].get_legend_handles_labels()
            
            leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '1024', '2048', '3072']], 
                                     labels=['512', '1024', '2048', '3072'], 
                                 loc='upper left', 
                                 bbox_to_anchor=(-2.3, -0.22, 3, 0),
                                 ncols=4,)               
            leg.set_title("Amortized (# Samples / Point)")
            axd[plot_key].add_artist(leg)
            # Adding the text to the left of the first legend
#             x_offset = -2.6  # Adjust this value as needed to position the text
#             y_offset = -0.25  # Adjust this value as needed for vertical positioning
#             axd[plot_key].text(x_offset, y_offset, "Amortized (# Samples / Point)", transform=axd[plot_key].transAxes, 
#                                verticalalignment='top', horizontalalignment='left')
            
            leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['KernelSHAP']], 
                                     labels=['KernelSHAP'], 
                                 loc='upper left', 
                                 bbox_to_anchor=(-0.8, -0.32, 3, 0),
                                 ncols=4,)              
            #
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

In [None]:
metric_list_plot_df_epoch=metric_list_plot_df

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute_epoch_appendix.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_compute_epoch_appendix.pdf", bbox_inches='tight')

In [None]:
shapley_loaded_dict.keys()

# Error from prediction vs Error from targets

In [None]:
# metric_df[metric_df["split"]=="train"].groupby(["method_type", "num_subsets"])[['sample_idx',  "num_subsets", 
#         'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
#        'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

In [None]:
# def prettify_metric_name(metric_name):
#     if metric_name=="mse_all":
#         return "MSE (all classes)"
#     elif metric_name=="pearsonr_all":
#         return "Pearson corr. (all classes)"
#     elif metric_name=="spearmanr_all":
#         return "Spearman corr. (all classes)"
#     elif metric_name=="pearsonr_all_per_class":
#         return "Pearson corr. (Per classes)"
#     elif metric_name=="spearmanr_all_per_class":
#         return "Spearman corr. (Per classes)"
#     else:
#         raise ValueError(metric_name)
        
def prettify_metric_name(metric_name):
    if metric_name=="mse_all":
        return "Error"
    elif metric_name=="pearsonr_all":
        return "Correlation"
    elif metric_name=="spearmanr_all":
        return "Rank Correlation"
    elif metric_name=="sign_agreement_all":
        return "Sign Agreement"    
    elif metric_name=="pearsonr_all_per_class":
        return "Pearson corr. (Per classes)"
    elif metric_name=="spearmanr_all_per_class":
        return "Spearman corr. (Per classes)"
    else:
        raise ValueError(metric_name)
         
        
def prettify_method_type(method_type):
    if method_type=="KernelSHAP":
        return "KernelSHAP"
    elif method_type=="Permutation":
        return "Permutation Sampling"
    elif method_type=="SGD-Shapley":
        return "SGD-Shapley"  
    elif method_type=="BanzhafMSR":
        return "Banzhaf (MSR)"  
    elif method_type=="LIME":
        return "LIME"      
    else:
        raise ValueError(method_type)

In [None]:
def prettify_transform_mode(transform_mode):
    if transform_mode=="global":
        return ""
    elif transform_mode=="sqrt":
        return "Sqrt"       
    elif transform_mode=="perinstance":
        return "Per-sample Norm."
    elif transform_mode=="perinstanceperclass":
        #return "Per-sample/Per-class Norm."
        return "Per-Label Scaling"
    else:
        raise ValueError(transform_mode)    

In [None]:
# metric_list_value_shapley_=[]
# for num_subsets in [512, 1024, 2048, 3072]:
#     metric_list_value_shapley_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
#                                        iters_ground_truth=999424, 
#                                        attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
#                                                  })
    
#     metric_list_value_shapley_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
#                                        iters_ground_truth=999424, 
#                                        attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
#                                                    "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train",
#                                                  })  

# metric_list_plot_=[]
# for metric in metric_list_value_shapley_:
#     metric_temp=copy.copy(metric)
    
#     if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train" and\
#        metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
#         metric_temp.update(
#             {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]})',
#              "method_type": 'KernelSHAP',
#              "antithetical": False
#             }
#         )
        
#     elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train" and\
#        metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
#         metric_temp.update(
#             {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]}, antithetical)',
#              "method_type": 'KernelSHAP',
#              "antithetical": True
#             }
#         ) 
#     metric_list_plot_.append(metric_temp)

In [None]:
metric_list_ground_truth_=[]

for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
    metric_list_ground_truth_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"].keys()
                                      
                                      ))    

for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
    metric_list_ground_truth_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"].keys()
                                      
                                      ))
    
    
    
    
# for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
#     metric_list_ground_truth_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train"], 
#                                        iters_ground_truth=999424, 
#                                        attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train",
#                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train",
#                                                  },
                                                                
#                                       ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train"].keys()).intersection(
#                                       shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train"].keys()
                                      
#                                       ))      
    

In [None]:
for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6]]:
    metric_list_ground_truth_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"].keys()
                                      
                                      ))  

In [None]:
for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6]]:
    metric_list_ground_truth_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"].keys()
                                      
                                      ))  

In [None]:
for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6]]:
    metric_list_ground_truth_+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train"].keys()
                                      
                                      ))  

In [None]:
metric_list_plot_reference_=[]
for metric in metric_list_ground_truth_:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        #continue
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": True,
             "split": "train",
            }
        )  
        
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        #continue
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": True,
             "split": "train",
            }
        )          
        
    #
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical_/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
            }
        ) 
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
            }
        )         
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_reference_.append(metric_temp)

In [None]:
pd.DataFrame(metric_list_plot_reference_).groupby(["true_name", "estimated_name", "num_subsets"])["mse_all"].mean().to_frame()

In [None]:
200k

In [None]:
pd.DataFrame(metric_list_plot_reference_temp).groupby(["true_name", "estimated_name", "num_subsets"])["mse_all"].mean()



In [None]:
metric_list_ground_truth_shapley=[]

# for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
#     metric_list_ground_truth+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"], 
#                                        iters_ground_truth=999424, 
#                                        attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
#                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
#                                                  },
                                                                
#                                       ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"].keys()).intersection(
#                                       shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train"].keys()
                                      
#                                       ))    

# for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
#     metric_list_ground_truth_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
#                                        iters_ground_truth=999424, 
#                                        attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"],
#                                        iters_calculated=num_subsets,
#                                        meta_info={"num_subsets": num_subsets,
#                                                   "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
#                                                   "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
#                                                  },
                                                                
#                                       ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
#                                       shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()
                                      
#                                       ))

for num_subsets in [512*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]+list(range(512*20, 512*20*20, 512*20)):
    metric_list_ground_truth_shapley+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                 },
                                                                
                                      ground_truth_key_select=set(shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()).intersection(
                                      shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"].keys()
                                      
                                      ))

In [None]:
metric_list_plot_reference=[]
for metric in metric_list_ground_truth_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        continue
        metric_temp.update(
            {"method_name": f'KernelSHAP',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train",
            }
        )        
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_reference.append(metric_temp)

metric_list_plot_target=[]
for metric in metric_list_value_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train"
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        continue
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'Permutation ({metric_temp["num_subsets"]})',
             "method_type": 'Permutation',
             "antithetical": False,
             "split": "train"         
            }
        )
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        continue
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_permutation_newsample_196/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train":
        continue
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'SGD-Shapley ({metric_temp["num_subsets"]})',
             "method_type": 'SGD-Shapley',
             "antithetical": True,
             "split": "train"         
            }
        )        
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)
    
metric_list_plot_explainer=[]    
for metric in metric_list_shapley:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}" for num_subsets in [512, 1024, 2048, 3072]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (KernelSHAP, {num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )    
    
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_upfront_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            

        metric_temp.update(
            {"method_name": f'Reg-AO (Permutation, {num_subsets})',
             "method_type": 'Permutation',
             "antithetical": False,
             "num_subsets":num_subsets,
             "split": split,
            }
        )   
        
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_regexplainer_permutation_newsample_{num_subsets}" for num_subsets in [196, 392, 588, 1176, 3136]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}" for num_subsets in [9986]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            

        metric_temp.update(
            {"method_name": f'Reg-AO (SGD-Shapley, {num_subsets})',
             "method_type": 'SGD-Shapley',
             "antithetical": True,
             "num_subsets":num_subsets,
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue 
        
    elif metric_temp["model_path"] in [f"logs/vitbase_imagenette_shapley_objexplainer_antithetical_newsample_{num_subsets}" for num_subsets in [32]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test",
                                   ]:
        continue         
        
    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)


metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df["is_best_checkpoint"][(metric_list_plot_explainer_df['method_type']=="SGD-Shapley")&
                              (metric_list_plot_explainer_df['epoch']==20)
                             ]="best"
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                                                           ]


metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
metric_list_plot_target_df_=metric_list_plot_target_df.copy()
metric_list_plot_target_df_["split"]="test"
idx_mapping=dict(zip(np.random.RandomState(seed=42).permutation(list(range(9469)))[:100],
list(range(100))))
metric_list_plot_target_df_["sample_idx"]=metric_list_plot_target_df_["sample_idx"].map(lambda x: idx_mapping[x])

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
# sdsd
# metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
# [['sample_idx',  "num_subsets", 
# 'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
# 'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
[['sample_idx', "num_subsets", 
'mse_target_explainer', 'mse_target_target',
'mse_nontarget_explainer', 'mse_nontarget_target',
'mse_all_explainer', 'mse_all_target',
'pearsonr_target_explainer', 'pearsonr_target_target', 
'pearsonr_all_explainer', 'pearsonr_all_target',
'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
'spearmanr_target_explainer', 'spearmanr_target_target',
'spearmanr_all_explainer', 'spearmanr_all_target',
'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target', 
"sign_agreement_all_explainer", "sign_agreement_all_target"]].mean()#.reset_index()

In [None]:
df_temp=pd.DataFrame(metric_list_plot_target)

In [None]:
# metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
# metric_list_plot_explainer_df["is_best_checkpoint"][(metric_list_plot_explainer_df['method_type']=="SGD-Shapley")&
#                               (metric_list_plot_explainer_df['epoch']==20)
#                              ]="best"
# metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
#                                                            ]



# metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
# metric_list_plot_target_df_=metric_list_plot_target_df.copy()
# metric_list_plot_target_df_["split"]="test"
# idx_mapping=dict(zip(np.random.RandomState(seed=42).permutation(list(range(9469)))[:100],
# list(range(100))))
# metric_list_plot_target_df_["sample_idx"]=metric_list_plot_target_df_["sample_idx"].map(lambda x: idx_mapping[x])
# metric_list_plot_target_df=pd.concat([metric_list_plot_target_df, metric_list_plot_target_df_])

# metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
#                           left_on=["method_type", "sample_idx", "num_subsets", "split"],
#                           right_on=["method_type", "sample_idx", "num_subsets", "split"],
#                           suffixes=('_explainer', '_target')
#                          )
# # sdsd

In [None]:
metric_list_plot_df[metric_list_plot_df["method_type"]==method_type]

In [None]:
metric_list_plot_df[metric_list_plot_df["method_type"]==method_type]#["method_name"].value_counts()

In [None]:
# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

fig = plt.figure(figsize=(12, 7)
                )

box1 = gridspec.GridSpec(1, 3, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation"]):
        box2 = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation"]):

        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_target_explainer', 'mse_target_target',
                                    'mse_nontarget_explainer', 'mse_nontarget_target',
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_target_explainer', 'pearsonr_target_target', 
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_target_explainer', 'spearmanr_target_target',
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target']].mean().reset_index()          

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                alpha=0.9,
                ax=axd[plot_key],
            )

            if method_type=="KernelSHAP":
                reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
                reference_df_idx_list=[]
    #             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
    #                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
    #                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
                for idx, row in reference_df.iterrows():
                    if row["num_subsets"] in [10240, 20480, 40960]:
                        reference_df_idx_list.append(idx)

                count=0
                for idx in sorted(list(set(reference_df_idx_list))):
                    row=reference_df.loc[idx]
                    axd[plot_key].vlines(ymin=0, ymax=1, 
                                         x=row[metric_name], linewidth=1, color=list(sns.color_palette("Set2"))[count],
                                         label=f'{row["method_name"]} {row["num_subsets"]}')
                    count+=1                

    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error from target")#, fontsize=20)
            axd[plot_key].set_xlabel("Error from prediction")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-3, right=1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-3, 1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            if method_type=="KernelSHAP":
                reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
                reference_df_idx_list=[]
    #             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
    #                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
    #                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
                for idx, row in reference_df.iterrows():
                    if row["num_subsets"] in [10240, 20480, 40960]:
                        reference_df_idx_list.append(idx)                
                count=0
                for idx in sorted(list(set(reference_df_idx_list))):
                    row=reference_df.loc[idx]
                    axd[plot_key].vlines(ymin=0, ymax=1, 
                                         x=row[metric_name], linewidth=1, color=list(sns.color_palette("Set2"))[count],
                                         label=f'{row["method_name"]} {row["num_subsets"]}')
                    count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson corr. with target")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson corr. with prediction")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            if method_type=="KernelSHAP":

                reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
                reference_df_idx_list=[]
    #             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
    #                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
    #                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
                for idx, row in reference_df.iterrows():
                    if row["num_subsets"] in [10240, 20480, 40960]:
                        reference_df_idx_list.append(idx)
                count=0
                for idx in sorted(list(set(reference_df_idx_list))):
                    row=reference_df.loc[idx]
                    axd[plot_key].vlines(ymin=0, ymax=1, 
                                         x=row[metric_name], linewidth=1, color=list(sns.color_palette("Set2"))[count],
                                         label=f'{row["method_name"]} {row["num_subsets"]}')
                    count+=1                

            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman corr. with target")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman corr. with prediction")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.2, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.set_title("# Samples / Point")   
            
            for legend_text in leg.get_texts():
                try:
                    int(legend_text.get_text())
                except:
                    legend_text.set_text(f"Reference ({legend_text.get_text()[-5:]})")         
                else:
                    legend_text.set_text(f"{int(legend_text.get_text())}")
                        

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_shapley.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_shapley.pdf", bbox_inches='tight')

In [None]:
metric_list_plot_explainer_df.columns

In [None]:
pd.DataFrame(metric_list_plot_reference)\
.groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    

In [None]:
pd.DataFrame(metric_list_plot_reference)#["method_name"]#.value_counts()

In [None]:
Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136)(22)

In [None]:
(0.8666666666666667*256, 0.5176470588235295*256, 0.3215686274509804*256)

In [None]:
import matplotlib as mpl

In [None]:
# fig = plt.figure(figsize=(4*(4.3), 3)
#                 )

# box1 = gridspec.GridSpec(1, 4, hspace=0.3)
# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
#                               ]):
#     ax=plt.Subplot(fig, box1[idx1])
#     fig.add_subplot(ax)

#     plot_key=(idx1)
#     axd[plot_key]=ax  

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)


axd={}
for idx1, metric in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):

    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax     



# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
#                               ]):

#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax    


for idx1, metric in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):

        plot_key=(metric, method_type)


        if metric=="each":
            sns.barplot(
                x="mse_all",
                y="method_name",
            #     hue="method",
            #     style="AO type",
            #     style_order=["Reg-AO", "Obj-AO"],
                #palette="tab10",
                #marker='o',    
                markeredgecolor=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=metric_list_plot_df[~metric_list_plot_df["method_name"].str.contains("newsample")],
                ax=axd[plot_key]
            )


            axd[plot_key].set_ylabel("Method")#, fontsize=20)
            axd[plot_key].set_xlabel("MSE (all classes)")#, fontsize=20)

            # xaxis
            # axd[plot_key].xaxis.set_major_locator(MultipleLocator(1))
            # axd[plot_key].xaxis.set_minor_locator(MultipleLocator(0.1))            
            # axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            # axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_xlim(0, 40)

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(1))
            axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
    #         axd[plot_key].set_ylim(0, 0.1)

            axd[plot_key].tick_params(axis='x', which='major', rotation=-90, labelsize=20, labelright=True)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0, labelsize=10)  
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

    #         leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(0.0, -1.2, 0.5, 1))

    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)  

        elif metric=="mse_all":
            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #palette="tab10",
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                





            axd[plot_key].set_title(f"Error")# ({prettify_method_type(method_type)})")


            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Error")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))  
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=0.5)#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))              
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0.0001, 0.015)
            axd[plot_key].set_yscale("log")


            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            


            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0) 

        elif metric=="pearsonr_all":

            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                  



            axd[plot_key].set_title(f"Correlation")# ({prettify_method_type(method_type)})")

            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Pearson Correlation")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100)) 
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            # yaxis
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))

            leg.remove()
    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)  

        elif metric=="spearmanr_all":

            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                 


            axd[plot_key].set_title(f"Rank correlation")# ({prettify_method_type(method_type)})")


            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Spearman Correlation")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100)) 
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            # yaxis
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)  


        elif metric=="sign_agreement_all":

            sns.lineplot(
                y=metric,
                x="num_subsets",
                #style="antithetical",
                hue="method_type",
                palette=[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)],
                #marker='o',    
                markeredgecolor=None,
                errorbar=None,
                #markersize=10,   
                #alpha=0.8,            
                #linewidth=3,
                data=pd.DataFrame(metric_list_plot_reference),
                ax=axd[plot_key]
            )

            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0
            for num_subsets in metric_list_plot_explainer_df_mean.index:
                row=metric_list_plot_explainer_df_mean.loc[num_subsets]
                #print(row)
                axd[plot_key].hlines(xmin=0, xmax=1000000, 
                                     y=row[metric], linewidth=2, color=Blue_scalar_color_mapping(num_subsets, color_map=plt.cm.Blues, data_min=-500, data_max=3136),
                                     label=f'{num_subsets}')
                count+=1                 


            axd[plot_key].set_title(f"Sign Agreement")# ({prettify_method_type(method_type)})")#


            axd[plot_key].set_xlabel("# Samples / Point (KernelSHAP)")#, fontsize=20)
            axd[plot_key].set_ylabel("Sign Agreement")#, fontsize=20)

            # xaxis
            axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100)) 
            axd[plot_key].get_xaxis().set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))                        
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0, 110001)

            # yaxis
            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=0.8, alpha=0.6)
            axd[plot_key].yaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(0, 1.05)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  # , labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best')#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()            

    #         for line in leg.get_lines():
    #             line.set_linewidth(3.0)             

In [None]:
fig.savefig("logs/plots/"+f"shapley_prediction_contextualize.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_prediction_contextualize.pdf", bbox_inches='tight')

In [None]:
fig

In [None]:
            metric_list_plot_explainer_df_mean=metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]\
            .groupby(["method_name", "num_subsets"])[metric].mean().reset_index().sort_values("num_subsets").set_index("num_subsets")
            count=0

In [None]:
metric_list_plot_explainer_df[metric_list_plot_explainer_df["method_type"]==method_type]

In [None]:
row

In [None]:
metric_list_plot_explainer_df_mean

In [None]:
metric_list_plot_explainer_df

In [None]:
metric_list_.iloc[0]

In [None]:
method_name_target

In [None]:
metric_list_plot_df[metric_list_plot_df["method_type"]==method_type]\
.groupby(["method_name", "num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")

In [None]:
metric_list_plot_df[metric_list_plot_df["method_type"]==method_type]\
.columns

In [None]:
metric_list_plot_df

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):

        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_target_explainer', 'mse_target_target',
                                    'mse_nontarget_explainer', 'mse_nontarget_target',
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_target_explainer', 'pearsonr_target_target', 
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_target_explainer', 'spearmanr_target_target',
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()          

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                

    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-3, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-3, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            
            
            leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #leg.set_title("# Samples / Point")   
            
            for legend_text in leg.get_texts():
                try:
                    int(legend_text.get_text())
                except:
                    legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                else:
                    legend_text.set_text(f"{int(legend_text.get_text())}")
                        

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_appendix.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_appendix.pdf", bbox_inches='tight')

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation", "SGD-Shapley"]):

        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_target_explainer', 'mse_target_target',
                                    'mse_nontarget_explainer', 'mse_nontarget_target',
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_target_explainer', 'pearsonr_target_target', 
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_target_explainer', 'spearmanr_target_target',
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()          

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                

    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-3, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-3, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            
            
            leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #leg.set_title("# Samples / Point")   
            
            for legend_text in leg.get_texts():
                try:
                    int(legend_text.get_text())
                except:
                    legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                else:
                    legend_text.set_text(f"{int(legend_text.get_text())}")
                        

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_external_appendix.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_external_appendix.pdf", bbox_inches='tight')

In [None]:
import matplotlib.colors as mcolors
def Blue_scalar_color_mapping(value, color_map, data_min, data_max):    
    norm = mcolors.Normalize(vmin=data_min, vmax=data_max)
    scalar_map = plt.cm.ScalarMappable(norm=norm, cmap=color_map)
    return scalar_map.to_rgba(value)

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

# fig = plt.figure(figsize=(4*4 + 3*0.3, 3*4 + 2*0.8)
#                 )
fig = plt.figure(figsize=(4*4 + 3*0.2, 1*3)
                )

box1 = gridspec.GridSpec(1, 4, wspace=0.3)

axd={}
for idx1, method_type in enumerate(["KernelSHAP", "Permutation"]):
     
    for idx2, metric_name in enumerate(["mse_all", "pearsonr_all"]):
        #box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1])#, wspace=0.2, hspace=0.1)    

        ax=plt.Subplot(fig, box1[idx1*2+idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all",
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP", "Permutation"]):

        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_target_explainer', 'mse_target_target',
                                    'mse_nontarget_explainer', 'mse_nontarget_target',
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_target_explainer', 'pearsonr_target_target', 
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_target_explainer', 'spearmanr_target_target',
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()          

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                #palette=[sns.color_palette("Blues")[i] for i in [1,2,3,4,5]],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                

    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-3, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-3, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', 
                                     bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
            
            leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #leg.set_title("# Samples / Point")   
            
            for legend_text in leg.get_texts():
                try:
                    int(legend_text.get_text())
                except:
                    legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                else:
                    legend_text.set_text(f"{int(legend_text.get_text())}")
                        

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            axd[plot_key].set_title(f"{prettify_method_type(method_type)}")#, fontsize=20)
                        


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                #palette=[sns.color_palette("Blues")[i] for i in [1,2,3,4,5]],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(f"{prettify_method_type(method_type)}")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Correlation with target")#, fontsize=20)
            axd[plot_key].set_xlabel("Correlation with prediction")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_method_type(method_type)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Correlation with target")#, fontsize=20)
            axd[plot_key].set_xlabel("Correlation with prediction")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_two.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_shapley_two.pdf", bbox_inches='tight')

In [None]:
len(banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"][3126]\
["iters"])

In [None]:
banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"][3126]\
["iters"]

In [None]:
banzhaf_loaded_dict.keys()

In [None]:
metric_list_ground_truth_banzahf=[]

for num_subsets in [500*i for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 40, 80, 100, 200, 400, 800, 1000]]:
    metric_list_ground_truth_banzahf+=get_ground_truth_metric_with_value(attribution_values_ground_truth=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=banzhaf_loaded_dict["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                                 },
                                      )    

In [None]:
metric_list_plot_reference=[]
for metric in metric_list_ground_truth_banzahf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'BanzhafMSR',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "train",
            }
        )

    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_reference.append(metric_temp)

    

metric_list_plot_target=[]
for metric in metric_list_value_banzhaf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'BanzhafMSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "train"
            }
        )  
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_short/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'BanzhafMSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "train"
            }
        )          
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
        metric_temp.update(
            {"method_name": f'BanzhafMSR ({metric_temp["num_subsets"]})',
             "method_type": 'BanzhafMSR',
             "antithetical": False,
             "split": "test"
            }
        )          
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
        continue        
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)
    
metric_list_plot_explainer=[]    
for metric in metric_list_banzhaf:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_global_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, global, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "global",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_sqrt_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])

        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, sqrt, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "sqrt",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_perinstance_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, perinstance, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "perinstance",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )   
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_perinstanceperclass_{num_subsets}" for num_subsets in [10, 100, 500]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (BanzhafMSR, perinstanceperclass, {num_subsets})',
             "method_type": 'BanzhafMSR',
             "transform_mode": "perinstanceperclass",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )           

    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)


metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                                                           ]


# metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
# metric_list_plot_target_df_=metric_list_plot_target_df.copy()
# metric_list_plot_target_df_["split"]="test"
# metric_list_plot_target_df=pd.concat([metric_list_plot_target_df, metric_list_plot_target_df_])

metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
metric_list_plot_target_df_=metric_list_plot_target_df.copy()
metric_list_plot_target_df_["split"]="test"
idx_mapping=dict(zip(np.random.RandomState(seed=42).permutation(list(range(9469)))[:100],
list(range(100))))
metric_list_plot_target_df_["sample_idx"]=metric_list_plot_target_df_["sample_idx"].map(lambda x: idx_mapping[x])
metric_list_plot_target_df=pd.concat([metric_list_plot_target_df, metric_list_plot_target_df_])

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
# sdsd
# metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
# [['sample_idx',  "num_subsets", 
# 'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
# 'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "transform_mode" , "num_subsets"])\
[['sample_idx', "num_subsets", 
'mse_all_explainer', 'mse_all_target',
'pearsonr_all_explainer', 'pearsonr_all_target',
'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
'spearmanr_all_explainer', 'spearmanr_all_target',
'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target']].mean()#.reset_index()

In [None]:
pd.DataFrame(metric_list_plot_explainer_df)["split"].value_counts()

In [None]:
metric_list_plot_target_df["split"].value_counts()

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=10.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 10.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()

                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.set_title("# Samples / Point")   
            leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf.pdf", bbox_inches='tight')

In [None]:
metric_list_plot_df.columns

In [None]:
metric_list_plot_df["split"].value_counts()

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=10.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 10.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()

                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-200, data_max=500) 
                         for i in metric_list_plot_df_select["num_subsets"]],                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.set_title("# Samples / Point")   
            leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf_external.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_banzhaf_external.pdf", bbox_inches='tight')

In [None]:
metric_list_ground_truth_lime=[]

for num_subsets in [100000, 200000, 300000, 400000, 500000, 600000, 700000, 800000, 900000, 1000000]:
    metric_list_ground_truth_lime+=get_ground_truth_metric_with_value(attribution_values_ground_truth=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"], 
                                       iters_ground_truth=1000000, 
                                       attribution_values_calculated=lime_loaded_dict["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train"],
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                                 },
                                      )    

In [None]:
metric_list_plot_reference=[]
for metric in metric_list_ground_truth_lime:
    metric_temp=copy.copy(metric)
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        #continue
        
        metric_temp.update(
            {"method_name": f'LIME',
             "method_type": 'LIME',
             "antithetical": False,
             "split": "train",
            }
        )

    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_reference.append(metric_temp)

    

metric_list_plot_target=[]
for metric in metric_list_value_lime:
    metric_temp=copy.copy(metric)
    
    
    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_binomial_eval_train/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'LIME ({metric_temp["num_subsets"]})',
             "method_type": 'LIME',
             "antithetical": False,
             "split": "train"
            }
        )   
        
    elif metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
        metric_temp.update(
            {"method_name": f'LIME ({metric_temp["num_subsets"]})',
             "method_type": 'LIME',
             "antithetical": False,
             "split": "train"
            }
        )           
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)
    
metric_list_plot_explainer=[]    
for metric in metric_list_lime:
    metric_temp=copy.copy(metric)

    
    if metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_global_{num_subsets}" for num_subsets in [128, 256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, global, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "global",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )  
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_sqrt_{num_subsets}" for num_subsets in [128, 256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, sqrt, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "sqrt",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )   
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_perinstance_{num_subsets}" for num_subsets in [256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, perinstance, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "perinstance",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )     
        
    elif metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_lime_regexplainer_upfront_perinstanceperclass_{num_subsets}" for num_subsets in [256, 512]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test",
                                   ]:
        num_subsets=int(metric_temp["model_path"].split('_')[-1])
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_train_regression_long/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_lime_eval_test_regression_long/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()

        metric_temp.update(
            {"method_name": f'Reg-AO (LIME, perinstanceperclass, {num_subsets})',
             "method_type": 'LIME',
             "transform_mode": "perinstanceperclass",
             "antithetical": False,
             "num_subsets":num_subsets,             
             "split": split,
            }
        )          

    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)


metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                                                           ]

# metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)

metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)
metric_list_plot_target_df_=metric_list_plot_target_df.copy()
metric_list_plot_target_df_["split"]="test"
idx_mapping=dict(zip(np.random.RandomState(seed=42).permutation(list(range(9469)))[:100],
list(range(100))))
metric_list_plot_target_df_["sample_idx"]=metric_list_plot_target_df_["sample_idx"].map(lambda x: idx_mapping[x])
metric_list_plot_target_df=pd.concat([metric_list_plot_target_df, metric_list_plot_target_df_])

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
# sdsd
# metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets"])\
# [['sample_idx',  "num_subsets", 
# 'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
# 'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "transform_mode" , "num_subsets"])\
[['sample_idx', "num_subsets", 
'mse_all_explainer', 'mse_all_target',
'pearsonr_all_explainer', 'pearsonr_all_target',
'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
'spearmanr_all_explainer', 'spearmanr_all_target',
'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target']].mean()#.reset_index()

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="train")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:                
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()                
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_lime.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_lime.pdf", bbox_inches='tight')

In [None]:
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

# fig = plt.figure(figsize=(5, 15)
#                 )

# box1 = gridspec.GridSpec(3, 1, hspace=0.3)

# axd={}
# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
     
#     for idx2, method_type in enumerate(["KernelSHAP"]):
#         box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

#         ax=plt.Subplot(fig, box2[idx2])
#         fig.add_subplot(ax)

#         plot_key=(metric, method_type)
#         axd[plot_key]=ax          
        

# for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", 
#                               ]):
#     for idx2, method_type in enumerate(["KernelSHAP"]):

#         plot_key=(metric, method_type)
        

fig = plt.figure(figsize=(4*4+3*0.3, 3*4+2*0.8)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):
        box2 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.4)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, transform_mode)
        axd[plot_key]=ax          
        

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all",
                              ]):
    for idx2, transform_mode in enumerate(["global", "perinstanceperclass"]):       
        
        metric_list_plot_df_select=metric_list_plot_df[(metric_list_plot_df["split"]=="test")&(metric_list_plot_df["transform_mode"]==transform_mode)].groupby(["method_type", "num_subsets"])\
                                    [['sample_idx', # "num_subsets", 
                                    'mse_all_explainer', 'mse_all_target',
                                    'pearsonr_all_explainer', 'pearsonr_all_target',
                                    'pearsonr_all_per_class_explainer', 'pearsonr_all_per_class_target', 
                                    'spearmanr_all_explainer', 'spearmanr_all_target',
                                    'spearmanr_all_per_class_explainer', 'spearmanr_all_per_class_target',
                                     "sign_agreement_all_explainer", "sign_agreement_all_target"]].mean().reset_index()                  

        plot_key=(metric_name, transform_mode)
        
        if metric_name=="mse_all":
            
            
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                                
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            

            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                

    
            axd[plot_key].plot(np.linspace(0,10,100), np.linspace(0,10,100), color="grey", alpha=0.5, linestyle='--')
        
        
        


            axd[plot_key].set_ylabel("Error (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Error (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor')#, linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(left=1e-5, right=1.1)
            axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(1e-5, 1.1)
            axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            
            axd[plot_key].get_legend().remove()

            for line in leg.get_lines():
                line.set_linewidth(3.0) 
                
            if idx2==0:                
                leg=axd[plot_key].legend(loc='lower right', bbox_to_anchor=(0.97, 0.03))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
                #leg.set_title("# Samples / Point")   

                for legend_text in leg.get_texts():
                    try:
                        int(legend_text.get_text())
                    except:
                        legend_text.set_text(f"{prettify_method_type(method_type)} ({legend_text.get_text()[-5:]})")         
                    else:
                        legend_text.set_text(f"{int(legend_text.get_text())}")


    #             for line in leg.get_lines():
    #                 line.set_linewidth(3.0) 
                #leg.remove()                
                
            #axd[plot_key].set_title(prettify_method_type(method_type)+ " - " + prettify_metric_name(metric_name))#, fontsize=20)
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)


        elif metric_name=="pearsonr_all":
            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
            reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
            reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)                
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                              
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Pearson Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Pearson Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            #axd[plot_key].text(x=1.1, y=1.1, s=method_type,  ha='center')
            


        elif metric_name=="spearmanr_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Spearman Corr. (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Spearman Corr. (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 
            
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            
            
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
        elif metric_name=="sign_agreement_all":

            sns.scatterplot(
                x=metric_name+"_explainer", 
                y=metric_name+"_target", 
                hue="num_subsets",
                data=metric_list_plot_df_select,
                s=200,
                #palette=[i["color"] for i in list(plt.rcParams['axes.prop_cycle'])],
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=0, data_max=512) 
                         for i in metric_list_plot_df_select["num_subsets"]],                 
                alpha=0.9,
                ax=axd[plot_key],
            )
            
#             reference_df=pd.DataFrame(metric_list_plot_reference).groupby(["method_name","num_subsets"])[metric_name].mean().reset_index().sort_values("num_subsets")    
#             reference_df_idx_list=[]
#             for i, metric_value in metric_list_plot_df_select[metric_name+"_explainer"].items():
#                 reference_df_idx_list.append((reference_df[metric_name]-metric_value)[(reference_df[metric_name]-metric_value)>0].idxmin())
#                 reference_df_idx_list.append((metric_value-reference_df[metric_name])[(metric_value-reference_df[metric_name])>0].idxmin())   
#             for idx, row in reference_df.iterrows():
#                 if row["num_subsets"] in [10240, 20480, 40960]:
#                     reference_df_idx_list.append(idx)
#             count=0
#             for idx in sorted(list(set(reference_df_idx_list))):
#                 row=reference_df.loc[idx]
#                 axd[plot_key].vlines(ymin=0, ymax=1, 
#                                      x=row[metric_name], linewidth=2, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][count],
#                                      label=f'{row["method_name"]} {row["num_subsets"]}')
#                 count+=1                
    
            axd[plot_key].plot(np.linspace(0,1,100), np.linspace(0,1,100), color="grey", alpha=0.5, linestyle='--')


            axd[plot_key].set_ylabel("Sign Agreement (Label)")#, fontsize=20)
            axd[plot_key].set_xlabel("Sign Agreement (Prediction)")#, fontsize=20)

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major')#, linewidth=2, alpha=0.6)
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_xlim(0,1.01)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major')#, linewidth=2, alpha=0.6
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            #axd[plot_key].set_yscale("log")
            axd[plot_key].set_ylim(0,1.01)

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False)             

            leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.1, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
            #leg.remove()
    
            if prettify_transform_mode(transform_mode)=="":
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)}")#, fontsize=20)
            else:
                axd[plot_key].set_title(f"{prettify_metric_name(metric_name)} ({prettify_transform_mode(transform_mode)})")#, fontsize=20)
            
            

In [None]:
fig.savefig("logs/plots/"+f"training_target_prediction_lime_external.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"training_target_prediction_lime_external.pdf", bbox_inches='tight')

In [None]:
temp_df=pd.DataFrame(metric_list_plot_explainer)

In [None]:
temp_df[temp_df["is_best_checkpoint"]=="best"][["model_path", "epoch"]].value_counts()

In [None]:
with open("logs/vitbase_imagenette_shapley_regexplainer_upfront_512"+"/trainer_state.json") as f:
    trainer_state = json.load(f)
print(trainer_state["best_model_checkpoint"])

In [None]:
# for num_subsets, checkpoint_path in {
#     512: "logs/vitbase_imagenette_shapley_regexplainer_upfront_512/checkpoint-888",
#     1024: "logs/vitbase_imagenette_shapley_regexplainer_upfront_1024/checkpoint-1036",
#     1536: "logs/vitbase_imagenette_shapley_regexplainer_upfront_1536/checkpoint-1480",
#     2048: "logs/vitbase_imagenette_shapley_regexplainer_upfront_2048/checkpoint-1480",
#     3072: "logs/vitbase_imagenette_shapley_regexplainer_upfront_3072/checkpoint-1924",
# }.items():
    

#     with open(checkpoint_path+"/trainer_state.json") as f:
#         trainer_state = json.load(f)
#     num_epoch=int(trainer_state["epoch"])
    
#     print(checkpoint_path, num_epoch)

# compute-match

In [None]:
shapley_loaded_dict.keys()

In [None]:
metric_list_flops=[]

for num_train in [100, 250, 500, 1000, 2000, 5000]:

    model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_upfront_2257_numtrain_{num_train}"

    checkpoint_path_list=sorted(glob.glob(model_path+f"/checkpoint-{int(get_best_model_checkpoint(model_path).split('-')[-1])}"), key=lambda x: int(x.split('-')[-1]))

    
    for checkpoint_path in tqdm(checkpoint_path_list[:50]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)

        metric_list_flops+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["test"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })


        metric_list_flops+=get_ground_truth_metric_with_explainer(attribution_values=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                explainer=regexplainer,
                                dataset=dataset_explainer["train"],
                                iters_ground_truth=999424,
                                meta_info={
                                           "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                           "model_path": model_path,
                                           "epoch": int(checkpoint_trainer_state["epoch"]),
                                           "is_best_checkpoint": 
                                            compare_checkpoint_value(current_checkpoint=int(checkpoint_path.split('-')[-1]), 
                                                                     best_checkpoint=int(get_best_model_checkpoint(model_path).split('-')[-1]))
                                          })        

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440"]\
=load_attribution("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train",
             target_subset_size=2440,
                  attribution_name="shapley",
sample_select=np.random.RandomState(seed=42).permutation(list(range(9469)))[:100])

In [None]:
metric_list_value_flops_matched=[]
num_subsets=2440
for i in range(1):
    shapley_loaded_dict_temp={}
    for sample_idx, tracking_dict in shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440"].items():
        shapley_loaded_dict_temp[sample_idx]=tracking_dict[i]

    metric_list_value_flops_matched+=get_ground_truth_metric_with_value(attribution_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
                                       iters_ground_truth=999424, 
                                       attribution_values_calculated=shapley_loaded_dict_temp,
                                       iters_calculated=num_subsets,
                                       meta_info={"num_subsets": num_subsets,
                                                  "true_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                                  "estimated_name": "logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440",
                                                 }) 

In [None]:
metric_list_plot_target=[]
for metric in metric_list_value_flops_matched:
    metric_temp=copy.copy(metric)

    if metric_temp["estimated_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440" and\
       metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
        metric_temp.update(
            {"method_name": f'KernelSHAP ({metric_temp["num_subsets"]})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "split": "train"
            }
        )    
        
    else:
        print(metric_temp)
        raise RuntimError()
        
    metric_list_plot_target.append(metric_temp)

In [None]:
pd.DataFrame(metric_list_value_flops)["estimated_name"].value_counts()

In [None]:
pd.DataFrame(metric_list_flops)["model_path"].value_counts()

In [None]:
metric_list_plot_explainer=[]    
for metric in metric_list_flops:
    metric_temp=copy.copy(metric)
    
    if metric_temp["model_path"] in [f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_shapley_regexplainer_upfront_2257_numtrain_{num_train}" for num_train in [100, 250, 500, 1000, 2000, 5000]] and\
       metric_temp["true_name"] in ["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train",
                                    "logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test",
                                   ]:
        num_train=int(metric_temp["model_path"].split('_')[-1])
        if num_train<250:
            continue
        
        num_subsets=int(metric_temp["model_path"].split('_')[-3])
        num_subsets_per_sample={2257:2440}[num_subsets]
        
        if metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train":
            split="train"
        elif metric_temp["true_name"]=="logs/vitbase_imagenette_surrogate_shapley_eval_test_regression/extract_output/test":
            split="test"
        else:
            print(metric_temp)
            raise RuntimError()
            
                

        metric_temp.update(
            {"method_name": f'Reg-AO (KernelSHAP, {num_subsets})',
             "method_type": 'KernelSHAP',
             "antithetical": False,
             "num_train":num_train, 
             "num_subsets": num_subsets,
             "num_subsets_per_sample":num_subsets_per_sample,
             "split": split,
            }
        )    
        
        metric_list_plot.append(metric_temp)
        
    else:
        print(metric_temp)
        raise RuntimError()        
        
    metric_list_plot_explainer.append(metric_temp)

In [None]:
metric_list_plot_explainer_df=pd.DataFrame(metric_list_plot_explainer)
metric_list_plot_explainer_df=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")|
                                                            (metric_list_plot_explainer_df["method_type"]=="KernelSHAP")
                                                           
                                                           ]

metric_list_plot_target_df=pd.DataFrame(metric_list_plot_target)

metric_list_plot_df=metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets_per_sample", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )
metric_list_plot_df[metric_list_plot_df["split"]=="train"].groupby(["method_type", "num_subsets_per_sample", "num_train"])[[\
        'sample_idx',  "num_subsets_per_sample", 
        'mse_target_explainer', 'mse_nontarget_explainer', 'mse_all_explainer',  
       'mse_target_target', 'mse_nontarget_target', 'mse_all_target']].mean().T

In [None]:
metric_list_plot_target_df["split"].value_counts()

In [None]:
metric_list_plot_explainer_df["split"].value_counts()

In [None]:
metric_list_plot_explainer_df.merge(right=metric_list_plot_target_df, 
                          left_on=["method_type", "sample_idx", "num_subsets_per_sample", "split"],
                          right_on=["method_type", "sample_idx", "num_subsets", "split"],
                          suffixes=('_explainer', '_target')
                         )["split"].value_counts()

In [None]:
metric_list_plot_df["split"].value_counts()

In [None]:
metric_list_plot_df[metric_list_plot_df["method_name"]=="KernelSHAP"]["split"].value_counts()

In [None]:
metric_list_plot_explainer_df

In [None]:
metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")&(metric_list_plot_df["method_type"]==method_type)].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 


fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        
        
plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])          

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )            
            

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, .045)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0,  labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        
        
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name("mse_all"))#, fontsize=20)
        
            from matplotlib.lines import Line2D

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1, linestyle="-")]

            axd[plot_key].legend(custom_lines, ['Amortized', 'KernelSHAP'], 
                                 #ncols=2,
                                 loc='best', 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0, 1, 1)
                                )
            #bbox_to_anchor=(-0.8, -0.32, 3, 0),             


        elif metric_name=="pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()            
        

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )

            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)", ) #fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor',alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=2, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)


        elif metric_name=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Spearman Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
    
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)

        
        elif metric_name=="sign_agreement_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',
                                                    "sign_agreement_all"]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class',
                                    'sign_agreement_all']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Sign Agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor',  alpha=0.1) #linewidth=1,
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        

            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)        

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples.pdf", bbox_inches='tight')

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 


fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric_name, method_type)
        axd[plot_key]=ax          
        
        
plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
 (0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])          

for idx1, metric_name in enumerate(["mse_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric_name, method_type)
        
        if metric_name=="mse_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )            
            

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, .045)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0,  labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        
        
            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name("mse_all"))#, fontsize=20)
        
            from matplotlib.lines import Line2D

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1, linestyle="-")]

            axd[plot_key].legend(custom_lines, ['Amortized', 'KernelSHAP'], 
                                 #ncols=2,
                                 loc='best', 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0, 1, 1)
                                )
            #bbox_to_anchor=(-0.8, -0.32, 3, 0),             


        elif metric_name=="pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()            
        

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )

            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)", ) #fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor',alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=2, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)


        elif metric_name=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Spearman Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()
    
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)

        
        elif metric_name=="sign_agreement_all":
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="test")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',
                                                    "sign_agreement_all"]
                                                    ].mean().reset_index()

            sns.lineplot(
                x="num_train", 
                y=metric_name, 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            #leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            #handles_, labels_ = axd[plot_key].get_legend_handles_labels()
            #print(handles_)
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class',
                                    'sign_agreement_all']].mean().loc[metric_name], 
                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               


            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Sign Agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor',  alpha=0.1) #linewidth=1,
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) # labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

#             leg=axd[plot_key].legend(loc='center', bbox_to_anchor=(1.0, 0, 0.5, 1))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             handles.append(handles_[0])
#             labels.append(labels_[0])
#             leg._legend_box = None
#             leg._init_legend_box(handles, labels)
#             leg._set_loc(leg._loc)   
            
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name(metric_name))#, fontsize=20)
        
        

            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)        

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples_external.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_compute_trainsamples_external.pdf", bbox_inches='tight')

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                                                 for i in [512, 1024, 2048, 3072]]+[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])

fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["MSE_all", "pearsonr_all", "spearmanr_all", "sign_agreement_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric, method_type)
        
        if metric=="MSE_all":
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]

            sns.scatterplot(
            x="flops",
            y="mse_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["mse_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
#             sns.scatterplot(
#             x="flops",
#             y="mse_all",
#             hue="method_name",
#             hue_order=[
#                      '512',
#                      '1024',
#                      '2048',
#                      '3072',],      
#             data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["mse_all"].mean().sort_index().reset_index().iloc[0])
#                                                                      ).reset_index(),
#             ax=axd[plot_key],
#                 s=40,
#                 style=True,
#                 markers=["X"]

#             )     
#             axd[plot_key].get_legend().remove()

                        
            
            sns.lineplot(
                x="flops",
                y="mse_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072', "KernelSHAP"],    
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            
            
           



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(top=0.08, bottom=0.00008)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)   #labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Error')

            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)            
            
            leg=axd[plot_key].legend(loc='center', 
                                     bbox_to_anchor=(-3.0, -0.35, 3, 0),
                                     ncols=4,
                                     #bbox_to_anchor=(1.0, 0, 0.5, 1)
                                    )#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             import matplotlib.patches as mpatches
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             empty_handle = mpatches.Patch(color='none', label='Empty Label')
#             labels.append('')
#             leg=axd[plot_key].legend(handles=[empty_handle]+handles, labels=[""]+labels, 
#                                  loc='center', 
#                                  bbox_to_anchor=(-3.0, -0.35, 3, 0),
#                                  ncols=6,)              
            handles, labels = axd[plot_key].get_legend_handles_labels()
            
#             leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '1024', '2048', '3072']], 
#                                      labels=['512', '1024', '2048', '3072'], 
#                                  loc='upper left', 
#                                  bbox_to_anchor=(0,0,1,1),
#                                  ncols=1,) 
            #leg.set_title("Amortized (# Samples / Point)")
            axd[plot_key].add_artist(leg)
            # Adding the text to the left of the first legend
#             x_offset = -2.6  # Adjust this value as needed to position the text
#             y_offset = -0.25  # Adjust this value as needed for vertical positioning
#             axd[plot_key].text(x_offset, y_offset, "Amortized (# Samples / Point)", transform=axd[plot_key].transAxes, 
#                                verticalalignment='top', horizontalalignment='left')
            
            leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '2048', 'KernelSHAP', '1024','3072']],  #512 2048, KernelSHAP, 1024, 3072
                                     #labels=['512 Samples', '1024 Samples', '2048 Samples', '3072 Samples','KernelSHAP'], 
                                     labels=['512', '2048', 'KernelSHAP', '1024', '3072'], 
                                 loc='upper left', 
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3))
#                                  columnspacing=1,
#                                  #loc='best',
#                                  #bbox_to_anchor=(0, 0, 1, 1),                                     
#                                  ncols=2,)
            leg.remove()
        
        


            
            
            from matplotlib.lines import Line2D
            import matplotlib.lines as mlines
            from matplotlib.legend_handler import HandlerBase

            class CustomHandler(HandlerBase):
                def create_artists(self, legend, orig_handle, x0, y0, width, height, fontsize, trans):
                    if orig_handle.get_color()=="black":
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linestyle='-', color=orig_handle.get_color())                        
                        return [line_o]
                    else:
                        # Create a line with the 'o' marker
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linewidth=0.8,
                                               linestyle='-', color=orig_handle.get_color())
                        # Create a line with the 'X' marker
                        marker_o = mlines.Line2D([x0 + width], [y0 + height/2.], 
                                               linestyle='', color=orig_handle.get_color(), marker='o', markersize=5)
#                         marker_x = mlines.Line2D([x0], [y0 + height/2.], 
#                                                linestyle='', color=orig_handle.get_color(), marker='X', markersize=5)                    
                        return [line_o, marker_o]            

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][2], lw=1),
                            Line2D([0], [0], color='black', lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][3], lw=1),
                            ]

            axd[plot_key].legend(custom_lines, ['512', '2048','KernelSHAP', '1024', '3072', ], 
                                 ncols=2,
                                 loc='upper left', 
                                 handler_map={mlines.Line2D: CustomHandler()}, 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3),
                                 columnspacing=0.5,
                                )   
        
        
        
        elif metric=="pearsonr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="pearsonr_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
#             sns.scatterplot(
#             x="flops",
#             y="pearsonr_all",
#             hue="method_name",
#             hue_order=[
#                      '512',
#                      '1024',
#                      '2048',
#                      '3072',],      
#             data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[0])
#                                                                      ).reset_index(),
#             ax=axd[plot_key],
#                 s=40,
#                 style=True,
#                 markers=["X"]

#             )     
#             axd[plot_key].get_legend().remove()            

            sns.lineplot(
                x="flops",
                y="pearsonr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
            
            

            

        elif metric=="spearmanr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="spearmanr_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["spearmanr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
#             sns.scatterplot(
#             x="flops",
#             y="pearsonr_all",
#             hue="method_name",
#             hue_order=[
#                      '512',
#                      '1024',
#                      '2048',
#                      '3072',],      
#             data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[0])
#                                                                      ).reset_index(),
#             ax=axd[plot_key],
#                 s=40,
#                 style=True,
#                 markers=["X"]

#             )     
#             axd[plot_key].get_legend().remove()            

            sns.lineplot(
                x="flops",
                y="spearmanr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Spearman Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Rank Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
            
            
            
        elif metric=="sign_agreement_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="sign_agreement_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["sign_agreement_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
#             sns.scatterplot(
#             x="flops",
#             y="pearsonr_all",
#             hue="method_name",
#             hue_order=[
#                      '512',
#                      '1024',
#                      '2048',
#                      '3072',],      
#             data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[0])
#                                                                      ).reset_index(),
#             ax=axd[plot_key],
#                 s=40,
#                 style=True,
#                 markers=["X"]

#             )     
#             axd[plot_key].get_legend().remove()            

            sns.lineplot(
                x="flops",
                y="sign_agreement_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Sign Agreement") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Sign Agrement')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()            



            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

In [None]:
fig.savefig("logs/plots/"+f"flops_matched_appendix.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"flops_matched_appendix.pdf", bbox_inches='tight')

In [None]:
# plt.rcParams['legend.fancybox'] = False
# plt.rcParams['legend.edgecolor']='1.0'
# plt.rcParams['legend.framealpha']=1
plt.rcParams['legend.fancybox'] = True
plt.rcParams['legend.edgecolor']='0.8'
plt.rcParams['legend.framealpha']=0.8

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[(0.2980392156862745, 0.4470588235294118, 0.6901960784313725),
  (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
  (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
  (0.7686274509803922, 0.3058823529411765, 0.3215686274509804),
  (0,0,0)]) 
plt.rcParams['axes.prop_cycle']=plt.cycler(color=[sns.color_palette("Blues")[i] for i in [1,2,3,4]]+\
                                           #[(0,0,0)]
                                            [(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
                                          ) 

plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                                                 for i in [512, 1024, 2048, 3072]]+[(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)])

fig = plt.figure(figsize=(4*(4.3), 3)
                )

box1 = gridspec.GridSpec(1, 4, hspace=0.3)

axd={}
for idx1, metric in enumerate(["epoch_MSE_all", "epoch_pearsonr_all", "trainsamples_MSE_all", "trainsamples_pearsonr_all"
                              ]):
     
    for idx2, method_type in enumerate(["KernelSHAP"]):
        box2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=box1[idx1], wspace=0.8, hspace=0.0)    

        ax=plt.Subplot(fig, box2[idx2])
        fig.add_subplot(ax)

        plot_key=(metric, method_type)
        axd[plot_key]=ax          
        

for idx1, metric in enumerate(["epoch_MSE_all", "epoch_pearsonr_all", "trainsamples_MSE_all", "trainsamples_pearsonr_all"
                              ]):
    for idx2, method_type in enumerate(["KernelSHAP"]):

        plot_key=(metric, method_type)
        
        if metric=="epoch_MSE_all":
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before")=="before")\
                                                          ]

            sns.scatterplot(
            x="flops",
            y="mse_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["mse_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
#             sns.scatterplot(
#             x="flops",
#             y="mse_all",
#             hue="method_name",
#             hue_order=[
#                      '512',
#                      '1024',
#                      '2048',
#                      '3072',],      
#             data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["mse_all"].mean().sort_index().reset_index().iloc[0])
#                                                                      ).reset_index(),
#             ax=axd[plot_key],
#                 s=40,
#                 style=True,
#                 markers=["X"]

#             )     
#             axd[plot_key].get_legend().remove()

                        
            
            sns.lineplot(
                x="flops",
                y="mse_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072', "KernelSHAP"],    
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            
            
           



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) # linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].set_ylim(top=0.08, bottom=0.00008)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)   #labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Error')

            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)            
            
            leg=axd[plot_key].legend(loc='center', 
                                     bbox_to_anchor=(-3.0, -0.35, 3, 0),
                                     ncols=4,
                                     #bbox_to_anchor=(1.0, 0, 0.5, 1)
                                    )#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
#             import matplotlib.patches as mpatches
#             handles, labels = axd[plot_key].get_legend_handles_labels()
#             empty_handle = mpatches.Patch(color='none', label='Empty Label')
#             labels.append('')
#             leg=axd[plot_key].legend(handles=[empty_handle]+handles, labels=[""]+labels, 
#                                  loc='center', 
#                                  bbox_to_anchor=(-3.0, -0.35, 3, 0),
#                                  ncols=6,)              
            handles, labels = axd[plot_key].get_legend_handles_labels()
            
#             leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '1024', '2048', '3072']], 
#                                      labels=['512', '1024', '2048', '3072'], 
#                                  loc='upper left', 
#                                  bbox_to_anchor=(0,0,1,1),
#                                  ncols=1,) 
            #leg.set_title("Amortized (# Samples / Point)")
            axd[plot_key].add_artist(leg)
            # Adding the text to the left of the first legend
#             x_offset = -2.6  # Adjust this value as needed to position the text
#             y_offset = -0.25  # Adjust this value as needed for vertical positioning
#             axd[plot_key].text(x_offset, y_offset, "Amortized (# Samples / Point)", transform=axd[plot_key].transAxes, 
#                                verticalalignment='top', horizontalalignment='left')
            
            leg=axd[plot_key].legend(handles=[handles[labels.index(i)] for i in ['512', '2048', 'KernelSHAP', '1024','3072']],  #512 2048, KernelSHAP, 1024, 3072
                                     #labels=['512 Samples', '1024 Samples', '2048 Samples', '3072 Samples','KernelSHAP'], 
                                     labels=['512', '2048', 'KernelSHAP', '1024', '3072'], 
                                 loc='upper left', 
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3))
#                                  columnspacing=1,
#                                  #loc='best',
#                                  #bbox_to_anchor=(0, 0, 1, 1),                                     
#                                  ncols=2,)
            leg.remove()
        
        


            
            
            from matplotlib.lines import Line2D
            import matplotlib.lines as mlines
            from matplotlib.legend_handler import HandlerBase

            class CustomHandler(HandlerBase):
                def create_artists(self, legend, orig_handle, x0, y0, width, height, fontsize, trans):
                    if orig_handle.get_color()=="black":
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linestyle='-', color=orig_handle.get_color())                        
                        return [line_o]
                    else:
                        # Create a line with the 'o' marker
                        line_o = mlines.Line2D([x0, x0 + width], [y0 + height/2., y0 + height/2.], 
                                               linewidth=0.8,
                                               linestyle='-', color=orig_handle.get_color())
                        # Create a line with the 'X' marker
                        marker_o = mlines.Line2D([x0 + width], [y0 + height/2.], 
                                               linestyle='', color=orig_handle.get_color(), marker='o', markersize=5)
#                         marker_x = mlines.Line2D([x0], [y0 + height/2.], 
#                                                linestyle='', color=orig_handle.get_color(), marker='X', markersize=5)                    
                        return [line_o, marker_o]            

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][2], lw=1),
                            Line2D([0], [0], color='black', lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][3], lw=1),
                            ]

            axd[plot_key].legend(custom_lines, ['512', '2048','KernelSHAP', '1024', '3072', ], 
                                 ncols=2,
                                 loc='upper left', 
                                 handler_map={mlines.Line2D: CustomHandler()}, 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0.05, 0.5, 0.3),
                                )            

        elif metric=="epoch_pearsonr_all":
            
            metric_list_plot_df_select=metric_list_plot_df_epoch[(metric_list_plot_df_epoch["split"]=="train")&\
                                                           (metric_list_plot_df_epoch["method_type"]==method_type)&\
                                                           (metric_list_plot_df_epoch["is_best_checkpoint"].fillna("before").isin(["before", "best"]))\
                                                          ]
            
            
            sns.scatterplot(
            x="flops",
            y="pearsonr_all",
            hue="method_name",
            hue_order=[
                     '512',
                     '1024',
                     '2048',
                     '3072',
                        'KernelSHAP',],      
            data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[-1])
                                                                     ).reset_index(),
            ax=axd[plot_key],
                s=40,

            )     
            axd[plot_key].get_legend().remove()
            
#             sns.scatterplot(
#             x="flops",
#             y="pearsonr_all",
#             hue="method_name",
#             hue_order=[
#                      '512',
#                      '1024',
#                      '2048',
#                      '3072',],      
#             data=metric_list_plot_df_select.groupby(["method_name"]).apply(lambda x: (x.groupby("flops")["pearsonr_all"].mean().sort_index().reset_index().iloc[0])
#                                                                      ).reset_index(),
#             ax=axd[plot_key],
#                 s=40,
#                 style=True,
#                 markers=["X"]

#             )     
#             axd[plot_key].get_legend().remove()            

            sns.lineplot(
                x="flops",
                y="pearsonr_all",
                hue="method_name",
                hue_order=[
                         '512',
                         '1024',
                         '2048',
                         '3072',
                            'KernelSHAP',],                  
                #style="antithetical",
                #palette="tab10",
                errorbar=None,                
                alpha=0.8,            
                linewidth=1.5,
                data=metric_list_plot_df_select,
                ax=axd[plot_key]
            )
            



            axd[plot_key].set_xlabel("FLOPs") #, fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].ticklabel_format(axis='x',style='sci',useOffset=True)            
            axd[plot_key].set_xlim(1e+16, 2e+19)
            axd[plot_key].set_xscale('log')

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            #axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)
            axd[plot_key].tick_params(axis='y', which='major', rotation=0) #, labelsize=20
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale('log')
            
            axd[plot_key].set_title('Correlation')
            
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
                
            

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            #leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.5, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            axd[plot_key].get_legend().remove()



        elif metric=="trainsamples_MSE_all":
#             plt.rcParams['axes.prop_cycle']=plt.cycler(color=\
#                 [(0,0,1),
#                  (0,0,0)])        
            plt.rcParams['axes.prop_cycle']=plt.cycler(color=[sns.color_palette("Blues")[i] for i in [4]]+\
                                                       #[(0,0,0)]
                                                       [(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
                                                      ) 
    
            plt.rcParams['axes.prop_cycle']=plt.cycler(color=[Blue_scalar_color_mapping(2257, color_map=plt.cm.Blues, data_min=-500, data_max=3136)]+\
                                                       [(0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
                                                      )     
    
    
            
            
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class',]
                                                    ].mean().reset_index()
            sns.lineplot(
                x="num_train", 
                y="mse_all", 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0],
                #linewidth=3,
                #palette="Set2",
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )
            
            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc["mse_all"], 
                                 linestyle="-",
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )            
            

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #, fontsize=20
            axd[plot_key].set_ylabel("Error") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            #axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=1, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, .045)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0,  labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #, labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)

            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            leg=axd[plot_key].legend(loc='best', bbox_to_anchor=(1, 0, 0.3, 0.5))#, bbox_to_anchor=(0.0, -1.2, 0.5, 1))
            leg.remove()

#             for line in leg.get_lines():
#                 line.set_linewidth(3.0) 
                
            axd[plot_key].set_title(prettify_metric_name("mse_all"))#, fontsize=20)
        
            from matplotlib.lines import Line2D

            custom_lines = [Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0], lw=1),
                            Line2D([0], [0], color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1], lw=1, linestyle="-")]

            axd[plot_key].legend(custom_lines, ['Amortized', 'KernelSHAP'], 
                                 #ncols=2,
                                 loc='best', 
                                 #bbox_to_anchor=(-0.6, -1.3, 0.5, 1)
                                 bbox_to_anchor=(0, 0, 1, 1)
                                )
            #bbox_to_anchor=(-0.8, -0.32, 3, 0),           


        elif metric=="trainsamples_pearsonr_all":
            metric_list_plot_df_select=metric_list_plot_explainer_df[(metric_list_plot_explainer_df["split"]=="train")\
                              &(metric_list_plot_explainer_df["method_type"]==method_type)\
                              &(metric_list_plot_explainer_df["is_best_checkpoint"]=="best")
                             ].groupby(["method_type", "num_train"])\
                                                    [['sample_idx',
                                                    'mse_target',
                                                    'mse_nontarget',
                                                    'mse_all',
                                                    'pearsonr_target',
                                                    'pearsonr_all',
                                                    'pearsonr_all_per_class',
                                                    'spearmanr_target',
                                                    'spearmanr_all',
                                                    'spearmanr_all_per_class']
                                                    ].mean().reset_index()            
        

            sns.lineplot(
                x="num_train", 
                y="pearsonr_all", 
                # hue="num_subsets",
                marker='o', 
                markersize=6,  
                alpha=0.8,
                #linewidth=3,
                #palette="Set2",
                color=plt.rcParams['axes.prop_cycle'].by_key()['color'][0],
                data=metric_list_plot_df_select,
                ax=axd[plot_key],
            )

            
            axd[plot_key].hlines(xmin=0, xmax=5000, 
                                 y=metric_list_plot_target_df[['sample_idx',
                                    'mse_target',
                                    'mse_nontarget',
                                    'mse_all',
                                    'pearsonr_target',
                                    'pearsonr_all',
                                    'pearsonr_all_per_class',
                                    'spearmanr_target',
                                    'spearmanr_all',
                                    'spearmanr_all_per_class']].mean().loc["pearsonr_all"], 
                                 linestyle="-",                                 
                                 #linewidth=3, 
                                 color=plt.rcParams['axes.prop_cycle'].by_key()['color'][1],
                                 label=f'KernelSHAP'
                                )               

            axd[plot_key].set_xlabel("# Training Datapoints (Amortized)") #fontsize=20
            axd[plot_key].set_ylabel("Pearson Correlation") #, fontsize=20

            # xaxis
            #axd[plot_key].xaxis.set_major_locator(MultipleLocator(500))
            #axd[plot_key].xaxis.set_minor_locator(MultipleLocator(100))            
            axd[plot_key].xaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].xaxis.grid(True, which='minor',alpha=0.1) # linewidth=1, 
            axd[plot_key].set_xlim(left=-1e-3)
            #axd[plot_key].set_xscale("log")

            axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.2))
            #axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.005))  
            axd[plot_key].yaxis.grid(True, which='major', alpha=0.6) #linewidth=2, 
            axd[plot_key].yaxis.grid(True, which='minor', alpha=0.1) #linewidth=2, 
            # axd[plot_key].set_ylim(0, 3 * metric_list_plot_df_select.groupby(["model_name", "epoch"])["mse_all"].mean().min())
            axd[plot_key].set_ylim(0, 1.01)
            #axd[plot_key].set_yscale("log")

            axd[plot_key].tick_params(axis='x', which='major', rotation=0, labelright=True) #labelsize=20, 
            axd[plot_key].tick_params(axis='y', which='major', rotation=0)  #labelsize=20
            # axd[plot_key].tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)


            axd[plot_key].spines['right'].set_visible(False)
            axd[plot_key].spines['top'].set_visible(False) 

            axd[plot_key].set_title(prettify_metric_name("pearsonr_all"))#, fontsize=20)
        
         #
            
#             for line in leg.get_lines():
#                 line.set_linewidth(3.0)   
sns.set_theme(style='whitegrid')
sns.set_context('paper', font_scale=1.2)

In [None]:
fig.savefig("logs/plots/"+f"shapley_compute.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_compute.pdf", bbox_inches='tight')

In [None]:
                palette=[Blue_scalar_color_mapping(i, color_map=plt.cm.Blues, data_min=-500, data_max=3136) 
                         for i in metric_list_plot_df_select["num_subsets"]],

# qualitative evaluation plot

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import matplotlib as mpl

def plot_figure(explainer, dataset, sample_idx_list):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    
    explainer_output=[]
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        explainer.eval()
        with torch.no_grad():
            explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
            explanation=explanation["logits"][0]
        explainer_output.append(explanation.detach().cpu().numpy())

        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                explainer.eval()
                with torch.no_grad():
                    explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
                    explanation=explanation["logits"][0]
                #explainer_output.append(explanation.detach().cpu().numpy())
                #print(explanation.shape)
                if len(explanation.shape)==2:
                    explanation_class=explanation[plot_type].detach().cpu().numpy()
                else:
                    explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=mpl.colormaps['Greys'](1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{explainer.explainer.config.surrogate_config['id2label'][str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
    return fig, explainer_output

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib import cm

def plot_figure_attribution(dataset, sample_idx_list, attribution_value, attribution_value_key):
    plt.rcParams["font.size"] = 8
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image"]+list(class_list.values()))+0.2*len(["empty"])), 2*len(sample_idx_list)))
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+list(class_list.values())), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.2]+[1]*len(list(class_list.keys())))

    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.2)
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
    print("class_list", class_list)
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+list(class_list.values())):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                axd[plot_key].set_title(f"{id2label[str(label)]}", pad=7, zorder=10)
            elif plot_type=="empty":
                pass
            else:         
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key][:,plot_type]
                
                #print(explanation_class.shape)
#                 print(explanation_class.shape, plot_type)
#                 explainer.eval()
#                 with torch.no_grad():
#                     explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
#                     explanation=explanation["logits"][0]
#                 if len(explanation.shape)==2:
#                     explanation_class=explanation[plot_type].detach().cpu().numpy()
#                 else:
#                     explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                axd[plot_key].set_title(f"{id2label[str(plot_type)]}")

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)  
    return fig           

In [None]:
import matplotlib as mpl
mpl.rc('text', usetex=False)
mpl.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}"]

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib import cm

def plot_target_prediction_groundtruth(dataset, 
                                       sample_idx_list,
                                       attribution_value,
                                       attribution_value_key,
                                       explainer,
                                       attribution_value_groundtruth,
                                       attribution_value_key_groundtruth                                 
                                      
                                      ):
    img_mean = np.array([0.4914, 0.4822, 0.4465])[:, np.newaxis, np.newaxis]
    img_std = np.array([0.2023, 0.1994, 0.2010])[:, np.newaxis, np.newaxis] 

    label_choice=np.unique([dataset[sample_idx]["labels"] for sample_idx in sample_idx_list])
    label_choice={idx:label for idx, label in enumerate(label_choice)}
    class_list = label_choice 

    fig = plt.figure(figsize=(1.53*(len(["image", "target", "prediction", "groundtruth"])+0.05*len(["empty"])), 
                              1.53*len(sample_idx_list) + 0.06* (len(sample_idx_list)-1)  ))
    
    box1 = gridspec.GridSpec(1, len(["image"]+["empty"]+ ["target", "prediction", "groundtruth"]  ), 
                              wspace=0.06, 
                              hspace=0,
                              width_ratios=[1]+[0.05]+[1, 1, 1])  
    
    
    axd={}
    for idx1, plot_type in enumerate(["image"]+["empty"]+ ["target", "prediction", "groundtruth"]):
        box2 = gridspec.GridSpecFromSubplotSpec(len(sample_idx_list),1, 
                                                subplot_spec=box1[idx1], wspace=0, hspace=0.0)
        
        for idx2, sample_idx in enumerate(sample_idx_list):
            box3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                                                subplot_spec=box2[idx2], wspace=0, hspace=0)
            ax=plt.Subplot(fig, box3[0])
            fig.add_subplot(ax)
            axd[f"{sample_idx}_{plot_type}"]=ax

    for plot_key in axd.keys():
        if 'empty' in plot_key:
            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(0) 
                
                
    for idx1, sample_idx in enumerate(sample_idx_list):
        dataset_item=dataset[sample_idx]

        image = dataset_item["pixel_values"]
        label = dataset_item["labels"]
        

        image_unnormlized=((image.numpy() * img_std) + img_mean).transpose(1,2,0)
        assert image_unnormlized.min()>0 and image_unnormlized.max()<1
        image_unnormlized_scaled=(image_unnormlized-image_unnormlized.min())/(image_unnormlized.max()-image_unnormlized.min())
  
        for idx2, plot_type in enumerate(["image"]+["empty"]+["target", "prediction", "groundtruth"]):
            if plot_type=="image":
                plot_key=f"{sample_idx}_image"
                axd[plot_key].imshow(image_unnormlized_scaled)
                class_name_={'0': 'Tench',
                 '1': 'English springer',
                 '2': 'Cassette player',
                 '3': 'Chain saw',
                 '4': 'Church',
                 '5': 'French horn',
                 '6': 'Garbage truck',
                 '7': 'Gas pump',
                 '8': 'Golf ball',
                 '9': 'Parachute'}[str(label)]
                if idx1==0:
                    axd[plot_key].set_title("Context\n$b$", pad=7, zorder=10)
                axd[plot_key].set_ylabel(class_name_)
            elif plot_type=="empty":
                pass
            elif plot_type=="target":  
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                #print({n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])})
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key][:,label]
                
                #print(explanation_class.shape)
#                 print(explanation_class.shape, plot_type)
#                 explainer.eval()
#                 with torch.no_grad():
#                     explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
#                     explanation=explanation["logits"][0]
#                 if len(explanation.shape)==2:
#                     explanation_class=explanation[plot_type].detach().cpu().numpy()
#                 else:
#                     explanation_class=explanation.detach().cpu().numpy()

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                #axd[plot_key].set_title(f"{id2label[str(plot_type)]}")
                #axd[plot_key].set_title(r'$\tilde{a}(b)$')
                if idx1==0:
                    axd[plot_key].set_title("Noisy label\n"+r'$\tilde{a}(b)$')                                
                
            elif plot_type=="prediction":  
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                #print({n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])})
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key][:,label]
                    
                explainer.eval()
                with torch.no_grad():
                    explanation=explainer(image.unsqueeze(0).to(explainer.device), return_loss=False)
                    explanation=explanation["logits"][0]
                if len(explanation.shape)==2:
                    explanation_class=explanation[label].detach().cpu().numpy()
                else:
                    explanation_class=explanation.detach().cpu().numpy()                    
                    
                print(explanation.shape)

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                #axd[plot_key].set_title(f"{id2label[str(plot_type)]}")                
                if idx1==0:
                    axd[plot_key].set_title("Prediction\n"+r'$a(b;\theta)$')                
                
            elif plot_type=="groundtruth":  
                
                plot_key=f"{sample_idx}_{plot_type}"
                #print(max(attribution_value[sample_idx].keys()))
                #print(plot_type, attribution_value[sample_idx][attribution_value_key].shape)
                #print({n_samples:values for n_samples, values in zip(attribution_value[sample_idx]["iters"], attribution_value[sample_idx]["values"])})
                explanation_class={n_samples:values for n_samples, values in zip(attribution_value_groundtruth[sample_idx]["iters"], attribution_value[sample_idx]["values"])}[attribution_value_key_groundtruth][:,label]

                explanation_class_expanded=np.repeat(np.repeat(explanation_class.reshape(14, 14), 16, axis=0), 16, axis=1)
                explanation_class_expanded=torch.nn.functional.interpolate(torch.Tensor(explanation_class.reshape(1, 1, 14, 14)), 
                                                                          scale_factor=16, align_corners=False, mode='bilinear').numpy().reshape(224, 224)                                                        

                explanation_class_expanded_normalized=(0.5+(explanation_class_expanded)/np.max(np.abs(explanation_class_expanded))*0.5)
                explanation_class_expanded_heatmap=sns.color_palette("icefire", as_cmap=True)(explanation_class_expanded_normalized)#[:,:,:-1]
                explanation_class_expanded_heatmap[:,:,3]=0.6

                image_unnormlized_normalized=(image_unnormlized.sum(axis=2))/3
                image_unnormlized_normalized=cm.get_cmap('Greys', 1000)(1-image_unnormlized_normalized)#[:,:,:-1]
                image_unnormlized_normalized[:,:,3]=0.5

                axd[plot_key].imshow(image_unnormlized_normalized, alpha=0.85)
                axd[plot_key].imshow(explanation_class_expanded_heatmap, alpha=0.9)
                #axd[plot_key].set_title(f"{id2label[str(plot_type)]}")
                if idx1==0:
                    axd[plot_key].set_title("Ground Truth\n"+r'$a(b)$')

            axd[plot_key].set_xticks([])
            axd[plot_key].set_yticks([])
            for axis in ['top','bottom','left','right']:
                axd[plot_key].spines[axis].set_linewidth(1)     
    return fig

In [None]:
fig=plot_target_prediction_groundtruth(
    dataset=dataset_explainer["train"],
    sample_idx_list=[774, 3772, 3418, 7132, 2183, 683, 6310], # 6310
    attribution_value=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train'],
    attribution_value_key=512,
    explainer=regexplainer,
    attribution_value_groundtruth=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train'],
    attribution_value_key_groundtruth=1000000,
)

In [None]:
fig.savefig("logs/plots/"+f"shapley_qualitative.png", bbox_inches='tight')
fig.savefig("logs/plots/"+f"shapley_qualitative.pdf", bbox_inches='tight')

In [None]:
num_subsets=512

model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}"


checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))

checkpoint_state_dict = torch.load(model_path+f"/checkpoint-{int(get_best_model_checkpoint(model_path).split('-')[-1])}"+"/pytorch_model.bin", map_location="cpu")
with open(checkpoint_path+"/trainer_state.json") as f:
    checkpoint_trainer_state = json.load(f)

regexplainer.load_state_dict(checkpoint_state_dict)

In [None]:
shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train'][6440]["iters"][-1]



In [None]:
banzhaf_loaded_dict['logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train']

In [None]:
shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train'].keys()

In [None]:
plot_figure_attribution

In [None]:
plot_figure_shapley

# banzhaf

In [None]:
fig=plot_figure_attribution(
            dataset=dataset_explainer["train"], 
            sample_idx_list=[1087, 1076,
                             4354, 4513,
                             7065, 6673,
                             2523, 2210,],
            attribution_value=banzhaf_loaded_dict['logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_long/extract_output/train'],
            attribution_value_key=100,
)

In [None]:
plt.hist(banzhaf_loaded_dict['logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train'][4513]\
["values"][-1][:,:8].flatten())

In [None]:
id2label

In [None]:
fig=plot_figure_attribution(
            dataset=dataset_explainer["train"], 
            sample_idx_list=[1087, 1076,
                             4354, 4513,
                             7065, 6673,
                             2523, 2210,],
            attribution_value=banzhaf_loaded_dict['logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling/extract_output/train'],
            attribution_value_key=100,
)

In [None]:
for num_subsets in [100, 500]:
    model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in checkpoint_path_list[89:89+1]:
        print(checkpoint_path)
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig, explainer_output=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["train"], 
                sample_idx_list=[1087, 1076,
                             4354, 4513,
                             7065, 6673,
                             2523, 2210,])
    fig.suptitle(f"Reg-AO {num_subsets}") 

In [None]:
fig=plot_figure_attribution(
            dataset=dataset_explainer["test"], 
            sample_idx_list=[0, 10, 11, 17, 18],
            attribution_value=banzhaf_loaded_dict['logs/vitbase_imagenette_surrogate_banzhaf_eval_test_sampling_long/extract_output/test'],
            attribution_value_key=100,
)

In [None]:
for num_subsets in [100, 500]:
    model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in [trainer_state["best_model_checkpoint"]]:
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig, explainer_output=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["test"], 
                sample_idx_list=[0, 10, 11])
    fig.suptitle(f"Reg-AO {num_subsets}") 

In [None]:
for num_subsets in [100, 500]:
    model_path=f"/sdata/chanwkim/xai-amortization/logs_0901/vitbase_imagenette_banzhaf_regexplainer_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    print([trainer_state["best_model_checkpoint"]])

In [None]:
banzhaf_loaded_dict['logs/vitbase_imagenette_surrogate_banzhaf_eval_train_sampling_antithetical/extract_output/train'].keys()

# other

In [None]:
import glob
for model_path_reg in ['logs/vitbase_imagenette_explainer_regression_0',
                       'logs/vitbase_imagenette_explainer_regression_512',                       
                       'logs/vitbase_imagenette_explainer_regression_1024',
                       'logs/vitbase_imagenette_explainer_regression_1536']:
    num_eval=int(model_path_reg.split('_')[-1])+512
    state_dict = torch.load(f"{model_path_reg}/pytorch_model.bin", map_location="cpu")
    regexplainer.load_state_dict(state_dict)
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset["test_explainer"], 
                sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
    fig.suptitle(f"Reg-AO {num_eval}")  

In [None]:
checkpoint_path_list

In [None]:
fig=plot_figure_shapley(
            dataset=dataset_explainer["train"], 
            sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639],
            shapley_value=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train'],
            shapley_value_key=1000000,
)

In [None]:
fig=plot_figure_shapley(
            dataset=dataset_explainer["test"], 
            sample_idx_list=[27, 50, 62, 15, 68, 86, 49, 84],
            shapley_value=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test'],
            shapley_value_key=1000000,
)

In [None]:
for num_subsets in [9986]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in tqdm(checkpoint_path_list[20:20+1]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig, explainer_output=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639])
    fig.suptitle(f"Reg-AO {num_subsets}") 
    
    fig=plot_figure_shapley(
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639],
                shapley_value=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'],
                shapley_value_key=9986,
    )
    fig.suptitle(f"Reg-AO {num_subsets}")      

In [None]:
shapley_loaded_dict.keys()

In [None]:
for num_subsets in [""]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_250"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in tqdm(checkpoint_path_list[46:46+1]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig, explainer_output=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639])
    fig.suptitle(f"Reg-AO {num_subsets}") 
    
    fig=plot_figure_shapley(
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639],
                shapley_value={key:value[0] for key,value in shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440'].items()},
                shapley_value_key=2440,
    )
    fig.suptitle(f"Reg-AO {num_subsets}")      

In [None]:
!ls logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_1000/trainer_state.json

In [None]:
!cat logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_1000/checkpoint-512/trainer_state.json

In [None]:
for num_subsets in [""]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_1000"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in tqdm(checkpoint_path_list[32:32+1]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig, explainer_output=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639])
    fig.suptitle(f"Reg-AO {num_subsets}") 
    
    fig=plot_figure_shapley(
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639],
                shapley_value={key:value[0] for key,value in shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train_2440'].items()},
                shapley_value_key=2440,
    )
    fig.suptitle(f"Reg-AO {num_subsets}")      

In [None]:
for i, sample_idx in enumerate([4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639]):
    ground_truth=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    targets=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    predictions=explainer_output
    
    plt.scatter(
        targets[:, np.argmax(ground_truth.sum(axis=0))],
        predictions[i][np.argmax(ground_truth.sum(axis=0)),:]
    )
    

In [None]:
for i, sample_idx in enumerate([4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639]):
    ground_truth=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    targets=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    predictions=explainer_output
    

    plt.scatter(
        targets[:, np.argmax(ground_truth.sum(axis=0))],
        ground_truth[:, np.argmax(ground_truth.sum(axis=0))]
    )    

In [None]:
for i, sample_idx in enumerate([4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639]):
    ground_truth=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    targets=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    predictions=explainer_output
    

    plt.scatter(
        np.arange(196),
        predictions[i][np.argmax(ground_truth.sum(axis=0)),:][np.argsort(ground_truth[:, np.argmax(ground_truth.sum(axis=0))])],
        alpha=0.8
    )    
    plt.title('explainer output')    

In [None]:
for i, sample_idx in enumerate([4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639]):
    ground_truth=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    targets=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    predictions=explainer_output
    

    plt.scatter(
        np.arange(196),
        targets[:, np.argmax(ground_truth.sum(axis=0))][np.argsort(ground_truth[:, np.argmax(ground_truth.sum(axis=0))])],
                alpha=0.8
    )    
    plt.title('training targets')

In [None]:
for i, sample_idx in enumerate([4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639]):
    ground_truth=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_regression_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    targets=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'][sample_idx]["values"][-1]
    
    predictions=explainer_output
    

    plt.scatter(
        np.arange(196),
        ground_truth[:, np.argmax(ground_truth.sum(axis=0))][np.argsort(ground_truth[:, np.argmax(ground_truth.sum(axis=0))])],
    )   
    plt.title('ground truth')

In [None]:
targets

In [None]:
sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639]
shapley_value=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'],

In [None]:
explainer_output[0].sum(axis=1)

In [None]:
shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train_SGD_antithetical/extract_output/train'][4832]["values"][-1]\
.sum(axis=0)

In [None]:
shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test'].keys()

In [None]:
plot_figure??

In [None]:
for num_subsets in [""]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_250"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in tqdm(checkpoint_path_list[46:46+1]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["test"], 
                sample_idx_list=[27, 50, 62, 15, 68, 86, 49, 84])
    fig.suptitle(f"Reg-AO {num_subsets}")          

In [None]:
!cat ./logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_5000/checkpoint-1099/trainer_state.json

In [None]:
for num_subsets in [""]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_5000"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in tqdm(checkpoint_path_list[14:14+1]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["test"], 
                sample_idx_list=[27, 50, 62, 15, 68, 86, 49, 84])
    fig.suptitle(f"Reg-AO {num_subsets}")          

In [None]:
for num_subsets in [""]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_2330_numtrain_1000"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in tqdm(checkpoint_path_list[32:32+1]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["test"], 
                sample_idx_list=[27, 50, 62, 15, 68, 86, 49, 84])
    fig.suptitle(f"Reg-AO {num_subsets}")          

In [None]:
for num_subsets in [9986]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_SGD_antithetical_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
    for checkpoint_path in tqdm(checkpoint_path_list[20:20+1]):
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["test"], 
                sample_idx_list=[27, 50, 62, 15, 68, 86, 49, 84])
    fig.suptitle(f"Reg-AO {num_subsets}")          

In [None]:
list(shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_test_regression_antithetical/extract_output/test'].keys())[:8]







### kernelshap

In [None]:
for num_subsets in [512]:
# for num_subsets in [2048, 3072]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
#     for checkpoint_path in tqdm(checkpoint_path_list[20:20+1]):
    for checkpoint_path in [trainer_state["best_model_checkpoint"]]:
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639])
    fig.suptitle(f"Reg-AO {num_subsets}")
    
    fig=plot_figure_shapley(
                dataset=dataset_explainer["train"], 
                sample_idx_list=[4832, 1928, 2523, 2997, 4838, 2210, 9286, 3639],
                shapley_value=shapley_loaded_dict['logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train'],
                shapley_value_key=512,
    )
    fig.suptitle(f"Reg-AO {num_subsets}")    

In [None]:
for num_subsets in [512]:
# for num_subsets in [2048, 3072]:
    model_path=f"logs/vitbase_imagenette_shapley_regexplainer_upfront_{num_subsets}"
    with open(model_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)

    checkpoint_path_list=sorted(glob.glob(model_path+"/checkpoint-*"), key=lambda x: int(x.split('-')[-1]))
    
#     for checkpoint_path in tqdm(checkpoint_path_list[20:20+1]):
    for checkpoint_path in [trainer_state["best_model_checkpoint"]]:
        checkpoint_state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
        with open(checkpoint_path+"/trainer_state.json") as f:
            checkpoint_trainer_state = json.load(f)

        regexplainer.load_state_dict(checkpoint_state_dict)
        
        
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset_explainer["test"], 
                sample_idx_list=[27, 50, 62, 15, 68, 86, 49, 84])
    fig.suptitle(f"Reg-AO {num_subsets}")

In [None]:
len(dataset_explainer["test"])

In [None]:
dataset_explainer

# end

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean().reset_index()

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean()

In [None]:
sns.scatterplot(x="mse_target_explainer", 
                y="mse_target_calculated", 
                hue="explainer",
                data=metric_df)

In [None]:
metric_df

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].apply(lambda x: len(x))

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean().T

In [None]:
metric_df.groupby("explainer")[['sample_idx', 'mse_target_explainer', 'mse_nontarget_explainer',
       'mse_all_explainer',  'epoch', 
       'mse_target_calculated', 'mse_nontarget_calculated',
       'mse_all_calculated']].mean().T

In [None]:
metric_df.fillna("None").groupby(["explainer", "sample_idx"]).apply(lambda x: print(x))

In [None]:
pd.DataFrame(get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_calculated=512*3, 
    meta_info={}
)).mean()

In [None]:
pd.DataFrame(get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"], 
    iters_calculated=512*1, 
    meta_info={}
)).mean()

In [None]:
pd.DataFrame(get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_calculated=512*2, 
    meta_info={}
)).mean()

In [None]:
199680/512

In [None]:
pd.DataFrame(metric).mean()

In [None]:
pd.DataFrame(metric).mean()

In [None]:
get_ground_truth_metric_with_value(
    shapley_values_ground_truth=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"], 
    iters_ground_truth=199680,
    shapley_values_calculated=shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"], 
    iters_calculated=512, 
    meta_info={}
)    

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"]=\
load_shapley("logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train")

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"][9213]["iters"]

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train_regression/extract_output/train"][9213]["iters"]

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_train/extract_output/train"]

In [None]:
shapley_loaded_dict["logs/vitbase_imagenette_surrogate_shapley_eval_test/extract_output/test"]

In [None]:
metric_list+=get_ground_truth_metric(shapley_values=shapley_loaded_test, 
                        explainer=regexplainer,
                        dataset=dataset_explainer["test"],
                        iters_ground_truth=200192,
                        meta_info={"explainer": "Reg-AO (upfront, 512)",
                                   "epoch":int(trainer_state["epoch"])
                                  })

In [None]:
metric_list

In [None]:
checkpoint_path_list=sorted(glob.glob("logs/vitbase_imagenette_shapley_objexplainer_newsample_32/"), key=lambda x: int(x.split('-')[-1]))
for checkpoint_path in tqdm(checkpoint_path_list[:100]):
    state_dict = torch.load(checkpoint_path+"/pytorch_model.bin", map_location="cpu")
    
    with open(checkpoint_path+"/trainer_state.json") as f:
        trainer_state = json.load(f)
        
    explainer.load_state_dict(state_dict)
    metric_list+=get_ground_truth_metric(shapley_values=shapley_loaded_test, 
                            explainer=explainer,
                            dataset=dataset_explainer["test"],
                            iters_ground_truth=200192,
                            meta_info={"explainer": "Obj-AO (newsample, 32)",
                                       "epoch":int(trainer_state["epoch"])
                                      })

# per_sample

In [None]:
num_eval_ground_truth=99840

record_dict_list=[]

for sample_idx, num_eval_shapley_values in enumerate(shapley_values_dict["test"]):
    target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)
        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "per-sample",
        })

In [None]:
record_dict_list=[]

for sample_idx, num_eval_shapley_values in enumerate(shapley_loaded_test["test"]):
    target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
    for num_eval, shapley_values in num_eval_shapley_values.items():
        diff=(shapley_values-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)
        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "per-sample",
        })

In [None]:
shapley_values_loaded

In [None]:
shapley_values.shape

In [None]:
99840/512

In [None]:
shapley_estimated.shape

# regression

In [None]:
import glob
for model_path_reg in ['logs/vitbase_imagenette_explainer_regression_0',
                       'logs/vitbase_imagenette_explainer_regression_512',                       
                       'logs/vitbase_imagenette_explainer_regression_1024',
                       'logs/vitbase_imagenette_explainer_regression_1536']:
    num_eval=int(model_path_reg.split('_')[-1])+512
    state_dict = torch.load(f"{model_path_reg}/pytorch_model.bin", map_location="cpu")
    regexplainer.load_state_dict(state_dict)
    fig=plot_figure(explainer=regexplainer, 
                dataset=dataset["test_explainer"], 
                sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
    fig.suptitle(f"Reg-AO {num_eval}")    

In [None]:
import glob
for model_path_reg in ['logs/vitbase_imagenette_explainer_regression_0',
                       'logs/vitbase_imagenette_explainer_regression_1024',
                       'logs/vitbase_imagenette_explainer_regression_512',
                       'logs/vitbase_imagenette_explainer_regression_1536']:
    num_eval=int(model_path_reg.split('_')[-1])+512
    state_dict = torch.load(f"{model_path_reg}/checkpoint-1480/pytorch_model.bin", map_location="cpu")
    regexplainer.load_state_dict(state_dict)
#     fig=plot_figure(explainer=regexplainer, 
#                 dataset=dataset["test_explainer"], 
#                 sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
#     fig.suptitle(f"Reg-AO {num_eval}")    
    for sample_idx, (num_eval_shapley_values, data) in enumerate(zip(shapley_values_dict["test"], dataset["test_explainer"])):
        target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
        regexplainer.eval()
        with torch.no_grad():
            shapley_estimated=regexplainer(pixel_values=data["pixel_values"].unsqueeze(0), return_loss=False)["logits"][0]
        diff=(shapley_estimated.T.detach().numpy()-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "regression_AO",            
        })
        

In [None]:
import tqdm

In [None]:
import glob
for model_path_obj in sorted(glob.glob("logs/vitbase_imagenette_explainer_objective/checkpoint-*"), key=lambda x: int(x.split('-')[-1])):
    num_eval=int(model_path_obj.split('-')[-1])/148*32
    if int(int(model_path_obj.split('-')[-1])/148)%10!=1:
        continue
    state_dict = torch.load(f"{model_path_obj}/pytorch_model.bin", map_location="cpu")
    explainer.load_state_dict(state_dict)
#     fig=plot_figure(explainer=regexplainer, 
#                 dataset=dataset["test_explainer"], 
#                 sample_idx_list=[0,  10, 20, 30, 40, 50, 60, 70])
#     fig.suptitle(f"Reg-AO {num_model_eval}")    
    for sample_idx, (num_eval_shapley_values, data) in enumerate(zip(tqdm.tqdm(shapley_values_dict["test"]), dataset["test_explainer"])):        
        target_class_idx=np.argmax(num_eval_shapley_values[num_eval_ground_truth].sum(axis=0))
        explainer.eval()
        with torch.no_grad():
            shapley_estimated=explainer(pixel_values=data["pixel_values"].unsqueeze(0).to(explainer.device), return_loss=False)["logits"][0]
        diff=(shapley_estimated.T.cpu().detach().numpy()-num_eval_shapley_values[num_eval_ground_truth])
        mse_class=(diff*diff).sum(axis=0)        
        record_dict_list.append({
            "sample_idx": sample_idx,
            "mse_target": mse_class[np.arange(len(mse_class))==target_class_idx].mean(),
            "mse_nontarget": mse_class[np.arange(len(mse_class))!=target_class_idx].mean(),
            "mse_all": mse_class[:].mean(),
            "num_eval":num_eval,
            "method": "objective_AO",            
        })

In [None]:
int(int(model_path_obj.split('-')[-1])/148)%10

In [None]:
model_path_obj

In [None]:
444/148

In [None]:
explainer.device

In [None]:
import pandas as pd

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import font_manager
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from cycler import cycler
from matplotlib.ticker import MultipleLocator, AutoMinorLocator

font_manager.findSystemFonts(fontpaths=None, fontext="ttf")
font_manager.findfont("Arial") # Test with "Special Elite" too
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_target",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.5))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.1))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
import pandas as pd

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df,
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(50000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(10000))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)    

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_target",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.05))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.01))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(1000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(500))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)                   

axd[plot_key].set_xlim(0,3500)
axd[plot_key].set_ylim(0, 0.1)

axd[plot_key].set_title("Target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
record_dict_list_df=pd.DataFrame(record_dict_list)

fig, ax = plt.subplots(1,1, figsize=(10,5))

axd={"main": ax}
plot_key="main"

sns.lineplot(x="num_eval",
             y="mse_nontarget",
             data=record_dict_list_df[record_dict_list_df["num_eval"]>0],
             hue="method",
            ax=ax)

axd[plot_key].set_ylabel("L2 distance", fontsize=20)
axd[plot_key].set_xlabel("# model evaluations", fontsize=20)


axd[plot_key].yaxis.set_major_locator(MultipleLocator(0.005))
axd[plot_key].yaxis.set_minor_locator(MultipleLocator(0.001))            
axd[plot_key].yaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].yaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].xaxis.set_major_locator(MultipleLocator(1000))
axd[plot_key].xaxis.set_minor_locator(MultipleLocator(500))            
axd[plot_key].xaxis.grid(True, which='major', linewidth=2, alpha=0.6)
axd[plot_key].xaxis.grid(True, which='minor', linewidth=1, alpha=0.1)

axd[plot_key].tick_params(axis='x', which='major', labelsize=20)
axd[plot_key].tick_params(axis='y', which='major', labelsize=20)   

axd[plot_key].spines['right'].set_visible(False)
axd[plot_key].spines['top'].set_visible(False)                   

axd[plot_key].set_xlim(0,3500)
axd[plot_key].set_ylim(0, 0.01)

axd[plot_key].set_title("Non-target", fontsize=20)

# axd[plot_key].set_yscale('log')

In [None]:
500/32

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_explainer_regression_0/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_explainer_regression/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)

In [None]:
explainer_out=explainer.forward(pixel_values=dataset["validation_explainer"][0]['pixel_values'].unsqueeze(0),
                               return_loss=False)

In [None]:
import tqdm

In [None]:
!gpustat

In [None]:
device="cuda:7"

In [None]:
explainer.to(device)
explainer.surrogate_null=explainer.surrogate_null.to(device)

In [None]:
plot_figure(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
plot_figure(explainer, dataset["validation_explainer"], [0,  250, 500,1000])

In [None]:
dataset["validation"][0]["image"]

In [None]:
dataset["test"][0]["image"]

In [None]:
dataset["test"] = (

In [None]:
data_args.max_test_samples

In [None]:
from utils import load_shapley

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output],
                'std': [],
                'iters': [0]}}, shapley_value_key=0)

In [None]:
shapley_loaded_train=load_shapley("logs/vitbase_imagenette_surrogate_eval_train/extract_output/train/")
shapley_loaded_train_permutation=load_shapley("logs/vitbase_imagenette_surrogate_train_validation_permutation/extract_output/train/")

In [None]:
shapley_loaded_validation=load_shapley("logs/vitbase_imagenette_surrogate_eval_validation/extract_output/validation/")
shapley_loaded_validation_permutation=load_shapley("logs/vitbase_imagenette_surrogate_eval_validation_permutation/extract_output/validation/")

In [None]:
shapley_loaded_test=load_shapley("logs/vitbase_imagenette_surrogate_eval_test/extract_output/test/")
shapley_loaded_test_permutation=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded_test_permutation=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded1=load_shapley("logs/vitbase_imagenette_surrogate_eval_validation/extract_output/validation/")

In [None]:
shapley_loaded1[0]["iters"]

In [None]:
shapley_loaded2[0]["iters"]

In [None]:
len(shapley_loaded1), len(shapley_loaded2)

In [None]:
shapley_loaded=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded.keys()

In [None]:
!ls logs/vitbase_imagenette_surrogate_eval_train_permutation/extract_output/train/3 -l

In [None]:
shapley_loaded[0]["values"]

In [None]:
load_shapley??

In [None]:
shapley_loaded[0]["values"][-1].sum(axis=0)

In [None]:
device="cuda:0"
explainer.surrogate.to(device)

In [None]:
import numpy as np
from tqdm.auto import tqdm
from scipy.special import softmax


def ShapleySampling(game,
                    batch_size=512,
                    detect_convergence=True,
                    thresh=0.01,
                    n_samples=None,
                    antithetical=False,
                    return_all=False,
                    bar=True,
                    verbose=False):
    # Verify arguments.
    stochastic = False
#     if isinstance(game, CooperativeGame):
#         stochastic = False
#     elif isinstance(game, StochasticCooperativeGame):
#         stochastic = True
#     else:
#         raise ValueError('game must be CooperativeGame or '
#                          'StochasticCooperativeGame')

    # Possibly force convergence detection.
    if n_samples is None:
        n_samples = 1e20
        if not detect_convergence:
            detect_convergence = True
            if verbose:
                print('Turning convergence detection on')

    if detect_convergence:
        assert 0 < thresh < 1

    # Calculate null coalition value.
    if stochastic:
        null = game.null(batch_size=batch_size)
    else:
        null = game.null()

    # Set up bar.
    n_loops = int(np.ceil(n_samples / batch_size))
    if bar:
        if detect_convergence:
            bar = tqdm(total=1)
        else:
            bar = tqdm(total=n_loops * batch_size)

    # Setup.
    num_players = game.players
    if isinstance(null, np.ndarray):
        values = np.zeros((num_players, len(null)))
        sum_squares = np.zeros((num_players, len(null)))
        deltas = np.zeros((batch_size, num_players, len(null)))
    else:
        values = np.zeros((num_players))
        sum_squares = np.zeros((num_players))
        deltas = np.zeros((batch_size, num_players))
    permutations = np.tile(np.arange(game.players), (batch_size, 1))
    arange = np.arange(batch_size)
    n = 0

    # For tracking progress.
    if return_all:
        N_list = []
        std_list = []
        val_list = []

    # Begin sampling.
    for it in range(n_loops):
        for i in range(batch_size):
            if antithetical and i % 2 == 1:
                permutations[i] = permutations[i - 1][::-1]
            else:
                np.random.shuffle(permutations[i])
        S = np.zeros((batch_size, game.players), dtype=int)

        # Sample exogenous (if applicable).
        if stochastic:
            U = game.sample(batch_size)

        # Unroll permutations.
        prev_value = null
        for i in tqdm(range(num_players)):
            S[arange, permutations[:, i]] = 1
            if stochastic:
                next_value = game(S, U)
            else:
                next_value = game(S)
            deltas[arange, permutations[:, i]] = next_value - prev_value
            prev_value = next_value

        # Welford's algorithm.
        n += batch_size
        diff = deltas - values
        values += np.sum(diff, axis=0) / n
        diff2 = deltas - values
        sum_squares += np.sum(diff * diff2, axis=0)

        # Calculate progress.
        var = sum_squares / (n ** 2)
        std = np.sqrt(var)
        ratio = np.max(
            np.max(std, axis=0) / (values.max(axis=0) - values.min(axis=0)))

        # Print progress message.
        if verbose:
            if detect_convergence:
                print(f'StdDev Ratio = {ratio:.4f} (Converge at {thresh:.4f})')
            else:
                print(f'StdDev Ratio = {ratio:.4f}')

        # Check for convergence.
        if detect_convergence:
            if ratio < thresh:
                if verbose:
                    print('Detected convergence')

                # Skip bar ahead.
                if bar:
                    bar.n = bar.total
                    bar.refresh()
                break

        # Forecast number of iterations required.
        if detect_convergence:
            N_est = (it + 1) * (ratio / thresh) ** 2
            if bar and not np.isnan(N_est):
                bar.n = np.around((it + 1) / N_est, 4)
                bar.refresh()
        elif bar:
            bar.update(batch_size)

        # Save intermediate quantities.
        if return_all:
            val_list.append(np.copy(values))
            std_list.append(np.copy(std))
            if detect_convergence:
                N_list.append(N_est)

    # Return results.
    if return_all:
        # Dictionary for progress tracking.
        iters = (np.arange(it + 1) + 1) * batch_size * num_players
        tracking_dict = {
            'values': val_list,
            'std': std_list,
            'iters': iters}
        if detect_convergence:
            tracking_dict['N_est'] = N_list

        return tracking_dict
    else:
        return (values, std)
    
class CooperativeGame:
    '''Base class for cooperative games.'''

    def __init__(self):
        raise NotImplementedError

    def __call__(self, S):
        '''Evaluate cooperative game.'''
        raise NotImplementedError

    def grand(self):
        '''Get grand coalition value.'''
        return self.__call__(np.ones((1, self.players), dtype=int))[0]

    def null(self):
        '''Get null coalition value.'''
        return self.__call__(np.zeros((1, self.players), dtype=int))[0]


class PredictionGame(CooperativeGame):
    '''
    Cooperative game for an individual example's prediction.

    Args:
      extension: model extension (see removal.py).
      sample: numpy array representing a single model input.
    '''

    def __init__(self, surrogate, sample):
        # Add batch dimension to sample.

        self.surrogate = surrogate
        self.sample = sample

        # Store feature groups.

        self.players = 196
        self.groups_matrix = None

        # Caching.
        self.sample_repeat = sample

    def __call__(self, S):
        '''
        Evaluate cooperative game.

        Args:
          S: array of player coalitions with size (batch, players).
        '''
        # Try to use caching for repeated data.
        input_data = self.sample_repeat

        # Evaluate.
        with torch.no_grad():
            output=self.surrogate(input_data["pixel_values"].unsqueeze(0).to(device), 
                                  torch.Tensor(S).unsqueeze(0).to(device), return_loss=False)
            logits=output.logits
            return softmax(logits[0].detach().cpu().numpy(), axis=1)
    


# shapley_sampling=ShapleySampling(game,
#                     batch_size=128,
#                     detect_convergence=True,
#                     thresh=0.01,
#                     antithetical=False,
#                     return_all=True,
#                     bar=True,
#                     verbose=False)

In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1

            # Compute y_i
            y_i = game(x_i.astype(bool)[np.newaxis])[0] - null

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        phi = np.mean(phi_iterates, axis=0)

        return phi


In [None]:
from collections import OrderedDict
class SGDshapley():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        """

        # Store in a dictionary for each size k of coalitions
        dict_ω_k = OrderedDict() # weights per size k
        dict_L_k = OrderedDict() # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            ω_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = ω_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_ω_k.update({k: ω_k})
            dict_L_k.update({k: L_k})
        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])
        # Probability distributions for sampling new instance
        # Classic SGD
        p = [ncr(d,k) for k in range(1,d)]
        p /= np.sum(p)
        # Importance Sampling proposal q
#         print(dict_L_k.keys(), dict_L_k.values())
#         sdsd
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.n = 2**d - 2
        self.dict_ω_k = dict_ω_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _F_i(self, Φ, x_i, y_i, ω_i):
        """Function value per instance i"""
        res = .5 * self.n * ω_i * (np.dot(x_i, Φ) - y_i)**2
        return res

    def _grad_F_i(self, Φ, x_i, y_i, ω_i):
        """Gradient vector per instance i"""
        res = ω_i * x_i[:,None].dot(x_i[None,:]).dot(Φ) - ω_i * y_i * x_i
        return res

    def _Π_1(self, x, b):
        """Projection Π on convex set K_1"""
        if np.abs((np.sum(x) - b)) <= 1e-6:
            return x
        else:
            return x - (np.sum(x) - b)/len(x)

    def _Π_2(self, x, D):
        """Projection Π on convex set K_2"""
        if np.linalg.norm(x) > D:
            return x * D / np.linalg.norm(x)
        else:
            return x

    def _Dykstra_proj(self, x, D, b, iter_proj=100, epsilon=1e-6):
        """
        Dykstra's algorithm to find orthogonal projection
        onto intersection of convex sets
        """
        xk = x.copy()
        d = len(x)
        pk, qk = np.zeros(d), np.zeros(d)
        for k in range(iter_proj):
            yk = self._Π_2(xk + pk, D)
            pk = xk + pk - yk
            if np.linalg.norm(self._Π_1(yk + qk, b) - xk, 2) <= epsilon:
                break
            else:
                xk = self._Π_1(yk + qk, b)
                qk = yk + qk - xk
        return xk

    def sgd(self, game, dimension_select, n_iter=100, step=.1, step_type="sqrt",
            callback=None, Φ_0=False):
        """
        Stochastic gradient descent algorithm
        The game is defined for an element x, a reference r and function fc

        """

        # Get general information
        
        f_x = game(np.ones((1, self.d), dtype=int))[0][dimension_select]
        f_r = game(np.zeros((1, self.d), dtype=int))[0][dimension_select]

        
        v_M = f_x - f_r

        d = self.d
        n = 2**d - 2
        p = self.p
        dict_ω_k = self.dict_ω_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # Store Shapley Values in a pandas Series
        if Φ_0:
            Φ = Φ_0.copy()
        else:
            Φ = np.zeros(d)
        Φ_storage = np.zeros((n_iter,d))

        # projection onto convex set K by using a simple algorithm
        # Φ = self._Dykstra_proj(Φ, D, v_M, iter_proj, epsilon=1e-6)
        Φ = Φ - (np.sum(Φ) - v_M) / d

        # Sample in advance coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        for t in tqdm(range(1, n_iter+1)):
            # build x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            # Compute y_i
            #z_S = np.array([x.values[j] if x_i[j] == 1 else ref.values[j] for j in range(d)])            
            f_S = game(x_i[np.newaxis])[0][dimension_select]
            y_i = f_S - f_r
            # get weight ω_i
            ω_i = dict_ω_k[k]
            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(Φ, x_i, y_i, ω_i)
            # update Φ
            if step_type == "constant":
                Φ = Φ - step * grad_i
            elif step_type == "sqrt":
                Φ = Φ - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                Φ = Φ - (step/(t)) * grad_i

            # projection onto convex set K
            # Φ = self._Dykstra_proj(Φ, D, v_M, iter_proj, epsilon=1e-6)
            Φ = Φ - (Φ.sum() - v_M) / d

            # update storage of Φ
            Φ_storage[t-1,:] = Φ


        # Average all Φ
        Φ = np.mean(Φ_storage,axis=0)

        return Φ
    
    def sgd_minibatch(self, game, batch_size, dimension_select, n_iter=100, step=.1, step_type="sqrt",
            callback=None, Φ_0=False):
        """
        Stochastic gradient descent algorithm
        The game is defined for an element x, a reference r and function fc

        """

        # Get general information
        
        f_x = game(np.ones((1, self.d), dtype=int))[0][dimension_select]
        f_r = game(np.zeros((1, self.d), dtype=int))[0][dimension_select]

        
        v_M = f_x - f_r

        d = self.d
        n = 2**d - 2
        p = self.p
        dict_ω_k = self.dict_ω_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # Store Shapley Values in a pandas Series
        if Φ_0:
            Φ = Φ_0.copy()
        else:
            Φ = np.zeros(d)
        Φ_storage = []

        # projection onto convex set K by using a simple algorithm
        # Φ = self._Dykstra_proj(Φ, D, v_M, iter_proj, epsilon=1e-6)
        Φ = Φ - (np.sum(Φ) - v_M) / d

        # Sample in advance coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        grad_i_accum=[]
        
        for t in tqdm(range(1, n_iter+1)):
            # build x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            # Compute y_i
            #z_S = np.array([x.values[j] if x_i[j] == 1 else ref.values[j] for j in range(d)])            
            f_S = game(x_i[np.newaxis])[0][dimension_select]
            y_i = f_S - f_r
            # get weight ω_i
            ω_i = dict_ω_k[k]
            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(Φ, x_i, y_i, ω_i)
            grad_i_accum.append(grad_i)
            
            if t%batch_size==0:
                # update Φ
                if step_type == "constant":
                    Φ = Φ - step * np.array(grad_i_accum).mean(axis=0)
                elif step_type == "sqrt":
                    Φ = Φ - (step/np.sqrt(t)) * np.array(grad_i_accum).mean(axis=0)
                elif step_type == "inverse":
                    Φ = Φ - (step/(t)) * np.array(grad_i_accum).mean(axis=0)

                # projection onto convex set K
                # Φ = self._Dykstra_proj(Φ, D, v_M, iter_proj, epsilon=1e-6)
                Φ = Φ - (Φ.sum() - v_M) / d

                # update storage of Φ
                Φ_storage.append(Φ)
                grad_i_accum=[]

        # Average all Φ
        Φ = np.mean(np.array(Φ_storage), axis=0)  
        
        return Φ

In [None]:
fig=plot_figure_shapley(dataset_explainer["validation"], sample_idx_list=[20], 
                    shapley_value=shapley_loaded1, shapley_value_key=3584)

In [None]:
fig=plot_figure_shapley(dataset_explainer["validation"], sample_idx_list=[20], 
                    shapley_value=shapley_loaded2, shapley_value_key=3332)

In [None]:
shapley_loaded_test[0]['iters']

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[5], 
                    shapley_value=shapley_loaded_test, shapley_value_key=5120)

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[5], 
                    shapley_value=shapley_loaded_test_permutation, shapley_value_key=3332)

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[1], 
                    shapley_value=shapley_loaded_test_permutation, shapley_value_key=3332)

In [None]:
shapley_loaded_test_permutation[0]["iters"]

In [None]:
3332/196

In [None]:
shapley_loaded2[0]['iters']

In [None]:
fig=plot_figure_shapley(dataset_explainer["validation"], sample_idx_list=[110], 
                    shapley_value=shapley_loaded2, shapley_value_key=3332)

In [None]:
shapley_loaded1[0]["iters"]

In [None]:
shapley_loaded1

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd(game,
            n_iter=5000,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
game=PredictionGame(surrogate=explainer.surrogate,
                    sample=dataset_explainer["test"][0]
                    )

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=32,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=32,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=64,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=64,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd_minibatch(game,
            n_iter=5000,
            batch_size=512,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
sgd_shapley_old_output.sum()

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_old_output

In [None]:
game(np.ones((1,196)))[0][8]-game(np.zeros((1,196)))[0][8]

In [None]:
sgd_shapley_old_output.sum()

In [None]:
sgd_shapley_old=SGDshapley(d=196, C=1)

sgd_shapley_old_output=sgd_shapley_old.sgd(game,
            n_iter=200000,
            dimension_select=8,
            step=0.1,
            step_type="sqrt")

In [None]:
sgd_shapley_old_output

In [None]:
np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))

In [None]:
sgd_shapley_old_output==np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))[:,0]

In [None]:
sgd_shapley_old_output

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [np.tile(sgd_shapley_old_output.reshape(-1,1), (1,10))],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
        f_x = game(np.ones((1, self.d), dtype=bool))[0]
        f_r = game(np.zeros((1, self.d), dtype=bool))[0]
        import ipdb
        ipdb.set_trace()
        
        v_M = f_x - f_r


In [None]:
game=PredictionGame(surrogate=explainer.surrogate,
                    sample=dataset_explainer["test"][0]
                    )

In [None]:
shapley_sampling=ShapleySampling(game,
                    batch_size=32,
                    n_samples=32*32,
                    detect_convergence=False,
                    thresh=0.01,
                    antithetical=False,
                    return_all=True,
                    bar=True,
                    verbose=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=200000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)
        
        
        k_record=[]
        x_i_record=[]
        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            
            k_record.append(k)
            x_i_record.append(x_i)
            
            
        
        x_i_record=np.array(x_i_record)
        y_i_record=[]
        
        for i in tqdm(range(int(np.ceil(len(x_i_record)/128)))):

            y_i = game(x_i_record[128*i:128*(i+1)].astype(int)) - null
            y_i_record.append(y_i)

        y_i_record=np.vstack(y_i_record)   
                
            
        for t in tqdm(range(1, n_iter+1)):
            # Compute y_i
#             print(x_i.astype(bool).shape)
            k=k_record[t-1]
            x_i=x_i_record[t-1]
            y_i = y_i_record[t-1]
            #print(game(x_i.astype(bool)[np.newaxis])[0], y_i)

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        return np.cumsum(phi_iterates, axis=0)/(np.arange(len(phi_iterates))+1).reshape(-1,1,1)
#         return phi_iterates
        #phi = np.mean(phi_iterates, axis=0)
        
        return phi


In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
#             res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
            res = x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
#         list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=self.p)
        
        
        k_record=[]
        x_i_record=[]
        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1
            
            k_record.append(k)
            x_i_record.append(x_i)
            
            
        
        x_i_record=np.array(x_i_record)
        y_i_record=[]
        
        for i in tqdm(range(int(np.ceil(len(x_i_record)/128)))):

            y_i = game(x_i_record[128*i:128*(i+1)].astype(int)) - null
            y_i_record.append(y_i)

        y_i_record=np.vstack(y_i_record)   
                
            
        for t in tqdm(range(1, n_iter+1)):
            # Compute y_i
#             print(x_i.astype(bool).shape)
            k=k_record[t-1]
            x_i=x_i_record[t-1]
            y_i = y_i_record[t-1]
            #print(game(x_i.astype(bool)[np.newaxis])[0], y_i)

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
#             grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)
            grad_i = self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        return np.cumsum(phi_iterates, axis=0)/(np.arange(len(phi_iterates))+1).reshape(-1,1,1)
#         return phi_iterates
        #phi = np.mean(phi_iterates, axis=0)
        
        return phi


In [None]:
# Edited by: Ian Covert and Chanwoo Kim

# Original authors: Simon Grah <simon.grah@thalesgroup.com>
#                   Vincent Thouvenot <vincent.thouvenot@thalesgroup.com>

# MIT License

# Copyright (c) 2020 Thales Six GTS France

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
import operator as op
from functools import reduce
from tqdm import tqdm


def ncr(n, r):
    """
    Combinatorial computation: number of subsets of size r among n elements
    Efficient algorithm
    """
    r = min(r, n-r)
    numer = reduce(op.mul, range(n, n-r, -1), 1)
    denom = reduce(op.mul, range(1, r+1), 1)
    return numer / denom


class SGDShapleyNew():
    """
    Estimate the Shapley Values using a Projected Stochastic Gradient algorithm.
    """

    def __init__(self, d, C):
        """
        Calculate internal values for later purposes
        Those elements depend only on the number of features d

        Parameters
        ----------
        d : integer
            Dimension of the problem. The number of features
        C : float
            Constant bounding |y|
        """

        # Store in a dictionary for each size k of coalitions
        dict_w_k = dict()  # weights per size k
        dict_L_k = dict()  # L-smooth constant per size k
        D = C * np.sqrt(d)
        for k in range(1, d):
            w_k = (d - 1) / (ncr(d, k) * k * (d - k))
            L_k = w_k * np.sqrt(k) * (np.sqrt(k) * D + C)
            dict_w_k.update({k: w_k})
            dict_L_k.update({k: L_k})

        # Summation of all L per coalition (closed formula)
        sum_L = np.sum([(d-1)/(np.sqrt(k)*(d-k)) * (np.sqrt(k)*D + C) for k in range(1, d)])

        # Probability distributions for sampling new instance

        # 1. Classic SGD (not used)
        p = [ncr(d, k) for k in range(1, d)]
        p /= np.sum(p)

        # 2. Importance Sampling proposal q (used)
        q = np.array(list(dict_L_k.values())) * np.array(p)
        q /= np.sum(q)

        # Save internal attributes
        self.d = d
        self.dict_w_k = dict_w_k
        self.dict_L_k = dict_L_k
        self.sum_L = sum_L
        self.p = p
        self.q = q

    def _grad_F_i(self, phi, x_i, y_i, w_i):
        """Gradient vector per instance i"""
        if isinstance(y_i, np.ndarray):
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)
        return res

    def sgd(self,
            game,
            n_iter=100,
            step=0.1,
            step_type="sqrt",
            phi_0=False):
        """
        Stochastic gradient descent algorithm
        """

        # Get general information
        grand = game(np.ones((1, self.d), dtype=bool))[0]
        null = game(np.zeros((1, self.d), dtype=bool))[0]
        if isinstance(grand, np.ndarray):
            out_dim = len(grand)
        else:
            out_dim = None
        total = grand - null
        # print(grand)
        # print(null)
        # print(total)
        # print(out_dim)

        d = self.d
        dict_w_k = self.dict_w_k
        q = self.q
        dict_L_k = self.dict_L_k
        sum_L = self.sum_L

        # initialize Shapley value estimates
        if phi_0:
            phi = phi_0.copy()
        else:
            if out_dim is None:
                phi = np.zeros(d)
            else:
                phi = np.zeros((d, out_dim))

        # projection step
        phi = phi - (np.sum(phi, axis=0) - total) / d

        # store for iterate averaging
        if out_dim is None:
            phi_iterates = np.zeros((n_iter, d))
        else:
            phi_iterates = np.zeros((n_iter, d, out_dim))

        # sample coalition sizes
        list_k = np.random.choice(list(range(1, d)), size=n_iter, p=q)

        for t in tqdm(range(1, n_iter+1)):
            # build subset indicator x_i
            k = list_k[t-1]
            indexes = np.random.permutation(d)[:k]
            x_i = np.zeros(d)
            x_i[indexes] = 1

            # Compute y_i
#             print(x_i.astype(bool).shape)
            y_i = game(x_i.astype(bool)[np.newaxis])[0] - null
            #print(game(x_i.astype(bool)[np.newaxis])[0], y_i)

            # get weight w_i for importance sampling
            w_i = dict_w_k[k]

            # calculate gradient
            p_i = dict_L_k[k] / sum_L
            grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)

            # update phi
            if step_type == "constant":
                phi = phi - step * grad_i
            elif step_type == "sqrt":
                phi = phi - (step/np.sqrt(t)) * grad_i
            elif step_type == "inverse":
                phi = phi - (step/(t)) * grad_i

            # projection step
            phi = phi - (phi.sum(axis=0) - total) / d

            # update iterate history
            phi_iterates[t-1] = phi

        # Average iterates
        return np.cumsum(phi_iterates, axis=0)/(np.arange(len(phi_iterates))+1).reshape(-1,1,1)
#         return phi_iterates
        #phi = np.mean(phi_iterates, axis=0)
        
        return phi


In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=500000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
1/np.sqrt(50000)

In [None]:
            # print('y is an array')
            res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
            import ipdb
            ipdb.set_trace()
            res = x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
        else:
            # print('y is a scalar')
            res = w_i * x_i * (x_i.dot(phi) - y_i)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.01,
            step_type="sqrt",
            phi_0=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=10,
            step_type="sqrt",
            phi_0=False)

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=1,
            step_type="sqrt",
            phi_0=False)

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.01,
            step_type="sqrt",
            phi_0=False)

In [None]:
sgd_shapley_output[-1]

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
 
# grad_i = 1/(p_i) * self._grad_F_i(phi, x_i, y_i, w_i)
grad_i = self._grad_F_i(phi, x_i, y_i, w_i)

# res = w_i * x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]
res = x_i[:, np.newaxis] * (x_i.dot(phi) - y_i)[np.newaxis]

In [None]:
for subset in [50,500,5000,50000]:
    fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                        shapley_value={0:{
                'values': [sgd_shapley_output[subset-1]],
                'std': [],
                'iters': [0]}}, shapley_value_key=int(0))
    fig.suptitle(str(subset))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output[50000-1]],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=5000,
            step=0.1,
            step_type="sqrt",
            phi_0=False)

In [None]:
sgd_shapley_output.shape

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output[10000-1]],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output[5000].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley_output.shape

In [None]:
sgd_shapley_output__[:500].mean(axis=0)

In [None]:
(np.array([sgd_shapley_output__[:i+1].mean(axis=0) for i in range(len(sgd_shapley_output__[:100]))])==\
np.cumsum(sgd_shapley_output__[:100], axis=0)/(np.arange(len(sgd_shapley_output__[:100]))+1).reshape(-1,1,1)).all()

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output__[:50].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output__[:500].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output__[:50].mean(axis=0)],
            'std': [],
            'iters': [0]}}, shapley_value_key=int(0))

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=100,
            step=0.1,
            step_type="inverse",
            phi_0=False)

In [None]:
game=PredictionGame(surrogate=explainer.surrogate,
                    sample=dataset_explainer["test"][0]
                    )

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=50000,
            step=0.1,
            step_type="inverse",
            phi_0=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output=sgd_shapley.sgd(game,
            n_iter=200000,
            step=0.01,
            step_type="inverse",
            phi_0=False)

In [None]:
sgd_shapley=SGDShapleyNew(d=196, C=1)

sgd_shapley_output_=sgd_shapley.sgd(game,
            n_iter=200000,
            step=0.1,
            step_type="inverse",
            phi_0=False)

In [None]:
shapley_sampling=ShapleySampling(game,
                    batch_size=128,
                    n_samples=8*128,
                    detect_convergence=False,
                    thresh=0.01,
                    antithetical=False,
                    return_all=True,
                    bar=True,
                    verbose=False)

In [None]:
4*128*196

In [None]:
sgd_shapley_output.shape

In [None]:
shapley_loaded[0]['values'][0].shape

In [None]:
len(shapley_sampling["values"])

In [None]:
sgd_shapley_output.shape

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': sgd_shapley_output,
            'std': [],
            'iters': list(range(1, len(sgd_shapley_output)+1))}}, shapley_value_key=int(200000))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output],
            'std': [],
            'iters': [50000]}}, shapley_value_key=int(50000))

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0:{
            'values': [sgd_shapley_output_],
            'std': [],
            'iters': [50000]}}, shapley_value_key=int(50000))

In [None]:
shapley_sampling["iters"]

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value={0: shapley_sampling}, shapley_value_key=int(50176))

In [None]:
shapley_loaded[0]["iters"]

In [None]:
fig=plot_figure_shapley(dataset_explainer["test"], sample_idx_list=[0], 
                    shapley_value=shapley_loaded, shapley_value_key=int(3332))

In [None]:
fig=plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:{
            'values': [sgd_shapley_output],
            'std': [],
            'iters': [50000]}}, shapley_value_key=int(50000))

In [None]:
fig.savefig("aaaa.png")

In [None]:
!pwd

In [None]:
        tracking_dict = 

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(100352))

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(100352))

In [None]:
4*196*128=10k

In [None]:
128*196

In [None]:
shapley_sampling.keys()

In [None]:
shapley_sampling["values"][0]

In [None]:
shapley_sampling["values"][0].sum(axis=0)

In [None]:
shapley_sampling["iters"]

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(25088))

In [None]:
plot_figure_shapley(dataset_explainer["train"], sample_idx_list=[0, 1, 2, 3, 4], 
                    shapley_value={0:shapley_sampling}, shapley_value_key=int(100352))

In [None]:
plot_figure_shapley?

In [None]:
int(np.ceil(100 / 1))

In [None]:
shapley_sampling.keys()

In [None]:
shapley_sampling["values"]

In [None]:
shapley_sampling["iters"]

In [None]:
!gpustat

In [None]:
shapley_loaded[0]["iters"]

In [None]:
10000*0.2/60

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(320068))

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(32144))

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(32144))

In [None]:
plot_figure_shapley(dataset_explainer["train"], [0, 1, 2, 3, 4], 
                    shapley_loaded, int(32144))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(512))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(3072))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(1536))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(2048))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(3072))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(5120))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_loaded, int(100352))

In [None]:
shapley_loaded=load_shapley("logs/vitbase_imagenette_surrogate_eval_test_permutation/extract_output/test/")

In [None]:
shapley_loaded=load_shapley("logs/vitbase_imagenette_surrogate_eval_test/extract_output/test/")

In [None]:
shapley_loaded[40]

In [None]:
shapley_loaded[0]["iters"]

In [None]:
196*17

In [None]:
shapley_loaded[0]["iters"]

In [None]:
shapley_loaded[0]["iters"]

In [None]:
200116/196

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 31], 
                    shapley_loaded, int(200116))

In [None]:
plot_figure_shapley(dataset_explainer["test"], [0,  10, 20, 30, 31], 
                    shapley_loaded, int(3332))

In [None]:
1036/148

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_regexplainer_permutation_upfront_196/checkpoint-888/pytorch_model.bin", map_location="cpu")
regexplainer.load_state_dict(state_dict)
plot_figure(regexplainer, dataset_explainer["test"],  [0,  10, 20, 30, 31])
# epoch 6

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_regexplainer_permutation_upfront_3332/checkpoint-8732/pytorch_model.bin", map_location="cpu")
regexplainer.load_state_dict(state_dict)
plot_figure(regexplainer, dataset_explainer["test"],  [0,  10, 20, 30, 31])
#19

In [None]:
32*5

In [None]:
148*5

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_newsample_32/checkpoint-740/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 30, 40])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_newsample/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_newsample/checkpoint-14800/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_upfront_3200/checkpoint-14800/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
state_dict = torch.load("logs/vitbase_imagenette_objexplainer_upfront_3200/checkpoint-1480/pytorch_model.bin", map_location="cpu")
explainer.load_state_dict(state_dict)
plot_figure(explainer, dataset_explainer["test"], [0,  10, 20, 30, 40, 50, 60, 70])

In [None]:
dataset_explainer["test"][0]

In [None]:
plot_figure?

In [None]:
shapley_loaded[0]["iters"]

In [None]:
shapley_loaded[1]["values"][0].sum(axis=0)

In [None]:
shapley_loaded[1]["values"][-1].sum(axis=0)

In [None]:
dict(shapley_loaded[0], )

In [None]:
shapley_loaded[0]["iters"]

In [None]:
512*10

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 5120)

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 1536)

In [None]:
max(shapley_values_test["test"][0].keys())

In [None]:
99840+512

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 99840)

In [None]:
plot_figure_shapley(explainer, dataset["test_explainer"], [0,  10, 20, 30, 40, 50, 60, 70], 
                    shapley_values_test["test"], 1536)

In [None]:
plot_figure_shapley(explainer, dataset["validation_explainer"], [0,  250, 500,1000],
                    shapley_values["validation"], 1536)

In [None]:
plot_figure(explainer, dataset["test_explainer"], [0,  10,20,30])

In [None]:
plot_figure(explainer, dataset["test"], [0,  250, 500,1000])

In [None]:
shapley_values = torch.load(
    "logs/vitbase_imagenette_surrogate_eval/shapley_train_val.pt",
    map_location="cpu",
)

In [None]:
shapley_values_test = torch.load(
    "logs/vitbase_imagenette_surrogate_eval/shapley.pt",
    map_location="cpu",
)

In [None]:
shapley_values_test.keys()

In [None]:
########################################################
# Initalize the explainer trainer
########################################################
# Load the accuracy metric from the datasets package
metric = evaluate.load("accuracy")

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
def compute_metrics(p):
    """Computes accuracy on a batch of predictions"""
    # import ipdb

    # ipdb.set_trace()
    # print(p.predictions.shape, p.label_ids.shape)
    # return metric.compute(
    #     predictions=np.argmax(p.predictions[:, 0, :], axis=1),
    #     references=p.label_ids,
    # )
    return {}

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    masks = torch.tensor(np.array([example["masks"] for example in examples]))

    return {
        "pixel_values": pixel_values,
        "labels": labels,
        "masks": masks,
    }

explainer_trainer = Trainer(
    model=explainer,
    args=training_args,
    train_dataset=dataset["train_explainer"] if training_args.do_train else None,
    eval_dataset=dataset["validation_explainer"] if training_args.do_eval else None,
    compute_metrics=compute_metrics,
    tokenizer=explainer_image_processor,
    data_collator=collate_fn,
)

# ipdb.set_trace()
# print("explainer_trainer.label_names", explainer_trainer.label_names)
# print(explainer_trainer.evaluate(dataset["validation_explainer"]))

########################################################
# Detecting last checkpoint
#######################################################
last_checkpoint = None
if (
    os.path.isdir(training_args.output_dir)
    and training_args.do_train
    and not training_args.overwrite_output_dir
):
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif (
        last_checkpoint is not None and training_args.resume_from_checkpoint is None
    ):
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )

########################################################
# Training
#######################################################
if training_args.do_train:
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = explainer_trainer.train(resume_from_checkpoint=checkpoint)
    explainer_trainer.save_model()
    explainer_trainer.log_metrics("train", train_result.metrics)
    explainer_trainer.save_metrics("train", train_result.metrics)
    explainer_trainer.save_state()

########################################################
# Evaluation
#######################################################
if training_args.do_eval:
    metrics = explainer_trainer.evaluate()
    explainer_trainer.log_metrics("eval", metrics)
    explainer_trainer.save_metrics("eval", metrics)

########################################################
# Write model card and (optionally) push to hub
#######################################################
kwargs = {
    "finetuned_from": explainer_args.explainer_model_name_or_path,
    "tasks": "image-classification",
    "dataset": data_args.dataset_name,
    "tags": ["image-classification", "vision"],
}
if training_args.push_to_hub:
    explainer_trainer.push_to_hub(**kwargs)
else:
    explainer_trainer.create_model_card(**kwargs)

In [None]:
    
    
    ########################################################
    # Initalize the explainer trainer
    ########################################################
    # Load the accuracy metric from the datasets package
    metric = evaluate.load("accuracy")

    # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
    # predictions and label_ids field) and has to return a dictionary string to float.
    def compute_metrics(p):
        """Computes accuracy on a batch of predictions"""
        # import ipdb

        # ipdb.set_trace()
        # print(p.predictions.shape, p.label_ids.shape)
        # return metric.compute(
        #     predictions=np.argmax(p.predictions[:, 0, :], axis=1),
        #     references=p.label_ids,
        # )
        return {}

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["labels"] for example in examples])
        masks = torch.tensor(np.array([example["masks"] for example in examples]))

        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "masks": masks,
        }

    explainer_trainer = Trainer(
        model=explainer,
        args=training_args,
        train_dataset=dataset["train_explainer"] if training_args.do_train else None,
        eval_dataset=dataset["validation_explainer"] if training_args.do_eval else None,
        compute_metrics=compute_metrics,
        tokenizer=explainer_image_processor,
        data_collator=collate_fn,
    )

    # ipdb.set_trace()
    # print("explainer_trainer.label_names", explainer_trainer.label_names)
    # print(explainer_trainer.evaluate(dataset["validation_explainer"]))

    ########################################################
    # Detecting last checkpoint
    #######################################################
    last_checkpoint = None
    if (
        os.path.isdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif (
            last_checkpoint is not None and training_args.resume_from_checkpoint is None
        ):
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    ########################################################
    # Training
    #######################################################
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = explainer_trainer.train(resume_from_checkpoint=checkpoint)
        explainer_trainer.save_model()
        explainer_trainer.log_metrics("train", train_result.metrics)
        explainer_trainer.save_metrics("train", train_result.metrics)
        explainer_trainer.save_state()

    ########################################################
    # Evaluation
    #######################################################
    if training_args.do_eval:
        metrics = explainer_trainer.evaluate()
        explainer_trainer.log_metrics("eval", metrics)
        explainer_trainer.save_metrics("eval", metrics)

    ########################################################
    # Write model card and (optionally) push to hub
    #######################################################
    kwargs = {
        "finetuned_from": explainer_args.explainer_model_name_or_path,
        "tasks": "image-classification",
        "dataset": data_args.dataset_name,
        "tags": ["image-classification", "vision"],
    }
    if training_args.push_to_hub:
        explainer_trainer.push_to_hub(**kwargs)
    else:
        explainer_trainer.create_model_card(**kwargs)


if __name__ == "__main__":
    main()

In [None]:
dataset_original, labels, label2id, id2label = setup_dataset(
    data_args=data_args, other_args=other_args
)

In [None]:
setup_dataset??

In [None]:
data__test = load_dataset(
        data_args.dataset_name,
        data_args.dataset_config_name,
        cache_dir=data_args.dataset_cache_dir,
        task=None,
        token=other_args.token,
    )

In [None]:
data__test

In [None]:
dataset_original