# Configuration


Use this section to adapt configuration around dataset and model.

In [None]:
# log into HF to download the datasets and models
import os
from huggingface_hub import login
from huggingface_hub import whoami

hf_token = os.environ.get("HF_TOKEN")

login(token=hf_token)
username = whoami()
print(f"You are logged in as: {username}")

In [None]:
base_processor_path = 'openai/whisper-tiny'
pretrained_model_path = 'cdli/whisper-tiny_Akan_standardspeech_spec_and_audio_augment'

# base_processor_path = 'openai/whisper-small'
# pretrained_model_path = 'cdli/whisper-small_Akan_standardspeech_spec_and_audio_augment'

# base_processor_path = 'openai/whisper-large-v3-turbo'
# pretrained_model_path = 'cdli/whisper-large_v3_turbo_Akan_standardspeech_spec_and_audio_augment'



LANGUAGE = 'yo'
TASK = "transcribe"

# Imports and Function Definitions

In [None]:
import datasets
import ipywidgets as widgets
import matplotlib.pyplot as plt
import pandas as pd


from evaluate import load as metrics_loader
from transformers.models.whisper.english_normalizer import BasicTextNormalizer


import torch
import time

from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration

from tqdm.auto import tqdm
from transformers import pipeline


from transformers.pipelines.pt_utils import Dataset

import os
import csv
import shutil
import soundfile as sf
import numpy as np

In [None]:
datasets.disable_caching()
print('cache:', datasets.is_caching_enabled())

# optimize settings for dataset access
datasets.disable_caching()
print('cache:', datasets.is_caching_enabled())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device is: ', device)

torch.set_num_threads(1)
torch.get_num_threads()
num_proc = os.cpu_count()

In [None]:
# note: this is slower when more stored in dataset (eg features), as it leads to data being copied around
class AudioTextDataset(Dataset):
    def __init__(self, dataset: Dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        if 'severity' in self.dataset[i].keys():
            return {"utterance_id": self.dataset[i]['utterance_id'],
                    "speaker_id": self.dataset[i]['speaker_id'],
                    "severity": self.dataset[i]['severity'],
                    "gender": self.dataset[i]['gender'],
                    "age": self.dataset[i]['age'],
                    "environment": self.dataset[i]['environment'],
                    "device": self.dataset[i]['device'],
                    "sampling_rate": self.dataset[i]['audio']['sampling_rate'],
                    # "path": self.dataset[i]['audio']['path'],
                    "raw": self.dataset[i]['audio']['array'],
                    "ground_truth": self.dataset[i]['text']}
        else:
            return {"utterance_id": self.dataset[i]['utterance_id'],
                    "speaker_id": self.dataset[i]['speaker_id'],
                    "gender": self.dataset[i]['gender'],
                    "age": self.dataset[i]['age'],
                    "environment": self.dataset[i]['environment'],
                    "device": self.dataset[i]['device'],
                    "sampling_rate": self.dataset[i]['audio']['sampling_rate'],
                    "raw": self.dataset[i]['audio']['array'],
                    "ground_truth": self.dataset[i]['text']}



In [None]:
wer_metric = metrics_loader("wer")
cer_metric = metrics_loader("cer")

transcript_normalizer = BasicTextNormalizer()

def get_wer_cer(references, predictions,
                calculate_utterance_level_averaged_wer=False,
                normalize=True, verbose=True,
                ):
  # calculate_utterance_level_averaged_wer -- we first calculate the WER per
  # utterance and then average. This is not the standard way to calculate WER
  # on a corpus, but in a scenario of high WER (as for NSS) this allows to cap
  # at 1.0 on a per-utterance level.
  rs = references
  ps = predictions
  if normalize:
    pred_strs = [transcript_normalizer(x) for x in predictions]
    label_strs = [transcript_normalizer(x) for x in references]
  if calculate_utterance_level_averaged_wer:
    wers = []
    cers = []
    for pred_str, label_str in zip(pred_strs, label_strs):
      p = transcript_normalizer(pred_str)
      l = transcript_normalizer(label_str)
      wer = wer_metric.compute(predictions=[p], references=[l])
      cer = cer_metric.compute(predictions=[p], references=[l])
      wers.append(wer)
      cers.append(cer)
      if verbose:
        print(label_str, '-->', pred_str, '-->', wer, cer)
    wer = np.mean([min(1.0, x) for x in wers])
    cer = np.mean([min(1.0, x) for x in cers])
  else:
    wer =  min(1, wer_metric.compute(references=label_strs, predictions=pred_strs))
    cer =  min(1, cer_metric.compute(references=label_strs, predictions=pred_strs))

    if verbose:
      for pred_str, label_str in zip(pred_strs, label_strs):
        print(label_str, '-->', pred_str)

  return (wer, cer)

In [None]:
datasets.disable_caching()
print('cache:', datasets.is_caching_enabled())

from huggingface_hub import hf_hub_download
import tarfile

def get_standard_speech_dataset(export_dir):
    REPO = 'cdli/akan_standard_speech_data_16khz'
    DATA_FILE = "data.tar.gz"
    METADAT_FILE = "metadata.csv"

    tar_gz_file = hf_hub_download(
        repo_id=REPO, repo_type="dataset", filename=DATA_FILE)
    with tarfile.open(tar_gz_file, "r:gz") as tar:
        tar.extractall(export_dir)

def get_nonstandard_speech_dataset(export_dir):
    REPO = 'cdli/akan_nonstandard_speech_data_16khz'
    DATA_FILE = "data.tar.gz"
    # previously we had a separate metadata and audio data file, now combined in one file
    # DATA_FILE = "Kumasi_Batch_16khz.tar.gz"
    # METADAT_FILE = "metadata.csv"

    tar_gz_file = hf_hub_download(
        repo_id=REPO, repo_type="dataset", filename=DATA_FILE)
    with tarfile.open(tar_gz_file, "r:gz") as tar:
        tar.extractall(export_dir)
    # rename the exracted dir to what we expect
    # os.rename(os.path.join(export_dir, 'Kumasi_Batch_16khz'),
    #           os.path.join(export_dir, 'data'))

    # # metadata
    # orig_metadata_file = hf_hub_download(
    #     repo_id=REPO, repo_type="dataset", filename=METADAT_FILE)
    # shutil.copy2(orig_metadata_file, os.path.join(export_dir, 'data', 'metadata.csv'))


# Get Test Set

In [None]:
#################
## directories
#################

LOCAL_DATA_DIR = '/tmp/eval_datasets'
!mkdir -p {LOCAL_DATA_DIR}


## Get SS Data

In [None]:
# can download standard or NSS dataset now

SS_DATASET_DIR = os.path.join(LOCAL_DATA_DIR, 'SS')
get_standard_speech_dataset(SS_DATASET_DIR)

print('Downloaded dataset to:', SS_DATASET_DIR)

In [None]:
ss_test_set = datasets.load_dataset("audiofolder", data_dir=SS_DATASET_DIR, split='test', streaming=False)
ss_test_set

## Get NSS Data

In [None]:
# can download standard or NSS dataset now

NSS_DATASET_DIR = os.path.join(LOCAL_DATA_DIR, 'NSS')
get_nonstandard_speech_dataset(NSS_DATASET_DIR)

print('Downloaded dataset to:', NSS_DATASET_DIR)

In [None]:
nss_test_set = datasets.load_dataset("audiofolder", data_dir=NSS_DATASET_DIR, split='test', streaming=False)
nss_test_set

# Load Model

In [None]:
print('loading model from:', pretrained_model_path)
processor = WhisperProcessor.from_pretrained(base_processor_path, language=LANGUAGE, task=TASK)

tuned_model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_path)

_ = tuned_model.to(device)
tuned_model.generation_config.forced_decoder_ids = None
tuned_model.generation_config.language = LANGUAGE
tuned_model.generation_config.task = TASK

pipe = pipeline(task="automatic-speech-recognition",
                tokenizer=processor.tokenizer,
                feature_extractor=processor.feature_extractor,
                model=tuned_model,
                device=device, torch_dtype="auto",
                chunk_length_s=30, # we need that,otherwise we run into some processing issues
                )

pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=LANGUAGE, task=TASK)
print('torch dtype:', pipe.torch_dtype)

# Run Inference

In [None]:
BATCH_SIZE = 16

# ds = ss_test_set
ds = nss_test_set
res = []
t1 = time.time()

gen_args = None

# # alternatively control the generation settings this way:
# # (increase the beam size for better performance at cost of higher inference time)
# gen_args={
#     "num_beams": 1,
#     "temperature": 0.0,
#     "do_sample": False,
#     "max_new_tokens": 128,
#     }

for out in tqdm(pipe(AudioTextDataset(ds), batch_size=BATCH_SIZE), total=len(ds)):
#for out in tqdm(pipe(AudioTextDataset(ds), batch_size=BATCH_SIZE, generate_kwargs=gen_args), total=len(ds)):
  res.append(out)
  # print(out)
t2 = time.time()
print('total inference time:', (t2-t1))
print('num examples:', len(res))
df = pd.DataFrame(res)

# Calculate WER 

In [None]:
%%time

BASE_COLS = ['utterance_id', 'device', 'gender', 'speaker_id', 'age', 'environment', 'ground_truth']

if 'severity' in df.columns:
  BASE_COLS.append('severity')

# clean up and organize output
def clean_res(row):
  cols = BASE_COLS

  for f in cols:
    v = row[f]
    if isinstance(v, list):
      row[f] = v[0]

  return row

def get_row_wer(row):
  reference = transcript_normalizer(row['ground_truth'])
  prediction = transcript_normalizer(row['prediction'])
  return get_wer_cer(references=[reference], predictions=[prediction], normalize=True, verbose=False)[0]

def get_row_cer(row):
  reference = transcript_normalizer(row['ground_truth'])
  prediction = transcript_normalizer(row['prediction'])
  return get_wer_cer(references=[reference], predictions=[prediction], normalize=True, verbose=False)[1]

results_df = df.apply(clean_res, axis=1)
results_df['prediction'] = results_df['text']
results_df.drop(columns=['text'])

# add normalized ground truth and prediction
results_df['ground_truth_normalized'] = results_df['ground_truth'].apply(lambda x: transcript_normalizer(x))
results_df['prediction_normalized'] = results_df['prediction'].apply(lambda x: transcript_normalizer(x))

results_df['wer'] = results_df.apply(get_row_wer, axis=1)
results_df['cer'] = results_df.apply(get_row_cer, axis=1)

# reorder columns
PRED_COLS = ['prediction', 'ground_truth_normalized', 'prediction_normalized', 'wer', 'cer']
COLS = BASE_COLS + PRED_COLS
# if 'Severity' in results_df.columns:
#   COLS.append('Severity')
results_df = results_df[COLS]

# calculate overall WER
overall_wer_cer = get_wer_cer(references=results_df.ground_truth.tolist(),
                              predictions=results_df.prediction.tolist(),
                              calculate_utterance_level_averaged_wer=False,
                              normalize=True, verbose=False)

avg_utterance_level_wer_cer = get_wer_cer(references=results_df.ground_truth.tolist(),
                              predictions=results_df.prediction.tolist(),
                              calculate_utterance_level_averaged_wer=True,
                              normalize=True, verbose=False)


print('Overall WER (normalized):', round(overall_wer_cer[0],3))
print('Overall CER (normalized):', round(overall_wer_cer[1],3))
print('Avg WER (normalized):', round(avg_utterance_level_wer_cer[0],3))
print('Avg CER (normalized):', round(avg_utterance_level_wer_cer[1],3))

In [None]:
# WER by Speaker
if 'severity' in results_df.columns:
    print(results_df[['speaker_id', 'severity', 'wer', 'cer']].groupby(['speaker_id', 'severity']).agg(['mean', 'count']))
else:
    print('Skipped as no Severity in results_df')

In [None]:
# WER by severity
if 'severity' in results_df.columns:
    print(results_df[['severity', 'wer', 'cer']].groupby(['severity']).agg(['mean', 'count']))
else:
    print('Skipped as no Severity in results_df')

In [None]:
# WER by Gender
results_df[['gender', 'wer', 'cer']].groupby(['gender']).agg(['mean', 'count'])

# Safe Predictions

In [None]:
# Safe results
dataset = 'NSS' if ds == nss_test_set else 'SS'
print('saving predictions on', dataset)
output_dir = os.path.join(LOCAL_DATA_DIR, 'results', dataset)
output_filename = os.path.join('/tmp', pretrained_model_path.replace('cdli/','') + '_' + dataset + '.tsv')
print('output_filename:', output_filename)
#results_df.to_csv(output_filename, index=False, sep='\t') 