<a href="https://colab.research.google.com/github/jaideep11061982/GenAINotebooks/blob/main/Late_Chunking_in_Long_Context_Embedding_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Late Chunking

This notebooks explains how the "Late Chunking" can be implemented. First you need to install the requirements:

In [None]:
!pip install transformers==4.43.4



Then we load a model which we want to use for the embedding. We choose `jinaai/jina-embeddings-v2-base-en` but any other model which supports mean pooling is possible. However, models with a large maximum context-length are preferred.

In [None]:
from transformers import AutoModel
from transformers import AutoTokenizer

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)

Now we define the text which we want to encode and split it into chunks. The `chunk_by_sentences` function also returns the span annotations.
Those specify the number of tokens per chunk which is needed for the chunked pooling.

In [None]:
def chunk_by_sentences(input_text: str, tokenizer: callable):
    """
    Split the input text into sentences using the tokenizer
    :param input_text: The text snippet to split into sentences
    :param tokenizer: The tokenizer to use
    :return: A tuple containing the list of text chunks and their corresponding token spans
    """
    inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
    punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
    sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
    token_offsets = inputs['offset_mapping'][0]
    token_ids = inputs['input_ids'][0]
    chunk_positions = [
        (i, int(start + 1))
        for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
        if token_id == punctuation_mark_id
        and (
            token_offsets[i + 1][0] - token_offsets[i][1] > 0
            or token_ids[i + 1] == sep_id
        )
    ]
    chunks = [
        input_text[x[1] : y[1]]
        for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    span_annotations = [
        (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    return chunks, span_annotations

This code seems to attempt sentence chunking of input text by identifying punctuation marks (specifically, periods) and using a tokenizer to segment the input into meaningful chunks. Here’s a detailed breakdown of what’s happening:

- Tokenization: It tokenizes the input text into tokens using the provided tokenizer, and retrieves the offsets for each token to map them back to the original text.

- Sentence Boundary Detection: It identifies the positions of sentence-ending punctuation (i.e., periods) and checks if there’s a gap before the next token (or if the next token is a [SEP] token, which might signify an end of the sequence).

- Chunking: Based on the positions of these punctuation marks, it creates chunks (presumably sentences or sentence-like segments) from the input text.

- Span Annotations: It keeps track of where each chunk begins and ends in terms of token indices.

In production, you should use more advanced and robust segmentation method such as Jina AI Tokenizer API https://jina.ai/tokenizer#apiform.

In [None]:
import requests

def chunk_by_tokenizer_api(input_text: str, tokenizer: callable):
    # Define the API endpoint and payload
    url = 'https://tokenize.jina.ai/'
    payload = {
        "content": input_text,
        "return_chunks": "true",
        "max_chunk_length": "1024"
    }

    # Make the API request
    response = requests.post(url, json=payload)
    response_data = response.json()

    # Extract chunks and positions from the response
    chunks = response_data.get("chunks", [])
    chunk_positions = response_data.get("chunk_positions", [])

    # Adjust chunk positions to match the input format
    span_annotations = [(start, end) for start, end in chunk_positions]

    return chunks, span_annotations

This function sends the input text to an external API for tokenization and chunking, leveraging the API to break down the text and return both the chunks and their respective positions within the original text.



Explanation of the Code:

**API Endpoint:** The code uses the Jina AI Tokenization API (https://tokenize.jina.ai/) to tokenize and chunk the input text.

**Payload:** The payload consists of:

- content: The input text that you want to chunk.
- return_chunks: A flag to indicate that you want to receive the chunks of text.
- max_chunk_length: Specifies the maximum length of each chunk (1024 characters in this case).

**API Request:** It sends a POST request to the API with the payload. The response is assumed to be in JSON format.

**Chunk and Position Extraction:**

The response is expected to contain two keys:
- chunks: This contains the actual chunks of the text.
- chunk_positions: This holds the start and end positions of each chunk in the original text.
- The function then processes these chunks and their corresponding positions.

**Return:**

The function returns two lists:
- chunks: The segmented text chunks.
- span_annotations: The corresponding positions of these chunks in the original text, adjusted to a format with start and end indices.


Now let's try to segement a toy example used in blog.

In [None]:
input_text = "Lara prepared her gear before setting out on the expedition. After hours of trekking, she discovered an ancient ruin hidden deep in the jungle. She stepped inside and found inscriptions that detailed an old civilization."

# determine chunks
chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)
print('Chunks:\n- "' + '"\n- "'.join(chunks) + '"')


Chunks:
- "Lara prepared her gear before setting out on the expedition."
- " After hours of trekking, she discovered an ancient ruin hidden deep in the jungle."
- " She stepped inside and found inscriptions that detailed an old civilization."


Now we encode the chunks with the traditional and the context-sensitive late_chunking method:

In [None]:
def late_chunking(
    model_output: 'BatchEncoding', span_annotation: list, max_length=None, num_tokens =1024
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs

In [None]:
# chunk before
embeddings_traditional_chunking = model.encode(chunks)

# chunk afterwards (context-sensitive chunked pooling)
inputs = tokenizer(input_text, return_tensors='pt')
model_output = model(**inputs)
embeddings = late_chunking(model_output, [span_annotations])[0]

Finally, we compare the similarity of the word "Lara" with the chunks. The similarity should be higher for the context-sensitive chunked pooling method:

In [None]:
import numpy as np

cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

Lara_embedding = model.encode('Lara')

for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
    print(f'similarity_new("Lara", "{chunk}"):', cos_sim(Lara_embedding, new_embedding))
    print(f'similarity_trad("Lara", "{chunk}"):', cos_sim(Lara_embedding, trad_embeddings))

similarity_new("Lara", "Lara prepared her gear before setting out on the expedition."): 0.73207825
similarity_trad("Lara", "Lara prepared her gear before setting out on the expedition."): 0.79933035
similarity_new("Lara", " After hours of trekking, she discovered an ancient ruin hidden deep in the jungle."): 0.73201466
similarity_trad("Lara", " After hours of trekking, she discovered an ancient ruin hidden deep in the jungle."): 0.6700105
similarity_new("Lara", " She stepped inside and found inscriptions that detailed an old civilization."): 0.7380304
similarity_trad("Lara", " She stepped inside and found inscriptions that detailed an old civilization."): 0.6695348
