In [None]:
###CODE FOR ALL TOKENS
from __future__ import annotations

import numpy as np
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, PreTrainedTokenizer


class SobolPerturbator:
    def __init__(self, tokenizer: PreTrainedTokenizer, baseline="[MASK]", n_perturbations=1000, proba=0.5):
        """
        - tokenizer: Hugging Face tokenizer associated with the model
        - baseline: replacement token (e.g. “[MASK]”)
        - n_perturbations: number of Monte Carlo samples
        - proba: probability of keeping a token (i.e. putting 1 in the mask)
        """
        self.tokenizer = tokenizer
        self.baseline = baseline
        self.n_perturbations = n_perturbations
        self.proba = proba

    def perturb(self, text, Sobol_indices="first order"):
        """
        Generates a mask matrix of size (n_perturbations x seq_len) and associated perturbations.

        Parameters:
          - text : original sentence (str)
          - Sobol_indices : "first order" or "total" (str)

        Returns a dictionary containing:
          - "origin perturbated inputs": dictionary containing the input_ids and attention_mask for the original perturbations.
          - "list of perturbated inputs for each token": list of dictionaries containing the input_ids and attention_mask
            for the perturbations obtained by flipping each token on the original perturbations.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Tokenize text; note that tokenized tensors have shape (1, seq_len)
        inputs_model = self.tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
        # Remove the extra batch dimension so that we have shape (seq_len,)
        input_ids = inputs_model["input_ids"].squeeze(0)
        attention_mask = inputs_model["attention_mask"].squeeze(0)
        baseline_id = self.tokenizer.convert_tokens_to_ids(self.baseline)

        seq_len = input_ids.shape[0]
        print("Number of tokens:", len(self.tokenizer.tokenize(text)), "Sequence length:", seq_len)

        # Create the original perturbations (i.e. the original text with some tokens replaced by the baseline token)
        origin_masks = []
        origin_input_ids_list = []
        origin_attention_mask_list = []
        for _ in range(self.n_perturbations):
            # Create a binary mask with probability self.proba for keeping the token
            origin_mask = torch.bernoulli(torch.full((seq_len,), self.proba)).long()
            # Replace dropped tokens with the baseline token
            origin_input_ids = input_ids * origin_mask + baseline_id * (1 - origin_mask)
            origin_masks.append(origin_mask)
            origin_input_ids_list.append(origin_input_ids)
            origin_attention_mask_list.append(attention_mask)
        # Stack the perturbations: now each tensor is of shape (n_perturbations, seq_len)
        origin_input_ids_tensor = torch.stack(origin_input_ids_list).to(device)
        origin_attention_mask_tensor = torch.stack(origin_attention_mask_list).to(device)
        origin_inputs_model = {
            "input_ids": origin_input_ids_tensor,
            "attention_mask": origin_attention_mask_tensor,
        }

        # For each token position, create perturbations by flipping that token's mask in the original perturbations
        pert_inputs_model_per_token = []
        for i in range(seq_len):
            pert_input_ids_list = []
            pert_attention_mask_list = []
            for j in range(self.n_perturbations):
                pert_mask = origin_masks[j].clone()
                # Flip the i-th bit (if 1 then 0, if 0 then 1)
                pert_mask[i] = 1 - pert_mask[i]
                # If total Sobol indices are desired, flip all bits except the i-th bit
                if Sobol_indices == "total":
                    pert_mask = 1 - pert_mask
                pert_input_ids = input_ids * pert_mask + baseline_id * (1 - pert_mask)
                pert_input_ids_list.append(pert_input_ids)
                pert_attention_mask_list.append(attention_mask)
            # Stack each token's perturbations so that tensors have shape (n_perturbations, seq_len)
            pert_inputs_model = {
                "input_ids": torch.stack(pert_input_ids_list).to(device),
                "attention_mask": torch.stack(pert_attention_mask_list).to(device),
            }
            pert_inputs_model_per_token.append(pert_inputs_model)

        return {
            "origin perturbated inputs": origin_inputs_model,
            "list of perturbated inputs for each token": pert_inputs_model_per_token,
        }


def inference(model, inputs_model):
    """
    Run the model on the inputs_model and return the logits.
    """
    with torch.no_grad():
        outputs = model(**inputs_model)
    logits = outputs.logits
    return logits


def batched_inference(model, inputs_model, batch_size=32):
    """
    Run the model on the inputs_model in batches and return the concatenated logits.

    Parameters:
      - model: the model to run inference on.
      - inputs_model: dictionary with keys "input_ids" and "attention_mask", each of shape (n_samples, seq_len)
      - batch_size: the batch size for inference.
    """
    input_ids = inputs_model["input_ids"]
    attention_mask = inputs_model["attention_mask"]
    n_samples = input_ids.shape[0]
    logits_list = []
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch_input_ids = input_ids[i : i + batch_size]
            batch_attention_mask = attention_mask[i : i + batch_size]
            outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
            logits_list.append(outputs.logits)
    logits = torch.cat(logits_list, dim=0)
    return logits


class SobolAggregator:
    """
    Sobol indices aggregation
    """

    @staticmethod
    def aggregate(f_orig, l_f_hybrid):
        """
        Compute the Sobol indices from the output on the original perturbations (f_orig)
        and the list of outputs for token-specific perturbations (l_f_hybrid).
        The Sobol indices (first order or total) depend on the chosen perturbation scheme.
        """
        # Convert to numpy if necessary
        if torch.is_tensor(f_orig):
            f_orig = f_orig.cpu().detach().numpy()

        num_tokens = len(l_f_hybrid)
        S = np.zeros(num_tokens)
        var_f = np.var(f_orig)
        # To avoid division by zero
        if var_f == 0:
            var_f = 1e-6

        # Calculate the sensitivity index for each token
        for i in range(num_tokens):
            f_hybrid = l_f_hybrid[i]
            if torch.is_tensor(f_hybrid):
                f_hybrid = f_hybrid.cpu().detach().numpy()
            delta = f_orig - f_hybrid
            S[i] = np.mean(delta**2) / var_f

        return S


# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

text = "This is an example sentence for Sobol attribution in NLP."
baseline = "[MASK]"
num_pert = 200  # Number of Monte Carlo samples
p = 0.5

perturbator = SobolPerturbator(tokenizer, baseline, num_pert, p)
perturbation_outputs = perturbator.perturb(text, Sobol_indices="first order")

# Run inference on the original perturbations
# f_orig = inference(model, inputs_model=perturbation_outputs["origin perturbated inputs"])
f_orig = batched_inference(model, inputs_model=perturbation_outputs["origin perturbated inputs"], batch_size=32)

# Run inference for token-specific perturbations for each token position
l_f_hybrid = [
    inference(model, inputs_model=perturbation_outputs["list of perturbated inputs for each token"][i])
    for i in range(len(perturbation_outputs["list of perturbated inputs for each token"]))
]

aggregator = SobolAggregator()
sobol_attribution = aggregator.aggregate(f_orig, l_f_hybrid)

print("Sobol attribution indices:", sobol_attribution)

Number of tokens: 15 Sequence length: 17
Sobol attribution indices: [0.00215907 0.0124188  0.0163542  0.02872035 0.03049477 0.15114236
 0.02000404 0.05688819 0.00708974 0.01939227 0.00886717 0.04474352
 0.00648187 0.0184624  0.01511224 0.0068806  0.19568759]


In [None]:
###CODE FOR REAL TOKENS
from __future__ import annotations  # noqa: F404

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


class SobolPerturbator:
    def __init__(self, tokenizer: PreTrainedTokenizer, baseline="[MASK]", n_perturbations=1000, proba=0.5):
        """
        - tokenizer: Hugging Face tokenizer associated with the model
        - baseline: replacement token (e.g. “[MASK]”)
        - n_perturbations: number of Monte Carlo samples
        - proba: probability of keeping a token (i.e. putting 1 in the mask)
        """
        self.tokenizer = tokenizer
        self.baseline = baseline
        self.n_perturbations = n_perturbations
        self.proba = proba

    def perturb(self, text, Sobol_indices="first order"):
        """
        Generates perturbations for the entire input and for each "real" token position only.

        Parameters:
          - text : original sentence (str)
          - Sobol_indices : "first order" or "total" (str)

        Returns a dictionary containing:
          - "origin perturbated inputs": dictionary with input_ids and attention_mask for the full-sequence perturbations.
          - "list of perturbated inputs for each token": a dict mapping each real token's position to its own perturbation inputs.
          - "real_tokens": a dict mapping each real token position to its token string.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Tokenize text with offsets (includes special tokens)
        inputs_model = self.tokenizer(text, return_offsets_mapping=True, return_tensors="pt")
        offset_mapping = (
            inputs_model["offset_mapping"].squeeze(0).tolist()
        )  # List of (start, end) pairs for each token
        input_ids = inputs_model["input_ids"].squeeze(0)  # Shape: (seq_len,)
        attention_mask = inputs_model["attention_mask"].squeeze(0)  # Shape: (seq_len,)
        all_tokens = self.tokenizer.convert_ids_to_tokens(input_ids)

        # Identify "real" tokens: those with a nonzero span in the offset mapping.
        real_indices = [i for i, (start, end) in enumerate(offset_mapping) if (end - start) > 0]
        real_tokens = {i: all_tokens[i] for i in real_indices}

        baseline_id = self.tokenizer.convert_tokens_to_ids(self.baseline)
        seq_len = input_ids.shape[0]
        print(
            "Number of all tokens:",
            len(all_tokens),
            "number of real tokens:",
            len(real_tokens),
            "Sequence length:",
            seq_len,
        )
        print("All tokens:", all_tokens)
        print("Real tokens:", real_tokens)

        # Create origin perturbations for the entire sequence.
        origin_masks = []
        origin_input_ids_list = []
        origin_attention_mask_list = []
        for _ in range(self.n_perturbations):
            # Create a binary mask with probability self.proba for keeping the token.
            origin_mask = torch.bernoulli(torch.full((seq_len,), self.proba)).long()
            origin_input_ids = input_ids * origin_mask + baseline_id * (1 - origin_mask)
            origin_masks.append(origin_mask)
            origin_input_ids_list.append(origin_input_ids)
            origin_attention_mask_list.append(attention_mask)
        origin_input_ids_tensor = torch.stack(origin_input_ids_list).to(device)
        origin_attention_mask_tensor = torch.stack(origin_attention_mask_list).to(device)
        origin_inputs_model = {
            "input_ids": origin_input_ids_tensor,
            "attention_mask": origin_attention_mask_tensor,
        }

        # For each real token position, create perturbations by flipping that token's mask.
        pert_inputs_model_per_token = {}
        for i in real_indices:
            pert_input_ids_list = []
            pert_attention_mask_list = []
            for j in range(self.n_perturbations):
                pert_mask = origin_masks[j].clone()
                # Flip the bit at token position i.
                pert_mask[i] = 1 - pert_mask[i]
                # If computing total Sobol indices, flip all bits except the i-th bit.
                if Sobol_indices == "total":
                    pert_mask = 1 - pert_mask
                pert_input_ids = input_ids * pert_mask + baseline_id * (1 - pert_mask)
                pert_input_ids_list.append(pert_input_ids)
                pert_attention_mask_list.append(attention_mask)
            pert_inputs_model = {
                "input_ids": torch.stack(pert_input_ids_list).to(device),
                "attention_mask": torch.stack(pert_attention_mask_list).to(device),
            }
            pert_inputs_model_per_token[i] = pert_inputs_model

        return {
            "origin perturbated inputs": origin_inputs_model,
            "list of perturbated inputs for each token": pert_inputs_model_per_token,
            "real_tokens": real_tokens,
        }


class SobolInferenceWrapper:
    def __init__(self, model):
        self.model = model

    def inference(self, inputs_model):
        """
        Run the model on the inputs_model and return the logits.
        """
        with torch.no_grad():
            outputs = self.model(**inputs_model)
        return outputs.logits

    def batched_inference(self, inputs_model, batch_size=32):
        """
        Run the model on the inputs_model in batches and return the concatenated logits.

        Parameters:
        - inputs_model: dictionary with keys "input_ids" and "attention_mask", each of shape (n_samples, seq_len)
        - batch_size: the batch size for inference.
        """
        input_ids = inputs_model["input_ids"]
        attention_mask = inputs_model["attention_mask"]
        n_samples = input_ids.shape[0]
        logits_list = []
        with torch.no_grad():
            for i in range(0, n_samples, batch_size):
                batch_input_ids = input_ids[i : i + batch_size]
                batch_attention_mask = attention_mask[i : i + batch_size]
                outputs = self.model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
                logits_list.append(outputs.logits)
        logits = torch.cat(logits_list, dim=0)
        return logits


class SobolAggregator:
    """
    Aggregates Sobol indices from model outputs.
    """

    @staticmethod
    def aggregate(f_orig, dict_f_hybrid):
        """
        Compute the Sobol indices from the model outputs on the origin perturbations (f_orig)
        and the token-specific perturbations (dict_f_hybrid).

        Returns a dictionary mapping the token index to its Sobol attribution index.
        """
        # Convert to numpy if necessary.
        if torch.is_tensor(f_orig):
            f_orig = f_orig.cpu().detach().numpy()
        var_f = np.var(f_orig)
        # To avoid division by zero.
        if var_f == 0:
            var_f = 1e-6
        S = {}
        for token_idx, f_hybrid in dict_f_hybrid.items():
            if torch.is_tensor(f_hybrid):
                f_hybrid = f_hybrid.cpu().detach().numpy()
            delta = f_orig - f_hybrid
            S[token_idx] = np.mean(delta**2) / var_f
        return S


# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

text = "This is an example sentence for Sobol attribution in NLP."
baseline = "[MASK]"
num_pert = 200  # Number of Monte Carlo samples
p = 0.5

# Create perturbations.
perturbator = SobolPerturbator(tokenizer, baseline, num_pert, p)
perturbation_outputs = perturbator.perturb(text, Sobol_indices="first order")

# Run inference on the full-sequence perturbations.
model_wrapper = SobolInferenceWrapper(model)
# f_orig = model_wrapper.inference(inputs_model=perturbation_outputs["origin perturbated inputs"])
f_orig = model_wrapper.batched_inference(inputs_model=perturbation_outputs["origin perturbated inputs"], batch_size=32)

# Run inference for each token-specific perturbation (only for real tokens).
l_f_hybrid = {}
for token_idx, pert_inputs in perturbation_outputs["list of perturbated inputs for each token"].items():
    l_f_hybrid[token_idx] = model_wrapper.inference(inputs_model=pert_inputs)

aggregator = SobolAggregator()
sobol_attribution = aggregator.aggregate(f_orig, l_f_hybrid)

# Display the Sobol attribution for each "real" token.
real_tokens = perturbation_outputs["real_tokens"]
print("Sobol attribution indices for real tokens:")
for idx, attr in sobol_attribution.items():
    print(f"Token {idx}: '{real_tokens[idx]}' -> {attr}")

Number of all tokens: 17 number of real tokens: 15 Sequence length: 17
All tokens: ['[CLS]', 'this', 'is', 'an', 'example', 'sentence', 'for', 'sob', '##ol', 'at', '##tri', '##bution', 'in', 'nl', '##p', '.', '[SEP]']
Real tokens: {1: 'this', 2: 'is', 3: 'an', 4: 'example', 5: 'sentence', 6: 'for', 7: 'sob', 8: '##ol', 9: 'at', 10: '##tri', 11: '##bution', 12: 'in', 13: 'nl', 14: '##p', 15: '.'}
Sobol attribution indices for real tokens:
Token 1: 'this' -> 0.012475554831326008
Token 2: 'is' -> 0.01789218559861183
Token 3: 'an' -> 0.02246691845357418
Token 4: 'example' -> 0.030135653913021088
Token 5: 'sentence' -> 0.1560068279504776
Token 6: 'for' -> 0.021225513890385628
Token 7: 'sob' -> 0.05978991463780403
Token 8: '##ol' -> 0.010796360671520233
Token 9: 'at' -> 0.019393419846892357
Token 10: '##tri' -> 0.008098594844341278
Token 11: '##bution' -> 0.04226751998066902
Token 12: 'in' -> 0.00629765959456563
Token 13: 'nl' -> 0.019798938184976578
Token 14: '##p' -> 0.015558471903204918
T