# utils

> Text processing utilities for segmentation: word counting, position mapping, and statistics

In [None]:
#| default_exp utils

In [None]:
#| export
from typing import List, Dict, Any, Set, Optional, TYPE_CHECKING

if TYPE_CHECKING:
    from cjm_transcript_segmentation.models import TextSegment

## Word Operations

In [None]:
#| export
def count_words(
    text: str  # Text to count words in
) -> int:  # Word count
    """Count the number of whitespace-delimited words in text."""
    if not text:
        return 0
    return len(text.split())

## Position Mapping

In [None]:
#| export
def word_index_to_char_position(
    text: str,  # Full text
    word_index: int  # Word index (0-based, split happens before this word)
) -> int:  # Character position for split
    """Convert a word index to the character position where a split should occur."""
    if word_index <= 0:
        return 0
    
    words = text.split()
    if word_index >= len(words):
        return len(text)
    
    # Find the character position before the word at word_index
    position = 0
    for i, word in enumerate(words):
        if i == word_index:
            break
        position += len(word)
        # Account for space after word (except for last word before split)
        if i < word_index - 1 or i < len(words) - 1:
            # Find the actual space position in text
            while position < len(text) and text[position] == ' ':
                position += 1
    
    return position

## Segment Statistics

In [None]:
#| export
def calculate_segment_stats(
    segments: List["TextSegment"]  # List of segments to analyze
) -> Dict[str, Any]:  # Statistics dictionary with total_words, total_segments
    """Calculate aggregate statistics for a list of segments."""
    total_words = sum(count_words(s.text) for s in segments)
    total_segments = len(segments)
    
    return {
        "total_words": total_words,
        "total_segments": total_segments,
    }

## Source Boundaries

In [None]:
#| export
def get_source_boundaries(
    segments: List["TextSegment"],  # Ordered list of segments
) -> Set[int]:  # Indices where source_id changes from the previous segment
    """Find indices where source_id changes between adjacent segments.
    
    A boundary at index N means segment[N].source_id differs from
    segment[N-1].source_id. Both must be non-None for a boundary to exist.
    """
    boundaries = set()
    for i in range(1, len(segments)):
        prev_sid = segments[i - 1].source_id
        curr_sid = segments[i].source_id
        if prev_sid is not None and curr_sid is not None and prev_sid != curr_sid:
            boundaries.add(i)
    return boundaries

In [None]:
#| export
def get_source_count(
    segments: List["TextSegment"],  # Ordered list of segments
) -> int:  # Number of unique non-None source_ids
    """Count the number of unique audio sources in the segment list."""
    source_ids = {s.source_id for s in segments if s.source_id is not None}
    return len(source_ids)

In [None]:
#| export
def get_source_position(
    segments: List["TextSegment"],  # Ordered list of segments
    focused_index: int,  # Index of the focused segment
) -> Optional[int]:  # 1-based position in ordered unique sources, or None
    """Get the source position (1-based) of the focused segment.
    
    Returns which source group the focused segment belongs to,
    based on order of first appearance.
    """
    if not segments or focused_index < 0 or focused_index >= len(segments):
        return None
    focused_sid = segments[focused_index].source_id
    if focused_sid is None:
        return None
    seen = []
    for s in segments:
        if s.source_id is not None and s.source_id not in seen:
            seen.append(s.source_id)
    if focused_sid in seen:
        return seen.index(focused_sid) + 1
    return None

## Tests

In [None]:
assert count_words("") == 0
assert count_words("hello") == 1
assert count_words("The art of war") == 4
print("count_words tests passed")

In [None]:
text = "The art of war is vital"

assert word_index_to_char_position(text, 0) == 0
assert word_index_to_char_position(text, 100) == len(text)
print("word_index_to_char_position tests passed")

In [None]:
from cjm_transcript_segmentation.models import TextSegment

test_segments = [
    TextSegment(index=0, text="The art of war"),
    TextSegment(index=1, text="is of vital importance"),
    TextSegment(index=2, text="to the state"),
]

stats = calculate_segment_stats(test_segments)
assert stats["total_segments"] == 3
assert stats["total_words"] == 11
print("calculate_segment_stats tests passed")

In [None]:
# Test get_source_boundaries
segs_single = [
    TextSegment(index=0, text="a", source_id="src1"),
    TextSegment(index=1, text="b", source_id="src1"),
    TextSegment(index=2, text="c", source_id="src1"),
]
assert get_source_boundaries(segs_single) == set()

segs_multi = [
    TextSegment(index=0, text="a", source_id="src1"),
    TextSegment(index=1, text="b", source_id="src1"),
    TextSegment(index=2, text="c", source_id="src2"),
    TextSegment(index=3, text="d", source_id="src2"),
    TextSegment(index=4, text="e", source_id="src3"),
]
assert get_source_boundaries(segs_multi) == {2, 4}

segs_none = [
    TextSegment(index=0, text="a", source_id="src1"),
    TextSegment(index=1, text="b", source_id=None),
    TextSegment(index=2, text="c", source_id="src2"),
]
assert get_source_boundaries(segs_none) == set()

assert get_source_boundaries([]) == set()
print("get_source_boundaries tests passed")

In [None]:
# Test get_source_count
assert get_source_count(segs_single) == 1
assert get_source_count(segs_multi) == 3
assert get_source_count(segs_none) == 2
assert get_source_count([]) == 0
print("get_source_count tests passed")

In [None]:
# Test get_source_position
assert get_source_position(segs_multi, 0) == 1  # src1
assert get_source_position(segs_multi, 1) == 1  # src1
assert get_source_position(segs_multi, 2) == 2  # src2
assert get_source_position(segs_multi, 4) == 3  # src3
assert get_source_position(segs_multi, 99) is None  # out of bounds
assert get_source_position([], 0) is None
print("get_source_position tests passed")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()