In [14]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torchaudio
import torch
from datetime import timedelta
import os
import re

# Load processor and model from Hugging Face
print("🔄 Loading Whisper large-v2 from Hugging Face...")
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
model.eval()

# Speed optimization: Use half precision (Float16) if on CUDA 
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
use_half = device == "cuda"
if use_half:
    model = model.half()  # Use FP16 precision for faster inference
print(f"✅ Model loaded on {device}" + (" with FP16 precision" if use_half else ""))

# Load and preprocess audio
audio_path = "4-russian-japanese-war.mp3"
print(f"🎧 Loading audio: {audio_path}")
waveform, sample_rate = torchaudio.load(audio_path)

# Resample to 16000 Hz if needed
if sample_rate != 16000:
    print("🔄 Resampling audio to 16kHz...")
    resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
    waveform = resampler(waveform)

# Speed optimization: Process in chunks if audio is long
audio_array = waveform.squeeze().numpy()
max_length_seconds = 30  # Process 30 seconds at a time
sample_rate = 16000
audio_length_seconds = len(audio_array) / sample_rate
print(f"Audio length: {audio_length_seconds:.2f} seconds")

# Single-pass approach: Use return_timestamps right away
print("📝 Generating translation with timestamps in a single pass...")

# Function to format timestamp for SRT files
def format_timestamp(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    millisecs = int((seconds % 1) * 1000)
    return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}"

# Function to process audio chunks
def process_audio_chunk(chunk, chunk_offset=0.0):
    input_features = processor(
        chunk,
        sampling_rate=sample_rate,
        return_tensors="pt"
    ).input_features
    
    # Convert to half precision if using FP16
    if use_half:
        input_features = input_features.half()
    
    # Move to device after setting precision
    input_features = input_features.to(device)
    
    # Generate translation with timestamps in a single pass
    with torch.no_grad():
        forced_decoder_ids = processor.get_decoder_prompt_ids(language="ru", task="translate")
        predicted_ids = model.generate(
            input_features,
            return_timestamps=True,
            forced_decoder_ids=forced_decoder_ids,
            num_beams=1  # Speed optimization: Use greedy search instead of beam search
        )
    
    # Get output with timestamps
    output = processor.batch_decode(predicted_ids, return_timestamps=True, skip_special_tokens=False)[0]
    
    # Process output to extract segments
    segments = []
    
    # Check if output is already in the right format
    if isinstance(output, dict) and "chunks" in output:
        for chunk in output["chunks"]:
            # Add chunk offset to timestamps
            start_time = chunk["timestamp"][0] + chunk_offset
            end_time = chunk["timestamp"][1] + chunk_offset
            text = chunk["text"].strip()
            segments.append({
                "timestamp": (start_time, end_time),
                "text": text
            })
    else:
        # Parse timestamps from raw output
        text = processor.tokenizer.decode(predicted_ids[0], skip_special_tokens=False)
        
        # Extract timestamp tokens and text segments
        timestamp_pattern = re.compile(r'<\|(\d+\.\d+)\|>')
        timestamp_matches = re.findall(timestamp_pattern, text)
        
        # Convert to float and pair into start/end timestamps
        timestamps = []
        for i in range(0, len(timestamp_matches), 2):
            if i+1 < len(timestamp_matches):
                try:
                    # Safely convert to float and handle empty strings
                    if timestamp_matches[i] and timestamp_matches[i+1]:
                        start_time = float(timestamp_matches[i]) + chunk_offset
                        end_time = float(timestamp_matches[i+1]) + chunk_offset
                        timestamps.append((start_time, end_time))
                except (ValueError, TypeError):
                    # Skip this pair if there's any conversion issue
                    print(f"Warning: Skipping invalid timestamp pair: {timestamp_matches[i]}, {timestamp_matches[i+1]}")
                    continue
        
        # Get text segments between timestamps
        text_segments = re.split(r'<\|\d+\.\d+\|>', text)
        if text_segments and not text_segments[0].strip():
            text_segments = text_segments[1:]
        
        # Create segments by combining timestamps with text
        for i in range(min(len(timestamps), len(text_segments))):
            segments.append({
                "timestamp": timestamps[i],
                "text": text_segments[i].strip()
            })
    
    return segments

# Process audio in chunks if it's long, otherwise process it all at once
all_segments = []
if audio_length_seconds > max_length_seconds:
    print(f"Processing long audio in chunks of {max_length_seconds} seconds...")
    chunk_size = max_length_seconds * sample_rate
    
    for i in range(0, len(audio_array), int(chunk_size)):
        chunk_start_time = i / sample_rate
        print(f"Processing chunk starting at {chunk_start_time:.2f}s...")
        
        chunk_end = min(i + int(chunk_size), len(audio_array))
        audio_chunk = audio_array[i:chunk_end]
        
        # Process this chunk
        chunk_segments = process_audio_chunk(audio_chunk, chunk_offset=chunk_start_time)
        all_segments.extend(chunk_segments)
        
        print(f"  ✅ Processed chunk with {len(chunk_segments)} segments")
else:
    # Process the entire audio file at once
    all_segments = process_audio_chunk(audio_array)

# Sort segments by start time (in case they got out of order)
all_segments.sort(key=lambda x: x["timestamp"][0])

🔄 Loading Whisper large-v2 from Hugging Face...
✅ Model loaded on cuda with FP16 precision
🎧 Loading audio: 4-russian-japanese-war.mp3
🔄 Resampling audio to 16kHz...
Audio length: 0.00 seconds
📝 Generating translation with timestamps in a single pass...


In [15]:
# Write SRT file
print("💾 Writing SRT file...")
srt_file = audio_path.rsplit(".", 1)[0] + ".translated.srt"

with open(srt_file, "w", encoding="utf-8") as f:
    for i, segment in enumerate(all_segments, 1):
        if not segment["text"]:  # Skip empty segments
            continue
        start = format_timestamp(segment["timestamp"][0])
        end = format_timestamp(segment["timestamp"][1])
        text = segment["text"]
        f.write(f"{i}\n{start} --> {end}\n{text}\n\n")

print(f"✅ SRT file written to {srt_file}")

💾 Writing SRT file...
✅ SRT file written to 4-russian-japanese-war.translated.srt
