## Packages: install and import

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
%%capture
!pip install transformers datasets captum

In [4]:
from collections import defaultdict
from copy import copy, deepcopy
import gc
import json
from pathlib import Path
from tqdm.notebook import trange, tqdm
from typing import Any, Callable, Dict, Iterable, List, Tuple


from captum.attr import DeepLift, InputXGradient, IntegratedGradients, NoiseTunnel, Saliency
from datasets import load_dataset
import numpy as np
from numpy.typing import NDArray
import torch
from torch.nn.functional import softmax as torch_softmax
from transformers import BertTokenizer, BertForSequenceClassification

## Prepare

### Load models

In [5]:
model_path = Path("/content/drive/MyDrive/models_new/bert_imdb/full_trainer_20230912_custom_training_params_3/_final")

In [6]:
model = BertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

In [7]:
model.load_state_dict(torch.load(model_path / "pytorch_model.bin"))

<All keys matched successfully>

In [8]:
model.to("cuda")

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

### Classes and functions definitions

In [9]:
class BERTModelWrapper:
  def __init__(
      self, model: BertForSequenceClassification, tokenizer: BertTokenizer
    ) -> None:
    self.model = model
    self.tokenizer = tokenizer

  def __postinit__(self) -> None:
    self.model.eval()
    self.model.zero_grad()

  def tokenize(self, text: str) -> Dict[str, torch.Tensor]:
    return self.tokenizer(text, return_tensors='pt', truncation=True).to(self.model.device)

  def tokenize_to_list(self, text: str) -> Dict[str, List]:
    return self.tokenizer(text, truncation=True)

  def get_input_embeds(self, text: str) -> torch.Tensor:
    item = self.tokenize(text).to(self.model.device)
    input_embeds = self._get_input_embeds_from_ids(item["input_ids"][0])
    return input_embeds.unsqueeze(0)

  def _get_input_embeds_from_ids(self, ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings()(ids)

  def _get_input_embeds_from_ids_numpy(self, ids: NDArray) -> torch.Tensor:
    if type(ids) == np.ndarray:
      ids = torch.as_tensor(ids, device=self.model.device)
    embeds = self.model.get_input_embeddings()(ids)
    if len(embeds.shape) == 2:
      embeds = embeds.reshape(1, embeds.shape[0], embeds.shape[1])
    return embeds

  def get_outputs(
      self,
      input_embeds: torch.Tensor,
      attention_mask:
      torch.Tensor=None
    ) -> torch.Tensor:
    outputs = self.model(
        inputs_embeds=input_embeds, attention_mask=attention_mask
    )
    return outputs.logits

  def get_outputs_pred(
      self,
      input_embeds: torch.Tensor,
      attention_mask:
      torch.Tensor=None
    ) -> torch.Tensor:
    outputs = self.model(
        inputs_embeds=input_embeds, attention_mask=attention_mask
    )
    return outputs.logits

  def get_output_probabilities(
      self, input_embeds: torch.Tensor, attention_mask: torch.Tensor=None
    ) -> torch.Tensor:
    output_logits = self.get_outputs(input_embeds, attention_mask=None)
    return torch_softmax(output_logits, dim=1)

  def get_outputs_from_text(self, text: str) -> torch.Tensor:
    item = self.tokenize(text)
    item = {k: v.to(self.model.device) for k, v in item.items()}
    input_embeds = self._get_input_embeds_from_ids(item["input_ids"][0])
    input_embeds = input_embeds.unsqueeze(0)
    return self.get_outputs(
        input_embeds=input_embeds, attention_mask=item["attention_mask"]
    )

  def get_input_embeds_and_attention_mask(
      self, text: str
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    item = self.tokenize(text)
    item = {k: v.to(self.model.device) for k, v in item.items()}
    input_embeds = self._get_input_embeds_from_ids(item["input_ids"][0])
    return input_embeds.unsqueeze(0), item["attention_mask"]

In [10]:
def squeeze_attributions(attributions: torch.Tensor) -> NDArray:
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().numpy()
    return attributions

In [11]:
class Explanation:
  def __init__(
      self,
      text_id: int,
      text: str,
      input_tokens: Dict[str, NDArray],
      label: int,
      model_outputs: NDArray,
      attr: NDArray,
      delta: float,
      attr_kwargs: Dict[str, Any],
      explainer_name: str,
      perturbation_info: Dict[str, Any]
    ) -> None:
    self.text_id = text_id
    self.text = text
    self.input_tokens = input_tokens
    self.label = label
    self.model_outputs = model_outputs
    self.attr = attr
    self.delta = delta
    self.attr_kwargs = attr_kwargs
    self.explainer_name = explainer_name
    self.perturbation_info = perturbation_info

  @property
  def normalized_attr(self) -> NDArray:
    return self.attr / np.linalg.norm(self.attr, ord=1)

  @property
  def pred_label(self) -> int:
    return np.argmax(self.model_outputs)

  @property
  def model_outputs_prob(self) -> NDArray:
    return (np.exp(self.model_outputs)/np.exp(self.model_outputs).sum())

  @property
  def n_tokens(self) -> NDArray:
    return len(self.input_tokens["input_ids"])

  @property
  def n_tokens_non_baseline(self) -> NDArray:
    return np.sum(self.input_tokens["input_ids"] > 0)

  @property
  def n_tokens_baseline(self) -> NDArray:
    return np.sum(self.input_tokens["input_ids"] == 0)

  def to_json(self) -> Dict[str, Any]:
    return {
        "text_id": self.text_id,
        "text": self.text,
        "label": self.label,
        "input_tokens": {k: v.astype(str).tolist() for k, v in self.input_tokens.items()},
        "model_outputs": [str(i) for i in self.model_outputs],
        "attr": [str(i) for i in self.attr],
        "attr_kwargs": self.attr_kwargs,
        "explainer_name": self.explainer_name,
        "perturbation_info": self.perturbation_info
    }

  @classmethod
  def from_dict(cls, exp: Dict[str, Any]):
    exp["text_id"] = int(exp["text_id"])
    exp["attr"] = np.array(exp["attr"]).astype('float32')
    exp["model_outputs"] = np.array(exp["model_outputs"]).astype('float32')
    exp["input_tokens"] = {k: np.array(v).astype(int) for k, v in exp["input_tokens"].items()}
    exp["delta"] = exp.get("delta")
    return Explanation(**exp)

#### Get explanations: regular

In [12]:
def get_filename(filename: str, explainer_name: str, attr_kwargs: Dict[str, Any]) -> str:
  if explainer_name in ["LayerIntegratedGradients", "LayerIntegratedGradientsGlobal"]:
    n_steps = attr_kwargs["n_steps"]
    filename = filename.replace("default", f"n_steps_{n_steps}_default")
  elif explainer_name == "Saliency":
    abs = attr_kwargs["abs"]
    filename = filename.replace("default", f"abs_default") if abs else filename
  else:
    filename = filename
  if "nt_type" in attr_kwargs:
    filename = filename.replace("default", f"nt_type_{attr_kwargs['nt_type']}_nt_samples_{attr_kwargs['nt_samples']}")
  return filename

In [13]:
def get_baseline(input_ids: torch.Tensor, sep_token_id: int = 102, cls_token_id: int = 101, pad_token_id: int = 0):
  baseline = input_ids.clone()
  baseline[0, np.arange(1, input_ids.size()[1]-1)] = 0
  return baseline.long()

In [14]:
def get_single_attribution(
    inputs: torch.Tensor,
    additional_forward_args: Tuple[torch.Tensor],
    label: int,
    attr_callable: Callable,
    attr_kwargs: Dict[str, Any],
    explainer_name: str
  ) -> Dict[str, Any]:
  if explainer_name in ["LayerIntegratedGradients", "LayerDeepLift", "LayerDeepLiftXActivation"]:
    baselines = get_baseline(inputs)
    attr = attr_callable(inputs, target=label, baselines=baselines, additional_forward_args=additional_forward_args, **attr_kwargs)
    delta = None
  else:
    attr = attr_callable(inputs, target=label, additional_forward_args=additional_forward_args, **attr_kwargs)
    delta = None
  return {
        "label": label,
        "attr": squeeze_attributions(attr.detach().cpu()),
        "attr_kwargs": attr_kwargs,
        "delta": delta,
        "explainer_name": explainer_name,
    }

In [15]:
def get_single_explanation(
    input_id: int,
    input: Dict[str, Any],
    explainer: Any,
    explainer_name: str,
    attr_kwargs: Dict[str, Any],
    input_tokens: Dict[str, torch.Tensor],
    model_outputs: torch.Tensor,
    perturbation_info: Dict[str, Any]=None
) -> Explanation:
  _explanation = get_single_attribution(
      inputs=input_tokens["input_ids"], additional_forward_args=(input_tokens["token_type_ids"], input_tokens["attention_mask"]),
      label=input["label"], attr_callable=explainer.attribute, attr_kwargs=attr_kwargs, explainer_name=explainer_name)
  return Explanation.from_dict({
        "text_id": input_id,
        "text": input["text"],
        "input_tokens": {k: v.detach().cpu().detach().numpy()[0] if type(v) == torch.Tensor else v for k, v in input_tokens.items()},
        "model_outputs": model_outputs.detach().cpu().detach().numpy()[0],
        **_explanation,
        "perturbation_info": perturbation_info
    })

In [16]:
def get_multiple_explanations(
    input_ids: Iterable[int],
    inputs: Iterable[Dict[str, Any]],
    explainer: Any,
    bert_wrapper: BERTModelWrapper,
    explainer_name: str,
    attr_kwargs: Dict[str, Any]
  ) -> List[Explanation]:
  explanations = []
  for input_id, input in tqdm(zip(input_ids, inputs)):
    input_tokens = bert_wrapper.tokenize(input["text"])
    bert_wrapper.model.eval()
    bert_wrapper.model.zero_grad()
    model_outputs = predict(**input_tokens)
    _explanation = get_single_explanation(
        input_id=input_id,
        input=input,
        explainer=explainer,
        explainer_name=explainer_name,
        attr_kwargs=attr_kwargs,
        input_tokens=input_tokens,
        model_outputs=model_outputs
    )
    explanations.append(_explanation.to_json())
  return explanations

In [17]:
def save_explanations(explanations: Any, explanation_path: Path, filename: str) -> None:
  explanation_path.mkdir(parents=True, exist_ok=True)
  with open(explanation_path / filename, "w") as file:
    json.dump(explanations, file)

In [18]:
def create_explanations_multiple_explainers(
    explainers: Dict[str, Any],
    filenames: Dict[str, str],
    inputs: Dict[str, Any],
    input_ids: List[int],
    bert_wrapper: BERTModelWrapper,
    attr_kwargs: Dict[str, Dict[str, Any]],
    explanation_path: Path
  ) -> None:
  for explainer_name, explainer in tqdm(explainers.items()):
    for _attr_kwargs in attr_kwargs.get(explainer_name, [{}]):
      explanations = get_multiple_explanations(
        input_ids=input_ids,
        inputs=inputs,
        explainer=explainer,
        bert_wrapper=bert_wrapper,
        explainer_name=explainer_name,
        attr_kwargs=_attr_kwargs,
        )
      filename = get_filename(filenames[explainer_name], explainer_name, _attr_kwargs)
      path = explanation_path / filename
      save_explanations(explanations, explanation_path, filename)
    del explainer
    gc.collect()

#### Get explanations: perturbed

In [19]:
def get_random_token_indices(input_len: int, n_samples: int, n_tokens: int, random_seed: int = 123) -> NDArray:
  np.random.seed(random_seed)
  return np.array([np.random.choice(range(1, input_len-1), size=min(n_tokens, input_len-3), replace=False) for i in range(n_samples)])

In [20]:
input_tokens = {k: np.array(v) for k, v in tokenizer("women are weaker smaller and less intelligent").items()}

In [21]:
def replace_tokens_with_baseline(input_tokens: Dict[str, NDArray[Any]], token_indices: Iterable[int], baseline_value: int=0) -> Dict[str, NDArray[Any]]:
  new_input_tokens = {}
  for k, v in input_tokens.items():
    vv = deepcopy(v)
    vv[token_indices] = np.full(len(token_indices), baseline_value)
    new_input_tokens[k] = torch.Tensor(vv).unsqueeze(dim=0).to(model.device).long()
  return new_input_tokens

In [22]:
def get_token_indices_completement(input_len: int, token_indices_all_samples: NDArray) -> List[NDArray]:
  completement = []
  for sample in token_indices_all_samples:
    completement.append(np.delete(np.delete(np.arange(0, input_len), sample), [0, input_len-len(sample)-1]))
  return completement

In [23]:
def random_n_tokens_baseline_replacement_with_completement(
    explanation: Explanation,
    bert_wrapper: BERTModelWrapper,
    explainers: Dict[str, Any],
    attr_kwargs: Dict[str, Any],
    n_samples=5,
    n_tokens_to_replace=1,
    random_seed=123,
    completement=True
  ) -> Tuple[Dict[str, List[Explanation]], Dict[str, List[Explanation]]]:
  explanations, compl_explanations = defaultdict(list), defaultdict(list)

  input_len = len(explanation.input_tokens["input_ids"])
  all_token_indices_to_replace = get_random_token_indices(input_len, n_samples, n_tokens_to_replace, random_seed * explanation.text_id)
  all_token_indices_completements = get_token_indices_completement(input_len, all_token_indices_to_replace)
  for i, token_indices in enumerate(all_token_indices_to_replace):
    input_tokens = replace_tokens_with_baseline(explanation.input_tokens, token_indices)
    model_outputs = predict(**input_tokens)
    for explainer_name, explainer in explainers.items():
      for _attr_kwargs in attr_kwargs.get(explainer_name, [{}]):
        _explanation = get_single_explanation(
            input_id=explanation.text_id,
            input={"text": explanation.text, "label": explanation.label},
            explainer=explainer,
            explainer_name=explainer_name,
            attr_kwargs=_attr_kwargs,
            input_tokens=input_tokens,
            model_outputs=model_outputs,
            perturbation_info={
                "name": "random_token_baseline_replacement",
                "lvl": str(i), "n_tokens_to_replace":
                str(n_tokens_to_replace),
                "attr_kwargs": _attr_kwargs
            }
        )
        explanations[explainer_name].append(_explanation.to_json())
        if completement:
          completement_token_indices = all_token_indices_completements[i]
          compl_input_tokens = replace_tokens_with_baseline(explanation.input_tokens, completement_token_indices)
          compl_model_outputs = predict(**compl_input_tokens)
          compl_explanation = get_single_explanation(
              input_id=explanation.text_id,
              input={"text": explanation.text, "label": explanation.label},
              explainer=explainer,
              explainer_name=explainer_name,
              attr_kwargs=_attr_kwargs,
              input_tokens=compl_input_tokens,
              model_outputs=compl_model_outputs,
              perturbation_info={
                  "name": "random_token_baseline_replacement",
                  "lvl": str(i),
                  "n_tokens_to_replace": str(input_len - n_tokens_to_replace),
                  "attr_kwargs": _attr_kwargs
              }
          )
          compl_explanations[explainer_name].append(compl_explanation.to_json())
  return explanations, compl_explanations

In [24]:
def random_n_tokens_baseline_replacement_with_completement_multiple_explanations(
    explanations: List[Explanation],
    explainers: Dict[str, Any],
    bert_wrapper: BERTModelWrapper,
    n_samples: int,
    attr_kwargs: Dict[str, List[Dict]],
    n_tokens_to_replace: int,
    completement=True
):
  perturbed_exp = []
  for j, exp_dict in tqdm(enumerate(explanations)):
    exp = Explanation.from_dict(exp_dict)
    if n_tokens_to_replace > len(exp.input_tokens["input_ids"])-3:
      break
    explanations, compl_explanations = random_n_tokens_baseline_replacement_with_completement(
        explanation=exp,
        bert_wrapper=bert_wrapper,
        explainers=explainers,
        attr_kwargs=attr_kwargs,
        n_samples=n_samples,
        n_tokens_to_replace=n_tokens_to_replace,
        completement=completement
      )
    perturbed_exp.append(
      {
        "explanations": explanations,
        "completements": compl_explanations
      }
    )
  return perturbed_exp

In [25]:
def get_most_relevant_attribution(attr: NDArray, label: int):
  # taking [CLS] and [SEP] into account
  # return np.argmax(attr[1:-1]) + 1 if label > 0 else np.argmin(attr[1:-1]) + 1
  n = np.argmax(np.abs(attr[1:-1]))
  return n + 1

In [26]:
class WrongExplainerError(Exception):
  ...

def recurrent_token_baseline_replacement(
    explanation: Explanation,
    bert_wrapper: BERTModelWrapper,
    explainer_name: str,
    explainer: Any,
    attr_kwargs: Dict[str, Any],
    max_n_samples: int,
    while_same_pred_label=True
  ):
  if explanation.explainer_name != explainer_name:
    raise WrongExplainerError
  curr_exp = deepcopy(explanation)
  input_token_len = len(explanation.input_tokens["input_ids"])-2
  explanations, max_attr_indices = [], []
  for i in range(min(max_n_samples, input_token_len)):
    if while_same_pred_label and np.argmax(curr_exp.model_outputs) != explanation.pred_label:
      break
    curr_exp_max_attr_ind = get_most_relevant_attribution(attr=curr_exp.attr, label=explanation.label)
    max_attr_indices.append(curr_exp_max_attr_ind)
    input_tokens = replace_tokens_with_baseline(explanation.input_tokens, max_attr_indices)
    model_outputs = predict(**input_tokens)
    _explanation = get_single_explanation(
        input_id=curr_exp.text_id,
        input={"text": curr_exp.text, "label": curr_exp.label},
        explainer=explainer,
        explainer_name=explainer_name,
        attr_kwargs=attr_kwargs,
        input_tokens=input_tokens,
        model_outputs=model_outputs,
        perturbation_info={
            "name": "recurrent_token_baseline_replacement",
            "lvl": str(i),
            "max_n_samples": str(max_n_samples),
            "while_same_pred_label": while_same_pred_label
        }
    )
    curr_exp = _explanation
    explanations.append(_explanation.to_json())
  return explanations

### Create helper objects

In [27]:
version_date = "20230924"

In [28]:
explanation_path = Path(f"/content/drive/MyDrive/explanations/bert_imdb/model_20230912/{version_date}/")
explanation_path.mkdir(parents=True, exist_ok=True)

In [29]:
bert_wrapper = BERTModelWrapper(model, tokenizer)

Commented part below should be used once and selected inputs should be saved in order to avoid loading the entire dataset with each run.

In [30]:
# reviews = load_dataset("imdb")["test"]

In [31]:
# np.random.seed(123)
# selected_input_ids = np.random.choice(25000, 1000, replace=False)
# selected_inputs = reviews.select(selected_input_ids)

In [32]:
# with open(explanation_path / "selected_input_ids.json", "w") as file:
#   json.dump(list(selected_input_ids.astype(str)), file)

In [35]:
# with open(explanation_path / "selected_inputs.json", "w") as file:
#   json.dump(selected_inputs.to_dict(), file)

In [36]:
with open(explanation_path / "selected_input_ids.json", "r") as file:
  selected_input_ids = np.array(json.load(file)).astype(int)

In [37]:
with open(explanation_path / "selected_inputs.json", "r") as file:
  selected_inputs = json.load(file)

In [38]:
selected_inputs = [{"text": text.replace("<br />", ""), "label": label} for text, label in zip(selected_inputs["text"], selected_inputs["label"])]

## Generate explanations

#### Helper functions

In [40]:
def forward_func(inputs, token_type_ids=None, attention_mask=None):
  return predict(inputs, token_type_ids=token_type_ids, attention_mask=attention_mask)

def predict(input_ids, token_type_ids=None, attention_mask=None):
  model.eval()
  model.zero_grad()
  output = model(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
  return output.logits

In [41]:
class WrapperModel(torch.nn.Module):
  def __init__(self, model):
    super(WrapperModel, self).__init__()
    self.m = model

  def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    return self.m(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask).logits

### Regular

In [42]:
from captum.attr import LayerDeepLift, LayerGradientXActivation, LayerIntegratedGradients

In [43]:
dl_wrapper = WrapperModel(model)

In [44]:
explainers_regular = {
    "LayerDeepLift": LayerDeepLift(dl_wrapper, model.bert.embeddings, multiply_by_inputs=False),
    "LayerDeepLiftXActivation": LayerDeepLift(dl_wrapper, model.bert.embeddings, multiply_by_inputs=True),
    "LayerGradient": LayerGradientXActivation(forward_func, model.bert.embeddings, device_ids=[0], multiply_by_inputs=False),
    "LayerGradientXActivation": LayerGradientXActivation(forward_func, model.bert.embeddings, device_ids=[0], multiply_by_inputs=True),
    "LayerIntegratedGradients": LayerIntegratedGradients(forward_func, model.bert.embeddings, device_ids=[0], multiply_by_inputs=False),
    "LayerIntegratedGradientsXActivation": LayerIntegratedGradients(forward_func, model.bert.embeddings, device_ids=[0], multiply_by_inputs=True),

}

In [45]:
attr_kwargs = {
    "LayerIntegratedGradientsXActivation": [{"n_steps": n_steps} for n_steps in [5, 10, 15]],
    "LayerIntegratedGradients": [{"n_steps": n_steps} for n_steps in [5, 10, 15]],
}

In [46]:
filenames_regular = {explainer_name: f"{explainer_name.replace('Layer','').lower()}_default_{version_date}.json" for explainer_name in explainers_regular.keys()}

#### Example explanations

In [47]:
example_input_ids = [-1, -2, -3, -4]
example_inputs = [
    {"text": "It was a great movie, I loved it!", "label": 1},
    {"text": "It was a terrible movie, I hated it!", "label": 0},
    {"text": "This movie was not great, I did not like it.", "label": 0},
    {"text": "This movie was not terrible, I did not hate it.", "label": 1}
    ]

In [48]:
filenames_examples = {explainer_name: f"{explainer_name.replace('Layer','').lower()}_default_examples_{version_date}.json" for explainer_name in explainers_regular.keys()}

In [49]:
create_explanations_multiple_explainers(
    explainers=explainers_regular,
    filenames=filenames_examples,
    inputs=example_inputs,
    input_ids=example_input_ids,
    bert_wrapper=bert_wrapper,
    attr_kwargs=attr_kwargs,
    explanation_path=explanation_path / "examples_extra"
)

  0%|          | 0/6 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

#### Backpropagation-based

In [55]:
filenames_regular = {explainer_name: f"{explainer_name.replace('Layer','').lower()}_default_{version_date}.json" for explainer_name in explainers_regular.keys()}

In [56]:
create_explanations_multiple_explainers(
    explainers=explainers_regular,
    filenames=filenames_regular,
    inputs=selected_inputs,
    input_ids=selected_input_ids,
    bert_wrapper=bert_wrapper,
    attr_kwargs=attr_kwargs,
    explanation_path=explanation_path / "regular"
)

  0%|          | 0/6 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

### Perturbed

#### Random baseline replacement

In [57]:
explainers_simple = {
    explainer_name: explainer for explainer_name, explainer in explainers_regular.items()
    if explainer_name in ["LayerGradient", "LayerDeepLift", "LayerDeepLiftXActivation", "LayerGradientXActivation"]
}
attr_kwargs_simple = {}

In [58]:
ig_n_steps = [5, 10, 15]
attr_kwargs_igxa = {"LayerIntegratedGradientsXActivation": [{"n_steps": n_steps} for n_steps in ig_n_steps]}
attr_kwargs_ig = {"LayerIntegratedGradients": [{"n_steps": n_steps} for n_steps in ig_n_steps]}

In [61]:
with open(explanation_path / f"regular/deeplift_default_{version_date}.json", "r") as file:
    default_explanations = json.load(file)

In [62]:
exp_sets = {
    "simple": {"explainers": explainers_simple, "attr_kwargs": attr_kwargs_simple},
    "igxa": {"explainers": {"LayerIntegratedGradientsXActivation": explainers_regular["LayerIntegratedGradientsXActivation"]}, "attr_kwargs": attr_kwargs_igxa},
    "ig": {"explainers": {"LayerIntegratedGradients": explainers_regular["LayerIntegratedGradients"]}, "attr_kwargs": attr_kwargs_ig}
    }

In [63]:
max_n_tokens_to_replace = 2
n_samples=5
prefix = "random_token_baseline_replacement"
final_exps_random_n = defaultdict(dict)
for n_tokens_to_replace in tqdm(range(1, max_n_tokens_to_replace+1)):
  for exp_set_name, exp_set in tqdm(exp_sets.items()):
    explanations = random_n_tokens_baseline_replacement_with_completement_multiple_explanations(
        explanations=default_explanations,
        explainers=exp_set["explainers"],
        bert_wrapper=bert_wrapper,
        n_samples=n_samples,
        attr_kwargs=exp_set["attr_kwargs"],
        n_tokens_to_replace=n_tokens_to_replace,
        completement=False
    )
    final_exps_random_n[n_tokens_to_replace][exp_set_name] = explanations
    filename = f"{prefix}_{exp_set_name}_n_samples_{n_samples}_n_tokens_{n_tokens_to_replace}_{version_date}.json"
    save_explanations(explanations, explanation_path / "random_baseline", filename)

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [66]:
len(exps)

15

In [67]:
assert set(final_exps_random_n.keys()) == set(range(1, max_n_tokens_to_replace+1))
for n, exp_set_dict in final_exps_random_n.items():
  random_tokens_n = {i: defaultdict(list) for i in range(0, 5)}
  for exp_set_name, exp_set in exp_set_dict.items():
    for exp_dict in exp_set:
      assert all([len(v) == 0 for v in exp_dict["completements"].values()])
      for explainer_name, exps in exp_dict["explanations"].items():
        if explainer_name not in ["LayerIntegratedGradients", "LayerIntegratedGradientsXActivation"]:
          assert len(exps) == n_samples
          for i, exp in enumerate(exps):
            baselines = np.argwhere(np.array(exp["input_tokens"]["input_ids"]) == '0')[:, 0]
            assert len(baselines) == n
            random_tokens_n[i][exp["text_id"]].append(baselines)
        else:
          assert len(exps) == len(ig_n_steps) * n_samples
          for i, exp in enumerate(exps):
            baselines = np.argwhere(np.array(exp["input_tokens"]["input_ids"]) == '0')[:, 0]
            assert len(baselines) == n
            random_tokens_n[int(i/3)][exp["text_id"]].append(baselines)
  for i, v in random_tokens_n.items():
    for text_id, vv in v.items():
      assert all([all(vv[0] == p) for p in vv])

### Recurrent token baseline replacement

In [68]:
def load_explanations(explanation_path: Path, filename: str):
  with open(explanation_path / filename, "r") as file:
    explanations = json.load(file)
  return [Explanation.from_dict(exp) for exp in explanations]

#### Regular

In [69]:
explainer_info = {explainer_name: f"{explainer_name.replace('Layer','').lower()}_default_examples_{version_date}.json" for explainer_name in explainers_regular.keys()}

In [70]:
max_n_samples = 20

prefix = "recurrent_token_baseline_replacement"
folder = "recurrent_baseline"

final_exps = defaultdict(list)
for explainer_name, explainer in tqdm(explainers_regular.items()):
  for _attr_kwargs in attr_kwargs.get(explainer_name, [{}]):
    filename = get_filename(filenames_regular[explainer_name], explainer_name, _attr_kwargs)
    explanations = load_explanations(explanation_path / "regular", filename)
    perturbed_explanations = []
    for exp in tqdm(explanations):
      perturbed_explanations.append(recurrent_token_baseline_replacement(
          explanation=exp,
          bert_wrapper=bert_wrapper,
          explainer=explainer,
          explainer_name=explainer_name,
          attr_kwargs=_attr_kwargs,
          max_n_samples=max_n_samples,
          while_same_pred_label=False
      ))
    final_exps[explainer_name].append(perturbed_explanations)
    explainer_name_info = explainer_name.lower().replace('layer','').lower()
    explainer_name_info = f"{explainer_name_info}_n_steps_{_attr_kwargs['n_steps']}" if explainer_name in attr_kwargs.keys() else explainer_name_info
    filename_perturbed = f"{prefix}_{explainer_name_info}_n_samples_{max_n_samples}_diff_pred_{version_date}.json"
    save_explanations(perturbed_explanations, explanation_path / folder, filename_perturbed)

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [72]:
for explainer_name, perturbed_explanations_list in final_exps.items():
  for perturbed_explanations in perturbed_explanations_list:
    assert len(perturbed_explanations) == len(selected_inputs)
    for exps in perturbed_explanations:
      assert len(exps) == max_n_samples
      diffs = list(np.argwhere(np.array(exps[0]["input_tokens"]["input_ids"]) == '0')[:, 0])
      for i in exps:
        a = set(np.argwhere(np.array(i["input_tokens"]["input_ids"]) == '0')[:, 0])
        attr_max = np.argmax(np.abs(np.array(i["attr"][1:-1]).astype(float))) + 1
        diffs.append(attr_max)
        assert a.difference(set(diffs)) == set()