In [17]:
import os
import re
import pandas as pd
import logging
from tqdm import tqdm
from nemo.collections.asr.models import EncDecMultiTaskModel
from nemo.collections.asr.metrics.wer import word_error_rate
import warnings
warnings.filterwarnings('ignore')

In [None]:
# === Silence warnings and NeMo logs ===
warnings.filterwarnings('ignore')
logging.getLogger('nemo').setLevel(logging.ERROR)

# === Paths ===
model_path = 'canary-180m-flash/canary-180m-flash.nemo'
csv_path = '../../TORGO_CLEANED.csv'
results_dir = 'results'
checkpoint_path = os.path.join(results_dir, 'checkpoint.csv')
os.makedirs(results_dir, exist_ok=True)

# === Load model ===
model = EncDecMultiTaskModel.restore_from(restore_path=model_path)
model.eval()

# === Load dataset ===
df = pd.read_csv(csv_path)

# === Load checkpoint if available ===
if os.path.exists(checkpoint_path):
    df_checkpoint = pd.read_csv(checkpoint_path)
    processed_paths = set(df_checkpoint['wav'])
    transcriptions = df_checkpoint['prediction'].tolist()
    references = df_checkpoint['reference'].tolist()
    speakers = df_checkpoint['spk_id'].tolist()
else:
    df_checkpoint = pd.DataFrame(columns=['wav', 'spk_id', 'prediction', 'reference'])
    processed_paths = set()
    transcriptions = []
    references = []
    speakers = []

# === Inference with checkpointing ===
for _, row in tqdm(df.iterrows(), total=len(df)):
    if row['wav'] in processed_paths:
        continue

    audio_path = os.path.join('../../', row['wav'])
    ref = str(row['wrd']).lower()
    spk = row['spk_id']

    if not os.path.exists(audio_path):
        pred = ''
    else:
        result = model.transcribe([audio_path])[0]

        # Clean transcription
        pred = re.sub(r'\[[^\]]*\]', '', result)
        pred = pred.lower()
        pred = re.sub(r"[^a-z0-9\s']", '', pred)
        pred = pred.strip()

    # Store results
    transcriptions.append(pred)
    references.append(ref)
    speakers.append(spk)
    new_row = pd.DataFrame([{
        'wav': row['wav'],
        'spk_id': spk,
        'prediction': pred,
        'reference': ref
    }])
    df_checkpoint = pd.concat([df_checkpoint, new_row], ignore_index=True)

    # Save checkpoint
    df_checkpoint.to_csv(checkpoint_path, index=False)

# === Merge results with original metadata ===
df_result = df.merge(df_checkpoint, on=['wav', 'spk_id'])
df_result = df_result[df_result['prediction'].str.strip() != '']
df_result = df_result[['ID', 'duration', 'wav', 'spk_id', 'prediction', 'reference']]
df_result.to_csv(os.path.join(results_dir, 'transcription_results.csv'), index=False)

# === Compute WERs ===
wer_per_spk = (
    df_result.groupby('spk_id')
    .apply(lambda g: pd.Series({
        'wer': word_error_rate(
            hypotheses=g['prediction'].tolist(),
            references=g['reference'].tolist()
        )
    }))
    .reset_index()
)
wer_per_spk.to_csv(os.path.join(results_dir, 'wer_per_speaker.csv'), index=False)

total_wer = word_error_rate(
    hypotheses=df_result['prediction'].tolist(),
    references=df_result['reference'].tolist()
)

# === Save WER summary ===
summary_path = os.path.join(results_dir, 'wer_summary.txt')
with open(summary_path, 'w') as f:
    f.write(f'Total WER: {total_wer:.4f}\n\n')
    f.write('WER per speaker:\n')
    for _, row in wer_per_spk.iterrows():
        f.write(f'{row["spk_id"]}: {row["wer"]:.4f}\n')

print(f"Total WER: {total_wer:.4f}")
print(f"Results saved in: {results_dir}")

[NeMo I 2025-03-27 21:12:21 mixins:200] _setup_tokenizer: detected an aggregate tokenizer
[NeMo I 2025-03-27 21:12:21 mixins:339] Tokenizer SentencePieceTokenizer initialized with 1152 tokens
[NeMo I 2025-03-27 21:12:21 mixins:339] Tokenizer SentencePieceTokenizer initialized with 1024 tokens
[NeMo I 2025-03-27 21:12:21 mixins:339] Tokenizer SentencePieceTokenizer initialized with 1024 tokens
[NeMo I 2025-03-27 21:12:21 mixins:339] Tokenizer SentencePieceTokenizer initialized with 1024 tokens
[NeMo I 2025-03-27 21:12:21 mixins:339] Tokenizer SentencePieceTokenizer initialized with 1024 tokens
[NeMo I 2025-03-27 21:12:21 aggregate_tokenizer:73] Aggregate vocab size: 5248


[NeMo W 2025-03-27 21:12:21 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    use_lhotse: true
    input_cfg: null
    tarred_audio_filepaths: null
    manifest_filepath: null
    sample_rate: 16000
    shuffle: true
    num_workers: 8
    pin_memory: true
    prompt_format: canary2
    max_tps: 25
    max_duration: 40.0
    text_field: answer
    lang_field: target_lang
    use_bucketing: true
    bucket_duration_bins:
    - - 3.56
      - 30
    - - 3.56
      - 77
    - - 4.608
      - 38
    - - 4.608
      - 88
    - - 5.48
      - 49
    - - 5.48
      - 106
    - - 6.05
      - 52
    - - 6.05
      - 109
    - - 6.85
      - 54
    - - 6.85
      - 124
    - - 7.914
      - 59
    - - 7.914
      - 137
    - - 8.52
      - 67
    - - 8.52
      - 158
    - - 9.51
      - 67
    - - 9.51
      - 153
    - - 10.29
      - 78
 

[NeMo I 2025-03-27 21:12:21 features:305] PADDING: 0
[NeMo I 2025-03-27 21:12:24 save_restore_connector:275] Model EncDecMultiTaskModel was successfully restored from /scratch/flatala/pre_trained_models/cannary/canary-180m-flash/canary-180m-flash.nemo.


  0%|          | 0/16538 [00:00<?, ?it/s][NeMo W 2025-03-27 21:12:24 dataloader:230] You are using a non-tarred dataset and requested tokenization during data sampling (pretokenize=True). This will cause the tokenization to happen in the main (GPU) process, possibly impacting the training speed if your tokenizer is very large. If the impact is noticable, set pretokenize=False in dataloader config. (note: that will disable token-per-second filtering and 2D bucketing features)

Transcribing: 1it [00:00,  8.70it/s]
  0%|          | 1/16538 [00:00<43:46,  6.30it/s][NeMo W 2025-03-27 21:12:24 dataloader:230] You are using a non-tarred dataset and requested tokenization during data sampling (pretokenize=True). This will cause the tokenization to happen in the main (GPU) process, possibly impacting the training speed if your tokenizer is very large. If the impact is noticable, set pretokenize=False in dataloader config. (note: that will disable token-per-second filtering and 2D bucketing feat