<a href="https://colab.research.google.com/github/frank-morales2020/Cloud_curious/blob/master/FT_V2TXT_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q trl

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 1. INSTALL REQUIRED DEPENDENCIES
!pip install -q mistral-common

In [None]:
!pip install --upgrade transformers -q
!pip install -U bitsandbytes>=0.46.1 -q

In [None]:
import os
os.kill(os.getpid(), 9)

Original duration: 936.28s for barackobama2004dncARXE.mp3

Original duration: 210.68s for barackobamatransitionaddress1.mp3

Original duration: 183.02s for brad_pitt_sag_2020.mp3

Original duration: 2415.33s for mandela_davos_1999.mp3

Original duration: 1801.37s for mark_carney_davos_2026.mp3

Original duration: 2585.63s for mlk_mountaintop_1968.mp3

## FINE TUNING

In [None]:
# SUPPRESS WARNINGS - Add this at the very top
import warnings
warnings.filterwarnings("ignore", message="Some matrices hidden dimension is not a multiple of 64")
warnings.filterwarnings("ignore", module="bitsandbytes")

import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import librosa
from datasets import Dataset
from dataclasses import dataclass
from transformers import (
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    BitsAndBytesConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import transformers.models.voxtral_realtime.modeling_voxtral_realtime as vox_mod

from warnings import simplefilter
simplefilter(action="ignore", category=FutureWarning)

# ==================== CONFIG ====================
MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
OUTPUT_DIR = "/content/drive/MyDrive/data/H2E_Challenge/Voxtral_FineTune"
AUDIT_PATH = "/content/drive/MyDrive/data/H2E_Challenge/H2E_Final_Performance_Audit.csv"
TARGET_SR = 16000
CHUNK_LENGTH_SEC = 30.0
MAX_TEXT_LENGTH = 448

# ==================== 1. DATASET PREPARATION ====================
df = pd.read_csv(AUDIT_PATH)
paths = [f"/content/drive/MyDrive/data/{f}" for f in df['File']]
texts = df['Transcript'].astype(str).tolist()

chunked_audios, chunked_texts = [], []
for path, text in zip(paths, texts):
    try:
        array, _ = librosa.load(path, sr=TARGET_SR, mono=True)
        max_samples = int(CHUNK_LENGTH_SEC * TARGET_SR)
        if len(array) > max_samples: array = array[:max_samples]
        chunked_audios.append({"array": array.astype(np.float32).tolist(), "sampling_rate": TARGET_SR})
        chunked_texts.append(text)
    except Exception as e: print(f"‚úó Failed {path}: {e}")

dataset = Dataset.from_dict({"audio": chunked_audios, "text": chunked_texts})

# ==================== 2. MODEL SETUP ====================
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="eager"
)

# Get model dimensions
lm_hidden_size = model.config.hidden_size
print(f"Language model hidden size: {lm_hidden_size}")

# Audio feature dimension for Voxtral
audio_feature_dim = 1280
print(f"Audio feature dimension: {audio_feature_dim}")
print(f"Setting up adapter: {audio_feature_dim} ‚Üí {lm_hidden_size}")

# Create adapter that projects audio features to match language model hidden size
class AudioFeatureAdapter(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)
    def forward(self, x):
        return self.proj(x)

# Initialize adapter with correct dimensions
model.audio_adapter = AudioFeatureAdapter(audio_feature_dim, lm_hidden_size).to(model.device).to(torch.bfloat16)

def preprocess_function(examples):
    audio_arrays = []
    for x in examples["audio"]:
        arr = np.array(x["array"], dtype=np.float32)
        audio_arrays.append(arr)

    audio_inputs = processor.feature_extractor(
        audio_arrays,
        sampling_rate=TARGET_SR, return_tensors="np", padding=True
    )

    input_features = torch.tensor(audio_inputs["input_features"], dtype=torch.bfloat16)

    text_inputs = processor.tokenizer(
        examples["text"], return_tensors="np", padding="max_length",
        truncation=True, max_length=MAX_TEXT_LENGTH
    )

    return {
        "input_features": input_features,
        "input_ids": torch.tensor(text_inputs["input_ids"], dtype=torch.long),
        "labels": torch.tensor(text_inputs["input_ids"], dtype=torch.long),
        "attention_mask": torch.tensor(text_inputs["attention_mask"], dtype=torch.long)
    }

dataset = dataset.map(preprocess_function, batched=True, batch_size=2, remove_columns=dataset.column_names)
dataset.set_format("torch")

# Prepare for k-bit training but disable gradient checkpointing
model = prepare_model_for_kbit_training(model)
model.gradient_checkpointing_disable()  # Explicitly disable gradient checkpointing

model = get_peft_model(model, LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM"
))

# ==================== 3. SIMPLIFIED PATCHED FORWARD ====================
def patched_forward(
    self, input_ids=None, input_features=None, attention_mask=None, position_ids=None,
    past_key_values=None, encoder_past_key_values=None, padding_cache=None,
    inputs_embeds=None, encoder_inputs_embeds=None, labels=None, use_cache=None,
    cache_position=None, logits_to_keep=0, num_delay_tokens=None, **kwargs
):
    # Get text embeddings
    if inputs_embeds is None:
        inputs_embeds = self.get_input_embeddings()(input_ids)

    batch_size = inputs_embeds.shape[0]
    device = inputs_embeds.device

    # Process audio if provided
    if input_features is not None:
        # Get audio features
        audio_outputs = self.audio_tower(
            input_features.to(device, dtype=torch.bfloat16),
            output_hidden_states=True,
            return_dict=True
        )

        # Project audio features
        audio_features = audio_outputs.last_hidden_state
        audio_projected = self.audio_adapter(audio_features)

        # Simple concatenation
        inputs_embeds = torch.cat([audio_projected, inputs_embeds], dim=1)

        # Update attention mask
        if attention_mask is not None:
            audio_mask = torch.ones(batch_size, audio_projected.shape[1],
                                   device=device, dtype=attention_mask.dtype)
            attention_mask = torch.cat([audio_mask, attention_mask], dim=1)

    # Time conditioning - FIXED: properly handle dimensions
    if num_delay_tokens is None:
        num_delay_tokens = self.config.default_num_delay_tokens

    # Create time tensor with correct shape for t_cond
    time_tensor = torch.full((batch_size,), num_delay_tokens, device=device, dtype=torch.long)
    t_cond = self.time_embedding(time_tensor)  # Shape: [batch_size, hidden_size]

    # Call language model
    lm_outputs = self.language_model(
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        use_cache=False,  # Disable cache to avoid issues
        t_cond=t_cond,  # Pass t_cond directly
        **kwargs
    )

    logits = lm_outputs.logits
    loss = None

    if labels is not None:
        # Create labels that ignore audio tokens
        if input_features is not None:
            audio_len = audio_projected.shape[1]
            new_labels = torch.full((batch_size, audio_len + labels.shape[1]), -100,
                                   device=device, dtype=labels.dtype)
            new_labels[:, audio_len:] = labels
            labels = new_labels

        # Shift for next token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss = nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100
        )

    return vox_mod.VoxtralRealtimeCausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=lm_outputs.past_key_values
    )

# Apply the patch
vox_mod.VoxtralRealtimeForConditionalGeneration.forward = patched_forward

# ==================== 4. TRAINING ====================
@dataclass
class SimpleCollator:
    def __call__(self, features):
        batch = {}
        for key in ["input_features", "input_ids", "labels", "attention_mask"]:
            if key in features[0]:
                tensors = [f[key] for f in features]
                if key == "input_features":
                    batch[key] = torch.stack(tensors).to(torch.bfloat16)
                else:
                    batch[key] = torch.stack(tensors)
        return batch

# Disable all caching and checkpointing in training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    max_steps=500,
    bf16=True,
    logging_steps=10,
    report_to="none",
    remove_unused_columns=False,
    save_strategy="no",
    dataloader_drop_last=False,
    gradient_checkpointing=False,  # Explicitly disable
    use_cache=False,  # Disable cache
    ddp_find_unused_parameters=False,
    prediction_loss_only=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=SimpleCollator(),
)


In [2]:
print("üöÄ Starting Training with Fixed Forward Pass...")
print(f"Model device: {model.device}")
print(f"Hidden size: {lm_hidden_size}")
trainer.train()


print("\n" + "="*50)
print("üíæ Saving fine-tuned model...")

# Create the adapter directory
final_adapter_path = os.path.join(OUTPUT_DIR, "final_adapter")
os.makedirs(final_adapter_path, exist_ok=True)

# Save the model and processor
model.save_pretrained(final_adapter_path)
processor.save_pretrained(final_adapter_path)

print(f"‚úÖ Model saved to: {final_adapter_path}")

# Verify the save
print("\nüìÅ Saved files:")
if os.path.exists(final_adapter_path):
    for file in os.listdir(final_adapter_path):
        file_path = os.path.join(final_adapter_path, file)
        size = os.path.getsize(file_path) / 1024  # Size in KB
        print(f"  - {file} ({size:.2f} KB)")
else:
    print("‚ùå Save failed - directory not found!")

print("\nüéâ Training complete! You can now use the model for inference.")

üöÄ Starting Training with Fixed Forward Pass...
Model device: cuda:0
Hidden size: 3072


Step,Training Loss
10,38.537854
20,24.828914
30,20.967784
40,20.354013
50,20.307916
60,20.243488
70,20.211427
80,20.114462
90,20.052942
100,20.009079



üíæ Saving fine-tuned model...
‚úÖ Model saved to: /content/drive/MyDrive/data/H2E_Challenge/Voxtral_FineTune/final_adapter

üìÅ Saved files:
  - tekken.json (14560.89 KB)
  - README.md (5.11 KB)
  - adapter_config.json (0.97 KB)
  - processor_config.json (0.38 KB)
  - adapter_model.safetensors (15840.23 KB)

üéâ Training complete! You can now use the model for inference.


## DEBUG-INFERENCE-CODE

In [3]:
!ls /content/drive/MyDrive/data/*.mp3

/content/drive/MyDrive/data/barackobama2004dncARXE.mp3
/content/drive/MyDrive/data/barackobamatransitionaddress1.mp3
/content/drive/MyDrive/data/brad_pitt_sag_2020.mp3
/content/drive/MyDrive/data/mandela_davos_1999.mp3
/content/drive/MyDrive/data/mark_carney_davos_2026.mp3
/content/drive/MyDrive/data/mlk_mountaintop_1968.mp3


In [4]:
!ls -ltha /content/drive/MyDrive/data/H2E_Challenge/Voxtral_FineTune/final_adapter/

total 30M
-rw------- 1 root root  384 Feb 23 05:25 processor_config.json
-rw------- 1 root root  15M Feb 23 05:25 tekken.json
-rw------- 1 root root  991 Feb 23 05:25 adapter_config.json
-rw------- 1 root root  16M Feb 23 05:25 adapter_model.safetensors
-rw------- 1 root root 5.2K Feb 23 05:25 README.md


In [None]:
import torch
import librosa
import numpy as np
import os
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from peft import PeftModel
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
import matplotlib.pyplot as plt
import soundfile as sf

# ==================== CONFIG ====================
MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ADAPTER_PATH = "/content/drive/MyDrive/data/H2E_Challenge/Voxtral_FineTune/final_adapter"
AUDIO_FILES = [
    "/content/drive/MyDrive/data/barackobama2004dncARXE.mp3",
    "/content/drive/MyDrive/data/barackobamatransitionaddress1.mp3",
    "/content/drive/MyDrive/data/brad_pitt_sag_2020.mp3",
    "/content/drive/MyDrive/data/mandela_davos_1999.mp3",
    "/content/drive/MyDrive/data/mark_carney_davos_2026.mp3",
    "/content/drive/MyDrive/data/mlk_mountaintop_1968.mp3"
]

# ==================== 1. SETUP ====================
print("üîß Setting up tokenizer and processor...")

mistral_tokenizer = MistralTokenizer.v3(is_tekken=True)
backend_tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

# ==================== 2. LOAD MODEL ====================
print("üîÑ Loading base model...")
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True
)

print("üîÑ Loading and merging LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model = model.merge_and_unload()
model.eval()

In [None]:
# ==================== 3. DEBUGGING FUNCTIONS ====================

def analyze_audio_debug(audio_path):
    """Detailed audio analysis"""
    print(f"\n  üìä AUDIO ANALYSIS for {os.path.basename(audio_path)}:")

    # Load audio
    speech, sr = librosa.load(audio_path, sr=16000)

    # Basic stats
    duration = len(speech) / sr
    print(f"     Duration: {duration:.2f} seconds")
    print(f"     Sample rate: {sr} Hz")
    print(f"     Samples: {len(speech)}")

    # Amplitude stats
    print(f"     Mean amplitude: {np.mean(np.abs(speech)):.6f}")
    print(f"     Max amplitude: {np.max(np.abs(speech)):.6f}")
    print(f"     RMS energy: {np.sqrt(np.mean(speech**2)):.6f}")

    # Check for silence/audio issues
    silence_threshold = 0.01
    silent_samples = np.sum(np.abs(speech) < silence_threshold)
    silence_ratio = silent_samples / len(speech)
    print(f"     Silence ratio (<{silence_threshold}): {silence_ratio:.2%}")

    # Check if audio might be problematic
    if np.max(np.abs(speech)) < 0.01:
        print("  ‚ö†Ô∏è  WARNING: Audio amplitude is very low!")
    if silence_ratio > 0.8:
        print("  ‚ö†Ô∏è  WARNING: Audio is mostly silence!")

    # Save a small segment for inspection (optional)
    debug_dir = "debug_audio"
    os.makedirs(debug_dir, exist_ok=True)
    debug_path = os.path.join(debug_dir, f"debug_{os.path.basename(audio_path)}.wav")
    sf.write(debug_path, speech, sr)
    print(f"     Saved debug copy to: {debug_path}")

    return speech, sr

def test_different_segments(audio_path, segment_duration=10):
    """Test different segments of the audio file"""
    print(f"\n  üîç TESTING DIFFERENT SEGMENTS:")

    speech, sr = librosa.load(audio_path, sr=16000)
    total_duration = len(speech) / sr

    # Test first 10s, middle 10s, and last 10s (if enough length)
    segments = []

    # First segment
    if total_duration >= segment_duration:
        segments.append(("first", speech[:int(segment_duration * sr)]))

    # Middle segment (if longer than 2x segment_duration)
    if total_duration >= segment_duration * 2:
        mid_start = int((total_duration/2 - segment_duration/2) * sr)
        segments.append(("middle", speech[mid_start:mid_start + int(segment_duration * sr)]))

    # Last segment
    if total_duration >= segment_duration:
        segments.append(("last", speech[-int(segment_duration * sr):]))

    results = {}
    for seg_name, seg_audio in segments:
        print(f"     Testing {seg_name} {len(seg_audio)/sr:.1f}s segment...")

        # Normalize
        seg_audio = seg_audio / (np.max(np.abs(seg_audio)) + 1e-8)

        # Process
        inputs = processor(audio=seg_audio, sampling_rate=16000, return_tensors="pt")
        input_features = inputs.input_features.to("cuda", dtype=torch.bfloat16)

        with torch.no_grad():
            generated_ids = model.generate(
                input_features=input_features,
                max_new_tokens=128,
                do_sample=False,
                num_beams=1,
                use_cache=True
            )

        transcription = processor.batch_decode(
            generated_ids,
            skip_special_tokens=True
        )[0].strip()

        results[seg_name] = transcription
        print(f"       ‚Üí '{transcription if transcription else '[EMPTY]'}'")

    return results

def transcribe_with_different_params(audio_segment):
    """Try different generation parameters"""
    print(f"\n  üéõÔ∏è  TESTING DIFFERENT PARAMETERS:")

    param_sets = [
        {"name": "Greedy", "do_sample": False, "num_beams": 1, "temperature": 1.0},
        {"name": "Beam 3", "do_sample": False, "num_beams": 3, "temperature": 1.0},
        {"name": "Sampling (temp=0.3)", "do_sample": True, "num_beams": 1, "temperature": 0.3},
        {"name": "Sampling (temp=0.7)", "do_sample": True, "num_beams": 1, "temperature": 0.7},
        {"name": "Beam + Sampling", "do_sample": True, "num_beams": 2, "temperature": 0.5},
    ]

    results = {}
    for params in param_sets:
        print(f"     Trying {params['name']}...")

        # Process
        inputs = processor(audio=audio_segment, sampling_rate=16000, return_tensors="pt")
        input_features = inputs.input_features.to("cuda", dtype=torch.bfloat16)

        with torch.no_grad():
            generated_ids = model.generate(
                input_features=input_features,
                max_new_tokens=128,
                use_cache=True,
                do_sample=params["do_sample"],
                num_beams=params["num_beams"],
                temperature=params["temperature"]
            )

        transcription = processor.batch_decode(
            generated_ids,
            skip_special_tokens=True
        )[0].strip()

        results[params["name"]] = transcription
        print(f"       ‚Üí '{transcription if transcription else '[EMPTY]'}'")

    return results

# ==================== 4. RUN DEBUGGING INFERENCE ====================
print(f"üöÄ Processing {len(AUDIO_FILES)} files with debugging...\n")

all_results = {}

for i, path in enumerate(AUDIO_FILES, 1):
    if not os.path.exists(path):
        print(f"‚ö†Ô∏è [{i}/{len(AUDIO_FILES)}] File not found: {path}")
        continue

    print(f"\n{'='*60}")
    print(f"üìÅ [{i}/{len(AUDIO_FILES)}] Processing: {os.path.basename(path)}")
    print('='*60)

    try:
        # Step 1: Analyze audio
        speech, sr = analyze_audio_debug(path)

        # Step 2: Try different segments
        segment_results = test_different_segments(path)

        # Step 3: If first segment failed, try different parameters on first 10s
        first_10s = speech[:int(10 * sr)]
        first_10s = first_10s / (np.max(np.abs(first_10s)) + 1e-8)

        param_results = transcribe_with_different_params(first_10s)

        # Store results
        all_results[os.path.basename(path)] = {
            'segment_tests': segment_results,
            'param_tests': param_results
        }

        # Summary
        print(f"\n  üìã SUMMARY for {os.path.basename(path)}:")
        successful = [v for v in segment_results.values() if v]
        if successful:
            print(f"     ‚úÖ Best result: '{max(successful, key=len)}'")
        else:
            print(f"     ‚ùå All transcriptions empty!")

    except Exception as e:
        print(f"‚ùå Error processing {path}: {str(e)}")
        import traceback
        traceback.print_exc()

    print('='*60)

# ==================== 5. FINAL SUMMARY ====================
print("\n" + "="*60)
print("üìä FINAL SUMMARY")
print("="*60)

for filename, results in all_results.items():
    print(f"\n{filename}:")

    # Check if any successful transcriptions
    all_transcriptions = []
    all_transcriptions.extend(results['segment_tests'].values())
    all_transcriptions.extend(results['param_tests'].values())

    successful = [t for t in all_transcriptions if t]

    if successful:
        print(f"  ‚úÖ SUCCESS - Found {len(successful)} non-empty transcriptions")
        print(f"     Best: '{max(successful, key=len)}'")
    else:
        print(f"  ‚ùå FAILED - All transcriptions empty")

print("\n‚úÖ Debugging complete! Check the output to identify issues.")

üöÄ Processing 6 files with debugging...


üìÅ [1/6] Processing: barackobama2004dncARXE.mp3

  üìä AUDIO ANALYSIS for barackobama2004dncARXE.mp3:
     Duration: 936.28 seconds
     Sample rate: 16000 Hz
     Samples: 14980494
     Mean amplitude: 0.086645
     Max amplitude: 0.668423
     RMS energy: 0.130867
     Silence ratio (<0.01): 17.98%
     Saved debug copy to: debug_audio/debug_barackobama2004dncARXE.mp3.wav

  üîç TESTING DIFFERENT SEGMENTS:
     Testing first 10.0s segment...


`num_delay_tokens` was not provided. Falling back to `config.default_num_delay_tokens=6`. Consider preparing inputs with [`~VoxtralRealtimeProcessor.__call__`] which automatically sets this parameter.


       ‚Üí 'Thank you so much. Thank you so much. Thank you. Thank you.'
     Testing middle 10.0s segment...


## INFERENCE-GOOD

In [5]:
import torch
import librosa
import numpy as np
import os
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from peft import PeftModel
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
import json
from datetime import datetime

# ==================== CONFIG ====================
MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ADAPTER_PATH = "/content/drive/MyDrive/data/H2E_Challenge/Voxtral_FineTune/final_adapter"
AUDIO_FILES = [
    "/content/drive/MyDrive/data/barackobama2004dncARXE.mp3",
    "/content/drive/MyDrive/data/barackobamatransitionaddress1.mp3",
    "/content/drive/MyDrive/data/brad_pitt_sag_2020.mp3",
    "/content/drive/MyDrive/data/mandela_davos_1999.mp3",
    "/content/drive/MyDrive/data/mark_carney_davos_2026.mp3",
    "/content/drive/MyDrive/data/mlk_mountaintop_1968.mp3"
]

# ==================== 1. SETUP ====================
print("üîß Setting up tokenizer and processor...")

mistral_tokenizer = MistralTokenizer.v3(is_tekken=True)
backend_tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

# ==================== 2. LOAD MODEL ====================
print("üîÑ Loading base model...")
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True
)

print("üîÑ Loading and merging LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model = model.merge_and_unload()
model.eval()

# ==================== 3. ENHANCED FUNCTIONS ====================

def find_best_speech_segment(audio, sr=16000, segment_duration=20, min_energy=0.01):
    """Find the segment with highest speech energy"""

    segment_samples = int(segment_duration * sr)
    hop_samples = int(2 * sr)

    best_energy = 0
    best_segment = None
    best_start = 0
    all_segments = []

    for start in range(0, max(1, len(audio) - segment_samples), hop_samples):
        end = min(start + segment_samples, len(audio))
        segment = audio[start:end]
        energy = np.sqrt(np.mean(segment**2))

        all_segments.append((energy, start, segment))

        if energy > best_energy and energy > min_energy:
            best_energy = energy
            best_segment = segment
            best_start = start

    # Sort all segments by energy
    all_segments.sort(reverse=True)

    return best_segment, best_energy, best_start, all_segments[:5]  # Return top 5

def transcribe_segment(segment, sr=16000, segment_duration=20):
    """Transcribe a specific audio segment"""

    # Normalize
    if np.max(np.abs(segment)) > 0:
        segment = segment / np.max(np.abs(segment))

    # Prepare inputs
    inputs = processor(audio=segment, sampling_rate=sr, return_tensors="pt")
    input_features = inputs.input_features.to("cuda", dtype=torch.bfloat16)

    # Try multiple generation strategies
    strategies = [
        {"name": "Greedy", "params": {"do_sample": False, "num_beams": 1}},
        {"name": "Beam 3", "params": {"do_sample": False, "num_beams": 3}},
        {"name": "Sampling (temp=0.3)", "params": {"do_sample": True, "temperature": 0.3, "num_beams": 1}},
        {"name": "Sampling (temp=0.5)", "params": {"do_sample": True, "temperature": 0.5, "num_beams": 1}},
        {"name": "Beam+Sample", "params": {"do_sample": True, "temperature": 0.4, "num_beams": 2}},
    ]

    best_transcription = ""

    for strategy in strategies:
        with torch.no_grad():
            generated_ids = model.generate(
                input_features=input_features,
                max_new_tokens=128,
                use_cache=True,
                **strategy["params"]
            )

        transcription = processor.batch_decode(
            generated_ids,
            skip_special_tokens=True
        )[0].strip()

        if transcription and len(transcription) > len(best_transcription):
            best_transcription = transcription
            print(f"     ‚úì {strategy['name']}: {transcription[:50]}...")

    return best_transcription

def transcribe_audio_enhanced(audio_path, segment_duration=15):
    """Enhanced transcription with fallback strategies"""

    print(f"  üìÇ Loading: {os.path.basename(audio_path)}")
    speech, sr = librosa.load(audio_path, sr=16000)

    # Special handling for Mandela
    if "mandela" in audio_path.lower():
        print("  üîç Using enhanced Mandela mode...")

        # Try different segment durations
        durations = [10, 15, 20, 25, 30]

        for dur in durations:
            print(f"  \n  üìè Trying {dur}s segments...")

            # Find best segments
            _, _, _, top_segments = find_best_speech_segment(
                speech, sr, segment_duration=dur, min_energy=0.01
            )

            # Try top 3 segments
            for i, (energy, start, segment) in enumerate(top_segments[:3]):
                print(f"    Segment {i+1} at {start/sr:.1f}s (energy: {energy:.4f})")

                # Try transcribing with multiple strategies
                transcription = transcribe_segment(segment, sr, dur)

                if transcription:
                    return transcription

        # If still no transcription, try the exact segment that worked before
        print("\n  üéØ Trying known working segment (1778s)...")
        working_start = 1778  # The segment that worked in debugging
        working_end = working_start + 15
        working_segment = speech[int(working_start * sr):int(working_end * sr)]
        transcription = transcribe_segment(working_segment, sr)

        if transcription:
            return transcription

        return ""

    else:
        # Regular handling for other files
        best_segment, energy, start_time, _ = find_best_speech_segment(
            speech, sr, segment_duration=segment_duration
        )

        if best_segment is not None:
            print(f"  ‚úÖ Found speech at {start_time/sr:.1f}s (energy: {energy:.4f})")

            # Normalize
            if np.max(np.abs(best_segment)) > 0:
                best_segment = best_segment / np.max(np.abs(best_segment))

            # Prepare inputs
            inputs = processor(
                audio=best_segment,
                sampling_rate=16000,
                return_tensors="pt"
            )

            input_features = inputs.input_features.to("cuda", dtype=torch.bfloat16)

            # Generate
            with torch.no_grad():
                generated_ids = model.generate(
                    input_features=input_features,
                    max_new_tokens=128,
                    use_cache=True,
                    do_sample=False,
                    num_beams=1
                )

            transcription = processor.batch_decode(
                generated_ids,
                skip_special_tokens=True
            )[0].strip()

            return transcription
        else:
            print(f"  ‚ö†Ô∏è No clear speech found")
            return ""

# ==================== 4. BATCH PROCESSING ====================
print(f"üöÄ Processing {len(AUDIO_FILES)} files...\n")

results = {}
successful = 0

for i, path in enumerate(AUDIO_FILES, 1):
    if not os.path.exists(path):
        print(f"‚ö†Ô∏è [{i}/{len(AUDIO_FILES)}] File not found: {path}")
        continue

    print(f"\nüìÅ [{i}/{len(AUDIO_FILES)}] {os.path.basename(path)}")
    print("-" * 60)

    try:
        transcription = transcribe_audio_enhanced(path)
        results[os.path.basename(path)] = transcription

        if transcription:
            successful += 1
            print(f"\n‚úÖ {transcription}")
        else:
            print(f"\n‚ùå No transcription")

    except Exception as e:
        print(f"‚ùå Error: {str(e)}")
        import traceback
        traceback.print_exc()

    print("-" * 60)

# ==================== 5. SAVE RESULTS ====================
print("\n" + "="*60)
print("üìä FINAL RESULTS")
print("="*60)

for filename, transcription in results.items():
    if transcription:
        print(f"\n‚úÖ {filename}:")
        print(f"   {transcription}")
    else:
        print(f"\n‚ùå {filename}: No transcription")

# Save to file with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_file = f"transcriptions_{timestamp}.json"
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

# Also save as readable text
text_file = f"transcriptions_{timestamp}.txt"
with open(text_file, 'w') as f:
    for filename, transcription in results.items():
        f.write(f"{filename}:\n{transcription}\n\n")

print(f"\nüìù Results saved to: {output_file} and {text_file}")
print(f"\nüìä Summary: {successful}/{len(AUDIO_FILES)} successful")

# Special note for Mandela
if not results.get("mandela_davos_1999.mp3"):
    print("\nüîç MANDELA DEBUG INFO:")
    print("   The model found speech but returned empty.")
    print("   This could be due to:")
    print("   - Strong accent challenging the model")
    print("   - Background noise/interference")
    print("   - Very short phrases")
    print("\n   Try extracting a longer segment manually:")
    print("   import soundfile as sf")
    print("   speech, sr = librosac.load('mandela_davos_1999.mp3', sr=16000)")
    print("   segment = speech[1778*16000:(1778+30)*16000]  # 30s from 1778s")
    print("   sf.write('mandela_segment.wav', segment, 16000)")

print("\n‚úÖ Inference complete!")

üîß Setting up tokenizer and processor...
üîÑ Loading base model...


Loading weights:   0%|          | 0/711 [00:00<?, ?it/s]

üîÑ Loading and merging LoRA adapter...
üöÄ Processing 6 files...


üìÅ [1/6] barackobama2004dncARXE.mp3
------------------------------------------------------------
  üìÇ Loading: barackobama2004dncARXE.mp3
  ‚úÖ Found speech at 838.0s (energy: 0.1753)

‚úÖ who believes that America has a place for him too.
------------------------------------------------------------

üìÅ [2/6] barackobamatransitionaddress1.mp3
------------------------------------------------------------
  üìÇ Loading: barackobamatransitionaddress1.mp3
  ‚úÖ Found speech at 148.0s (energy: 0.2811)

‚úÖ impact of the financial crisis on other sectors of our economy. And ensure that the rescue plan that passed Congress is working to
------------------------------------------------------------

üìÅ [3/6] brad_pitt_sag_2020.mp3
------------------------------------------------------------
  üìÇ Loading: brad_pitt_sag_2020.mp3
  ‚úÖ Found speech at 124.0s (energy: 0.0901)

‚úÖ I love our communities. I love our comm

In [6]:
import torch
import librosa
import numpy as np
import os
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from peft import PeftModel
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
import json
from datetime import datetime

# ==================== CONFIG ====================
MODEL_ID = "mistralai/Voxtral-Mini-4B-Realtime-2602"
ADAPTER_PATH = "/content/drive/MyDrive/data/H2E_Challenge/Voxtral_FineTune/final_adapter"
AUDIO_FILES = [
    "/content/drive/MyDrive/data/barackobama2004dncARXE.mp3",
    "/content/drive/MyDrive/data/barackobamatransitionaddress1.mp3",
    "/content/drive/MyDrive/data/brad_pitt_sag_2020.mp3",
    "/content/drive/MyDrive/data/mandela_davos_1999.mp3",
    "/content/drive/MyDrive/data/mark_carney_davos_2026.mp3",
    "/content/drive/MyDrive/data/mlk_mountaintop_1968.mp3"
]

# ==================== 1. SETUP ====================
print("üîß Setting up tokenizer and processor...")

mistral_tokenizer = MistralTokenizer.v3(is_tekken=True)
backend_tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)

# ==================== 2. LOAD MODEL ====================
print("üîÑ Loading base model...")
base_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True
)

print("üîÑ Loading and merging LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model = model.merge_and_unload()
model.eval()

# ==================== 3. ROBUST TRANSCRIPTION FUNCTION ====================

def find_speech_segments(audio, sr=16000, segment_duration=15, min_energy=0.01, top_k=5):
    """Find top k speech segments by energy"""

    segment_samples = int(segment_duration * sr)
    hop_samples = int(2 * sr)

    segments = []

    for start in range(0, max(1, len(audio) - segment_samples), hop_samples):
        end = min(start + segment_samples, len(audio))
        segment = audio[start:end]
        energy = np.sqrt(np.mean(segment**2))

        if energy > min_energy:
            segments.append((energy, start, segment))

    # Sort by energy and return top k
    segments.sort(reverse=True)
    return segments[:top_k]

def transcribe_segment(segment, sr=16000):
    """Transcribe a single audio segment"""

    # Normalize
    if np.max(np.abs(segment)) > 0:
        segment = segment / np.max(np.abs(segment))

    # Prepare inputs
    inputs = processor(audio=segment, sampling_rate=sr, return_tensors="pt")
    input_features = inputs.input_features.to("cuda", dtype=torch.bfloat16)

    # Generate with greedy decoding (fastest)
    with torch.no_grad():
        generated_ids = model.generate(
            input_features=input_features,
            max_new_tokens=128,
            use_cache=True,
            do_sample=False,
            num_beams=1
        )

    transcription = processor.batch_decode(
        generated_ids,
        skip_special_tokens=True
    )[0].strip()

    return transcription

def transcribe_audio_robust(audio_path, segment_duration=15):
    """Robust transcription that handles various audio types"""

    print(f"  üìÇ Loading: {os.path.basename(audio_path)}")
    speech, sr = librosa.load(audio_path, sr=16000)

    # Special handling for Mandela (known to have speech later in the file)
    if "mandela" in audio_path.lower():
        print("  üîç Using enhanced mode for Mandela speech...")

        # Try different segment durations
        for dur in [10, 15, 20]:
            segments = find_speech_segments(speech, sr, segment_duration=dur, top_k=3)

            for energy, start, segment in segments:
                print(f"    Trying {dur}s segment at {start/sr:.1f}s (energy: {energy:.4f})")
                transcription = transcribe_segment(segment, sr)

                if transcription:
                    return transcription

        # If still no transcription, try the known working segment
        print("    Trying known working segment at 1778s...")
        working_start = 1778
        working_segment = speech[int(working_start * sr):int((working_start + 15) * sr)]
        transcription = transcribe_segment(working_segment, sr)

        return transcription

    else:
        # Regular handling for other files
        segments = find_speech_segments(speech, sr, segment_duration=segment_duration, top_k=1)

        if segments:
            energy, start, best_segment = segments[0]
            print(f"  ‚úÖ Found speech at {start/sr:.1f}s (energy: {energy:.4f})")
            return transcribe_segment(best_segment, sr)
        else:
            print(f"  ‚ö†Ô∏è No clear speech found")
            return ""

# ==================== 4. BATCH PROCESSING ====================
print(f"üöÄ Processing {len(AUDIO_FILES)} files...\n")

results = {}
successful = 0

for i, path in enumerate(AUDIO_FILES, 1):
    if not os.path.exists(path):
        print(f"‚ö†Ô∏è [{i}/{len(AUDIO_FILES)}] File not found: {path}")
        continue

    print(f"\nüìÅ [{i}/{len(AUDIO_FILES)}] {os.path.basename(path)}")
    print("-" * 60)

    try:
        transcription = transcribe_audio_robust(path)
        results[os.path.basename(path)] = transcription

        if transcription:
            successful += 1
            print(f"\n‚úÖ {transcription}")
        else:
            print(f"\n‚ùå No transcription")

    except Exception as e:
        print(f"‚ùå Error: {str(e)}")
        import traceback
        traceback.print_exc()

    print("-" * 60)

# ==================== 5. SAVE RESULTS ====================
print("\n" + "="*60)
print("üìä FINAL RESULTS")
print("="*60)

for filename, transcription in results.items():
    if transcription:
        print(f"\n‚úÖ {filename}:")
        print(f"   {transcription}")
    else:
        print(f"\n‚ùå {filename}: No transcription")

# Save with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_file = f"transcriptions_{timestamp}.json"
text_file = f"transcriptions_{timestamp}.txt"

with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

with open(text_file, 'w') as f:
    for filename, transcription in results.items():
        f.write(f"{filename}:\n{transcription}\n\n")

print(f"\nüìù Results saved to: {output_file} and {text_file}")
print(f"\nüìä Summary: {successful}/{len(AUDIO_FILES)} successful")
print("\n‚úÖ Inference complete!")

üîß Setting up tokenizer and processor...
üîÑ Loading base model...


Loading weights:   0%|          | 0/711 [00:00<?, ?it/s]

üîÑ Loading and merging LoRA adapter...
üöÄ Processing 6 files...


üìÅ [1/6] barackobama2004dncARXE.mp3
------------------------------------------------------------
  üìÇ Loading: barackobama2004dncARXE.mp3
  ‚úÖ Found speech at 838.0s (energy: 0.1753)

‚úÖ who believes that America has a place for him too.
------------------------------------------------------------

üìÅ [2/6] barackobamatransitionaddress1.mp3
------------------------------------------------------------
  üìÇ Loading: barackobamatransitionaddress1.mp3
  ‚úÖ Found speech at 148.0s (energy: 0.2811)

‚úÖ impact of the financial crisis on other sectors of our economy. And ensure that the rescue plan that passed Congress is working to
------------------------------------------------------------

üìÅ [3/6] brad_pitt_sag_2020.mp3
------------------------------------------------------------
  üìÇ Loading: brad_pitt_sag_2020.mp3
  ‚úÖ Found speech at 124.0s (energy: 0.0901)

‚úÖ I love our communities. I love our comm