In [2]:
import librosa
import numpy as np
from datasets import Dataset, load_from_disk
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from difflib import SequenceMatcher
import re

cslu = load_from_disk("../data/cslu_kids.ds")

In [4]:
device = "cuda:0"

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, low_cpu_mem_usage=True,

)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    chunk_length_s=30,
    batch_size=1,  # batch size for inference - set based on your device
    device=device,
)

Device set to use cuda:0


In [6]:
def split_audio_by_timestamps(audio_array, sampling_rate, chunks):
    """
    Split audio array into segments based on timestamp chunks.
    
    Args:
        audio_array: numpy array of audio samples
        sampling_rate: sampling rate of audio
        chunks: list of dicts with 'text' and 'timestamp' keys
    
    Returns:
        list of (audio_segment, text, start_time, end_time) tuples
    """
    segments = []
    
    for chunk in chunks:
        text = chunk['text'].strip()
        timestamp = chunk['timestamp']
        
        if timestamp[0] is None or timestamp[1] is None:
            # Skip chunks with invalid timestamps
            continue
            
        start_time, end_time = timestamp
        
        # Convert time to sample indices
        start_sample = int(start_time * sampling_rate)
        end_sample = int(end_time * sampling_rate)
        
        # Extract audio segment
        audio_segment = audio_array[start_sample:end_sample]
        
        # Skip very short segments (< 0.1 seconds)
        if len(audio_segment) < sampling_rate * 0.1:
            continue
            
        segments.append({
            'audio': audio_segment,
            'text': text,
            'start_time': start_time,
            'end_time': end_time,
            'duration': end_time - start_time
        })
    
    return segments

def extract_words_from_transcript(transcript):
    """
    Extract words from the original CSLU transcript, removing annotations.
    """
    # Remove annotations in angle brackets
    text = re.sub(r'<[^>]*>', '', transcript)
    
    # Remove false starts (words ending with *)
    text = re.sub(r'\S*\*', '', text)
    
    # Split into words and clean up
    words = text.split()
    words = [word.strip() for word in words if word.strip()]
    
    return words

def align_segments_with_transcript(segments, original_transcript):
    """
    Align Whisper segments with the original transcript to create training pairs.
    
    Args:
        segments: list of audio segments from split_audio_by_timestamps
        original_transcript: original CSLU transcript string
    
    Returns:
        list of aligned (audio_segment, gold_text) pairs
    """
    # Extract words from original transcript
    gold_words = extract_words_from_transcript(original_transcript)
    
    # Extract words from Whisper segments
    whisper_words = []
    for segment in segments:
        words = segment['text'].split()
        whisper_words.extend(words)
    
    # Normalize for comparison (lowercase, remove punctuation)
    def normalize_word(word):
        return re.sub(r'[^\w]', '', word.lower())
    
    gold_normalized = [normalize_word(w) for w in gold_words]
    whisper_normalized = [normalize_word(w) for w in whisper_words]
    
    # Use sequence matching to align
    matcher = SequenceMatcher(None, whisper_normalized, gold_normalized)
    matches = matcher.get_matching_blocks()
    
    aligned_pairs = []
    segment_idx = 0
    word_idx_in_segment = 0
    
    for match in matches:
        whisper_start, gold_start, length = match.a, match.b, match.size
        
        if length == 0:
            continue
        
        # Find which segments these words belong to
        current_whisper_word = 0
        
        for seg_i, segment in enumerate(segments):
            seg_words = segment['text'].split()
            
            if current_whisper_word <= whisper_start < current_whisper_word + len(seg_words):
                # This segment contains part of the match
                
                # Calculate which words in this segment are part of the match
                seg_match_start = max(0, whisper_start - current_whisper_word)
                seg_match_end = min(len(seg_words), whisper_start + length - current_whisper_word)
                
                if seg_match_end > seg_match_start:
                    # Get corresponding gold words
                    gold_match_start = gold_start + (seg_match_start - (whisper_start - current_whisper_word))
                    gold_match_end = gold_match_start + (seg_match_end - seg_match_start)
                    
                    # Ensure we don't go out of bounds
                    gold_match_end = min(gold_match_end, len(gold_words))
                    
                    if gold_match_start < len(gold_words) and gold_match_end > gold_match_start:
                        gold_text = ' '.join(gold_words[gold_match_start:gold_match_end])
                        
                        aligned_pairs.append({
                            'audio': segment['audio'],
                            'whisper_text': ' '.join(seg_words[seg_match_start:seg_match_end]),
                            'gold_text': gold_text,
                            'start_time': segment['start_time'],
                            'end_time': segment['end_time'],
                            'duration': segment['duration']
                        })
            
            current_whisper_word += len(seg_words)
            
            if current_whisper_word > whisper_start + length:
                break
    
    return aligned_pairs

def create_training_dataset_from_aligned_pairs(aligned_pairs, sampling_rate=16000):
    """
    Create a HuggingFace dataset from aligned audio-text pairs.
    """
    dataset_dict = {
        'audio': [],
        'sentence': [],
        'whisper_text': [],
        'duration': []
    }
    
    for pair in aligned_pairs:
        # Create audio dict in HuggingFace format
        audio_dict = {
            'array': pair['audio'],
            'sampling_rate': sampling_rate
        }
        
        dataset_dict['audio'].append(audio_dict)
        dataset_dict['sentence'].append(pair['gold_text'])
        dataset_dict['whisper_text'].append(pair['whisper_text'])
        dataset_dict['duration'].append(pair['duration'])
    
    return Dataset.from_dict(dataset_dict)

def process_cslu_sample_with_timestamps(sample, pipe):
    """
    Process a single CSLU sample to create split training data.
    
    Args:
        sample: single sample from CSLU dataset
        pipe: Whisper pipeline with return_timestamps=True capability
    
    Returns:
        Dataset with split audio segments and aligned transcripts
    """
    # Get timestamped output from Whisper
    audio = sample['audio']
    chunks = pipe(audio, return_timestamps="word")["chunks"]
    
    # Split audio based on timestamps
    segments = split_audio_by_timestamps(
        audio['array'], 
        audio['sampling_rate'], 
        chunks
    )
    
    # Align with original transcript
    aligned_pairs = align_segments_with_transcript(
        segments, 
        sample['sentence']
    )
    
    # Create dataset
    if aligned_pairs:
        return create_training_dataset_from_aligned_pairs(
            aligned_pairs, 
            audio['sampling_rate']
        )
    else:
        return None

def create_split_dataset_from_cslu(cslu_dataset, pipe, max_samples=None):
    """
    Process entire CSLU dataset to create split training data.
    
    Args:
        cslu_dataset: CSLU dataset
        pipe: Whisper pipeline
        max_samples: maximum number of samples to process (for testing)
    
    Returns:
        Combined dataset with all split segments
    """
    all_datasets = []
    
    # Process each sample
    num_samples = len(cslu_dataset) if max_samples is None else min(max_samples, len(cslu_dataset))
    
    print(f"Processing {num_samples} samples...")
    
    for i in range(num_samples):
        if i % 10 == 0:
            print(f"Processing sample {i}/{num_samples}")
        
        try:
            sample = cslu_dataset[i]
            split_dataset = process_cslu_sample_with_timestamps(sample, pipe)
            
            if split_dataset is not None and len(split_dataset) > 0:
                all_datasets.append(split_dataset)
                
        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            continue
    
    if all_datasets:
        # Concatenate all datasets
        from datasets import concatenate_datasets
        combined_dataset = concatenate_datasets(all_datasets)
        
        print(f"Created {len(combined_dataset)} training segments from {num_samples} original samples")
        print(f"Average segments per sample: {len(combined_dataset) / num_samples:.2f}")
        
        return combined_dataset
    else:
        print("No valid segments created!")
        return None

# Example usage:
def create_split_cslu_dataset():
    """
    Main function to create the split CSLU dataset.
    """
    print("Creating split CSLU dataset...")
    
    # Create split dataset (start with a small subset for testing)
    split_dataset = create_split_dataset_from_cslu(
        cslu, 
        pipe, 
        max_samples=50  # Start small for testing
    )
    
    if split_dataset is not None:
        # Filter out very short segments (< 1 second)
        split_dataset = split_dataset.filter(lambda x: x['duration'] > 1.0)
        
        # Filter out very long segments (> 10 seconds) 
        split_dataset = split_dataset.filter(lambda x: x['duration'] < 10.0)
        
        print(f"Final dataset size after filtering: {len(split_dataset)}")
        
        # Show some examples
        print("\nExample segments:")
        for i in range(min(5, len(split_dataset))):
            sample = split_dataset[i]
            print(f"Segment {i+1}:")
            print(f"  Duration: {sample['duration']:.2f}s")
            print(f"  Whisper: {sample['whisper_text']}")
            print(f"  Gold: {sample['sentence']}")
            print()
        
        return split_dataset
    
    return None

In [7]:
# Run the splitting process
split_cslu_data = create_split_cslu_dataset()

Creating split CSLU dataset...
Processing 50 samples...
Processing sample 0/50


Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


Error processing sample 0: 'array'


KeyboardInterrupt: 