## Sentence Transformers Explainer

This notebook aims to clarify our own implementation of a 'Sentence Transformers' explainability method. 
For this we will be basing our work in two main sources:
* The 'Integrated Gradients' method, a numerical method that stablish a cuantitative relation between inputs and outputs in Deep learning models. For further context refer to the [original paper](https://arxiv.org/pdf/1703.01365)

* Captum, an Open source python library that implements multiple 'explainability' methods, on top of Pytorch models. Specially, we take as reference the [Bert Tutorial](https://captum.ai/tutorials/Bert_SQUAD_Interpret)



We are going to do now a small mathematical review of the Integrated Gradients method, that will be necessary afterwards to understand every component of our method:

As pointed in the Paper itself, this method 'aims' to  *Attribute the output (prediction) of a network to it's inputs*. This means that tries to 'decompose' the predicted value of the model in terms of the Input variables. This would be equivalent that each coefficient in a Linear Regression model. In this case, the idea behind the method is to get those 'Attributions', using a reference point in the space of inputs, called baseline, and compute a *path integral of the gradients along a straightline from the baseline x' to the input x* [(1)](https://arxiv.org/pdf/1703.01365):

$$
\text{IntegratedGrads}(x) := (x - x') \times \int_{\alpha=0}^{1} \frac{\partial F(x' + \alpha \times (x - x'))}{\partial x_i} d\alpha
$$


One of the most relevant and important consequences of the Integrated Gradients, is that, due to the properties of Path Integrals, the sum of the Attrobutions of each of the inputs have to correspond to the difference of the output between the chosen Baseline and the real input.

$$
\sum_{i=1}^{n} \text{IntegratedGrads}_i(x) = F(x) - F(x')
$$



Therefore, we can see that the Baseline that we choose is really influencing the method. Ideally we should pick one that satisfies F(x')~0, but we will see later that for most 'Sentence Transformers' models, this is not feasible.





## Define our model

The first step will be to define our model. Even if we are using sentence transformers model, we will be using it throught their Pytorch backbone, because we need to be able to use the 'tokens' as inputs directly.



In [6]:
# import pandas as pd
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

import torch
from transformers import AutoTokenizer, AutoModel

# Load a pre-trained sentence-transformers model
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model.to(device)
# model.eval()
# model.zero_grad()

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

Now we need to define a 'predict' function, that is going to generate the output of our model, and takes as input the exact inputs of the model. Therefore this inputs should be the tokens of our sentences, because those are the real inputs of our model and not a 'sentence'. This predict function will be used afterwards by the captum method to get the numerical gradients of the model transformation.



In [16]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def predict(input_ids, token_type_ids=None, attention_mask=None):
    """
    Predicts the similarity between pairs of sentences.
    
    Parameters:
    - input_ids: A tensor of shape [number_of_pairs, 2, seq_len] containing pairs of sentences.
    - token_type_ids: Optional token type ids (same shape as input_ids).
    - attention_mask: Optional attention mask (same shape as input_ids).
    
    Returns:
    - similarities: A tensor of shape [number_of_pairs, 1] containing similarity scores for each pair.
    """
    
    num_pairs = input_ids.shape[0]  # Number of sentence pairs
    seq_len = input_ids.shape[-1]   # Sequence length
    
    # Flatten the input for batch processing
    input_ids = input_ids.view(num_pairs * 2, seq_len)
    
    if token_type_ids is not None:
        token_type_ids = token_type_ids.view(num_pairs * 2, seq_len)
    if attention_mask is not None:
        attention_mask = attention_mask.view(num_pairs * 2, seq_len)

    # Step 1: Compute sentence embeddings for all sentences in the batch
    model_output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

    # Perform pooling to get embeddings
    sentence_embeddings = mean_pooling(model_output, attention_mask)

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    # Step 2: Reshape embeddings back into pairs (number of pairs, 2, embedding_dim)
    sentence_embeddings = sentence_embeddings.view(num_pairs, 2, -1)

    # Step 3: Compute cosine similarity for each pair
    similarities = torch.nn.functional.cosine_similarity(sentence_embeddings[:, 0, :], sentence_embeddings[:, 1, :], dim=1)

    # Return similarity scores in shape [number_of_pairs, 1]
    return similarities.view(-1, 1)


In [17]:
def sent_trans_pos_forward_func(input_ids,token_type_ids,attention_mask):
    print("Method calling the predict function for ids",input_ids.shape)
    pred = predict(input_ids,token_type_ids,attention_mask)
    print("Value predicted--->", pred)
    return pred

In [18]:
# Example input tensor (batch_size=2, 2 pairs per batch, seq_length=16)
input_ids = torch.tensor(
    [
        [
        [  101,  1045,  2293,  5983,  2003,  2026,  5440,  7570, 10322,  2666,  1012,   102,     0,     0,     0,     0],
        [  101,  1045,  2293,  5983,  2003,  2026,  5440,  7570, 10322,  2600,  1012,   102,  0,     0,     0,     0]
        ],
        [
        [  101,  1045,  2293,  5983,  2003,  2026,  5440,  7570, 10322,  2666,  1012,   102,     0,     0,     0,     0],
        [  101,  1045,  2293,  5983,  2003,  2026,  5440,  7570, 10322,  2600,  1012,   102,  0,     0,     0,     0]
        ]
    
    ]
)


input_ids.shape

torch.Size([2, 2, 16])

In [19]:

# Assuming token_type_ids and attention_mask are similar in structure
token_type_ids = torch.zeros_like(input_ids)  # For simplicity, assume all zeros
attention_mask = (input_ids != 0).long()  # Attention mask where 0s are padding

# Assuming `model` and `mean_pooling` are defined, and predict function is implemented
cosine_similarities = predict(input_ids, token_type_ids, attention_mask)

# Print the output
print("Cosine similarities:", cosine_similarities)

Cosine similarities: tensor([[0.9198],
        [0.9198]], grad_fn=<ViewBackward0>)


## Checking new things

In [20]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id

Input IDs: tensor([[[ 101, 1045, 2293, 3698, 4083, 1012,  102],
         [ 101, 3698, 4083, 2003, 2307,  999,  102]]])
Reference Input IDs: tensor([[[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0]]])
