## Late Chunking with Weaviate

Notebook author: Danny Williams @ weaviate (Developer Growth)

This notebook implements [late chunking](https://jina.ai/news/late-chunking-in-long-context-embedding-models/) with Weaviate. Late chunking is a change in the classical chunking framework where chunking happens _after_ token embeddings are output from the full document. This preserves contextual information from one chunk to another.



### Setup

First we install all required packages. We are using

In [1]:
# !pip install  torch numpy spacy transformers  

Then we load the packages and connect to the Weaviate client. Important, you need some API keys within a `.env` file:
- your Weaviate REST endpoint saved as `WEAVIATE_URL`
- your Weaviate API key saved as `WEAVIATE_KEY`
- if you want to run the final comparison in this notebook, an OpenAI API key saved as `OPENAI_API_KEY`, otherwise delete the `headers` argument in the `weaviate.connect_to_weaviate_cloud` function.


In [1]:
%%capture
# imports
import weaviate
import weaviate.classes as wvc
import weaviate.classes.config as wvcc

import os
import torch
import numpy as np 

import spacy
from spacy.tokens import Doc
from spacy.language import Language

import transformers
from transformers import AutoModel
from transformers import AutoTokenizer

# connect to weaviate
import weaviate

client = weaviate.connect_to_local()

print(client.is_ready())

Finally just for future-proofing, the versions of these packages are:

In [2]:
print(f"Weaviate version {weaviate.__version__}")
print(f"Pytorch version {torch.__version__}")
print(f"Numpy version {np.__version__}")
print(f"Spacy version {spacy.__version__}")
print(f"Transformers version {transformers.__version__}")

Weaviate version 0.1.dev3117+gae1bb03
Pytorch version 2.4.1+cu121
Numpy version 2.2.1
Spacy version 3.8.3
Transformers version 4.48.0.dev0


### Functions

Below are some general functions for chunking text into sentences, as well as the bulk of the operations behind late chunking.

Late chunking is simply the same chunks we would have on the naively chunked text, but the chunk embedding is taken from the pooling of the token embeddings, rather than an independently embedded chunk.

In [43]:
def chunk_by_sentences(input_text: str, tokenizer: callable):
    """
    Split the input text into sentences using the tokenizer
    :param input_text: The text snippet to split into sentences
    :param tokenizer: The tokenizer to use
    :return: A tuple containing the list of text chunks and their corresponding token spans
    """
    inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
    punctuation_mark_id = tokenizer.convert_tokens_to_ids('###')
    sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
    token_offsets = inputs['offset_mapping'][0]
    token_ids = inputs['input_ids'][0]
    chunk_positions = [
        (i, int(start + 1))
        for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
        if token_id == punctuation_mark_id
        and (
            token_offsets[i + 1][0] - token_offsets[i][1] > 0
            or token_ids[i + 1] == sep_id
        )
    ]
    chunks = [
        input_text[x[1] : y[1]]
        for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    span_annotations = [
        (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
    return chunks, span_annotations

def late_chunking(token_embeddings, span_annotation, max_length=None):
    """
    Given the token-level embeddings of document and their corresponding span annotations (start and end indices of chunks in terms of tokens),
    late chunking pools the token embeddings for each chunk.
    """
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go beyond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = []
        for start, end in annotations:
            
            if (end - start) >= 1:
                # print(f"start: {start}, end: {end}")
                # print(f"{[e[:5] for e in embeddings[start:end]]}")
                pooled_embeddings.append(
                    embeddings[start:end].sum(dim=0) / (end - start)
                )
                    
        pooled_embeddings = [
            embedding.detach().cpu().to(torch.float64).numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs
  
def late_chunking(
    model_output: 'BatchEncoding', span_annotation: list, max_length=None
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs


### Import into Weaviate

We aim to perform late chunking, obtain the contextually-aware embeddings, and then import these into a Weaviate collection.

First, create a Weaviate collection called `test_late_chunking`.

In [44]:
if client.collections.exists("test_late_chunking"):
    client.collections.delete("test_late_chunking")

# important to specify the config as none here, because we will be supplying our own vector embeddings in the form of the late chunking embeddings
late_chunking_collection = client.collections.create(
    name="test_late_chunking",
    vectorizer_config=wvc.config.Configure.Vectorizer.none(),
)

Now let's use a test document - the wikipedia page for Berlin (saved in a separate text file). We will later query this text using late chunking/naive chunking.

In [45]:
with open("RoleLLM  安全指令方案.md", "r", encoding="utf-8") as f:
    document = f.read()

print(f"First 50 characters of the document:\n{document[:150]}...")


First 50 characters of the document:
### RoleLLM:  安全指令方案

本文讨论一个从角色扮演中关于安全策略问题，通过概述角色扮演中信息的处理阶段、上下文的交互来指定安全策略

目标：角色扮演旨在使 LLM 能够或定制 LLM， 来模拟具有不同属性和会话风格的各种角色或人物角色，在这中间会涉及到一些安全性的问题，在对话角色时，...


Now, load the  jinaai/jina-embeddings-v3  model from Huggingface. Other embedding models can be used, but Jina's model has up to 8192 token length documents, which is important for late chunking as we want to encode large documents and separate them later.

In [7]:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('/mnt/ceph/develop/jiawei/model_checkpoint/jina-embeddings-v3', trust_remote_code=True)
model = AutoModel.from_pretrained('/mnt/ceph/develop/jiawei/model_checkpoint/jina-embeddings-v3', trust_remote_code=True).to(dtype=torch.float16, device='cuda:0') 

  def forward(
  def backward(ctx, dout, *args):


We call our functions we defined earlier: First chunk the text as normal, to obtain the beginning and end points of the chunks. Then embed the full document. Then perform the late chunking step - take the average over all token embeddings that correspond to each chunk (based on the beginning/end points of the chunks). These form as our embeddings for the chunks.

In [48]:
chunks, span_annotations = chunk_by_sentences(document, tokenizer)
print(f'Chunks:{len(chunks)}\n- "' + '"\n- "'.join(chunks[0:2]) + '"')

Chunks:4
- "#"
- "## RoleLLM:  安全指令方案

本文讨论一个从角色扮演中关于安全策略问题，通过概述角色扮演中信息的处理阶段、上下文的交互来指定安全策略

目标：角色扮演旨在使 LLM 能够或定制 LLM， 来模拟具有不同属性和会话风格的各种角色或人物角色，在这中间会涉及到一些安全性的问题，在对话角色时，不同的场景对应的规则也许不同，不同角色的效果一般由这几个基准来评分，说话风格模仿、回答准确性和特定角色知识捕获 



#"


In [49]:
inputs = tokenizer(document, return_tensors='pt').to(device='cuda:0')
model_output = model(**inputs)
chunk_embeddings = late_chunking(model_output, [span_annotations])[0] 

Finally, we can add this to our Weaviate collection by supplying our own vector embedding for each chunk.

In [51]:
# add data with manual embeddings
data = []
for i in range(len(chunks)):
    data.append(wvc.data.DataObject(
            properties={
                "content": chunks[i]
            },
            vector = chunk_embeddings[i].tolist()
    )
)

late_chunking_collection.data.insert_many(data);

### Example Query

First, define two functions to process queries. One using our Weaviate collection, and a different, slower search using cosine similarity running locally that we will use for comparison.

In [52]:
def late_chunking_query_function_weaviate(query, k = 3): 
   
    berlin_embedding = model.encode(query)

    results = late_chunking_collection.query.near_vector(
        near_vector=berlin_embedding.tolist(),
        limit = k
    )

    return [res.properties["content"] for res in results.objects]

def late_chunking_query_function_cosine_sim(query, k = 3):

    cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))
 
    query_vector =  model.encode(query)

    results = np.empty(len(chunk_embeddings))
    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
        results[i] = cos_sim(query_vector, embedding)

    results_order = results.argsort()[::-1]
    return np.array(chunks)[results_order].tolist()[:k]

Test both search functions.

In [53]:
late_chunking_query_function_weaviate("角色扮演会话的系统指令定制有哪些", 10)

['## 会话过程\n\n通过添加前置的信息在提示上下文中，要求模型在会话中使用要求的风格对话\n\n- 预处理数据\n\n在完成了数据预处理之后,一般会获得一个角色详细的事实数据，加入对话之前还需要进行角色预演，这个过程通常为了让角色能很好的遵从指令，防止一些事实内容的问题\n\n\u200b\t(1）分割角色档案\n\n\u200b\t(2）生成场景问题-内容-角色认定的答案数据集\n\n\u200b\t(3）对低质量数据进行过滤和后处理\n\n我们获得了一个角色事实数据集，载入结构化对话\n\n- 对话模板\n\n \t\t系统指令\n \t\t\t扮演的角色标签卡\n \t\t用户指令\n \t\t（1）角色描述和流行语，一直存在上下文中的指令\n \t\n \t\t（2）结构化对话 。基于对话事件指令\n\n- 生成问题-置信度-答案（对话事件指令描述 ）\n\n```\n\t为特定角色的训练数据生成数据集的过程，考虑了三个元素: \n\n\t1 与给定部分(即上下文)相关的问题(Q) ，\n\n\t2 相应的答案(A)\n\n\t3 具有基本原理的置信评分(C)\n\n\n```\n\n\n\n- 系统指令定制 !!!!!**(目前特定角色指令根据用户意图动态生成)**\n\n   ```\n   将系统指令与RoleGPT中的角色名称、描述、流行语和角色扮演任务指令一起准备到输入。 \n   在推理过程中，用户可以通过系统指令轻松修改LLM的角色，与检索增强相比，最大限度地减少了上下文窗口的消耗\n   ```\n\n\n\n\n\n##',
 '## RoleLLM:  安全指令方案\n\n本文讨论一个从角色扮演中关于安全策略问题，通过概述角色扮演中信息的处理阶段、上下文的交互来指定安全策略\n\n目标：角色扮演旨在使 LLM 能够或定制 LLM， 来模拟具有不同属性和会话风格的各种角色或人物角色，在这中间会涉及到一些安全性的问题，在对话角色时，不同的场景对应的规则也许不同，不同角色的效果一般由这几个基准来评分，说话风格模仿、回答准确性和特定角色知识捕获 \n\n\n\n#',
 '## 会话指令\n\n在整个会话过程中，用户对话验证可以插入到不同的阶段中，目前的系统对话被分为如下阶段，1、数据预处理阶段；2、系统指令生成阶段；3、上下文会话检索阶段

In [54]:
late_chunking_query_function_cosine_sim("角色扮演会话的系统指令定制有哪些", 10)

['## 会话过程\n\n通过添加前置的信息在提示上下文中，要求模型在会话中使用要求的风格对话\n\n- 预处理数据\n\n在完成了数据预处理之后,一般会获得一个角色详细的事实数据，加入对话之前还需要进行角色预演，这个过程通常为了让角色能很好的遵从指令，防止一些事实内容的问题\n\n\u200b\t(1）分割角色档案\n\n\u200b\t(2）生成场景问题-内容-角色认定的答案数据集\n\n\u200b\t(3）对低质量数据进行过滤和后处理\n\n我们获得了一个角色事实数据集，载入结构化对话\n\n- 对话模板\n\n \t\t系统指令\n \t\t\t扮演的角色标签卡\n \t\t用户指令\n \t\t（1）角色描述和流行语，一直存在上下文中的指令\n \t\n \t\t（2）结构化对话 。基于对话事件指令\n\n- 生成问题-置信度-答案（对话事件指令描述 ）\n\n```\n\t为特定角色的训练数据生成数据集的过程，考虑了三个元素: \n\n\t1 与给定部分(即上下文)相关的问题(Q) ，\n\n\t2 相应的答案(A)\n\n\t3 具有基本原理的置信评分(C)\n\n\n```\n\n\n\n- 系统指令定制 !!!!!**(目前特定角色指令根据用户意图动态生成)**\n\n   ```\n   将系统指令与RoleGPT中的角色名称、描述、流行语和角色扮演任务指令一起准备到输入。 \n   在推理过程中，用户可以通过系统指令轻松修改LLM的角色，与检索增强相比，最大限度地减少了上下文窗口的消耗\n   ```\n\n\n\n\n\n##',
 '## RoleLLM:  安全指令方案\n\n本文讨论一个从角色扮演中关于安全策略问题，通过概述角色扮演中信息的处理阶段、上下文的交互来指定安全策略\n\n目标：角色扮演旨在使 LLM 能够或定制 LLM， 来模拟具有不同属性和会话风格的各种角色或人物角色，在这中间会涉及到一些安全性的问题，在对话角色时，不同的场景对应的规则也许不同，不同角色的效果一般由这几个基准来评分，说话风格模仿、回答准确性和特定角色知识捕获 \n\n\n\n#',
 '## 会话指令\n\n在整个会话过程中，用户对话验证可以插入到不同的阶段中，目前的系统对话被分为如下阶段，1、数据预处理阶段；2、系统指令生成阶段；3、上下文会话检索阶段

Both give the same results so we are confident that our vector search for late chunking works! We would expect something slightly different as Weaviate uses HNSW for a speedy search, and we have directly used cosine similarity, but in this case, they are the same.

For comparison, let's look at what a naive chunking method implemented with Weaviate's search would give us.

In [55]:
# create the weaviate collection chunked by sentences
if client.collections.exists("test_naive_chunking"):
    client.collections.delete("test_naive_chunking")

naive_chunking_collection = client.collections.create(
    name="test_naive_chunking",
    vectorizer_config=wvcc.Configure.Vectorizer.text2vec_transformers(),
            properties=[
                    wvcc.Property(name="content", data_type=wvcc.DataType.TEXT)
            ]
)
 

In [56]:
# add data with manual embeddings
data1 = []
for i in range(len(chunks)):
    data1.append(wvc.data.DataObject(
            properties={
                "content": chunks[i]
            },
            vector = chunk_embeddings[i].tolist()
    )
)

naive_chunking_collection.data.insert_many(data1);

In [69]:
from weaviate.classes.query import MetadataQuery
def naive_chunking_query_function_weaviate(query, k=3):
 
    response = naive_chunking_collection.query.near_text(
        query=query,   
        limit=k,
        return_metadata=MetadataQuery(score=True, explain_score=True),
    )
    for o in response.objects:
        print("111111111111111111") 
        print(o.properties['content']) 
        print(o.metadata)

We can see that the naive chunking query still gives us good results - it matches more specifically with the question. Whereas the late chunking example skips straight to the chunks it _knows_ to be relevant, because they contain contextual information within the embeddings themselves!

In [70]:
naive_chunking_query_function_weaviate("角色扮演会话的系统指令定制有哪些", 5)

111111111111111111
## RoleLLM:  安全指令方案

本文讨论一个从角色扮演中关于安全策略问题，通过概述角色扮演中信息的处理阶段、上下文的交互来指定安全策略

目标：角色扮演旨在使 LLM 能够或定制 LLM， 来模拟具有不同属性和会话风格的各种角色或人物角色，在这中间会涉及到一些安全性的问题，在对话角色时，不同的场景对应的规则也许不同，不同角色的效果一般由这几个基准来评分，说话风格模仿、回答准确性和特定角色知识捕获 



#
MetadataReturn(creation_time=None, last_update_time=None, distance=None, certainty=None, score=0.0, explain_score='', is_consistent=None, rerank_score=None)
111111111111111111
#
MetadataReturn(creation_time=None, last_update_time=None, distance=None, certainty=None, score=0.0, explain_score='', is_consistent=None, rerank_score=None)
111111111111111111
## 会话过程

通过添加前置的信息在提示上下文中，要求模型在会话中使用要求的风格对话

- 预处理数据

在完成了数据预处理之后,一般会获得一个角色详细的事实数据，加入对话之前还需要进行角色预演，这个过程通常为了让角色能很好的遵从指令，防止一些事实内容的问题

​	(1）分割角色档案

​	(2）生成场景问题-内容-角色认定的答案数据集

​	(3）对低质量数据进行过滤和后处理

我们获得了一个角色事实数据集，载入结构化对话

- 对话模板

 		系统指令
 			扮演的角色标签卡
 		用户指令
 		（1）角色描述和流行语，一直存在上下文中的指令
 	
 		（2）结构化对话 。基于对话事件指令

- 生成问题-置信度-答案（对话事件指令描述 ）

```
	为特定角色的训练数据生成数据集的过程，考虑了三个元素: 

	1 与给定部分(即上下文)相关的问题(Q) ，

	2 相应的答案(A)

	3 具有基本原理