In [14]:
# Audio
from IPython.display import Audio
import librosa

import numpy as np

# HuggingFace
from datasets import Dataset, load_from_disk
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

# String manipulation
import re
import string
from difflib import SequenceMatcher

cslu = load_from_disk("../data/cslu_kids.ds")

In [3]:
cslu.select(range(3))["audio"]

[{'path': 'ksi3exx0.wav',
  'array': array([-0.04537964, -0.0453186 , -0.04519653, ..., -0.01287842,
         -0.01300049, -0.01263428]),
  'sampling_rate': 16000},
 {'path': 'ks93mxx0.wav',
  'array': array([-0.04711914, -0.04708862, -0.04693604, ..., -0.01290894,
         -0.0123291 , -0.01266479]),
  'sampling_rate': 16000},
 {'path': 'ks730xx0.wav',
  'array': array([-0.04574585, -0.04492188, -0.04437256, ..., -0.01333618,
         -0.01379395, -0.01345825]),
  'sampling_rate': 16000}]

In [4]:
device = "cuda:0"

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    chunk_length_s=30,
    batch_size=1,  # batch size for inference - set based on your device
    device=device,
)

Device set to use cuda:0


In [5]:
punctuation_remover = str.maketrans("", "", string.punctuation)


def normalize_transcript(text):
    # The original transcript has annotations, for example a pause is <pau>
    # Remove tags in angle brackets
    text = re.sub(r"<[^>]*>", "", text)

    # These are "false starts" in the original transcript, for example th*
    # These are ignored by ASR
    # Remove words that end with asterisks (e.g., th*)
    text = re.sub(r"\S*\*", "", text)

    # Remove all punctuation
    text = text.translate(punctuation_remover)

    # Clean up excess spaces in the original transcript or resulting from above operations
    text = re.sub(r"\s+", " ", text)

    return text.strip().lower()

In [6]:
def split_audio_by_timestamps(audio_array, sampling_rate, chunks):
    """
    Split audio array into segments based on timestamp chunks.

    Args:
        audio_array: numpy array of audio samples
        sampling_rate: sampling rate of audio
        chunks: list of dicts with 'text' and 'timestamp' keys

    Returns:
        list of (audio_segment, text, start_time, end_time) tuples
    """
    segments = []

    for chunk in chunks:
        text = chunk["text"].strip()
        timestamp = chunk["timestamp"]

        if timestamp[0] is None or timestamp[1] is None:
            # Skip chunks with invalid timestamps
            continue

        start_time, end_time = timestamp

        # Convert time to sample indices
        start_sample = int(start_time * sampling_rate)
        end_sample = int(end_time * sampling_rate)

        # Extract audio segment
        audio_segment = audio_array[start_sample:end_sample]

        # Skip very short segments (< 0.1 seconds)
        if len(audio_segment) < sampling_rate * 0.1:
            continue

        segments.append(
            {
                "audio": audio_segment,
                "text": text,
                "start_time": start_time,
                "end_time": end_time,
                "duration": end_time - start_time,
            }
        )

    return segments


def find_anchor_points(segments, original_transcript):
    """
    Find high-confidence alignment points between Whisper segments and gold transcript.

    Args:
        segments: list of audio segments from split_audio_by_timestamps
        original_transcript: original CSLU transcript string

    Returns:
        list of anchor points with audio timestamps and transcript positions
    """
    from difflib import SequenceMatcher

    # Extract words from original transcript
    gold_words = normalize_transcript(original_transcript).split()

    # Extract words from Whisper segments with their timing info
    whisper_words = []
    word_to_segment = []  # Maps word index to segment info

    for segment in segments:
        words = normalize_transcript(segment["text"]).split()
        whisper_words.extend(words)

        # Map each word to its segment
        for _ in words:
            word_to_segment.append(
                {
                    "segment": segment,
                    "word_idx_in_segment": len(word_to_segment) - len(whisper_words),
                }
            )

    # Use sequence matching to find aligned regions
    matcher = SequenceMatcher(None, whisper_words, gold_words)
    matches = matcher.get_matching_blocks()

    anchor_points = []

    for match in matches:
        whisper_start, gold_start, length = match.a, match.b, match.size

        if length == 0:
            continue

        # Only use matches of reasonable length as anchors
        if length >= 3:  # At least 3 consecutive matching words
            whisper_end = whisper_start + length
            gold_end = gold_start + length

            # Get timing from the segments
            start_segment = word_to_segment[whisper_start]["segment"]
            end_segment = word_to_segment[whisper_end - 1]["segment"]

            anchor_points.append(
                {
                    "whisper_word_start": whisper_start,
                    "whisper_word_end": whisper_end,
                    "gold_word_start": gold_start,
                    "gold_word_end": gold_end,
                    "audio_start_time": start_segment["start_time"],
                    "audio_end_time": end_segment["end_time"],
                    "confidence": length,  # Longer matches = higher confidence
                }
            )

    # Sort by audio start time
    anchor_points.sort(key=lambda x: x["audio_start_time"])

    return anchor_points


def create_training_chunks(
    audio_array,
    sampling_rate,
    original_transcript,
    anchor_points,
    max_duration=30.0,
    min_duration=5.0,
    target_duration=15.0,
):
    """
    Create training chunks using the middle of anchor points as boundaries, preferring longer chunks.

    Args:
        audio_array: numpy array of audio samples
        sampling_rate: sampling rate of audio
        original_transcript: original CSLU transcript string
        anchor_points: list of anchor points from find_anchor_points
        max_duration: maximum duration for each chunk in seconds
        min_duration: minimum preferred duration for chunks in seconds
        target_duration: target duration to aim for when possible

    Returns:
        list of training chunks with audio and transcript pairs
    """
    gold_words = normalize_transcript(original_transcript).split()
    total_audio_duration = len(audio_array) / sampling_rate

    training_chunks = []

    # Filter anchor points and only use well-spaced ones to avoid tiny chunks
    filtered_anchors = []
    if anchor_points:
        filtered_anchors.append(anchor_points[0])

        for anchor in anchor_points[1:]:
            last_anchor = filtered_anchors[-1]
            # Only add anchor if it's far enough from the previous one
            if (
                anchor["audio_start_time"] - last_anchor["audio_end_time"]
                > min_duration
            ):
                filtered_anchors.append(anchor)

    # Add boundaries at start and end
    boundaries = [{"audio_time": 0.0, "transcript_word_idx": 0}]

    # Add the MIDDLE of each filtered anchor point as a boundary
    for anchor in filtered_anchors:
        # Calculate middle of the anchor point
        mid_audio_time = (anchor["audio_start_time"] + anchor["audio_end_time"]) / 2
        mid_word_idx = (anchor["gold_word_start"] + anchor["gold_word_end"]) // 2

        boundaries.append(
            {"audio_time": mid_audio_time, "transcript_word_idx": mid_word_idx}
        )

    # Add final boundary
    boundaries.append(
        {"audio_time": total_audio_duration, "transcript_word_idx": len(gold_words)}
    )

    # Remove duplicates and sort
    unique_boundaries = []
    seen_times = set()
    for boundary in sorted(boundaries, key=lambda x: x["audio_time"]):
        if boundary["audio_time"] not in seen_times:
            unique_boundaries.append(boundary)
            seen_times.add(boundary["audio_time"])

    # Merge small adjacent segments to create longer chunks
    merged_boundaries = [unique_boundaries[0]]

    for i in range(1, len(unique_boundaries)):
        current_boundary = unique_boundaries[i]
        last_boundary = merged_boundaries[-1]

        segment_duration = current_boundary["audio_time"] - last_boundary["audio_time"]

        # If this segment would be too short, skip this boundary (merge with next segment)
        if segment_duration < min_duration and i < len(unique_boundaries) - 1:
            continue

        merged_boundaries.append(current_boundary)

    # Create chunks between merged boundaries
    for i in range(len(merged_boundaries) - 1):
        start_boundary = merged_boundaries[i]
        end_boundary = merged_boundaries[i + 1]

        start_time = start_boundary["audio_time"]
        end_time = end_boundary["audio_time"]
        duration = end_time - start_time

        # Skip very short segments that couldn't be merged
        if duration < 1.0:
            continue

        # If segment is longer than max_duration, split it intelligently
        if duration > max_duration:
            # Calculate number of sub-chunks, preferring target_duration
            num_subchunks = max(2, int(np.ceil(duration / target_duration)))
            subchunk_duration = duration / num_subchunks

            # Calculate word indices for sub-chunks (approximate)
            start_word_idx = start_boundary["transcript_word_idx"]
            end_word_idx = end_boundary["transcript_word_idx"]
            total_words = end_word_idx - start_word_idx

            for j in range(num_subchunks):
                sub_start_time = start_time + j * subchunk_duration
                sub_end_time = start_time + (j + 1) * subchunk_duration

                # Approximate word boundaries for sub-chunks
                sub_start_word = start_word_idx + int(j * total_words / num_subchunks)
                sub_end_word = start_word_idx + int(
                    (j + 1) * total_words / num_subchunks
                )

                # Extract audio and text
                start_sample = int(sub_start_time * sampling_rate)
                end_sample = int(sub_end_time * sampling_rate)
                audio_chunk = audio_array[start_sample:end_sample]

                if sub_end_word > sub_start_word:
                    text_chunk = " ".join(gold_words[sub_start_word:sub_end_word])

                    training_chunks.append(
                        {
                            "audio": audio_chunk,
                            "text": text_chunk,
                            "start_time": sub_start_time,
                            "end_time": sub_end_time,
                            "duration": sub_end_time - sub_start_time,
                            "word_start_idx": sub_start_word,
                            "word_end_idx": sub_end_word,
                        }
                    )
        else:
            # Use the segment as-is
            start_sample = int(start_time * sampling_rate)
            end_sample = int(end_time * sampling_rate)
            audio_chunk = audio_array[start_sample:end_sample]

            start_word_idx = start_boundary["transcript_word_idx"]
            end_word_idx = end_boundary["transcript_word_idx"]

            if end_word_idx > start_word_idx:
                text_chunk = " ".join(gold_words[start_word_idx:end_word_idx])

                training_chunks.append(
                    {
                        "audio": audio_chunk,
                        "text": text_chunk,
                        "start_time": start_time,
                        "end_time": end_time,
                        "duration": duration,
                        "word_start_idx": start_word_idx,
                        "word_end_idx": end_word_idx,
                    }
                )

    return training_chunks


def process_audio_for_training(
    audio_array,
    sampling_rate,
    whisper_chunks,
    original_transcript,
    max_duration=30.0,
    min_duration=5.0,
    target_duration=15.0,
    do_print=False,
):
    """
    Complete pipeline to process audio for training with preference for longer chunks.

    Args:
        audio_array: numpy array of audio samples
        sampling_rate: sampling rate of audio
        whisper_chunks: Whisper output chunks with timestamps
        original_transcript: gold standard transcript
        max_duration: maximum duration for training chunks
        min_duration: minimum preferred duration for chunks
        target_duration: target duration to aim for when splitting long segments

    Returns:
        list of training-ready audio-text pairs
    """
    # Step 1: Split audio by Whisper timestamps
    segments = split_audio_by_timestamps(audio_array, sampling_rate, whisper_chunks)

    # Step 2: Find high-confidence anchor points
    anchor_points = find_anchor_points(segments, original_transcript)

    if do_print:
        print(f"Found {len(anchor_points)} anchor points")

    # Step 3: Create training chunks using anchor points as boundaries
    training_chunks = create_training_chunks(
        audio_array,
        sampling_rate,
        original_transcript,
        anchor_points,
        max_duration,
        min_duration,
        target_duration,
    )
    if do_print:
        print(f"Created {len(training_chunks)} training chunks")
        print(
            f"Average chunk duration: {np.mean([chunk['duration'] for chunk in training_chunks]):.1f}s"
        )
        print(
            f"Chunk duration range: {np.min([chunk['duration'] for chunk in training_chunks]):.1f}s - {np.max([chunk['duration'] for chunk in training_chunks]):.1f}s"
        )

    return training_chunks

## Testing

In [7]:
audio = cslu[0]["audio"]
audio_array = audio["array"]
sampling_rate = audio["sampling_rate"]
sentence = cslu[0]["sentence"]

chunks = pipe(audio, return_timestamps=True, chunk_length_s=20)["chunks"]

segments = split_audio_by_timestamps(audio_array, sampling_rate, chunks)

Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [8]:
display(Audio(data=audio_array, rate=sampling_rate))

In [9]:
normalize_transcript(cslu[0]["sentence"])

'a b c d e f g h i j k l m n o p q r s t u v w x y and z my family she went to go pick up my little sister and shes gonna come tomorrow shes gonna come at eleven yeah okay clean my room and then when im done i get to play with my friend brittney we go over to her house and we play barbies and we uhm we ride our bikes after were done and then we eat some ice cream i have four sisters ones fifteen thirteen and ten and ones five yeah theyre nice and they let me uhm watch tv in their room and uhm and she when sometimes when i do a little bit of chores she gives me a dollar'

In [10]:
segments[0]

{'audio': array([0.11987305, 0.12039185, 0.12045288, ..., 0.1451416 , 0.14727783,
        0.14804077]),
 'text': 'A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, and Z.',
 'start_time': 0.0,
 'end_time': 14.08,
 'duration': 14.08}

In [11]:
anchor_points = find_anchor_points(segments, sentence)
create_training_chunks(audio_array, sampling_rate, sentence, anchor_points)

[{'audio': array([0.11987305, 0.12039185, 0.12045288, ..., 0.14581299, 0.14501953,
         0.14559937]),
  'text': 'a b c d e f g h i j k l m n o p q r s',
  'start_time': 0.0,
  'end_time': 14.165,
  'duration': 14.165,
  'word_start_idx': 0,
  'word_end_idx': 19},
 {'audio': array([0.14572144, 0.14562988, 0.14590454, ..., 0.16622925, 0.16622925,
         0.16964722]),
  'text': 't u v w x y and z my family she went to',
  'start_time': 14.165,
  'end_time': 26.96666666666667,
  'duration': 12.80166666666667,
  'word_start_idx': 19,
  'word_end_idx': 32},
 {'audio': array([0.17526245, 0.18057251, 0.18630981, ..., 0.14389038, 0.14401245,
         0.14370728]),
  'text': 'go pick up my little sister and shes gonna come tomorrow shes gonna',
  'start_time': 26.96666666666667,
  'end_time': 39.76833333333333,
  'duration': 12.801666666666662,
  'word_start_idx': 32,
  'word_end_idx': 45},
 {'audio': array([0.14349365, 0.14437866, 0.14483643, ..., 0.14379883, 0.14434814,
         0.144073

In [19]:
for sample in result:
    display(Audio(data=sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))

## Process with Whisper

In [7]:
# Add whisper chunks
def add_whisper_transcription(example):
    result = pipe(example["audio"], return_timestamps=True, chunk_length_s=20)
    example["whisper_chunks"] = result["chunks"]
    return example


# Apply to dataset
cslu = cslu.map(add_whisper_transcription, num_proc=1)

Map:   0%|          | 0/1101 [00:00<?, ? examples/s]

Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make s

In [13]:
# Save dataset
output_path = "../data/cslu_kids_with_chunks.ds"
cslu.save_to_disk(output_path)
print(f"\nDataset saved to: {output_path}")

Saving the dataset (0/1 shards):   0%|          | 0/1101 [00:00<?, ? examples/s]


Dataset saved to: ../data/cslu_kids_with_chunks.ds


## Construct Segmented Dataset

In [11]:
cslu2 = load_from_disk("../data/cslu_kids_with_chunks.ds")

In [15]:
cslu = cslu.add_column("whisper_chunks", cslu2["whisper_chunks"])
cslu

Dataset({
    features: ['audio', 'sentence', 'grade', 'id', 'whisper_chunks'],
    num_rows: 1101
})

In [30]:
def split_one(example, max_duration=30.0):
    """
    Process a single example from the dataset and return multiple chunks.

    Args:
        example: Single example from HF dataset with keys like 'audio', 'whisper_chunks', 'gold_transcript'
        max_duration: Maximum duration for each chunk

    Returns:
        Dictionary with lists of segmented data
    """
    # Extract data from the example
    audio_array = example["audio"][0]["array"]
    sampling_rate = example["audio"][0]["sampling_rate"]
    whisper_chunks = example["whisper_chunks"][0]  # Whisper output
    gold_transcript = example["sentence"][0]  # gold standard transcript

    # Apply segmentation pipeline
    training_chunks = process_audio_for_training(
        audio_array, sampling_rate, whisper_chunks, gold_transcript, max_duration
    )

    # Convert to the format expected by HF datasets (lists of values)
    if len(training_chunks) == 0:
        raise

    return {
        "audio": [
            {"array": chunk["audio"], "sampling_rate": sampling_rate}
            for chunk in training_chunks
        ],
        "text": [chunk["text"] for chunk in training_chunks],
        "start_time": [chunk["start_time"] for chunk in training_chunks],
        "end_time": [chunk["end_time"] for chunk in training_chunks],
        "duration": [chunk["duration"] for chunk in training_chunks],
        "word_start_idx": [chunk["word_start_idx"] for chunk in training_chunks],
        "word_end_idx": [chunk["word_end_idx"] for chunk in training_chunks],
        "grade": example["grade"] * len(training_chunks),
        "id": example["id"] * len(training_chunks),
    }


def segment_dataset(
    dataset: Dataset, max_duration: float = 30.0, num_proc: int = 4
) -> Dataset:
    """
    Apply audio segmentation to an entire HuggingFace dataset.

    Args:
        dataset: HuggingFace Dataset with audio and transcript data
        max_duration: Maximum duration for each chunk
        num_proc: Number of processes for parallel processing

    Returns:
        New dataset with segmented audio chunks
    """

    # Apply the segmentation function to each example
    segmented_dataset = dataset.map(
        lambda example: split_one(example, max_duration),
        batched=True,
        batch_size=1,  # Process one example at a time
        remove_columns=dataset.column_names,  # Remove original columns
        num_proc=num_proc,  # Use multiprocessing
        desc="Segmenting audio",
    )

    # Filter out examples that produced no chunks
    segmented_dataset = segmented_dataset.filter(
        lambda example: len(example["audio"]) > 0, num_proc=num_proc
    )

    return segmented_dataset

In [31]:
cslu_segmented = segment_dataset(cslu)
cslu_segmented

Segmenting audio (num_proc=4):   0%|          | 0/1101 [00:00<?, ? examples/s]

Filter (num_proc=4):   0%|          | 0/7639 [00:00<?, ? examples/s]

Dataset({
    features: ['audio', 'grade', 'id', 'text', 'start_time', 'end_time', 'duration', 'word_start_idx', 'word_end_idx'],
    num_rows: 7639
})

In [32]:
# Save segmented dataset
output_path = "../data/cslu_kids_segmented_2.ds"
cslu_segmented.save_to_disk(output_path)
print(f"\nDataset saved to: {output_path}")

Saving the dataset (0/8 shards):   0%|          | 0/7639 [00:00<?, ? examples/s]


Dataset saved to: ../data/cslu_kids_segmented_2.ds
