# Whisper Transcript Preprocessing Pipeline
This notebook precomputes Whisper transcripts for all audio samples with three different time windows.
The cached transcripts will be used during training to avoid expensive inference.

## 1Ô∏è‚É£ Setup and Imports

In [1]:
import torch
import pandas as pd
import soundfile as sf
from pathlib import Path
from tqdm.auto import tqdm
import json
import pickle
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import warnings
warnings.filterwarnings('ignore')

print("‚úÖ Imports successful")

: 

## 2Ô∏è‚É£ Configuration

In [1]:
# Configuration
CONFIG = {
    'csv_path': '../SoccerNet_audio_labels.csv',
    'output_dir': '../transcript_cache',
    'whisper_model': 'openai/whisper-base',  # or 'small', 'medium', 'large'
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'translate_to_english': True,  # Set to True to match your encoder
    'max_samples': None,  # Set to a number for testing, None for full dataset
    'batch_size': 8,  # Process multiple samples at once for speed
}

# Window configurations: (seconds_before, seconds_after, dataset_name)
WINDOW_CONFIGS = [
    (5, 5, 'window_5s_centered'),      # ¬±5 seconds around event
    (10, 10, 'window_10s_centered'),   # ¬±10 seconds around event  
    (0, 10, 'window_10s_after')        # 10 seconds after event only
]

print(f"üìç Device: {CONFIG['device']}")
print(f"üé§ Whisper Model: {CONFIG['whisper_model']}")
print(f"üåç Translate to English: {CONFIG['translate_to_english']}")
print(f"\nüìä Window Configurations:")
for before, after, name in WINDOW_CONFIGS:
    print(f"  ‚Ä¢ {name}: {before}s before + {after}s after = {before+after}s total")

NameError: name 'torch' is not defined

## 3Ô∏è‚É£ Load Whisper Model

In [None]:
print("üîß Loading Whisper model...")

processor = WhisperProcessor.from_pretrained(CONFIG['whisper_model'])
model = WhisperForConditionalGeneration.from_pretrained(CONFIG['whisper_model'])
model = model.to(CONFIG['device'])
model.eval()

print(f"‚úÖ Whisper model loaded on {CONFIG['device']}")
print(f"üì¶ Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4Ô∏è‚É£ Helper Functions

In [None]:
def load_audio_segment(audio_path, center_time, before_sec, after_sec):
    """
    Load audio segment around an event timestamp.
    
    Returns:
        waveform, sr, actual_start, actual_end
    """
    data, sr = sf.read(audio_path, always_2d=True)
    
    # Calculate segment boundaries
    start_time = max(0, center_time - before_sec)
    end_time = center_time + after_sec
    
    start_frame = int(start_time * sr)
    end_frame = int(end_time * sr)
    end_frame = min(end_frame, len(data))
    
    # Extract mono channel
    segment = data[start_frame:end_frame, 0]
    
    actual_start = start_frame / sr
    actual_end = end_frame / sr
    
    return segment, sr, actual_start, actual_end


def transcribe_audio(waveform, sr, processor, model, device, translate=True):
    """
    Generate transcript for audio waveform.
    """
    # Process audio for Whisper (expects 16kHz)
    input_features = processor(
        waveform, 
        sampling_rate=sr, 
        return_tensors="pt"
    ).input_features.to(device)
    
    # Generate transcript
    generate_opts = {"task": "translate"} if translate else {}
    
    with torch.no_grad():
        predicted_ids = model.generate(input_features, **generate_opts)
        transcript = processor.batch_decode(
            predicted_ids, 
            skip_special_tokens=True
        )[0].strip()
    
    return transcript


def create_unique_key(audio_path, timestamp):
    """
    Create unique key for caching: {filename}_{timestamp}
    """
    path_stem = Path(audio_path).stem
    return f"{path_stem}_{timestamp:.2f}"


print("‚úÖ Helper functions defined")

## 5Ô∏è‚É£ Load Dataset

In [None]:
print(f"üìÇ Loading dataset from: {CONFIG['csv_path']}")

df = pd.read_csv(CONFIG['csv_path'])

if CONFIG['max_samples']:
    df = df.head(CONFIG['max_samples'])
    print(f"‚ö†Ô∏è  Limited to {CONFIG['max_samples']} samples for testing")

print(f"üìä Total samples: {len(df)}")
print(f"\nüìã Dataset preview:")
display(df.head())

print(f"\nüè∑Ô∏è  Label distribution:")
display(df['label'].value_counts())

## 6Ô∏è‚É£ Initialize Storage

In [None]:
# Create output directory
output_path = Path(CONFIG['output_dir'])
output_path.mkdir(parents=True, exist_ok=True)

# Initialize storage for each window configuration
datasets = {}
for before_sec, after_sec, name in WINDOW_CONFIGS:
    datasets[name] = {
        'transcripts': {},
        'metadata': {},
        'config': {
            'before_sec': before_sec,
            'after_sec': after_sec,
            'total_duration': before_sec + after_sec,
            'whisper_model': CONFIG['whisper_model'],
            'translate_to_english': CONFIG['translate_to_english']
        }
    }

print(f"‚úÖ Initialized {len(datasets)} dataset containers")
print(f"üìÅ Output directory: {output_path}")

## 7Ô∏è‚É£ Process All Samples

In [None]:
print("\n" + "="*80)
print("üöÄ Starting transcript generation...")
print("="*80 + "\n")

failed_samples = []
successful = 0

# Progress bar for overall processing
pbar = tqdm(df.iterrows(), total=len(df), desc="Processing samples")

for idx, row in pbar:
    audio_path = row['audio_path']
    timestamp = row['time_seconds']
    label = row['label']
    
    # Create unique key
    key = create_unique_key(audio_path, timestamp)
    
    try:
        # Process each window configuration
        for before_sec, after_sec, name in WINDOW_CONFIGS:
            # Load audio segment
            waveform, sr, actual_start, actual_end = load_audio_segment(
                audio_path, timestamp, before_sec, after_sec
            )
            
            # Generate transcript
            transcript = transcribe_audio(
                waveform, sr, processor, model, 
                CONFIG['device'], CONFIG['translate_to_english']
            )
            
            # Store transcript and metadata
            datasets[name]['transcripts'][key] = transcript
            datasets[name]['metadata'][key] = {
                'audio_path': audio_path,
                'event_timestamp': timestamp,
                'label': label,
                'segment_start': actual_start,
                'segment_end': actual_end,
                'segment_duration': actual_end - actual_start,
                'sample_rate': sr,
                'transcript_length': len(transcript),
                'word_count': len(transcript.split())
            }
        
        successful += 1
        pbar.set_postfix({'successful': successful, 'failed': len(failed_samples)})
    
    except Exception as e:
        failed_samples.append({
            'idx': idx,
            'audio_path': audio_path,
            'timestamp': timestamp,
            'error': str(e)
        })
        pbar.set_postfix({'successful': successful, 'failed': len(failed_samples)})
        continue

pbar.close()

print(f"\n‚úÖ Processing complete!")
print(f"   Successful: {successful}/{len(df)}")
print(f"   Failed: {len(failed_samples)}/{len(df)}")

## 8Ô∏è‚É£ Save Results

In [None]:
print("\n" + "="*80)
print("üíæ Saving transcript datasets...")
print("="*80 + "\n")

for name, data in datasets.items():
    # Save as pickle (most efficient for Python)
    pickle_path = output_path / f"transcripts_{name}.pkl"
    with open(pickle_path, 'wb') as f:
        pickle.dump(data, f)
    
    file_size = pickle_path.stat().st_size / (1024 * 1024)  # MB
    print(f"‚úÖ Saved: {pickle_path.name}")
    print(f"   Transcripts: {len(data['transcripts'])}")
    print(f"   File size: {file_size:.2f} MB\n")
    
    # Also save as JSON (human-readable backup)
    json_path = output_path / f"transcripts_{name}.json"
    with open(json_path, 'w') as f:
        json.dump({
            'config': data['config'],
            'transcripts': data['transcripts'],
            'metadata': data['metadata']
        }, f, indent=2)
    print(f"üìÑ JSON backup: {json_path.name}\n")

# Save failed samples log
if failed_samples:
    failed_path = output_path / "failed_samples.json"
    with open(failed_path, 'w') as f:
        json.dump(failed_samples, f, indent=2)
    print(f"‚ö†Ô∏è  Failed samples log: {failed_path.name}")

print("\n" + "="*80)
print("üéâ All files saved successfully!")
print("="*80)

## 9Ô∏è‚É£ Generate Statistics

In [None]:
print("\nüìä Dataset Statistics:\n")

for name, data in datasets.items():
    transcripts = list(data['transcripts'].values())
    metadata = list(data['metadata'].values())
    
    transcript_lengths = [len(t) for t in transcripts]
    word_counts = [len(t.split()) for t in transcripts]
    durations = [m['segment_duration'] for m in metadata]
    
    print(f"\n{'='*60}")
    print(f"Dataset: {name}")
    print(f"{'='*60}")
    print(f"Total Samples: {len(transcripts)}")
    print(f"Avg Transcript Length: {sum(transcript_lengths)/len(transcript_lengths):.1f} chars")
    print(f"Avg Word Count: {sum(word_counts)/len(word_counts):.1f} words")
    print(f"Avg Segment Duration: {sum(durations)/len(durations):.2f}s")
    print(f"Empty Transcripts: {sum(1 for t in transcripts if not t)}")
    
    # Label distribution
    label_counts = {}
    for m in metadata:
        label = m['label']
        label_counts[label] = label_counts.get(label, 0) + 1
    
    print(f"\nLabel Distribution:")
    for label, count in sorted(label_counts.items(), key=lambda x: x[1], reverse=True):
        pct = count / len(transcripts) * 100
        print(f"  {label}: {count} ({pct:.1f}%)")
    
    # Save stats to file
    stats_path = output_path / f"stats_{name}.txt"
    with open(stats_path, 'w') as f:
        f.write(f"Transcript Dataset Statistics: {name}\n")
        f.write("="*60 + "\n")
        f.write(f"Total Samples: {len(transcripts)}\n")
        f.write(f"Avg Transcript Length: {sum(transcript_lengths)/len(transcript_lengths):.1f} chars\n")
        f.write(f"Avg Word Count: {sum(word_counts)/len(word_counts):.1f} words\n")
        f.write(f"Avg Segment Duration: {sum(durations)/len(durations):.2f}s\n")
        f.write(f"Empty Transcripts: {sum(1 for t in transcripts if not t)}\n\n")
        f.write("Label Distribution:\n")
        for label, count in sorted(label_counts.items(), key=lambda x: x[1], reverse=True):
            pct = count / len(transcripts) * 100
            f.write(f"  {label}: {count} ({pct:.1f}%)\n")

print(f"\n‚úÖ Statistics saved to individual files")

## üîü View Sample Transcripts

In [None]:
print("\nüìù Sample Transcripts:\n")

# Show first 5 transcripts from each dataset
for name, data in datasets.items():
    print(f"\n{'='*80}")
    print(f"Dataset: {name}")
    print(f"{'='*80}\n")
    
    for i, (key, transcript) in enumerate(list(data['transcripts'].items())[:5]):
        meta = data['metadata'][key]
        print(f"Sample {i+1}:")
        print(f"  Key: {key}")
        print(f"  Label: {meta['label']}")
        print(f"  Duration: {meta['segment_duration']:.2f}s")
        print(f"  Transcript: '{transcript}'")
        print()

## ‚úÖ Usage in Training

Now you can use these precomputed transcripts in your training script:

```python
# Load the transcript cache
import pickle
with open('transcript_cache/transcripts_window_10s_centered.pkl', 'rb') as f:
    transcript_data = pickle.load(f)

# In your training loop:
def get_transcript(audio_path, timestamp):
    from pathlib import Path
    key = f"{Path(audio_path).stem}_{timestamp:.2f}"
    return transcript_data['transcripts'][key]

# Use it:
transcript = get_transcript('path/to/audio.wav', 123.45)
```