In [21]:
import os
import pandas as pd
import numpy as np
from datasets import load_dataset, load_metric, Dataset

wer_metric = load_metric("wer")

Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [28]:
results_dir = '/data/AfriSpeech-Dataset-Paper/'
result_paths = [
    'results/intron-open-test-wav2vec2-large-xlsr-53-generative_single_task_baseline-wer-0.2519-5474.csv',
    'results/intron-open-test-wav2vec2-large-xlsr-53-generative-multitask-asr-domain-prepend-wer-0.2467-5474.csv',
    'results/intron-open-test-atnafu-asr-domain-9-1-wer-0.2455-5474.csv'
]
data_paths = [
    'data/intron-train-public-58000-clean.csv',
    'data/intron-dev-public-3231-clean.csv',
    'data/intron-test-public-6346-clean.csv'
]

In [10]:
test = pd.read_csv(os.path.join(results_dir, data_paths[2]))
train = pd.read_csv(os.path.join(results_dir, data_paths[0]))
val = pd.read_csv(os.path.join(results_dir, data_paths[1]))

In [11]:
accent_counts = train.accent.value_counts().to_dict()

In [12]:
accent_counts

{'yoruba': 14396,
 'igbo': 8104,
 'swahili': 5484,
 'hausa': 5453,
 'ijaw': 2371,
 'afrikaans': 1911,
 'idoma': 1767,
 'twi': 1321,
 'zulu': 1309,
 'setswana': 1275,
 'igala': 906,
 'kiswahili': 811,
 'izon': 783,
 'isizulu': 779,
 'ebira': 696,
 'urhobo': 578,
 'nembe': 546,
 'luganda': 529,
 'ibibio': 482,
 'pidgin': 442,
 'kinyarwanda': 439,
 'luhya': 426,
 'esan': 353,
 'xhosa': 342,
 'tshivenda': 334,
 'alago': 310,
 'tswana': 289,
 'isoko': 259,
 'fulani': 256,
 'efik': 232,
 'akan (fante)': 230,
 'edo': 201,
 'ikwere': 200,
 'hausa/fulani': 192,
 'isindebele': 188,
 'luo': 179,
 'sepedi': 176,
 'venda and xitsonga': 174,
 'bekwarra': 165,
 'kikuyu': 163,
 'isixhosa': 160,
 'epie': 147,
 'luganda and kiswahili': 134,
 'akan': 131,
 'sotho': 129,
 'afemai': 125,
 'kagoma': 123,
 'nasarawa eggon': 120,
 'south african english': 114,
 'borana': 112,
 'swahili ,luganda ,arabic': 109,
 'nupe': 106,
 'bette': 103,
 'benin': 103,
 'venda': 98,
 'damara': 92,
 'okrika': 90,
 'southern so

In [13]:
majority_threshold = 500

In [14]:
majority_accents = [accent for accent, count in accent_counts.items() if count >= majority_threshold]
minority_accents = [accent for accent, count in accent_counts.items() if count < majority_threshold]

In [17]:
train_accents = list(train.accent.unique()) + list(val.accent.unique())
test_accents = list(test.accent.unique())
ood_accents = [accent for accent in test_accents if accent not in train_accents]

In [19]:
print(len(set(train_accents)))
print(len(test_accents))
print(len(majority_accents))
print(len(minority_accents))
print(len(ood_accents))

79
108
18
53
41


In [5]:
test_data.shape

(6346, 14)

In [6]:
test_data.columns

Index(['idx', 'user_ids', 'accent', 'age_group', 'country', 'transcript',
       'nchars', 'audio_ids', 'audio_paths', 'duration', 'origin', 'domain',
       'split', 'gender'],
      dtype='object')

In [None]:
groups = [
    'clinical', 'general', 'ood', 'majority', 'minority'
]

In [None]:
# for result in results:
# merge with metadata on audio paths

In [33]:
for result_path in result_paths:
    print('\n\n', result_path)
    res = pd.read_csv(os.path.join(results_dir, result_path))
    #print(res.sample(5))
    res.drop_duplicates(subset=['audio_paths'], inplace=True)
    res['audio_paths'] = res.apply(lambda row: row.audio_paths.replace('/data/data/intron/', '/AfriSpeech-100/test/'), axis=1)
    # 
    res['audio_paths'] = res.apply(lambda row: row.audio_paths.replace('/media/4T/atnafu/adata/test/', '/AfriSpeech-100/test/'), axis=1)
    
    test.drop_duplicates(subset=['audio_paths'], inplace=True)
    #print(test.sample(5))
    pred = test.merge(res[['audio_paths', 'ref_clean', 'pred_clean']], how='left')[['audio_paths', 'ref_clean', 'pred_clean', 'accent', 'domain']]
    #print(pred.shape)
    #print(pred.sample(5))
    
    # clinical
    clin = pred[pred.domain == 'clinical']
    wer = wer_metric.compute(
            predictions=clin["pred_clean"].tolist(), references=clin["ref_clean"].tolist()
        )
    print("clinical_wer: ", round(wer,4))
    
    # general
    gen = pred[pred.domain == 'general']
    wer = wer_metric.compute(
            predictions=gen["pred_clean"].tolist(), references=gen["ref_clean"].tolist()
        )
    print("general_wer: ", round(wer,4))
    
    # ood
    ood = pred[pred.accent.isin(ood_accents)]
    wer = wer_metric.compute(
            predictions=ood["pred_clean"].tolist(), references=ood["ref_clean"].tolist()
        )
    print("ood_wer: ", round(wer,4))
    
    # majority
    major = pred[pred.accent.isin(majority_accents)]
    wer = wer_metric.compute(
            predictions=major["pred_clean"].tolist(), references=major["ref_clean"].tolist()
        )
    print("majority_wer: ", round(wer,4))
    
    # minority
    minor = pred[pred.accent.isin(minority_accents)]
    wer = wer_metric.compute(
            predictions=minor["pred_clean"].tolist(), references=minor["ref_clean"].tolist()
        )
    print("minority_wer: ", round(wer,4))



 results/intron-open-test-wav2vec2-large-xlsr-53-generative_single_task_baseline-wer-0.2519-5474.csv
clinical_wer:  0.2669
general_wer:  0.2324
ood_wer:  0.2605
majority_wer:  0.2392
minority_wer:  0.242


 results/intron-open-test-wav2vec2-large-xlsr-53-generative-multitask-asr-domain-prepend-wer-0.2467-5474.csv
clinical_wer:  0.264
general_wer:  0.225
ood_wer:  0.2539
majority_wer:  0.2331
minority_wer:  0.2397


 results/intron-open-test-atnafu-asr-domain-9-1-wer-0.2455-5474.csv
clinical_wer:  0.2598
general_wer:  0.2267
ood_wer:  0.255
majority_wer:  0.2291
minority_wer:  0.2395


In [None]:
pred.drop_duplicates(subset=['audio_paths'], inplace=True)
ner.drop_duplicates(subset=['audio_paths'], inplace=True)
merge_cols.append('audio_paths')

ner_cols = ['audio_ids', 'audio_paths', 'has_entity', 'PER']
predm = ner[ner_cols].merge(pred[merge_cols], how='left')
#print(predm.shape)
predm.drop_duplicates(subset=['audio_ids'], inplace=True)
predm.drop_duplicates(subset=['audio_paths'], inplace=True)

In [None]:
# define groups/filters
# loop through group 
# loop through results
# and compute wer