# Testing max_chunk_tokens Feature with IMDB Reviews

This notebook tests the new `max_chunk_tokens` and `split_long_sents` parameters added to the Encoder class, plus the **aligned lists** feature for experimenting with multiple chunk configurations.

## Features to Test

1. **max_chunk_tokens**: Greedy token-based chunking that accumulates sentences until token limit is reached
2. **split_long_sents**: How to handle sentences that exceed the token limit
   - `True`: Split long sentences at token boundaries
   - `False`: Keep sentences intact even if they exceed the limit
3. **Combination with max_chunk_sents**: "At most N sentences AND at most M tokens"
4. **Aligned lists (NEW!)**: Pass lists to both parameters for multiple configurations
   - `max_chunk_sents=[1,2,3]` and `max_chunk_tokens=[64,128,256]` creates 3 configs
   - DataFrames include `max_chunk_sents` and `max_chunk_tokens` columns for filtering

In [None]:
import numpy as np
import polars as pl
from datasets import load_dataset

from afterthoughts import Encoder

# Load IMDB dataset
print("Loading IMDB dataset...")
dataset = load_dataset("imdb", split="test[:100]")  # Just 100 samples for testing
docs = dataset["text"]
labels = dataset["label"]

print(f"Loaded {len(docs)} reviews")
print(f"\nFirst review preview (first 500 chars):\n{docs[0][:500]}...")

## Initialize Encoder

Using a small model for fast testing.

In [None]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
encoder = Encoder(model_name, normalize=True)

print(f"Model: {model_name}")
print(f"Max length: {encoder.tokenizer.model_max_length}")
print(f"Device: {encoder.device}")

## Test 1: Baseline - Sentence-based chunking (original behavior)

No token limit, just 2 sentences per chunk.

In [None]:
print("Test 1: Baseline sentence-based chunking (2 sentences per chunk)")
df1, embeds1 = encoder.encode(
    docs,
    max_chunk_sents=2,
    chunk_overlap_sents=0,
)

print("\nResults:")
print(f"Total chunks: {len(df1)}")
print(f"Embeddings shape: {embeds1.shape}")
print("\nDataFrame preview:")
print(df1.head(10))
print("\nChunk size distribution (num_sents):")
print(df1["num_sents"].value_counts().sort("num_sents"))

## Test 2: Token-based chunking with max_chunk_tokens only

Greedy accumulation of sentences up to 128 tokens.

In [None]:
print("Test 2: Token-based chunking (max 128 tokens, no sentence limit)")
df2, embeds2 = encoder.encode(
    docs,
    max_chunk_sents=None,  # No sentence limit
    max_chunk_tokens=128,
    chunk_overlap_sents=0,
    split_long_sents=True,
)

print("\nResults:")
print(f"Total chunks: {len(df2)}")
print(f"Embeddings shape: {embeds2.shape}")
print("\nDataFrame preview:")
print(df2.head(10))
print("\nChunk size distribution (num_sents):")
print(df2["num_sents"].value_counts().sort("num_sents"))

## Test 3: Combined limits (max_chunk_sents AND max_chunk_tokens)

At most 3 sentences AND at most 100 tokens - whichever limit is hit first.

In [None]:
print("Test 3: Combined limits (max 3 sentences AND max 100 tokens)")
df3, embeds3 = encoder.encode(
    docs,
    max_chunk_sents=3,
    max_chunk_tokens=100,
    chunk_overlap_sents=0,
    split_long_sents=True,
)

print("\nResults:")
print(f"Total chunks: {len(df3)}")
print(f"Embeddings shape: {embeds3.shape}")
print("\nDataFrame preview:")
print(df3.head(10))
print("\nChunk size distribution (num_sents):")
print(df3["num_sents"].value_counts().sort("num_sents"))

## Test 4: split_long_sents=False

Keep long sentences intact even if they exceed max_chunk_tokens.

In [None]:
print("Test 4: split_long_sents=False (keep sentences intact)")
df4, embeds4 = encoder.encode(
    docs,
    max_chunk_sents=None,
    max_chunk_tokens=128,
    chunk_overlap_sents=0,
    split_long_sents=False,  # Keep sentences intact
)

print("\nResults:")
print(f"Total chunks: {len(df4)}")
print(f"Embeddings shape: {embeds4.shape}")
print("\nDataFrame preview:")
print(df4.head(10))
print("\nChunk size distribution (num_sents):")
print(df4["num_sents"].value_counts().sort("num_sents"))

## Test 5: With chunk overlap

Test overlap with token-based chunking (must be integer overlap in sentences).

In [None]:
print("Test 5: Token-based chunking with overlap (1 sentence overlap)")
df5, embeds5 = encoder.encode(
    docs,
    max_chunk_sents=None,
    max_chunk_tokens=128,
    chunk_overlap_sents=1,  # Must be integer (number of sentences)
    split_long_sents=True,
)

print("\nResults:")
print(f"Total chunks: {len(df5)}")
print(f"Embeddings shape: {embeds5.shape}")
print("\nDataFrame preview:")
print(df5.head(10))
print("\nChunk size distribution (num_sents):")
print(df5["num_sents"].value_counts().sort("num_sents"))

## Test 6: Multiple chunk sizes (list approach)

Test multiple max_chunk_sents values to extract different granularities simultaneously.
**NEW**: DataFrame now includes `max_chunk_sents` and `max_chunk_tokens` columns to track which configuration produced each chunk!

In [None]:
print("Test 6: Multiple chunk sizes (1, 2, 3 sentences)")
df6, embeds6 = encoder.encode(
    docs,
    max_chunk_sents=[1, 2, 3],  # Extract 1-sent, 2-sent, and 3-sent chunks
    chunk_overlap_sents=0,
)

print("\nResults:")
print(f"Total chunks: {len(df6)}")
print(f"Embeddings shape: {embeds6.shape}")
print("\nDataFrame preview:")
print(df6.head(15))
print("\nChunk size distribution (num_sents):")
print(df6["num_sents"].value_counts().sort("num_sents"))

print("\n✅ NEW: The DataFrame now has 'max_chunk_sents' and 'max_chunk_tokens' columns!")
print("You can filter by configuration:")
print("\nChunks with max_chunk_sents=1:")
# Cast to int for comparison since the column is Object dtype (can contain None)
print(df6.filter(pl.col("max_chunk_sents").cast(pl.Int64, strict=False) == 1).head(3))
print("\nChunks with max_chunk_sents=3:")
print(df6.filter(pl.col("max_chunk_sents").cast(pl.Int64, strict=False) == 3).head(3))

In [None]:
print("Test 6.5: Aligned lists (NEW!)")
print("Testing: max_chunk_sents=[1, 2] and max_chunk_tokens=[64, 128]")
print("Expected: 2 configs (1,64) and (2,128) - NOT 4 configs!")

df6_5, embeds6_5 = encoder.encode(
    docs[:20],  # Use fewer docs for speed
    max_chunk_sents=[1, 2],
    max_chunk_tokens=[64, 128],  # Same length - creates aligned pairs!
    chunk_overlap_sents=0,
    split_long_sents=True,
)

print("\nResults:")
print(f"Total chunks: {len(df6_5)}")
print(f"Embeddings shape: {embeds6_5.shape}")

# Show unique configurations
print("\n✅ Unique configurations (max_chunk_sents, max_chunk_tokens):")
configs = df6_5.select(["max_chunk_sents", "max_chunk_tokens"]).unique()
print(configs)

# Count chunks per configuration
print("\n✅ Chunk count per configuration:")
config_counts = (
    df6_5.group_by(["max_chunk_sents", "max_chunk_tokens"])
    .agg(pl.count().alias("count"))
    .sort(["max_chunk_sents", "max_chunk_tokens"])
)
print(config_counts)

print("\n✅ Filter chunks by specific configuration:")
print("\nChunks with config (1, 64):")
# Cast for comparison since columns are Object dtype
filtered_1_64 = df6_5.filter(
    (pl.col("max_chunk_sents").cast(pl.Int64, strict=False) == 1)
    & (pl.col("max_chunk_tokens").cast(pl.Int64, strict=False) == 64)
)
print(filtered_1_64.head(3))

print("\nChunks with config (2, 128):")
filtered_2_128 = df6_5.filter(
    (pl.col("max_chunk_sents").cast(pl.Int64, strict=False) == 2)
    & (pl.col("max_chunk_tokens").cast(pl.Int64, strict=False) == 128)
)
print(filtered_2_128.head(3))

## Test 6.5: Aligned Lists (NEW Feature!)

Test the new aligned lists feature: pass lists to BOTH `max_chunk_sents` and `max_chunk_tokens`.
When both are lists, they must have the same length and are processed as **aligned pairs** (NOT cartesian product).

Example: `[1,2,3]` x `[64,128,256]` creates 3 configs: (1,64), (2,128), (3,256)

## Test 7: Inspect specific chunks

Look at actual chunk text to verify behavior.

In [None]:
print("Test 7: Inspect chunk text from token-based chunking")

# Get chunks from first document
doc0_chunks = df2.filter(pl.col("document_idx") == 0)

print(f"\nDocument 0 has {len(doc0_chunks)} chunks")
print("\nFirst 5 chunks:")
for i, row in enumerate(doc0_chunks.head(5).iter_rows(named=True)):
    print(
        f"\n--- Chunk {i} (document_idx={row['document_idx']}, chunk_idx={row['chunk_idx']}, num_sents={row['num_sents']}) ---"
    )
    print(f"{row['chunk'][:300]}...")  # First 300 chars

    # Count tokens to verify limit
    token_ids = encoder.tokenizer.encode(row["chunk"], add_special_tokens=False)
    print(f"Token count: {len(token_ids)}")

## Test 8: Verify edge cases

Test with very small token limits to trigger sentence splitting.

In [None]:
print("Test 8: Very small token limit (30 tokens) to trigger splitting")
df8, embeds8 = encoder.encode(
    docs[:10],  # Just 10 docs for speed
    max_chunk_sents=None,
    max_chunk_tokens=30,
    chunk_overlap_sents=0,
    split_long_sents=True,
)

print("\nResults:")
print(f"Total chunks: {len(df8)}")
print(f"Embeddings shape: {embeds8.shape}")
print("\nChunk size distribution (num_sents):")
print(df8["num_sents"].value_counts().sort("num_sents"))

# Show some example chunks
print("\nExample chunks (first 5):")
for i, row in enumerate(df8.head(5).iter_rows(named=True)):
    token_count = len(encoder.tokenizer.encode(row["chunk"], add_special_tokens=False))
    print(f"\nChunk {i}: {row['num_sents']} sents, {token_count} tokens")
    print(f"  {row['chunk'][:150]}...")

## Test 9: Semantic search with token-based chunks

Test query encoding and similarity search with the token-based chunks.

In [None]:
print("Test 9: Semantic search with token-based chunks")

# Use token-based chunks from Test 2
queries = [
    "great acting and cinematography",
    "terrible plot and boring story",
    "amazing special effects",
]

query_embeds = encoder.encode_queries(queries)
print(f"Query embeddings shape: {query_embeds.shape}")

# Compute similarities (cosine similarity via dot product since normalized)
similarities = query_embeds @ embeds2.T
print(f"Similarities shape: {similarities.shape}")

# Find top 3 chunks for each query
for i, query in enumerate(queries):
    print(f"\n{'='*80}")
    print(f"Query: '{query}'")
    print(f"{'='*80}")

    # Get top 3 indices
    top_k = 3
    top_indices = np.argsort(similarities[i])[::-1][:top_k]

    for rank, idx in enumerate(top_indices, 1):
        chunk_row = df2[idx]
        similarity = similarities[i, idx]
        print(f"\nRank {rank} (similarity: {similarity:.4f}):")
        print(f"  Document: {chunk_row['document_idx'][0]}")
        print(f"  Chunk: {chunk_row['chunk_idx'][0]}")
        print(
            f"  Sentiment: {'Positive' if labels[chunk_row['document_idx'][0]] == 1 else 'Negative'}"
        )
        print(f"  Text: {chunk_row['chunk'][0][:200]}...")

## Summary Statistics

Compare different chunking strategies.

In [None]:
print("\n" + "=" * 80)
print("SUMMARY COMPARISON")
print("=" * 80)

results = [
    ("Test 1: Sentence-based (2 sents)", df1),
    ("Test 2: Token-based (128 tokens)", df2),
    ("Test 3: Combined (3 sents & 100 tokens)", df3),
    ("Test 4: Token-based (no split)", df4),
    ("Test 5: Token-based (with overlap)", df5),
]

for name, df in results:
    print(f"\n{name}:")
    print(f"  Total chunks: {len(df)}")
    print(f"  Chunks per document (avg): {len(df) / len(docs):.2f}")
    print(f"  Sentences per chunk (avg): {df['num_sents'].mean():.2f}")
    print(f"  Unique documents: {df['document_idx'].n_unique()}")

## Error Validation Tests

Test that proper errors are raised for invalid parameters.

In [None]:
print("Test: Error handling for max_chunk_tokens > max_length")

try:
    df_error, _ = encoder.encode(
        docs[:5],
        max_chunk_tokens=1000,  # Exceeds model max_length (512)
        max_length=512,
    )
    print("ERROR: Should have raised ValueError!")
except ValueError as e:
    print(f"✓ Correctly raised ValueError: {e}")

print("\nTest: Error handling for float chunk_overlap_sents with max_chunk_tokens")
try:
    df_error, _ = encoder.encode(
        docs[:5],
        max_chunk_tokens=128,
        chunk_overlap_sents=0.5,  # Float not allowed with token-based chunking
    )
    print("ERROR: Should have raised TypeError!")
except TypeError as e:
    print(f"✓ Correctly raised TypeError: {e}")

print("\nTest: Error handling for max_chunk_sents=None without max_chunk_tokens")
try:
    df_error, _ = encoder.encode(
        docs[:5],
        max_chunk_sents=None,  # None only valid with max_chunk_tokens
        max_chunk_tokens=None,
    )
    print("ERROR: Should have raised ValueError!")
except ValueError as e:
    print(f"✓ Correctly raised ValueError: {e}")

print("\nTest: Aligned lists with mismatched lengths (should fail)")
try:
    df_error, _ = encoder.encode(
        docs[:5],
        max_chunk_sents=[1, 2, 3],  # 3 items
        max_chunk_tokens=[64, 128],  # 2 items - mismatch!
    )
    print("ERROR: Should have raised ValueError!")
except ValueError as e:
    print(f"✓ Correctly raised ValueError: {e}")

## Done!

All tests completed. Review the outputs above to verify the `max_chunk_tokens` and `split_long_sents` features are working correctly.