## Late Chunking with Weaviate

Notebook author: Danny Williams @ weaviate (Developer Growth)

This notebook implements [late chunking](https://jina.ai/news/late-chunking-in-long-context-embedding-models/) with Weaviate. Late chunking is a change in the classical chunking framework where chunking happens _after_ token embeddings are output from the full document. This preserves contextual information from one chunk to another.



### Setup

First we install all required packages. We are using

In [1]:
!pip install  torch numpy spacy transformers

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting torch
  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/2a/ef/834af4a885b31a0b32fff2d80e1e40f771e1566ea8ded55347502440786a/torch-2.5.1-cp310-cp310-manylinux1_x86_64.whl (906.4 MB)
Collecting numpy
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/f7/b6/d8110985501ca8912dfc1c3bbef99d66e62d487f72e46b2337494df77364/numpy-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.4 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.4/16.4 MB[0m [31m113.5 MB/s[0m eta [36m0:00:00[0m
Collecting spacy
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a6/76/216e55d51b37037760682efdff8b34765b2b879cfcc63c06872072c36db7/spacy-3.8.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.1 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.1/29.1 MB[0m [31m92.5 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m
Collec

Then we load the packages and connect to the Weaviate client. Important, you need some API keys within a `.env` file:
- your Weaviate REST endpoint saved as `WEAVIATE_URL`
- your Weaviate API key saved as `WEAVIATE_KEY`
- if you want to run the final comparison in this notebook, an OpenAI API key saved as `OPENAI_API_KEY`, otherwise delete the `headers` argument in the `weaviate.connect_to_weaviate_cloud` function.


In [1]:
%%capture
# imports
import weaviate
import weaviate.classes as wvc
import weaviate.classes.config as wvcc

import os
import torch
import numpy as np

import spacy
from spacy.tokens import Doc
from spacy.language import Language

import transformers
from transformers import AutoModel
from transformers import AutoTokenizer

# connect to weaviate
import weaviate

client = weaviate.connect_to_local()

print(client.is_ready())

Finally just for future-proofing, the versions of these packages are:

In [2]:
print(f"Weaviate version {weaviate.__version__}")
print(f"Pytorch version {torch.__version__}")
print(f"Numpy version {np.__version__}")
print(f"Spacy version {spacy.__version__}")
print(f"Transformers version {transformers.__version__}")

Weaviate version 0.1.dev3117+gae1bb03
Pytorch version 2.4.1+cu121
Numpy version 2.2.1
Spacy version 3.8.3
Transformers version 4.47.1


### Functions

Below are some general functions for chunking text into sentences, as well as the bulk of the operations behind late chunking.

Late chunking is simply the same chunks we would have on the naively chunked text, but the chunk embedding is taken from the pooling of the token embeddings, rather than an independently embedded chunk.

In [3]:
def sentence_chunker(document, batch_size=None):
    """
    Given a document (string), return the sentences as chunks and span annotations (start and end indices of chunks).
    Using spacy to do this sentence chunking.
    """

    if batch_size is None:
        batch_size = 10000 # no of characters

    # Batch with spacy
    nlp = spacy.blank("en")
    nlp.add_pipe("sentencizer", config={"punct_chars": None})
    doc = nlp(document)

    docs = []
    for i in range(0, len(document), batch_size):
        batch = document[i : i + batch_size]
        docs.append(nlp(batch))

    doc = Doc.from_docs(docs)

    span_annotations = []
    chunks = []
    for i, sent in enumerate(doc.sents):
        span_annotations.append((sent.start, sent.end))
        chunks.append(sent.text)

    return chunks, span_annotations


def document_to_token_embeddings(model, tokenizer, document, batch_size=4096):
    """
    Given a model and tokenizer from HuggingFace, return token embeddings of the input text document.
    """

    if batch_size > 8192: # no of tokens
        raise ValueError("Batch size is too large. Please use a batch size of 8192 or less.")

    tokenized_document = tokenizer(document, return_tensors="pt")
    tokens = tokenized_document.tokens()
    
    # Batch in sizes of batch_size
    outputs = []
    for i in range(0, len(tokens), batch_size):
        
        start = i
        end   = min(i + batch_size, len(tokens))

        # subset huggingface tokenizer outputs to i : i + batch_size
        batch_inputs = {k: v[:, start:end] for k, v in tokenized_document.items()}

        with torch.no_grad():
            model_output = model(**batch_inputs)

        outputs.append(model_output.last_hidden_state)

    model_output = torch.cat(outputs, dim=1)
    return model_output

def late_chunking(token_embeddings, span_annotation, max_length=None):
    """
    Given the token-level embeddings of document and their corresponding span annotations (start and end indices of chunks in terms of tokens),
    late chunking pools the token embeddings for each chunk.
    """
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go beyond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = []
        for start, end in annotations:
            
            if (end - start) >= 1:
                # print(f"start: {start}, end: {end}")
                # print(f"{[e[:5] for e in embeddings[start:end]]}")
                pooled_embeddings.append(
                    embeddings[start:end].sum(dim=0) / (end - start)
                )
                    
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs

### Import into Weaviate

We aim to perform late chunking, obtain the contextually-aware embeddings, and then import these into a Weaviate collection.

First, create a Weaviate collection called `test_late_chunking`.

In [4]:
if client.collections.exists("test_late_chunking"):
    client.collections.delete("test_late_chunking")

# important to specify the config as none here, because we will be supplying our own vector embeddings in the form of the late chunking embeddings
late_chunking_collection = client.collections.create(
    name="test_late_chunking",
    vectorizer_config=wvc.config.Configure.Vectorizer.none(),
)

Now let's use a test document - the wikipedia page for Berlin (saved in a separate text file). We will later query this text using late chunking/naive chunking.

In [5]:
with open("berlin.txt", "r", encoding="utf-8") as f:
    document = f.read()

print(f"First 50 characters of the document:\n{document[:150]}...")


First 50 characters of the document:
Berlin[a] is the capital and largest city of Germany, both by area and by population.[11] Its more than 3.85 million inhabitants[12] make it the Europ...


Now, load the `jina-embeddings-v2-base-en` model from Huggingface. Other embedding models can be used, but Jina's model has up to 8192 token length documents, which is important for late chunking as we want to encode large documents and separate them later.

In [6]:
tokenizer = AutoTokenizer.from_pretrained('/mnt/ceph/develop/jiawei/model_checkpoint/jina-embeddings-v2-base-en', trust_remote_code=True)
model     = AutoModel.from_pretrained('/mnt/ceph/develop/jiawei/model_checkpoint/jina-embeddings-v2-base-en', trust_remote_code=True)

We call our functions we defined earlier: First chunk the text as normal, to obtain the beginning and end points of the chunks. Then embed the full document. Then perform the late chunking step - take the average over all token embeddings that correspond to each chunk (based on the beginning/end points of the chunks). These form as our embeddings for the chunks.

In [7]:
chunks, span_annotations = sentence_chunker(document)
token_embeddings = document_to_token_embeddings(model, tokenizer, document)
chunk_embeddings = late_chunking(token_embeddings, [span_annotations])[0]

Finally, we can add this to our Weaviate collection by supplying our own vector embedding for each chunk.

In [8]:
# add data with manual embeddings
data = []
for i in range(len(chunks)):
    data.append(wvc.data.DataObject(
            properties={
                "content": chunks[i]
            },
            vector = chunk_embeddings[i].tolist()
    )
)

late_chunking_collection.data.insert_many(data);

### Example Query

First, define two functions to process queries. One using our Weaviate collection, and a different, slower search using cosine similarity running locally that we will use for comparison.

In [9]:
def late_chunking_query_function_weaviate(query, k = 3):
    query_vector = model(**tokenizer(query, return_tensors="pt")).last_hidden_state.mean(1).detach().cpu().numpy().flatten()

    results = late_chunking_collection.query.near_vector(
        near_vector=query_vector.tolist(),
        limit = k
    )

    return [res.properties["content"] for res in results.objects]

def late_chunking_query_function_cosine_sim(query, k = 3):

    cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

    query_vector = model(**tokenizer(query, return_tensors="pt")).last_hidden_state.mean(1).detach().cpu().numpy().flatten()

    results = np.empty(len(chunk_embeddings))
    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
        results[i] = cos_sim(query_vector, embedding)

    results_order = results.argsort()[::-1]
    return np.array(chunks)[results_order].tolist()[:k]

Test both search functions.

In [11]:
late_chunking_query_function_weaviate("What are the different demographics of Berlin?", 5)

['The Independent Evangelical Lutheran Church has eight parishes of different sizes in Berlin.[131] There are 36 Baptist congregations (within Union of Evangelical Free Church Congregations in Germany), 29 New Apostolic Churches, 15 United Methodist churches, eight Free Evangelical Congregations, four Churches of Christ, Scientist (1st, 2nd, 3rd, and 11th), six congregations of the Church of Jesus Christ of Latter-day Saints, an Old Catholic church, and an Anglican church in Berlin.',
 'Each borough has several subdistricts or neighborhoods (Ortsteile), which have roots in much older municipalities that predate the formation of Greater Berlin on 1 October 1920.',
 'The Senate consists of the Governing Mayor of Berlin (Regierender Bürgermeister), and up to ten senators holding ministerial positions, two of them holding the title of "Mayor" (Bürgermeister) as deputy to the Governing Mayor.[134]\n\n\nCharlottenburg Town Hall\n\nRathaus Spandau\nThe total annual budget of Berlin in 2015 ex

In [12]:
late_chunking_query_function_cosine_sim("What are the different demographics of Berlin?", 5)

['The Independent Evangelical Lutheran Church has eight parishes of different sizes in Berlin.[131] There are 36 Baptist congregations (within Union of Evangelical Free Church Congregations in Germany), 29 New Apostolic Churches, 15 United Methodist churches, eight Free Evangelical Congregations, four Churches of Christ, Scientist (1st, 2nd, 3rd, and 11th), six congregations of the Church of Jesus Christ of Latter-day Saints, an Old Catholic church, and an Anglican church in Berlin.',
 'Each borough has several subdistricts or neighborhoods (Ortsteile), which have roots in much older municipalities that predate the formation of Greater Berlin on 1 October 1920.',
 'The Senate consists of the Governing Mayor of Berlin (Regierender Bürgermeister), and up to ten senators holding ministerial positions, two of them holding the title of "Mayor" (Bürgermeister) as deputy to the Governing Mayor.[134]\n\n\nCharlottenburg Town Hall\n\nRathaus Spandau\nThe total annual budget of Berlin in 2015 ex

In [21]:
late_chunking_query_function_weaviate("Percentage of 1700 year's Berlin's residents were French?", 10)

['\n\nIn 2009, approximately 249,000 Muslims were reported by the Office of Statistics to be members of mosques and Islamic religious organizations in Berlin,[126] while in 2016, the newspaper Der Tagesspiegel estimated that about 350,000 Muslims observed Ramadan in Berlin.[127] In 2019, about 437,000 registered residents, 11.6% of the total, reported having a migration background from one of the Member states of the Organization of Islamic Cooperation.[103][128] Between 1992 and 2011 the Muslim population almost doubled.[129]\n\nAbout 0.9% of Berliners belong to other religions.',
 'It does not keep records of members of other religious organizations which may collect their own church tax, in this way.',
 'In 1861, neighboring suburbs including Wedding, Moabit and several others were incorporated into Berlin.[47] In 1871, Berlin became capital of the newly founded German Empire.[48] In 1881, it became a city district separate from Brandenburg.[49]\n\n20th to 21st centuries\nFurther in

Both give the same results so we are confident that our vector search for late chunking works! We would expect something slightly different as Weaviate uses HNSW for a speedy search, and we have directly used cosine similarity, but in this case, they are the same.

For comparison, let's look at what a naive chunking method implemented with Weaviate's search would give us.

In [27]:
# create the weaviate collection chunked by sentences
if client.collections.exists("test_naive_chunking"):
    client.collections.delete("test_naive_chunking")

naive_chunking_collection = client.collections.create(
    name="test_naive_chunking",
    vectorizer_config=wvcc.Configure.Vectorizer.text2vec_contextionary( vectorize_collection_name=False),
            properties=[
                    wvcc.Property(name="content", data_type=wvcc.DataType.TEXT)
            ]
)
 

In [28]:
# add data with manual embeddings
data1 = []
for i in range(len(chunks)):
    data1.append(wvc.data.DataObject(
            properties={
                "content": chunks[i]
            },
            vector = chunk_embeddings[i].tolist()
    )
)

naive_chunking_collection.data.insert_many(data1);

In [29]:
def naive_chunking_query_function_weaviate(query, k=3):
    results = naive_chunking_collection.query.near_text(
        query = query,
        limit = k
    )

    return [res.properties["content"] for res in results.objects]


In [15]:
naive_chunking_query_function_weaviate("What are the different demographics of Berlin?", 5)

["\n\nDemographics\nMain article: Demographics of Berlin\n\nBerlin population pyramid in 2022\n\nBerlin's population, 1880–2012\nAt the end of 2018, the city-state of Berlin had 3.75 million registered inhabitants[103] in an area of 891.1 km2 (344.1 sq mi).[3] The city's population density was 4,206 inhabitants per km2.",
 'Foreign residents of Berlin originate from about 190 countries.[112] 48 percent of the residents under the age of 15 have a migration background in 2017.[113] Berlin in 2009 was estimated to have 100,000 to 250,000 unregistered inhabitants.[114] Boroughs of Berlin with a significant number of migrants or foreign born population are Mitte, Neukölln and Friedrichshain-Kreuzberg.[115] The number of Arabic speakers in Berlin could be higher than 150,000.',
 'Around 130,000 jobs were added in this period.[150]\n\nImportant economic sectors in Berlin include life sciences, transportation, information and communication technologies, media and music, advertising and design,

We can see that the naive chunking query still gives us good results - it matches more specifically with the question. Whereas the late chunking example skips straight to the chunks it _knows_ to be relevant, because they contain contextual information within the embeddings themselves!

In [30]:
naive_chunking_query_function_weaviate("Percentage of 1700 year's Berlin's residents were French?", 5)

WeaviateQueryError: Query call with protocol GRPC search failed with message <AioRpcError of RPC that terminated with:
	status = StatusCode.UNKNOWN
	details = "explorer: get class: concurrentTargetVectorSearch): explorer: get class: vector search: object vector search at index test_naive_chunking: shard test_naive_chunking_jDqzKUjbvpgZ: vector search: knn search: distance between entrypoint and query node: 768 vs 300: vector lengths don't match"
	debug_error_string = "UNKNOWN:Error received from peer  {grpc_message:"explorer: get class: concurrentTargetVectorSearch): explorer: get class: vector search: object vector search at index test_naive_chunking: shard test_naive_chunking_jDqzKUjbvpgZ: vector search: knn search: distance between entrypoint and query node: 768 vs 300: vector lengths don\'t match", grpc_status:2, created_time:"2024-12-24T00:15:01.766223798+08:00"}"
>.