# How to use Late Chunk in RAG

Based on the [Late Chunking: Contextual Chunk Embeddings Using Long-Context Embedding Models](https://arxiv.org/abs/2409.04701) paper and the [Late Chunking](https://jina.ai/news/late-chunking-in-long-context-embedding-models) blog post

This notebooks explains how the `Late chunk Embedding` support with `LangChain`.

**Notes:**

- Can combine with any `text splitting` used in LangChain or you can custom with the [Chunk](https://github.com/jina-ai/late-chunking/blob/main/chunked_pooling/chunking.py) used in the paper.

- Support by Late Chunk Qdrant vectorstore

# Setup

In [None]:
%pip install -qU langchain langchain-community qdrant-client

In [None]:
import os

from qdrant_client import QdrantClient
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.embeddings import JinaLateChunkEmbeddings
from langchain_community.vectorstores import LateChunkQdrant

# Implement the Custom Text Splitter

In [None]:
import re
import copy
from typing import List, Optional, Tuple, Iterable, Sequence, Any
from langchain_core.documents import BaseDocumentTransformer, Document


class JinaTextSplitter(BaseDocumentTransformer):
    def __init__(
        self, 
        jina_huggingface_model_name,
        strategies: str = 'sentences',
        chunk_size: int = 1024,
        number_of_sentences:int = 2,
        sentence_split_regex: str = r"(?<=[.?!])\s+",
        add_start_index: bool = False,
    ):  
        try:
            from transformers import AutoTokenizer
        
        except ImportError:
            raise ValueError("Could not import transformers python package."
                             "Please install it with `pip install transformers`."
                            )
        self.tokenizer = AutoTokenizer.from_pretrained(jina_huggingface_model_name, trust_remote_code=True)
        self.strategies = strategies
        self.chunk_size = chunk_size
        self.number_of_sentences = number_of_sentences
        self.sentence_split_regex = sentence_split_regex
        self._add_start_index = add_start_index
        
    def create_documents(
        self, texts: List[str], metadatas: Optional[List[dict]] = None
    ) -> List[Document]:
        """Create documents from a list of texts."""
        _metadatas = metadatas or [{}] * len(texts)
        documents = []
        
        for i, text in enumerate(texts):
            start_index = 0
            for chunk in self.split_text(text):
                metadata = copy.deepcopy(_metadatas[i])
                if self._add_start_index:
                    metadata["start_index"] = start_index
                new_doc = Document(page_content=chunk, metadata=metadata)
                documents.append(new_doc)
                start_index += len(chunk)
        return documents
    
    def split_documents(self, documents: Iterable[Document]) -> List[Document]:
        """Split documents."""
        texts, metadatas = [], []
        for doc in documents:
            texts.append(doc.page_content)
            metadatas.append(doc.metadata)
        return self.create_documents(texts, metadatas=metadatas)
    
    def transform_documents(
        self, documents: Sequence[Document], **kwargs: Any
    ) -> Sequence[Document]:
        """Transform sequence of documents by splitting them."""
        return self.split_documents(list(documents))
    
    def split_text(
        self,
        text: str,
    ) -> List[str]:
        
        if self.strategies == "tokenize":
            if self.chunk_size < 4:
                raise ValueError("Chunk size must be >= 4.")
            return self._chunk_by_tokens(text, self.chunk_size)
        elif self.strategies == "sentences":
            return self._chunk_by_sentences(text, self.number_of_sentences)
        else:
            raise ValueError(f"Unsupported chunking strategy {self.strategies}")
        
    def _get_single_sentences_list(self, text: str) -> List[str]:
        return re.split(self.sentence_split_regex, text)

    def _chunk_by_sentences(
        self,
        text:str ,
        n_sentences: int
    ):
        single_sentences_list = self._get_single_sentences_list(text)
        
        # Calculate how many chunks we need
        num_chunks = len(single_sentences_list) // n_sentences
        if len(single_sentences_list) % n_sentences != 0:
            num_chunks += 1  # If there are leftovers, add an extra chunk
        
        # Create a list of lists to hold the sentence chunks
        chunk_text = []
        start_idx = 0
        
        for _ in range(num_chunks):
            end_idx = start_idx + n_sentences
            chunk = ' '.join(x for x in single_sentences_list[start_idx:end_idx])
            chunk_text.append(chunk)
            start_idx = end_idx
        
        return chunk_text
    
    def _chunk_by_tokens(
        self,
        text: str,
        chunk_size: int,
    ) -> List[Tuple[int, int, int]]:
        tokens = self.tokenizer.encode_plus(
            text, return_offsets_mapping=True, add_special_tokens=False
        )
        token_offsets = tokens.offset_mapping

        chunk_spans = []
        for i in range(0, len(token_offsets), chunk_size):
            chunk_end = min(i + chunk_size, len(token_offsets))
            if chunk_end - i > 0:
                chunk_spans.append((i, chunk_end))

        # get sub-text
        chunk_text = self._tokens_to_text(text, chunk_spans)    
    
        return chunk_text
    
    def _tokens_to_text(self, text: str, annotations: List[Tuple[int, int]]):
        tokens = self.tokenizer.encode_plus(
            text, return_offsets_mapping=True, add_special_tokens=False
        )
        token_offsets = tokens.offset_mapping
        chunks = []
        for start, end in annotations:
            chunk = text[token_offsets[start][0]:token_offsets[end-1][1]]
            chunks.append(chunk)
        return chunks
    
    def _get_token_length_of_text(self, text):
        input_ids = self.tokenizer(text)['input_ids']
        
        return len(input_ids)

In [None]:
# Split Text
text_splitter = JinaTextSplitter(
        jina_huggingface_model_name='jinaai/jina-embeddings-v3',
        strategies = 'sentences',
        chunk_size = 1024,
        number_of_sentences = 3,
    )

# Text embedding
text_embeddings = JinaLateChunkEmbeddings(jina_api_key="jina_*", model_name="jina-embeddings-v3")

# Init Vectorestore

In [None]:
ROOT = 'demo-qdrant'


def init_vectorstore(text_embeddings, text_splitter, collection_name='latechunk', topK=5):
    client = QdrantClient()
    vectorstore = LateChunkQdrant(
        client, collection_name=collection_name,
        embeddings=text_embeddings, text_splitter=text_splitter
    )

    if os.path.isdir(os.path.join(ROOT, 'collection', collection_name)):
        print(f"===== Load exits collection: {collection_name} ======")
        vectorstore = vectorstore.from_existing_collection(
            embedding=text_embeddings, path=ROOT,
            collection_name=collection_name, text_splitter=text_splitter
        )

    else:
        print(f"===== Create new collection: {collection_name} ======")

        loader = WebBaseLoader("https://github.com/hwchase17/chroma-langchain/blob/master/state_of_the_union.txt")
        data = loader.load()

        vectorstore = vectorstore.from_documents(
            documents=data, embedding=text_embeddings, text_splitter=text_splitter,
            path=ROOT, collection_name=collection_name
        )

    return vectorstore.as_retriever(search_kwargs={"k": topK})

vectordb = init_vectorstore(text_embeddings, text_splitter, collection_name='test')

In [None]:
query = "what did the president say about ketanji brown jackson?"

docs = vectordb.invoke(query)
len(docs)