## Sentence Transformers Explainer

This notebook aims to clarify our own implementation of a 'Sentence Transformers' explainability method. It would be focused in a task of 'Sentence similarity' using this models.
For this we will be basing our work in two main sources:
* The 'Integrated Gradients' method, a numerical method that stablish a cuantitative relation between inputs and outputs in Deep learning models. For further context refer to the [original paper](https://arxiv.org/pdf/1703.01365)

* Captum, an Open source python library that implements multiple 'explainability' methods, on top of Pytorch models. Specially, we take as reference the [Bert Tutorial](https://captum.ai/tutorials/Bert_SQUAD_Interpret)



We are going to do now a small mathematical review of the Integrated Gradients method, that will be necessary afterwards to understand every component of our method:

As pointed in the Paper itself, this method 'aims' to  *Attribute the output (prediction) of a network to it's inputs*. This means that tries to 'decompose' the predicted value of the model in terms of the Input variables. This would be equivalent that each coefficient in a Linear Regression model. In this case, the idea behind the method is to get those 'Attributions', using a reference point in the space of inputs, called baseline, and compute a *path integral of the gradients along a straightline from the baseline x' to the input x* [(1)](https://arxiv.org/pdf/1703.01365):

$$
\text{IntegratedGrads}(x) := (x - x') \times \int_{\alpha=0}^{1} \frac{\partial F(x' + \alpha \times (x - x'))}{\partial x_i} d\alpha
$$


One of the most relevant and important consequences of the Integrated Gradients, is that, due to the properties of Path Integrals, the sum of the Attrobutions of each of the inputs have to correspond to the difference of the output between the chosen Baseline and the real input.

$$
\sum_{i=1}^{n} \text{IntegratedGrads}_i(x) = F(x) - F(x')
$$



Therefore, we can see that the Baseline that we choose is really influencing the method. Ideally we should pick one that satisfies F(x')~0, but we will see later that for most 'Sentence Transformers' models, this is not feasible.





## Define our model

The first step will be to define our model. Even if we are using sentence transformers model, we will be using it throught their Pytorch backbone, because we need to be able to use the 'tokens' as inputs directly.



In [202]:
# import pandas as pd
import logging
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# Set up logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load a pre-trained sentence-transformers model
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model.to(device)
# model.eval()
# model.zero_grad()

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')



Now we need to define a 'forward' function, that is going to generate the output of our model, and takes as input the exact inputs of the model. Therefore this inputs should be the tokens of our sentences, because those are the real inputs of our model and not a 'sentence'. This forward function will be used afterwards by the captum method to get the numerical gradients of the model transformation.

Even if the tokens are already numerical features, we do know that they don't have a real numerical meaning. Therefore, we will be using LayerIntegratedGradients, that allows to get the 'Attributions' for each of the Inputs/Outputs of a specified layer of the model. In this case we are interested in getting the Attributions of the Embeddings layer, so we can get the relevance of each word (token) based in the Attributions of each dimension of their embeddings.





In [235]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def predict(input_ids, token_type_ids=None, attention_mask=None):
    """
    Predicts the similarity between pairs of sentences.
    
    Parameters:
    - input_ids: A tensor of shape [number_of_pairs, 2, seq_len] containing pairs of sentences.
    - token_type_ids: Optional token type ids (same shape as input_ids).
    - attention_mask: Optional attention mask (same shape as input_ids).
    
    Returns:
    - similarities: A tensor of shape [number_of_pairs, 1] containing similarity scores for each pair.
    """
    
    num_pairs = input_ids.shape[0]
    seq_len = input_ids.shape[-1]
    
    # Flatten the input for batch processing
    input_ids = input_ids.view(num_pairs * 2, seq_len)
    
    if token_type_ids is not None:
        token_type_ids = token_type_ids.view(num_pairs * 2, seq_len)
    if attention_mask is not None:
        attention_mask = attention_mask.view(num_pairs * 2, seq_len)


    # First we get the embeddings for all the inputs ids that have been flatten
    model_output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
    sentence_embeddings = mean_pooling(model_output, attention_mask)
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    
    # Then we return them to their original shape (n_pairs,n_sentences, length)
    sentence_embeddings = sentence_embeddings.view(num_pairs, 2, -1)
    similarities = torch.nn.functional.cosine_similarity(sentence_embeddings[:, 0, :], sentence_embeddings[:, 1, :], dim=1)

    return similarities.view(-1, 1)


In [232]:
def sent_trans_pos_forward_func(input_ids,token_type_ids,attention_mask):
    logging.info(f"Method calling the predict function with shape {input_ids.shape}")
    pred = predict(input_ids,token_type_ids,attention_mask)
    logging.info(f"Value predicted---> {pred}")
    return pred

## Input for the function and Baselines

Now that we have configured our forward function, we are going to build several functions to get the inputs for this function.
This was extracted from the Captum tutorial and slightly modified for our current use case. Let's go over them:


## Input IDs and Ref Input IDs

Before starting with this specific function, let's remind the objective, and introduce some special tokens that will help us with it. As we said in the introduction, the method is trying to 'Attribute' the Output of a model to it's inputs, by calculating a Path integral from a Baseline. As we are using Sentence transformers models there are two inherent problems:

* There is no Token that produces a 0 vector embedding.

* We are computing similarity between sentences, so we can not use both of them as Baseline, it would lead to a 1 prediction (comparing the same starting point).

Regarding the first point, the usual approach is to use the 'pad token' of our tokenizer, used to 'fill in gaps' and homogenize sentences of different dimensions. This is the closest token to an 'empty' one. There are other feasible options that we will see later.

For the second problem, we will use a sentence (normally the query of the `phrase that is used to search/trigger similarities in a bigger set) as a static reference, and will see which words/tokens caused that score. Also in the Baseline sentence, we will replace every token with the reference token, except the Sep token, used to indicate when a sentence finishes, and the cls token, that is commonly used to aggregate the sequence representation.




In [None]:
ref_token_id = tokenizer.pad_token_id
sep_token_id = tokenizer.sep_token_id
cls_token_id = tokenizer.cls_token_id

In [223]:
def construct_input_ref_pair(sentences, ref_token_id=0, process_both=True):
    """
    Construct input_ids and ref_input_ids for a pair of sentences, excluding the first one
    if specified.
    
    Input:
    - sentences: A list of 2 sentences.
    - ref_token_id: Token to be used as reference.
    - process_both: Flag to determine if processing both phrases or just the second
    
    Output:
    - input_ids: Tensor of shape [1, 2, max_length]
    - ref_input_ids: Tensor of shape [1, 2, max_length]
    """

    # We could generalize this method to any dimension of pairs of sentences
    assert len(sentences) == 2

    # Tokenize the pair
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    input_ids = encoded_input['input_ids']  # shape: [2, max_length]
    
    # Create a reference tensor using the ref token
    ref_input_ids = torch.full(input_ids.shape, ref_token_id)
    
    # Reintroduce the classification and sep tokens
    for i, input_id_sequence in enumerate(input_ids):
        if not process_both and i == 0:  # Skip processing for the first sentence if process_both is False
            ref_input_ids[i] = input_id_sequence
        else:
            ref_input_ids[i, input_id_sequence == cls_token_id] = cls_token_id
            ref_input_ids[i, input_id_sequence == sep_token_id] = sep_token_id

    # Reshape input_ids and ref_input_ids to add an additional level of depth [1, 2, max_length]
    input_ids = input_ids.unsqueeze(0)  # shape: [1, 2, max_length]
    ref_input_ids = ref_input_ids.unsqueeze(0)  # shape: [1, 2, max_length]

    return input_ids, ref_input_ids

### Let's tokenize a couple of sentences

In [254]:
sentences = ['What are sentence transformers?', 'A deep learning model that leverages the Transformers arquitecture to generate sentence embeddings and related tasks']

input_ids, ref_input_ids = construct_input_ref_pair(sentences, ref_token_id=ref_token_id, process_both=False)

In [246]:
print("Inputs Id's---->",input_ids)

print("Ref Inputs Id's---->",ref_input_ids)

print("Inputs Id's shape---->",input_ids.shape)

Inputs Id's----> tensor([[[  101,  2054,  2024,  6251, 19081,  1029,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0],
         [  101,  1037,  2784,  4083,  2944,  2008, 21155,  2015,  1996, 19081,
          12098, 15549, 26557, 11244,  2000,  9699,  6251,  7861,  8270,  4667,
           2015,  1998,  3141,  8518,   102]]])
Ref Inputs Id's----> tensor([[[  101,  2054,  2024,  6251, 19081,  1029,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0],
         [  101,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,   102]]])
Inputs Id's shape----> torch.Size([1, 2, 25])


The inputs ids are nothing more than the tokenized sentences, 
while the ref inputs idsare the references for our baseline. Is important to notice that our inputs are going to be of the shape [num_of_pairs, 2 (sentences), lenght of sentences], because the input of our model is two sentences that we want to compare.


## Token Type Id's

In our use case as we are using individual functions, our token type IDs is just a tensor with all 0's and the same shape of the input id's.

In [247]:
def construct_token_type_ids(input_ids):
    # Assume no token type differentiation if not needed
    token_type_ids = torch.zeros_like(input_ids)
    ref_token_type_ids = torch.zeros_like(input_ids)
    
    return token_type_ids, ref_token_type_ids

In [248]:
token_type_ids, ref_token_type_ids = construct_token_type_ids(input_ids)

In [249]:
print("Token type ids--->",token_type_ids)

print("Ref Token type ids--->",ref_token_type_ids)

Token type ids---> tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0]]])
Ref Token type ids---> tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0]]])


## Mask token Id's

Lastly, we need to provide the mask token ID's as our input. This is specifying whether a 'Input token id' should be taken into consideration by the model. It assigns 0's only to the pad token positions.

In [250]:
def construct_attention_mask(input_ids):
    return (input_ids != tokenizer.pad_token_id).long()

In [251]:
attention_mask = construct_attention_mask(input_ids)


print("Attention Mask--->",attention_mask)

Attention Mask---> tensor([[[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1]]])


Now that we have every input that our forward function requires for our sentences, as well as the Ref Inputs, we can define our Attributions method, but first let's test our it.

In [267]:
# Assuming `model` and `mean_pooling` are defined, and predict function is implemented
cosine_similarities = sent_trans_pos_forward_func(input_ids, token_type_ids, attention_mask)

# Print the output
print("Cosine similarities:", cosine_similarities)

baseline_cos_similarities = sent_trans_pos_forward_func(ref_input_ids, token_type_ids, attention_mask)

2024-10-05 20:33:41,603 - INFO - Method calling the predict function with shape torch.Size([1, 2, 25])
2024-10-05 20:33:41,694 - INFO - Value predicted---> tensor([[0.5704]], grad_fn=<ViewBackward0>)
2024-10-05 20:33:41,699 - INFO - Method calling the predict function with shape torch.Size([1, 2, 25])
2024-10-05 20:33:41,723 - INFO - Value predicted---> tensor([[0.0650]], grad_fn=<ViewBackward0>)


Cosine similarities: tensor([[0.5704]], grad_fn=<ViewBackward0>)


## Integrated Gradients

Let's break down the different arguments:

* sent_trans_pos_forward_func: Forward function that is going to be used to calculate the gradients.

* model.embeddings: Layer with respect to whose outputs we are going to calculate the allocations

* Inputs_ids: The real input tokens

* Baselines: The inicial point (neutral) from which the path integrals will start

* n_steps: Steps taken to calculate the integral as a numerical approximation.



In [255]:
lig = LayerIntegratedGradients(sent_trans_pos_forward_func, model.embeddings)

attributions_start = lig.attribute(inputs = input_ids,
                                  baselines = ref_input_ids,
                                  additional_forward_args=(token_type_ids, attention_mask),
                                  n_steps =100)


2024-10-05 20:01:41,677 - INFO - Method calling the predict function with shape torch.Size([1, 2, 25])
2024-10-05 20:01:41,710 - INFO - Value predicted---> tensor([[0.5704]])
2024-10-05 20:01:41,714 - INFO - Method calling the predict function with shape torch.Size([1, 2, 25])
2024-10-05 20:01:41,736 - INFO - Value predicted---> tensor([[0.0650]])
2024-10-05 20:01:41,768 - INFO - Method calling the predict function with shape torch.Size([100, 2, 25])
2024-10-05 20:01:42,396 - INFO - Value predicted---> tensor([[0.0649],
        [0.0649],
        [0.0649],
        [0.0649],
        [0.0648],
        [0.0648],
        [0.0647],
        [0.0647],
        [0.0646],
        [0.0645],
        [0.0645],
        [0.0644],
        [0.0643],
        [0.0642],
        [0.0641],
        [0.0640],
        [0.0639],
        [0.0638],
        [0.0638],
        [0.0638],
        [0.0637],
        [0.0638],
        [0.0639],
        [0.0640],
        [0.0642],
        [0.0644],
        [0.0646],
      

## Let's analyze the result

In [258]:
attributions_start.shape

torch.Size([2, 25, 384])

Starting by the shape of the attributions, we can see that we have one for each dimension of the embeddings (384) for each token (25) and each sentence (2), as expected. Now if we take a look into the values for each token:

In [259]:
attributions_start.sum(-1)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0232,  0.0210,  0.0397,  0.0149,  0.0105,  0.0105,  0.0087,
          0.0079,  0.1634,  0.0075,  0.0087, -0.0011,  0.0111,  0.0013,  0.0146,
          0.0956,  0.0078,  0.0110,  0.0098,  0.0116,  0.0044,  0.0172,  0.0061,
          0.0000]], dtype=torch.float64)

We realize that for the first sentence we do not have any value, that is because the baselines and the real inputs are the same, so no attributions is calculated. The same happens with cls and sep tokens. Let's see what happens when we group the attributions of the full sentence:


In [260]:
attributions_start.sum(-1).sum(-1)

tensor([0.0000, 0.5055], dtype=torch.float64)

In [262]:
attributions_start_sum = attributions_start.sum(-1)

In [282]:
attributions_start_sum

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.0000,  0.0232,  0.0210,  0.0397,  0.0149,  0.0105,  0.0105,  0.0087,
          0.0079,  0.1634,  0.0075,  0.0087, -0.0011,  0.0111,  0.0013,  0.0146,
          0.0956,  0.0078,  0.0110,  0.0098,  0.0116,  0.0044,  0.0172,  0.0061,
          0.0000]], dtype=torch.float64)

This value, as explained in the introduction, corresponds with the difference between the score of the real input minus the score at the baseline. Therefore, the difference between a score indicating the similarity of our sentence with a 'neutral' sentence, and the real score has been attributed to the different inputs. This is quite powerfull as is telling us which words had more weight in going from a low score to the real prediction of the models.

# Let's Visualize the attributions

First we need to relate the tokens and positions to the original 'words'


In [265]:
indices_0 = input_ids[0][0].detach().tolist()
indices_1 = input_ids[0][1].detach().tolist()
all_tokens_1 = input_ids[0][1].detach().tolist()
all_tokens_0 = tokenizer.convert_ids_to_tokens(indices_0)
all_tokens_1 = tokenizer.convert_ids_to_tokens(indices_1)

In [261]:
from explainability_visual_utils import visualize_text_v2

In [273]:
## Let's calculate the delta value
# Diference in score between baseline and real input
diff = cosine_similarities.item() - baseline_cos_similarities.item()

## Difference for all token attributions
attrib = attributions_start.sum(-1).sum(-1)[1].item()

In [275]:
delta = diff - attrib

In [283]:
first_position_vis = viz.VisualizationDataRecord(
                        attributions_start_sum[1],
                        cosine_similarities[0].item(),
                        round(cosine_similarities[0].item()),
                        1.0,
                        'Positive',
                        attributions_start_sum.sum(),       
                        all_tokens_1,
                        convergence_score = delta )

In [284]:
visualize_text_v2([first_position_vis])

Generated rows---> ['<tr><th>True Label</th><th>Predicted Label</th><th>Word Importance</th>', "<tr><td style='padding: 5px 10px;'>1.0</td><td style='padding: 5px 10px;'>1 (0.57)</td><td><span class='word' style='padding: 2px 4px;'>[CLS]</span> <span class='word' style='background-color: rgba(60, 179, 113, 0.023186476710182372); color: white; padding: 2px 4px; border-radius: 4px;' data-tooltip='Importance: 0.02'>a</span> <span class='word' style='background-color: rgba(60, 179, 113, 0.020956848192606382); color: white; padding: 2px 4px; border-radius: 4px;' data-tooltip='Importance: 0.02'>deep</span> <span class='word' style='background-color: rgba(60, 179, 113, 0.03970290993999699); color: white; padding: 2px 4px; border-radius: 4px;' data-tooltip='Importance: 0.04'>learning</span> <span class='word' style='background-color: rgba(60, 179, 113, 0.01492500990405346); color: white; padding: 2px 4px; border-radius: 4px;' data-tooltip='Importance: 0.01'>model</span> <span class='word' styl

True Label,Predicted Label,Word Importance
1.0,1 (0.57),[CLS] a deep learning model that leverage ##s the transformers ar ##qui ##tec ##ture to generate sentence em ##bed ##ding ##s and related tasks [SEP]


True Label,Predicted Label,Word Importance
1.0,1 (0.57),[CLS] a deep learning model that leverage ##s the transformers ar ##qui ##tec ##ture to generate sentence em ##bed ##ding ##s and related tasks [SEP]
