## Sentence Transformers Explainer

This notebook aims to clarify our own implementation of a 'Sentence Transformers' explainability method. 
For this we will be basing our work in two main sources:
* The 'Integrated Gradients' method, a numerical method that stablish a cuantitative relation between inputs and outputs in Deep learning models. For further context refer to the [original paper](https://arxiv.org/pdf/1703.01365)

* Captum, an Open source python library that implements multiple 'explainability' methods, on top of Pytorch models. Specially, we take as reference the [Bert Tutorial](https://captum.ai/tutorials/Bert_SQUAD_Interpret)



We are going to do now a small mathematical review of the Integrated Gradients method, that will be necessary afterwards to understand every component of our method:

As pointed in the Paper itself, this method 'aims' to  *Attribute the output (prediction) of a network to it's inputs*. This means that tries to 'decompose' the predicted value of the model in terms of the Input variables. This would be equivalent that each coefficient in a Linear Regression model. In this case, the idea behind the method is to get those 'Attributions', using a reference point in the space of inputs, called baseline, and compute a *path integral of the gradients along a straightline from the baseline x' to the input x* [(1)](https://arxiv.org/pdf/1703.01365):

$$
\text{IntegratedGrads}(x) := (x - x') \times \int_{\alpha=0}^{1} \frac{\partial F(x' + \alpha \times (x - x'))}{\partial x_i} d\alpha
$$


One of the most relevant and important consequences of the Integrated Gradients, is that, due to the properties of Path Integrals, the sum of the Attrobutions of each of the inputs have to correspond to the difference of the output between the chosen Baseline and the real input.

$$
\sum_{i=1}^{n} \text{IntegratedGrads}_i(x) = F(x) - F(x')
$$



Therefore, we can see that the Baseline that we choose is really influencing the method. Ideally we should pick one that satisfies F(x')~0, but we will see later that for most 'Sentence Transformers' models, this is not feasible.





## Define our model

The first step will be to define our model. Even if we are using sentence transformers model, we will be using it throught their Pytorch backbone, because we need to be able to use the 'tokens' as inputs directly.



In [202]:
# import pandas as pd
import logging
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



# Set up logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Load a pre-trained sentence-transformers model
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model.to(device)
# model.eval()
# model.zero_grad()

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')



Now we need to define a 'forward' function, that is going to generate the output of our model, and takes as input the exact inputs of the model. Therefore this inputs should be the tokens of our sentences, because those are the real inputs of our model and not a 'sentence'. This forward function will be used afterwards by the captum method to get the numerical gradients of the model transformation.

Even if the tokens are already numerical features, we do know that they don't have a real numerical model. Therefore, we will be using LayerIntegratedGradients, that allows to get the 'Attributions' for each of the Inputs/Outputs of a specified layer of the model. In this case we are interested in getting the Attributions of the Embeddings layer, so we can get the relevance of each word (token) based in the Attributions of each dimension of their embeddings.





In [203]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def predict(input_ids, token_type_ids=None, attention_mask=None):
    """
    Predicts the similarity between pairs of sentences.
    
    Parameters:
    - input_ids: A tensor of shape [number_of_pairs, 2, seq_len] containing pairs of sentences.
    - token_type_ids: Optional token type ids (same shape as input_ids).
    - attention_mask: Optional attention mask (same shape as input_ids).
    
    Returns:
    - similarities: A tensor of shape [number_of_pairs, 1] containing similarity scores for each pair.
    """
    
    num_pairs = input_ids.shape[0]
    seq_len = input_ids.shape[-1]
    
    # Flatten the input for batch processing
    input_ids = input_ids.view(num_pairs * 2, seq_len)
    
    if token_type_ids is not None:
        token_type_ids = token_type_ids.view(num_pairs * 2, seq_len)
    if attention_mask is not None:
        attention_mask = attention_mask.view(num_pairs * 2, seq_len)


    # First we get the embeddings for all the inputs ids that have been flatten
    model_output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
    sentence_embeddings = mean_pooling(model_output, attention_mask)
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    
    # Then we return them to their original shape (n_pairs,n_sentences, length)
    sentence_embeddings = sentence_embeddings.view(num_pairs, 2, -1)
    similarities = torch.nn.functional.cosine_similarity(sentence_embeddings[:, 0, :], sentence_embeddings[:, 1, :], dim=1)

    return similarities.view(-1, 1)


In [221]:
def sent_trans_pos_forward_func(input_ids,token_type_ids,attention_mask):
    logging.info(f"Method calling the predict function with shape {input_ids.shape}")
    pred = predict(input_ids,token_type_ids,attention_mask)
    logging.info(f"Value predicted---> {pred}")
    return pred

## Input for the function and Baselines

Now that we have configured our forward function, we are going to build several functions to get the inputs for this function.
This was extracted from the Captum tutorial and slightly modified for our current use case. Let's go over them:


## Input IDs and Ref Input IDs

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id

In [None]:
def construct_input_ref_pair(sentences, ref_token_id=0, process_both=True):
    """
    Construct input_ids and ref_input_ids for a pair of sentences, excluding the first one
    if specified.
    
    Input:
    - sentences: A list of 2 sentences.
    - ref_token_id: Token to be used as reference.
    - process_both: Flag to determine if processing both phrases or just the second.
    
    Output:
    - input_ids: Tensor of shape [1, 2, max_length]
    - ref_input_ids: Tensor of shape [1, 2, max_length]
    """

    # We could generalize this method to any dimension of pairs of sentences
    assert len(sentences) == 2

    # Tokenize the pair
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    input_ids = encoded_input['input_ids']  # shape: [2, max_length]
    
    # Create a reference tensor using the ref token
    ref_input_ids = torch.full(input_ids.shape, ref_token_id)
    
    # Reintroduce the classification and sep tokens
    for i, input_id_sequence in enumerate(input_ids):
        if not process_both and i == 0:  # Skip processing for the first sentence if process_both is False
            ref_input_ids[i] = input_id_sequence
        else:
            ref_input_ids[i, input_id_sequence == cls_token_id] = cls_token_id
            ref_input_ids[i, input_id_sequence == sep_token_id] = sep_token_id

    # Reshape input_ids and ref_input_ids to add an additional level of depth [1, 2, max_length]
    input_ids = input_ids.unsqueeze(0)  # shape: [1, 2, max_length]
    ref_input_ids = ref_input_ids.unsqueeze(0)  # shape: [1, 2, max_length]

    return input_ids, ref_input_ids

In [206]:
# Example input tensor (batch_size=2, 2 pairs per batch, seq_length=16)
input_ids = torch.tensor(
    [
        [
        [  101,  1045,  2293,  5983,  2003,  2026,  5440,  7570, 10322,  2666,  1012,   102,     0,     0,     0,     0],
        [  101,  1045,  2293,  5983,  2003,  2026,  5440,  7570, 10322,  2600,  1012,   102,  0,     0,     0,     0]
        ],
        [
        [  101,  1045,  2293,  5983,  2003,  2026,  5440,  7570, 10322,  2666,  1012,   102,     0,     0,     0,     0],
        [  101,  1045,  2293,  5983,  2,  2026,  5440,  7570, 10322,  2600,  1012,   102,  0,     0,     0,     0]
        ]
    
    ]
)


input_ids.shape

torch.Size([2, 2, 16])

In [207]:

# Assuming token_type_ids and attention_mask are similar in structure
token_type_ids = torch.zeros_like(input_ids)  # For simplicity, assume all zeros
attention_mask = (input_ids != 0).long()  # Attention mask where 0s are padding

# Assuming `model` and `mean_pooling` are defined, and predict function is implemented
cosine_similarities = predict(input_ids, token_type_ids, attention_mask)

# Print the output
print("Cosine similarities:", cosine_similarities)

Cosine similarities: tensor([[0.9198],
        [0.8362]], grad_fn=<ViewBackward0>)


## Checking new things

In [209]:
sentences = ["This are the sentences to be compared", "more sentences"]

In [210]:


def construct_token_type_ids(input_ids):
    # Assume no token type differentiation if not needed
    token_type_ids = torch.zeros_like(input_ids)
    ref_token_type_ids = torch.zeros_like(input_ids)
    
    return token_type_ids, ref_token_type_ids

def construct_position_ids_pair(input_ids):
    batch_size, pair_size, seq_length = input_ids.size()  # Get the size of each dimension
    
    # Create position IDs for each sequence in the pair (0, 1, ..., seq_length - 1)
    position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0).unsqueeze(0)
    
    # Expand position_ids to match the shape of input_ids: [batch_size, pair_size, seq_length]
    position_ids = position_ids.expand(batch_size, pair_size, seq_length)
    
    # Create reference position IDs filled with zeros (same shape as position_ids)
    ref_position_ids = torch.zeros_like(position_ids)
    
    return position_ids, ref_position_ids

    
def construct_attention_mask(input_ids):
    return (input_ids != tokenizer.pad_token_id).long()


def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model[0].auto_model.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model[0].auto_model.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings


In [212]:
input_ids, ref_input_ids = construct_input_ref_pair(sentences,process_both=False)

input_ids

tensor([[[  101,  2023,  2024,  1996, 11746,  2000,  2022,  4102,   102],
         [  101,  2062, 11746,   102,     0,     0,     0,     0,     0]]])

In [213]:

token_type_ids, ref_token_type_ids = construct_token_type_ids(input_ids)

position_ids, ref_position_ids = construct_position_ids_pair(input_ids)

attention_mask = construct_attention_mask(input_ids)

In [214]:
attention_mask

tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 0, 0, 0, 0, 0]]])

In [215]:
ref_input_ids


tensor([[[  101,  2023,  2024,  1996, 11746,  2000,  2022,  4102,   102],
         [  101,     0,     0,   102,     0,     0,     0,     0,     0]]])

In [216]:
reference_sentence = "I do love tenis"

In [217]:
attention_mask

tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 0, 0, 0, 0, 0]]])

In [222]:
lig = LayerIntegratedGradients(sent_trans_pos_forward_func, model.embeddings)

attributions_start = lig.attribute(inputs = input_ids,
                                  baselines = ref_input_ids,
                                  additional_forward_args=(token_type_ids, attention_mask),
                                  return_convergence_delta=False,
                                  n_steps =100)


2024-10-05 18:25:06,821 - INFO - Method calling the predict function with shape torch.Size([1, 2, 9])
2024-10-05 18:25:06,880 - INFO - Value predicted---> tensor([[0.5634]])
2024-10-05 18:25:06,883 - INFO - Method calling the predict function with shape torch.Size([1, 2, 9])
2024-10-05 18:25:06,894 - INFO - Value predicted---> tensor([[0.1246]])
2024-10-05 18:25:06,930 - INFO - Method calling the predict function with shape torch.Size([100, 2, 9])
2024-10-05 18:25:07,185 - INFO - Value predicted---> tensor([[0.1246],
        [0.1246],
        [0.1246],
        [0.1246],
        [0.1246],
        [0.1246],
        [0.1246],
        [0.1246],
        [0.1246],
        [0.1247],
        [0.1247],
        [0.1247],
        [0.1248],
        [0.1249],
        [0.1250],
        [0.1251],
        [0.1252],
        [0.1254],
        [0.1256],
        [0.1258],
        [0.1261],
        [0.1264],
        [0.1268],
        [0.1272],
        [0.1277],
        [0.1283],
        [0.1290],
        [

In [161]:
attributions_start.sum(-1)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0898, 0.3490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
       dtype=torch.float64)