In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
import sys
sys.path.append('..')

In [3]:
from pathlib import Path
from fastcore.xtras import *
import torch
import torch.nn.functional as F

In [4]:
from framework.documents import load_docling
from docling_core.types.doc import ImageRefMode

# Jina Late Chunking


In [5]:
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import numpy as np

# model_name = "Alibaba-NLP/gte-modernbert-base"
model_name = "nomic-ai/modernbert-embed-base"
# model_name = "jinaai/jina-embeddings-v2-base-en"

MAX_LEN = 8192

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
enc_model = SentenceTransformer(model_name)

In [6]:
# Load in the document
doc = load_docling(Path("scratch/sample_doc.json"))

In [7]:
# NOTE: removing image placeholders in the document for now
# text = doc.export_to_markdown(image_mode=ImageRefMode.PLACEHOLDER, image_placeholder="</IMAGE>")
text = doc.export_to_text()

Parameter `strict_text` has been deprecated and will be ignored.


In [8]:
type(doc)

docling_core.types.doc.document.DoclingDocument

In [9]:
# see how long the document is
inputs = tokenizer(
    text,
    return_tensors='pt',
    return_offsets_mapping=True,

    # NOTE: we are passing in the first chunk
    truncation=True,
    max_length=MAX_LEN,
); len(inputs)

3

In [10]:
inputs['input_ids'].shape

torch.Size([1, 8192])

# Sanity checking that transformers and sentence transformers are giving the same results

In [11]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )

In [12]:
e1 = enc_model.encode([text[:100]])
e1.shape, type(e1)


((1, 768), numpy.ndarray)

In [13]:
t2 = tokenizer(text[:100], return_tensors='pt')
e2 = model(**t2)
e2 = mean_pooling(e2, t2['attention_mask'])
e2.shape, type(e2)

(torch.Size([1, 768]), torch.Tensor)

In [14]:
e1
e2
cos_sim(e1, e2)


tensor([[1.0000]], grad_fn=<MmBackward0>)

In [15]:
# Mapping back from token ids in the inputs to the original text, so we know what's what

In [16]:
# pull out a sample token from the encoded inputs
token_ids = inputs['input_ids'][0]
_sample_id = 100
_token_one_id = token_ids[_sample_id]
_token_one_id # <- This is the token id at this place in encoded inputs

tensor(656)

In [17]:
# using the mappings, we can find out where in the text this token is
token_offsets = inputs['offset_mapping'][0]
_token_offset = token_offsets[_sample_id]
_token_offset

tensor([340, 342])

In [18]:
# sure enough, we can index into the original text with the offset and be anchored in our context
region = text[
    _token_offset[0]:_token_offset[1]
]; region

'ys'

In [19]:
# NOTE: this should agree with the token_id we pulled earlier.
tokenizer.vocab[region]; assert tokenizer.vocab[region] == _token_one_id.item()

# Doing the embedding

In [20]:
# see how long the document is
inputs = tokenizer(
    text,
    return_tensors='pt',

    # NOTE: we are passing in the first chunk
    truncation=True,
    max_length=MAX_LEN,
); len(inputs)

2

In [21]:
inputs['input_ids'].shape

torch.Size([1, 8192])

In [22]:
# Now, we can embed each token

In [23]:
model_output = model(**inputs)
token_embeddings = model_output[0]
token_embeddings.shape

torch.Size([1, 8192, 768])

> Each of these represents a word's meaning in the context of this document, not just an isolated meaning of the word itself. That's the genius of this model.

# Chunking the doc-level embeddings

In [24]:
# we can do this by periods in the token space, with the added benefit that now each token includes the context of the previous tokens
punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
punctuation_mark_id

15

We can use period + space to get the start of a chunk. We're basically doing sentence encoding here. 

In [25]:
chunk_pos, token_spans = [], []

span_start_char, span_start_token = 0, 0 

for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets)):
    # iterate through the tokens
    if i < len(token_ids) - 1:
        # check if we have a period, followed by a space or newline. 
        if token_id == punctuation_mark_id and text[end:end+1] in [' ', '\n']:
            # store both char positions, and token start and ends
            chunk_pos.append(
                (span_start_char, int(end))
            )
            token_spans.append(
                (span_start_token, i+1)
            )

            # update pos for the next chunk
            span_start_char, span_start_token = int(end)+1, i+1

We can print out some of the chunks now

In [26]:
for i in range(3):
    char_start, char_end = chunk_pos[i]
    tok_start, tok_end = token_spans[i]

    print(f"Chunk: {i}")
    print(f'Char span: ({char_start}:{char_end}): {text[char_start:char_end].strip("\n")}')
    # print(f'Token span: ({tok_start}:{tok_end}): {token_ids[tok_start:tok_end]}')

Chunk: 0
Char span: (0:960): Article

## DNA Damage in Moderate and Severe COVID-19 Cases: Relation to Demographic, Clinical, and Laboratory Parameters

Kalashyan 1D , Naira Stepanyan (D) Hovhannisyan 1,2 @ Lily

- Laboratory of General and Molecular Genetics , Research Institute of Yerevan State University, Alex Manoogian 1, Yerevan 0025, Armenia; tigranharutyunyan@ysu.am (TH); angela.sargsyan@ysu.am (A.S.); lilikalashyan@ysuam (LK.); genetik@ysu.am (RA.); galinahovhannisyan@ysu.am (G.H:) Biology
- Department of Genetics and Cytology; Yerevan State University, Alex Manoogian 1, Yerevan 0025 , Armenia
- National Center for Infectious Diseases , Arno Babajanyan 21, Yerevan 0064, Armenia; nsstepanyang@gmailcom
- Jena University Hospital, Institute of Human Genetics, Friedrich Schiller University, Am Klinikum 1, D-07747 Jena, Germany

Abstract: The of the SARS-CoV-2 virus to cause DNA damage in infected humans requires its study as a potential indicator of COVID-19 progression.
Chunk: 1
C

In [27]:
# now we can use our chunks in token space to chunk the token embeddings

In [28]:
start_token, end_token = token_spans[0]
chunk_embedding = token_embeddings[0, start_token:end_token]
chunk_embedding.shape


torch.Size([297, 768])

This is how many tokens are in the first chunk, and we have an embedding for each token. But, we need to "average" these down to get a single embedding for the chunk.

In [29]:
# pool the embeddings
chunk_embedding = chunk_embedding.mean(dim=0)
chunk_embedding.shape

torch.Size([768])

In [30]:
len(token_spans)

78

In [31]:
embeddings = []

# for each token span, calculate the mean of its token embeddings
for start, end in token_spans:
    if end > start: # ensure span has at least one token
        # mean pool the token embeddings for this chunk
        chunk_embed = token_embeddings[0, start:end].mean(dim=0)
        embeddings.append(chunk_embed)

len(embeddings)

78

# Creating a function for it all

In [32]:
def late_chunking(document, model, tokenizer):
    "Implements late chunking on a document."

    # Tokenize with offset mapping to find sentence boundaries
    inputs_with_offsets = tokenizer(
        document,
        return_tensors='pt',
        return_offsets_mapping=True,
        truncation=True,
        max_length=MAX_LEN,
    )
    token_offsets = inputs_with_offsets['offset_mapping'][0]
    token_ids = inputs_with_offsets['input_ids'][0]
    
    # Find chunk boundaries
    punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')    
    chunk_positions, token_span_annotations = [], []
    span_start_char, span_start_token = 0, 0

    for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets)):
        if i < len(token_ids)-1:
            if token_id == punctuation_mark_id and document[end:end+1] in [' ', '\n']:
                # Store both character positions and token positions
                chunk_positions.append((span_start_char, int(end)))
                token_span_annotations.append((span_start_token, i+1))
                
                # Update start positions for next chunk
                span_start_char, span_start_token = int(end)+1, i+1
    
    # Create text chunks from character positions
    chunks = [document[start:end].strip() for start, end in chunk_positions]
    
    # Encode the entire document
    inputs = tokenizer(
        document,
        return_tensors='pt',
        truncation=True,
        max_length=MAX_LEN,
    )
    model_output = model(**inputs)
    token_embeddings = model_output[0]
    
    # Create embeddings for each chunk using mean pooling
    embeddings = []
    for start_token, end_token in token_span_annotations:
        if end_token > start_token:  # Ensure span has at least one token
            chunk_embedding = token_embeddings[0, start_token:end_token].mean(dim=0)
            embeddings.append(chunk_embedding.detach().cpu().numpy())

    embeddings = np.stack(embeddings)
    
    return chunks, embeddings

In [33]:
# get the embeddings
late_chunks, late_embeds = late_chunking(text, model, tokenizer)

In [34]:
def matched_late_retrieval(query, chunks, chunk_embeddings, top_k=3):
    """Retrieve the most relevant chunk for a query."""

    # embed the query, pooling as we did the chunks
    query_tokens = tokenizer(query, return_tensors='pt')
    query_embeddings = model(**query_tokens)
    query_embedding = mean_pooling(query_embeddings, query_tokens['attention_mask'])
    
    # find similarities between query and chunks
    similarities = cos_sim(query_embedding, chunk_embeddings).detach().cpu().numpy().squeeze()
    
    # sort the most similar chunks
    top_idx = np.argsort(similarities)[::-1][:top_k]

    # get the top chunks and their similarities
    top_chunks = [chunks[i] for i in top_idx]
    top_sims = [similarities[i] for i in top_idx]
    
    return top_chunks, top_sims

In [35]:
# sample saving to file
embeds_arr, chunks_arr = np.array(late_embeds), np.array(late_chunks)
np.savez("scratch/late_embeds.npz", embeds=embeds_arr, chunks=chunks_arr)

In [36]:
best, sims = matched_late_retrieval(
    "Was COVID more severe in men or women?",
    late_chunks,
    late_embeds,
    top_k=5,
)

In [37]:
print(best)

['The only difference found between men and women with severe COVID-19 was that the procalcitonin (inflammation index) was higher in men than in women (Table 3).\n\nTable 3.', 'Men with severe versus moderate illness had higher BMI and CRP: Women with severe versus moderate disease had higher INR levels and longer hospital stays.', 'Thus, the sexual composition of the two groups was not significantly different.', 'In the group of severely ill patients , 55.2% were women and 44.8% were men.', 'WBC, white blood cells; NEU , neutrophils; LYM, lymphocytes; NLR, neutrophil to lymphocyte ratio; PLT, platelets; CRP; C-reactive protein; PCT, procalcitonin; INR, international nor malized ratio; APTT, activated thromboplastin time; ALT, alanine transaminase; AST, aspartate transferase; LOS, length of hospital stay partial\n\nLaboratory parameters were analyzed in COVID-19 patients in the context of age- and sex-related changes (Tables 2 and 3) Both men and women with COVID-19 were older in the s

# Look into the specific text

In [38]:
text_subset = ' '.join(late_chunks)

In [39]:
# print(text_subset)

In [40]:
late_chunks[1]

'DNA damage was studied in leukocytes of 65 COVID-19 patients stratified by sex, age, and disease severity in relation to demographic, clinical, and laboratory parameters.'

# Meta

Why Late Chunking Works
Late chunking solves the lost context problem in several important ways:

Bidirectional context awareness: Each token embedding is influenced by all other tokens in the document, both before and after it. This means references like "the city" can be properly linked to "Berlin" mentioned earlier.

Consistent representation: All chunks from the same document share the same contextual foundation, ensuring that related concepts are represented similarly regardless of which chunk they appear in.

Preservation of long-range dependencies: Information from the beginning of a document can influence the representation of content at the end, maintaining semantic connections across the entire text.

Resilience to boundary selection: Since each token's embedding already contains document-wide context, the specific chunking boundaries become less critical. This means simpler chunking strategies can work just as well as complex ones.

The Importance of Long-Context Models
Late chunking requires embedding models that can handle long contexts—ideally 8K tokens or more. These models aren't just standard embedding models with longer input windows; they're specifically designed to maintain coherent representations across thousands of tokens.

The key advantages of these long-context models for late chunking include:

Attention across the entire document: They can attend to relationships between distant parts of the text
Training on document-level tasks: They're often fine-tuned on tasks that require understanding document structure
Optimized pooling strategies: They use pooling methods that effectively compress long sequences
Without these capabilities, late chunking wouldn't be possible or effective.