# Chunk the texts to context window sized chunks for embedding generation

In order to generate embeddings for each case's content, we need to chunk the context to contex window sized chunks, generate embeddings for all chunks and then aggregate the embeddings to create one embedding per case. This notebook outlines the steps I undertook to chunk the content in preparation for embedding generation.

# Import Libraries

In [9]:
#%pip install nupunkt -q
#%pip install transformers -q

import ast

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from nupunkt import sent_tokenize
from transformers import AutoTokenizer

# Global Variables & Helper Functions

In [2]:
MODEL_NAME = "answerdotai/ModernBERT-base"
MAX_TOKENS = 8192
MAX_SENTENCE = int(0.002 * MAX_TOKENS)

TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)

In [3]:
def split_text_into_chunks(text, tokenizer=TOKENIZER, max_tokens=MAX_TOKENS, num_overlap_sentences=MAX_SENTENCE) -> list[str]:
    """Split text into chunks based on sentences, not exceeding max_tokens, with sentence overlap"""

    # Split the text to sentences & encode sentences with tokenizer
    sentences = sent_tokenize(text)
    encoded_sentences = [
        tokenizer.encode(sentence, add_special_tokens=False)
        for sentence in sentences
    ]
    lead_text = ""
    lead_tokens = tokenizer.encode(lead_text)
    lead_len = len(lead_tokens)
    chunks = []
    current_chunks: list[str] = []
    current_token_counts = len(lead_tokens)

    for sentence_tokens in encoded_sentences:
        sentence_len = len(sentence_tokens)
        #print(sentence_len)
        # if the current sentence itself is above max_tokens
        if lead_len + sentence_len > max_tokens:
            # store the previous chunk
            if current_chunks:
                chunks.append(lead_text + " ".join(current_chunks))
            # truncate the sentence and store the truncated sentence as its own chunk
            truncated_sentence = tokenizer.decode(
                sentence_tokens[: (max_tokens - len(lead_tokens))]
            )
            chunks.append(lead_text + truncated_sentence)

            # start a new chunk with no overlap (because adding the current sentence will exceed the max_tokens)
            current_chunks = []
            current_token_counts = lead_len
            continue

        # if adding the new sentence will cause the chunk to exceed max_tokens
        if current_token_counts + sentence_len > max_tokens:
            overlap_sentences = current_chunks[
                -max(0, num_overlap_sentences) :
            ]
            # store the previous chunk
            if current_chunks:
                chunks.append(lead_text + " ".join(current_chunks))

            overlap_token_counts = tokenizer.encode(
                " ".join(overlap_sentences), add_special_tokens=False
            )
            # If the sentence with the overlap exceeds the limit, start a new chunk without overlap.
            if (
                lead_len + len(overlap_token_counts) + sentence_len
                > max_tokens
            ):
                current_chunks = [tokenizer.decode(sentence_tokens)]
                current_token_counts = lead_len + sentence_len
            else:
                current_chunks = overlap_sentences + [
                    tokenizer.decode(sentence_tokens)
                ]
                current_token_counts = (
                    lead_len + len(overlap_token_counts) + sentence_len
                )
            continue

        # if within max_tokens, continue to add the new sentence to the current chunk
        current_chunks.append(tokenizer.decode(sentence_tokens))
        current_token_counts += len(sentence_tokens)

    # store the last chunk if it has any content
    if current_chunks:
        chunks.append(lead_text + " ".join(current_chunks))
    return chunks

In [10]:
abbreviations = {
    'COVID-19': 'COVID',
    'Benefits (Source)': 'BENEFITS',
    'LGBTQ+': 'LGBTQ',
    'Reproductive rights': 'REPRO',
    'Policing': 'POLICING',
    'Affected National Origin/Ethnicity(s)': 'NATION_ORIG',
    'Voting': 'VOTE',
    'Immigration/Border': 'IMMIGRATION',
    'Medical/Mental Health Care': 'MED',
    'Disability and Disability Rights': 'DISABILITY',
    'Affected Race(s)': 'RACE',
    'EEOC-centric': 'EEOC',
    'Jails, Prisons, Detention Centers, and Other Institutions': 'PRISON',
    'Affected Sex/Gender(s)': 'GENDER',
    'Discrimination Area': 'DISC_AREA',
    'Discrimination Basis': 'DISC_BASE',
    'General/Misc.': 'GENERAL'
}

# Process each file to create chunks of 8192 context window for generating embeddings

In [6]:
for key, filename in abbreviations.items():
    print("processing: ", filename)
    df = pd.read_json(f'data/train/clean/{filename}.json')
    
    df['label'] = df['issue_category'].apply(lambda x: 1 if f'{key}' in x else 0)
    df = df[["case_id", "content", "label"]]
    
    for index, row in df.iterrows():
        content = row["content"].replace(" .", "").replace("=", "").replace(" - ", "")
        chunks = split_text_into_chunks(content)
        df.at[index, "chunks"] = f"{chunks}"
        df.at[index, "clean_content"] = content
        df.at[index, "num_chunks"] = len(chunks)
    
    df = df[["case_id", "clean_content", "chunks", "num_chunks", "label"]]
    
    display(df["num_chunks"].describe())

    df.to_json(f'data/train/clean/chunks/{filename}.json')

processing:  COVID


Token indices sequence length is longer than the specified maximum sequence length for this model (26739 > 8192). Running this sequence through the model will result in indexing errors


count    240.000000
mean       9.070833
std       14.054853
min        1.000000
25%        2.000000
50%        4.000000
75%       10.250000
max      132.000000
Name: num_chunks, dtype: float64

processing:  BENEFITS


count    240.000000
mean       6.762500
std        9.008472
min        1.000000
25%        2.000000
50%        3.000000
75%        8.000000
max       64.000000
Name: num_chunks, dtype: float64

processing:  LGBTQ


count    240.000000
mean       5.483333
std        6.396238
min        1.000000
25%        1.750000
50%        3.000000
75%        7.000000
max       58.000000
Name: num_chunks, dtype: float64

processing:  REPRO


count    240.000000
mean       6.491667
std        9.222490
min        1.000000
25%        2.000000
50%        4.000000
75%        7.000000
max       96.000000
Name: num_chunks, dtype: float64

processing:  POLICING


count    240.0000
mean       6.6125
std        7.3641
min        1.0000
25%        1.0000
50%        4.0000
75%        8.2500
max       44.0000
Name: num_chunks, dtype: float64

processing:  NATION_ORIG


count    240.000000
mean       7.829167
std       16.062378
min        1.000000
25%        1.000000
50%        3.000000
75%        7.000000
max      159.000000
Name: num_chunks, dtype: float64

processing:  VOTE


count    240.000000
mean       6.825000
std       13.039151
min        1.000000
25%        1.000000
50%        3.000000
75%        7.000000
max      159.000000
Name: num_chunks, dtype: float64

processing:  IMMIGRATION


count    240.000000
mean       7.733333
std       12.487807
min        1.000000
25%        2.000000
50%        4.000000
75%        8.000000
max      132.000000
Name: num_chunks, dtype: float64

processing:  MED


count    240.000000
mean       5.850000
std        6.905398
min        1.000000
25%        1.000000
50%        3.000000
75%        8.000000
max       46.000000
Name: num_chunks, dtype: float64

processing:  DISABILITY


count    240.000000
mean       7.004167
std        9.325584
min        1.000000
25%        2.000000
50%        4.000000
75%        8.000000
max       68.000000
Name: num_chunks, dtype: float64

processing:  RACE


count    240.000000
mean       6.837500
std       10.260348
min        1.000000
25%        1.000000
50%        3.000000
75%        9.000000
max       92.000000
Name: num_chunks, dtype: float64

processing:  EEOC


count    240.000000
mean       4.362500
std        7.483165
min        1.000000
25%        1.000000
50%        2.000000
75%        4.000000
max       68.000000
Name: num_chunks, dtype: float64

processing:  PRISON


count    240.000000
mean       5.616667
std        7.169418
min        1.000000
25%        1.000000
50%        3.000000
75%        6.000000
max       44.000000
Name: num_chunks, dtype: float64

processing:  GENDER


count    240.000000
mean       4.650000
std        7.427924
min        1.000000
25%        1.000000
50%        2.000000
75%        6.000000
max       60.000000
Name: num_chunks, dtype: float64

processing:  DISC_AREA


count    240.000000
mean       7.045833
std       11.242942
min        1.000000
25%        1.000000
50%        3.000000
75%        7.000000
max       92.000000
Name: num_chunks, dtype: float64

processing:  DISC_BASE


count    240.000000
mean       5.900000
std        8.266381
min        1.000000
25%        1.000000
50%        3.000000
75%        7.000000
max       58.000000
Name: num_chunks, dtype: float64

processing:  GENERAL


count    240.000000
mean       5.445833
std        6.833739
min        1.000000
25%        1.000000
50%        3.000000
75%        7.000000
max       48.000000
Name: num_chunks, dtype: float64

In [7]:
filename = "val"
print("processing: ", filename)
df = pd.read_json(f'data/val/clean/{filename}.json')

for index, row in df.iterrows():
    content = row["content"].replace(" .", "").replace("=", "").replace(" - ", "")
    chunks = split_text_into_chunks(content)
    df.at[index, "chunks"] = f"{chunks}"
    df.at[index, "clean_content"] = content
    df.at[index, "num_chunks"] = len(chunks)

df = df[["case_id", "clean_content", "chunks", "num_chunks", "issue_category"]]

display(df["num_chunks"].describe())

df.to_json(f'data/val/clean/chunks/{filename}.json')

processing:  val


count    800.000000
mean       5.400000
std        8.283464
min        1.000000
25%        1.000000
50%        3.000000
75%        6.000000
max      108.000000
Name: num_chunks, dtype: float64

In [8]:
filename = "test"
print("processing: ", filename)
df = pd.read_json(f'data/test/clean/{filename}.json')

for index, row in df.iterrows():
    content = row["content"].replace(" .", "").replace("=", "").replace(" - ", "")
    chunks = split_text_into_chunks(content)
    df.at[index, "chunks"] = f"{chunks}"
    df.at[index, "clean_content"] = content
    df.at[index, "num_chunks"] = len(chunks)

df = df[["case_id", "clean_content", "chunks", "num_chunks", "issue_category"]]

display(df["num_chunks"].describe())

df.to_json(f'data/test/clean/chunks/{filename}.json')

processing:  test


count    800.000000
mean       5.820000
std        9.515029
min        1.000000
25%        1.000000
50%        3.000000
75%        6.000000
max       99.000000
Name: num_chunks, dtype: float64