## Late Chunking with Weaviate

Notebook author: Danny Williams @ weaviate (Developer Growth)

This notebook implements [late chunking](https://jina.ai/news/late-chunking-in-long-context-embedding-models/) with Weaviate. Late chunking is a change in the classical chunking framework where chunking happens _after_ token embeddings are output from the full document. This preserves contextual information from one chunk to another.



### Setup

First we install all required packages. We are using

In [1]:
# !pip install  torch numpy spacy transformers  

Then we load the packages and connect to the Weaviate client. Important, you need some API keys within a `.env` file:
- your Weaviate REST endpoint saved as `WEAVIATE_URL`
- your Weaviate API key saved as `WEAVIATE_KEY`
- if you want to run the final comparison in this notebook, an OpenAI API key saved as `OPENAI_API_KEY`, otherwise delete the `headers` argument in the `weaviate.connect_to_weaviate_cloud` function.


In [3]:
%%capture
# imports
import weaviate
import weaviate.classes as wvc
import weaviate.classes.config as wvcc

import os
import torch
import numpy as np 

import spacy
from spacy.tokens import Doc
from spacy.language import Language

import transformers
from transformers import AutoModel
from transformers import AutoTokenizer

# connect to weaviate
import weaviate

client = weaviate.connect_to_local()

print(client.is_ready())

Finally just for future-proofing, the versions of these packages are:

In [5]:
print(f"Weaviate version {weaviate.__version__}")
print(f"Pytorch version {torch.__version__}")
print(f"Numpy version {np.__version__}")
print(f"Spacy version {spacy.__version__}")
print(f"Transformers version {transformers.__version__}")

Weaviate version 0.1.dev3117+gae1bb03
Pytorch version 2.4.1+cu121
Numpy version 2.2.1
Spacy version 3.8.3
Transformers version 4.47.1


### Functions

Below are some general functions for chunking text into sentences, as well as the bulk of the operations behind late chunking.

Late chunking is simply the same chunks we would have on the naively chunked text, but the chunk embedding is taken from the pooling of the token embeddings, rather than an independently embedded chunk.

In [6]:
from sentence_transformers import SentenceTransformer


def sentence_chunker(document, batch_size=None):
    """
    Given a document (string), return the sentences as chunks and span annotations (start and end indices of chunks).  
    Using spacy to do this sentence chunking.
    """

    if batch_size is None:
        batch_size = 10000 # no of characters

    # Batch with spacy
    nlp = spacy.blank("en")
    nlp.add_pipe("sentencizer", config={"punct_chars": None})
    doc = nlp(document)

    docs = []
    for i in range(0, len(document), batch_size):
        batch = document[i : i + batch_size]
        docs.append(nlp(batch))

    doc = Doc.from_docs(docs)

    span_annotations = []
    chunks = []
    for i, sent in enumerate(doc.sents):
        span_annotations.append((sent.start, sent.end))
        chunks.append(sent.text)

    return chunks, span_annotations


def document_to_token_embeddings(model, tokenizer, document, batch_size=8192):
    """
    Given a model and tokenizer from HuggingFace, return token embeddings of the input text document.
    """

    if batch_size > 8192: # no of tokens
        raise ValueError("Batch size is too large. Please use a batch size of 8192 or less.")

    tokenized_document = tokenizer(document, return_tensors="pt")
    tokens = tokenized_document.tokens()
    
    # Batch in sizes of batch_size
    outputs = []
    for i in range(0, len(tokens), batch_size):
        
        start = i
        end   = min(i + batch_size, len(tokens))

        # subset huggingface tokenizer outputs to i : i + batch_size
        batch_inputs = {k: v[:, start:end] for k, v in tokenized_document.items()}

        with torch.no_grad():
            model_output = model(**batch_inputs)

        outputs.append(model_output.last_hidden_state)

    model_output = torch.cat(outputs, dim=1)
    return model_output

def late_chunking(token_embeddings, span_annotation, max_length=None):
    """
    Given the token-level embeddings of document and their corresponding span annotations (start and end indices of chunks in terms of tokens),
    late chunking pools the token embeddings for each chunk.
    """
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go beyond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = []
        for start, end in annotations:
            
            if (end - start) >= 1:
                # print(f"start: {start}, end: {end}")
                # print(f"{[e[:5] for e in embeddings[start:end]]}")
                pooled_embeddings.append(
                    embeddings[start:end].sum(dim=0) / (end - start)
                )
                    
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs
  

### Import into Weaviate

We aim to perform late chunking, obtain the contextually-aware embeddings, and then import these into a Weaviate collection.

First, create a Weaviate collection called `test_late_chunking`.

In [7]:
if client.collections.exists("test_late_chunking"):
    client.collections.delete("test_late_chunking")

# important to specify the config as none here, because we will be supplying our own vector embeddings in the form of the late chunking embeddings
late_chunking_collection = client.collections.create(
    name="test_late_chunking",
    vectorizer_config=wvc.config.Configure.Vectorizer.none(),
)

Now let's use a test document - the wikipedia page for Berlin (saved in a separate text file). We will later query this text using late chunking/naive chunking.

In [8]:
with open("berlin.txt", "r", encoding="utf-8") as f:
    document = f.read()

print(f"First 50 characters of the document:\n{document[:150]}...")


First 50 characters of the document:
Berlin[a] is the capital and largest city of Germany, both by area and by population.[11] Its more than 3.85 million inhabitants[12] make it the Europ...


Now, load the  jinaai/jina-embeddings-v3  model from Huggingface. Other embedding models can be used, but Jina's model has up to 8192 token length documents, which is important for late chunking as we want to encode large documents and separate them later.

In [None]:
tokenizer = AutoTokenizer.from_pretrained('/mnt/ceph/develop/jiawei/model_checkpoint/jina-embeddings-v3', trust_remote_code=True)
model     = AutoModel.from_pretrained('/mnt/ceph/develop/jiawei/model_checkpoint/jina-embeddings-v3', trust_remote_code=True)

We call our functions we defined earlier: First chunk the text as normal, to obtain the beginning and end points of the chunks. Then embed the full document. Then perform the late chunking step - take the average over all token embeddings that correspond to each chunk (based on the beginning/end points of the chunks). These form as our embeddings for the chunks.

In [13]:
chunks, span_annotations = sentence_chunker(document)
token_embeddings = document_to_token_embeddings(model, tokenizer, document)
chunk_embeddings = late_chunking(token_embeddings, [span_annotations])[0]

TypeError: Got unsupported ScalarType BFloat16

In [None]:

# 找到字符长度为1700的内容
target_chunks = [chunk for chunk in chunks if '1700' in chunk]
target_indexes = [index for index, chunk in enumerate(chunks)  if '1700' in chunk]

# 输出结果
print(target_chunks)
print(target_indexes)
 

Finally, we can add this to our Weaviate collection by supplying our own vector embedding for each chunk.

In [69]:
# add data with manual embeddings
data = []
for i in range(len(chunks)):
    data.append(wvc.data.DataObject(
            properties={
                "content": chunks[i]
            },
            vector = chunk_embeddings[i].tolist()
    )
)

late_chunking_collection.data.insert_many(data);

### Example Query

First, define two functions to process queries. One using our Weaviate collection, and a different, slower search using cosine similarity running locally that we will use for comparison.

In [70]:
def late_chunking_query_function_weaviate(query, k = 3):
    query_vector = model(**tokenizer(query, return_tensors="pt")).last_hidden_state.mean(1).detach().cpu().numpy().flatten()

    results = late_chunking_collection.query.near_vector(
        near_vector=query_vector.tolist(),
        limit = k
    )

    return [res.properties["content"] for res in results.objects]

def late_chunking_query_function_cosine_sim(query, k = 3):

    cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

    query_vector = model(**tokenizer(query, return_tensors="pt")).last_hidden_state.mean(1).detach().cpu().numpy().flatten()

    results = np.empty(len(chunk_embeddings))
    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
        results[i] = cos_sim(query_vector, embedding)

    results_order = results.argsort()[::-1]
    return np.array(chunks)[results_order].tolist()[:k]

Test both search functions.

In [71]:
late_chunking_query_function_weaviate("17th to 19th centuries at 1700 year's Berlin's residents were French Proportion?", 10)

["1920s Berlin was the third-largest city in the world by population.[18]\n\nAfter World War II and following Berlin's occupation, the city was split into West Berlin and East Berlin, divided by the Berlin Wall.[19] East Berlin was declared the capital of East Germany, while Bonn became the West German capital.",
 'During the Gründerzeit, an industrialization-induced economic boom triggered a rapid population increase in Berlin.',
 'Berlin has served as a scientific, artistic, and philosophical hub during the Age of Enlightenment, Neoclassicism, and the German revolutions of 1848–1849.',
 'Following German reunification in 1990, Berlin once again became the capital of all of Germany.',
 'Due to its geographic location and history, Berlin has been called "the heart of Europe".[20][21][22]\n\nThe economy of Berlin is based on high tech and the service sector, encompassing a diverse range of creative industries, startup companies, research facilities, and media corporations.[23][24] Berli

In [72]:
late_chunking_query_function_cosine_sim("17th to 19th centuries at 1700 year's Berlin's residents were French Proportion?", 10)

  cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))


['Its name commemorates the uprisings in East Berlin of 17 June 1953.',
 'Approximately halfway from the Brandenburg Gate is the Großer Stern, a circular traffic island on which the Siegessäule (Victory Column) is situated.',
 "This monument, built to commemorate Prussia's victories, was relocated in 1938–39 from its previous position in front of the Reichstag.",
 "\n\nThe Kurfürstendamm is home to some of Berlin's luxurious stores with the Kaiser Wilhelm Memorial Church at its eastern end on Breitscheidplatz.",
 'The church was destroyed in the Second World War and left in ruins.',
 "Nearby on Tauentzienstraße is KaDeWe, claimed to be continental Europe's largest department store.",
 'The Rathaus Schöneberg, where John F. Kennedy made his famous "Ich bin ein Berliner!"',
 'speech, is in Tempelhof-Schöneberg.',
 '\n\nWest of the center, Bellevue Palace is the residence of the German President.',
 'Charlottenburg Palace, which was burnt out in the Second World War, is the largest histor

Both give the same results so we are confident that our vector search for late chunking works! We would expect something slightly different as Weaviate uses HNSW for a speedy search, and we have directly used cosine similarity, but in this case, they are the same.

For comparison, let's look at what a naive chunking method implemented with Weaviate's search would give us.

In [36]:
# create the weaviate collection chunked by sentences
if client.collections.exists("test_naive_chunking"):
    client.collections.delete("test_naive_chunking")

naive_chunking_collection = client.collections.create(
    name="test_naive_chunking",
    vectorizer_config=wvcc.Configure.Vectorizer.text2vec_transformers(),
            properties=[
                    wvcc.Property(name="content", data_type=wvcc.DataType.TEXT)
            ]
)
 

In [37]:
# add data with manual embeddings
data1 = []
for i in range(len(chunks)):
    data1.append(wvc.data.DataObject(
            properties={
                "content": chunks[i]
            },
            vector = chunk_embeddings[i].tolist()
    )
)

naive_chunking_collection.data.insert_many(data1);

In [38]:
def naive_chunking_query_function_weaviate(query, k=3):
    results = naive_chunking_collection.query.near_text(
        query = query,
        limit = k
    )

    return [res.properties["content"] for res in results.objects]


We can see that the naive chunking query still gives us good results - it matches more specifically with the question. Whereas the late chunking example skips straight to the chunks it _knows_ to be relevant, because they contain contextual information within the embeddings themselves!

In [40]:
naive_chunking_query_function_weaviate("Percentage of 1700 year's Berlin's residents were French?", 5)

["Berlin[a] is the capital and largest city of Germany, both by area and by population.[11] Its more than 3.85 million inhabitants[12] make it the European Union's most populous city, as measured by population within city limits.[13] The city is also one of the states of Germany, and is the third smallest state in the country in terms of area.",
 "Other gardens in the city include the Britzer Garten, and the Gärten der Welt (Gardens of the World) in Marzahn.[274]\n\n\nThe Victory Column in Tiergarten\nThe Tiergarten park in Mitte, with landscape design by Peter Joseph Lenné, is one of Berlin's largest and most popular parks.[275] In Kreuzberg, the Viktoriapark provides a viewing point over the southern part of inner-city Berlin.",
 'Temperatures can be 4 °C (7 °F) higher in the city than in the surrounding areas.[89] Annual precipitation is 570 millimeters (22 in) with moderate rainfall throughout the year.',
 'Berlin is the most populous city proper in the European Union.',
 "The Volk