In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [1]:
import os
import pandas as pd
import librosa
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from tqdm import tqdm
from itertools import islice
from pathlib import Path



  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# ─── CONFIG ─────────────────────────────────────────────────────────────────────

DATASET_ROOT = "[DATASET_ROOT_PATH]"
DATASET_NAME = "en"

# Path to dataset clips
ROOT_CLIPS = os.path.join(DATASET_ROOT, DATASET_NAME, 'clips')

# Path to your input TSV with columns:
#    client_id, path, sentence, ... etc.
INPUT_TSV = os.path.join(DATASET_ROOT, f"{DATASET_NAME}/annotation/test.tsv")
# Root directory where audio files live
AUDIO_ROOT = os.path.join(DATASET_ROOT, f"{DATASET_NAME}/clips")

# Path where to write the output TSV
SAVE_PATH = os.path.join(DATASET_ROOT, f"{DATASET_NAME}/whisper-outputs")
os.makedirs(SAVE_PATH, exist_ok=True)
OUTPUT_TSV = os.path.join(SAVE_PATH, f"asr_output_test.tsv")

# Whisper model checkpoint
WHISPER_MODEL = "openai/whisper-tiny.en"


In [3]:

# ─── MODULES ────────────────────────────────────────────────────────────────────

def load_metadata(tsv_path: str) -> pd.DataFrame:
    """
    Load the metadata TSV into a DataFrame.
    Expects a 'path' column pointing to audio filenames,
    and a 'sentence' column with the target text.
    """
    df = pd.read_csv(tsv_path, sep="\t", dtype=str)
    if "path" not in df.columns or "sentence" not in df.columns:
        raise ValueError("Input TSV must contain 'path' and 'sentence' columns.")
    return df

def init_asr_model(model_name: str):
    """
    Load and return the Whisper processor & model.
    """
    processor = WhisperProcessor.from_pretrained(model_name)
    model     = WhisperForConditionalGeneration.from_pretrained(model_name)
    return processor, model

def transcribe_audio(
    processor: WhisperProcessor,
    model: WhisperForConditionalGeneration,
    audio_path: str,
    sr: int = 16_000,
) -> str:
    """
    Load an audio file, run Whisper ASR, and return the transcription string.
    """
    # 1) Load with librosa at the expected sampling rate
    wav, _ = librosa.load(audio_path, sr=sr)

    # 2) Extract input features for Whisper
    inputs = processor(
        wav,
        sampling_rate=sr,
        return_tensors="pt",
    ).input_features.to(model.device)

    # 3) Generate predicted token IDs
    predicted_ids = model.generate(inputs)

    # 4) Decode to text (keep special tokens if you like, skip otherwise)
    transcription = processor.batch_decode(
        predicted_ids,
        skip_special_tokens=True
    )[0]

    return transcription

def run_inference(
    metadata_df: pd.DataFrame,
    processor: WhisperProcessor,
    model: WhisperForConditionalGeneration,
    audio_root: str
) -> pd.DataFrame:
    """
    For each row in metadata_df, transcribe the corresponding audio file.
    Returns a new DataFrame with columns: [audio_name, target_text, transcription].
    """
    results = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    for _, row in tqdm(metadata_df.iterrows(), total=len(metadata_df), desc="ASR Inference"):
        audio_name  = row["path"]
        target_text = row["sentence"]
        audio_path  = os.path.join(audio_root, audio_name)

        if not os.path.isfile(audio_path):
            # Skip missing files, or you could raise an error
            transcription = "<MISSING>"
        else:
            transcription = transcribe_audio(processor, model, audio_path)

        results.append({
            "audio_name":     audio_name,
            "target_text":    target_text,
            "transcription":  transcription
        })

    return pd.DataFrame(results)


def run_inference_batch(
    metadata_df: pd.DataFrame,
    processor: WhisperProcessor,
    model: WhisperForConditionalGeneration,
    audio_root: str,
    batch_size: int = 1024,
    sr: int = 16000
) -> pd.DataFrame:
    """
    Batch-process the metadata_df through Whisper.

    Args:
      metadata_df:   DataFrame with columns ['path','sentence']
      processor:     WhisperProcessor
      model:         WhisperForConditionalGeneration
      audio_root:    root directory for all audio files
      batch_size:    how many samples per batch
      sr:            sampling rate for librosa.load

    Returns:
      DataFrame with columns ['audio_name','target_text','transcription']
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    results = []
    it = iter(metadata_df.to_dict('records'))

    def batched_iterator(iterator, size):
        """Yield lists of up to `size` items from iterator."""
        while True:
            batch = list(islice(iterator, size))
            if not batch:
                break
            yield batch

    for batch in tqdm(batched_iterator(it, batch_size), desc="ASR Batches"):
        # 1) Load all wavs in this batch
        audio_names   = [row['path'] for row in batch]
        target_texts  = [row['sentence'] for row in batch]
        wavs = []
        for name in audio_names:
            full_path = os.path.join(audio_root, name)
            if os.path.isfile(full_path):
                wav, _ = librosa.load(full_path, sr=sr)
            else:
                wav = None
            wavs.append(wav)

        # 2) Prepare inputs — only for existing wavs
        valid_indices = [i for i,w in enumerate(wavs) if w is not None]
        valid_wavs    = [wavs[i] for i in valid_indices]

        if valid_wavs:
            inputs = processor(
                valid_wavs,
                sampling_rate=sr,
                return_tensors="pt",
                padding=True
            ).input_features.to(device)

            # 3) Generate in one go
            with torch.no_grad():
                predicted_ids = model.generate(inputs)

            # 4) Decode all
            decoded = processor.batch_decode(predicted_ids, skip_special_tokens=True)

        # 5) Collect results, preserving order
        di = 0
        for idx, name, tgt in zip(range(len(batch)), audio_names, target_texts):
            if wavs[idx] is None:
                transcription = "<MISSING>"
            else:
                transcription = decoded[di]
                di += 1

            results.append({
                "audio_name":    name,
                "target_text":   tgt,
                "transcription": transcription
            })

    return pd.DataFrame(results)

CHECKPOINT_PATH = "partial_results.tsv"
CHECKPOINT_PERCENT = 5  # checkpoint every 5%

def run_inference_with_checkpoint(
    metadata_df: pd.DataFrame,
    processor,
    model,
    audio_root: str,
    checkpoint_path: str = CHECKPOINT_PATH,
    checkpoint_percent: int = CHECKPOINT_PERCENT
) -> pd.DataFrame:
    """
    Exactly like run_inference(), but every `checkpoint_percent`% of the way through
    it writes out a partial TSV so if you get killed you can resume.
    """

    total = len(metadata_df)
    interval = max(1, int(total * checkpoint_percent / 100))

    # 1) See if we have an existing checkpoint to resume from
    if os.path.exists(checkpoint_path):
        done_df = pd.read_csv(checkpoint_path, sep="\t")
        done_set = set(done_df["audio_name"])
        results = done_df.to_dict("records")
        start_idx = done_df.shape[0]
        print(f"🔄 Resuming from checkpoint: already have {start_idx}/{total}")
    else:
        results = []
        done_set = set()
        start_idx = 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    for i, (_, row) in enumerate(
        tqdm(
            list(metadata_df.iloc[start_idx:].iterrows()),
            total=total - start_idx,
            desc="ASR Inference (with ckpt)"
        ), start=start_idx
    ):
        audio_name  = row["path"]
        target_text = row["sentence"]

        if audio_name in done_set:
            continue

        audio_path = os.path.join(audio_root, audio_name)
        if not os.path.isfile(audio_path):
            transcription = "<MISSING>"
        else:
            transcription = transcribe_audio(processor, model, audio_path)

        results.append({
            "audio_name":     audio_name,
            "target_text":    target_text,
            "transcription":  transcription
        })
        done_set.add(audio_name)

        # every `interval` items, write a checkpoint
        if (i + 1) % interval == 0 or (i + 1) == total:
            pd.DataFrame(results).to_csv(checkpoint_path, sep="\t", index=False)
            print(f"💾 Checkpoint at {i+1}/{total} rows")

    return pd.DataFrame(results)


In [4]:
import os
import pandas as pd
import torch
from itertools import islice
from tqdm import tqdm

In [5]:
CHECKPOINT_PATH = "partial_batch_results.tsv"
CHECKPOINT_PERCENT = 5  # checkpoint every 5%

def run_inference_batch_with_checkpoint(
    metadata_df: pd.DataFrame,
    processor,
    model,
    audio_root: str,
    batch_size: int = 64,
    sr: int = 16000,
    checkpoint_path: str = CHECKPOINT_PATH,
    checkpoint_percent: int = CHECKPOINT_PERCENT
) -> pd.DataFrame:
    """
    Batch-process with Whisper, checkpointing every `checkpoint_percent`% of the total rows.
    On startup, resumes from any existing `checkpoint_path`.
    """
    # Prepare records and checkpoint
    records = metadata_df.to_dict("records")
    total = len(records)
    interval = max(1, int(total * checkpoint_percent / 100))

    if os.path.exists(checkpoint_path):
        done_df = pd.read_csv(checkpoint_path, sep="\t")
        done_set = set(done_df["audio_name"])
        results = done_df.to_dict("records")
        start_idx = done_df.shape[0]
        print(f"🔄 Resuming from checkpoint: {start_idx}/{total} rows done")
    else:
        done_set = set()
        results = []
        start_idx = 0

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()

    # Process in batches, but only from start_idx onward
    for batch_start in range(start_idx, total, batch_size):
        batch = records[batch_start : batch_start + batch_size]

        # 1) Load wavs
        wavs = []
        for rec in batch:
            name = rec["path"]
            path = os.path.join(audio_root, name)
            if os.path.isfile(path):
                wav, _ = librosa.load(path, sr=sr)
            else:
                wav = None
            wavs.append(wav)

        # 2) Prepare & run model on valid wavs
        valid_indices = [i for i, w in enumerate(wavs) if w is not None]
        decoded = []
        if valid_indices:
            valid_wavs = [wavs[i] for i in valid_indices]
            inputs = processor(
                valid_wavs,
                sampling_rate=sr,
                return_tensors="pt",
                padding=True
            ).input_features.to(device)
            with torch.no_grad():
                preds = model.generate(inputs)
            decoded = processor.batch_decode(preds, skip_special_tokens=True)

        # 3) Collect results, use absolute index for checkpoint logic
        di = 0
        for offset, rec in enumerate(batch):
            abs_idx = batch_start + offset
            name = rec["path"]
            target = rec["sentence"]
            if name in done_set:
                continue

            if wavs[offset] is None:
                text = "<MISSING>"
            else:
                text = decoded[di]
                di += 1

            results.append({
                "audio_name":    name,
                "target_text":   target,
                "transcription": text
            })
            done_set.add(name)

            # checkpoint every interval rows, or at the very end
            if (abs_idx + 1) % interval == 0 or (abs_idx + 1) == total:
                pd.DataFrame(results).to_csv(checkpoint_path, sep="\t", index=False)
                print(f"💾 Checkpoint at {abs_idx + 1}/{total}")

    return pd.DataFrame(results)

In [6]:
meta_df = load_metadata(INPUT_TSV)

In [7]:
meta_df.head(5)

Unnamed: 0,client_id,path,sentence,up_votes,down_votes,age,gender,accent
0,0013037a1d45cc33460806cc3f8ecee9d536c45639ba4c...,common_voice_en_699711.mp3,She'll be all right.,2,1,,,
1,001509f4624a7dee75247f6a8b642c4a0d09f8be3eeea6...,common_voice_en_18132047.mp3,All's well that ends well.,2,0,,,
2,003fb666a99eb3aa3ba05d9c8641c18e55cf7d34d1b981...,common_voice_en_17263741.mp3,Do you mean it?,2,0,,,
3,004017ba82a23768d58dff3b91da8e8f951ea5fb6d3cd9...,common_voice_en_17893917.mp3,The new patch is less invasive than the old on...,2,1,,,
4,0047f1aea3f39c4c6a9298d84f046c1f84f439f594d840...,common_voice_en_17561821.mp3,How is Mozilla going to handle ambiguities lik...,2,0,,,


In [8]:
## Find and Remove Faulty Audios
def remove_faulty_samples(df, root_clips="[Path_to_ASR_DATASET_AUDIO_CLIPS]", 
    ):
    audio_dir = Path(root_clips)

    # Find problematic files
    bad_files = []
    for mp3_path in audio_dir.glob("*.mp3"):
        if not mp3_path.exists():
            print(f"Missing file: {mp3_path}")
            continue
        if mp3_path.stat().st_size < 1024:  # Check for empty/small files
            print(f"Corrupted file: {mp3_path}")
            bad_files.append(mp3_path)
    
    # Assuming badfiles contains full paths to faulty audio files
    # Extract just the filenames from badfiles paths
    faulty_filenames = {os.path.basename(badfile) for badfile in bad_files}

    # Filter the dataframe to find matching entries
    faulty_samples = meta_df[meta_df['path'].isin(faulty_filenames)]

    # Print results
    print(f"Found {len(faulty_samples)} faulty audio samples:")
    for idx, row in faulty_samples.iterrows():
        print(f"- {row['path']} (Index: {idx})")
    
    # Get clean DataFrame (excluding faulty samples)
    clean_meta_df = meta_df[~meta_df['path'].isin(faulty_filenames)]

    # Verify removal
    print(f"Original rows: {len(meta_df)}")
    print(f"Clean rows: {len(clean_meta_df)}")
    print(f"Removed {len(meta_df) - len(clean_meta_df)} faulty samples")
        
    return clean_meta_df
    

In [9]:

# ─── MAIN ENTRYPOINT ─────────────────────────────────────────────────────────────

def main_old_old():
    # 1) Load metadata
    meta_df = load_metadata(INPUT_TSV)

    # 2) Init ASR model
    processor, model = init_asr_model(WHISPER_MODEL)

    # 3) Run inference
    out_df = run_inference(meta_df, processor, model, AUDIO_ROOT)

    # 4) Save to TSV
    out_df.to_csv(OUTPUT_TSV, sep="\t", index=False)
    print(f"✅ Wrote {len(out_df)} lines to {OUTPUT_TSV}")


def main_old():
    # 1) Load metadata
    meta_df = load_metadata(INPUT_TSV)

    # 2) Init ASR model
    processor, model = init_asr_model(WHISPER_MODEL)

    # 3) Run inference with checkpointing
    out_df = run_inference_with_checkpoint(meta_df, processor, model, AUDIO_ROOT)

    # 4) Save final TSV
    out_df.to_csv(OUTPUT_TSV, sep="\t", index=False)
    print(f"✅ Wrote {len(out_df)} lines to {OUTPUT_TSV}")

def main():
    meta_df   = load_metadata(INPUT_TSV)
    meta_df   = remove_faulty_samples(meta_df, root_clips=AUDIO_ROOT)
    processor, model = init_asr_model(WHISPER_MODEL)
    out_df    = run_inference_batch_with_checkpoint(
                    meta_df,
                    processor,
                    model,
                    AUDIO_ROOT,
                    batch_size=8,
                    sr=16_000
                )
    out_df.to_csv(OUTPUT_TSV, sep="\t", index=False)
    print(f"✅ Wrote {len(out_df)} lines to {OUTPUT_TSV}")

if __name__ == "__main__":
    main()

Corrupted file: [YOUR_ROOT_PATH]/datasets/audio-dataset/en/clips/common_voice_en_37101.mp3
Corrupted file: [YOUR_ROOT_PATH]/datasets/audio-dataset/en/clips/common_voice_en_626040.mp3
Corrupted file: [YOUR_ROOT_PATH]/datasets/audio-dataset/en/clips/common_voice_en_577779.mp3
Corrupted file: [YOUR_ROOT_PATH]/datasets/audio-dataset/en/clips/common_voice_en_641439.mp3
Found 0 faulty audio samples:
Original rows: 15531
Clean rows: 15531
Removed 0 faulty samples


The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


💾 Checkpoint at 776/15531
💾 Checkpoint at 1552/15531
💾 Checkpoint at 2328/15531
💾 Checkpoint at 3104/15531
💾 Checkpoint at 3880/15531
💾 Checkpoint at 4656/15531
💾 Checkpoint at 5432/15531
💾 Checkpoint at 6208/15531
💾 Checkpoint at 6984/15531
💾 Checkpoint at 7760/15531
💾 Checkpoint at 8536/15531
💾 Checkpoint at 9312/15531
💾 Checkpoint at 10088/15531
💾 Checkpoint at 10864/15531
💾 Checkpoint at 11640/15531
💾 Checkpoint at 12416/15531
💾 Checkpoint at 13192/15531
💾 Checkpoint at 13968/15531
💾 Checkpoint at 14744/15531
💾 Checkpoint at 15520/15531
💾 Checkpoint at 15531/15531
✅ Wrote 15531 lines to [YOUR_ROOT_PATH]/datasets/audio-dataset/en/whisper-outputs/asr_output_test.tsv


# Evaluate and Align

#### Filter Faulty Samples

In [12]:
import csv
import re

def normalize_text(text):
    """Robust text normalization handling float/None"""
    if text is None or isinstance(text, float):
        return ""
    if not isinstance(text, str):
        try:
            text = str(text)
        except:
            return ""
    
    # Normalization steps
    text = text.lower().strip()
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    text = re.sub(r'\s+', ' ', text)      # Normalize whitespace
    return text

def filter_faulty_rows(input_tsv, output_tsv):
    """Filter rows with invalid text fields"""
    valid_rows = 0
    total_rows = 0
    filtered_rows = 0
    
    with open(input_tsv, 'r') as infile, open(output_tsv, 'w') as outfile:
        reader = csv.DictReader(infile, delimiter='\t')
        writer = csv.DictWriter(outfile, fieldnames=reader.fieldnames, delimiter='\t')
        writer.writeheader()
        
        for total_rows, row in enumerate(reader, 1):
            target = row.get('target_text', '')
            trans = row.get('transcription', '')
            
            # Normalize and check validity
            norm_target = normalize_text(target)
            norm_trans = normalize_text(trans)
            
            # Filter criteria
            is_valid = (
                isinstance(target, str) and 
                isinstance(trans, str) and
                norm_target != "" and
                norm_trans != ""
            )
            
            if is_valid:
                writer.writerow(row)
                valid_rows += 1
            else:
                filtered_rows += 1
                print(f"Filtered row {total_rows}:")
                print(f"  Audio: {row.get('audio_name', '')}")
                print(f"  Target: {repr(target)}")
                print(f"  Transcription: {repr(trans)}")
                print("-" * 50)
    
    print("\nFiltering complete:")
    print(f"  Total rows: {total_rows}")
    print(f"  Valid rows: {valid_rows}")
    print(f"  Filtered rows: {filtered_rows}")
    print(f"  Clean TSV saved to: {output_tsv}")

# Example usage
if __name__ == "__main__":
    input_tsv = "[Path_to_your_asr_inference_annotation]" #includes audio path and target text
    output_tsv = "[Output_Path]" # a TSV file
    filter_faulty_rows(input_tsv, output_tsv)

Filtered row 203:
  Audio: common_voice_en_479944.mp3
  Target: 'But instead of being saddened, he was happy.'
  Transcription: ' .'
--------------------------------------------------
Filtered row 255:
  Audio: common_voice_en_509774.mp3
  Target: "You're just the one I wanted to see."
  Transcription: ' .'
--------------------------------------------------
Filtered row 459:
  Audio: common_voice_en_16759015.mp3
  Target: ''
  Transcription: ' HTML when it grows in'
--------------------------------------------------
Filtered row 1295:
  Audio: common_voice_en_17846037.mp3
  Target: 'First impressions are the most lasting.'
  Transcription: ' .'
--------------------------------------------------
Filtered row 1353:
  Audio: common_voice_en_696712.mp3
  Target: 'A young Arab, also loaded down with baggage, entered, and greeted the Englishman.'
  Transcription: ' .'
--------------------------------------------------
Filtered row 1464:
  Audio: common_voice_en_503476.mp3
  Target: 'But I di

In [13]:
!python data_preprocess.py "[Output_Path]"

Overall WER = 
Overall CER = 
Wrote detailed results to filtered_asr_output_test.tsv
