In [1]:
#%%
import pickle as pkl
from collections import defaultdict
import pandas as pd

output_path = None

def load_results(predictions_path):
    '''
    Return dictionary with distribution of predictions by accent
    '''
    with open(predictions_path, "rb") as f:
        eval_data = pkl.load(f)
    
    predictions_by_accent = defaultdict(lambda: defaultdict(int))
    for prediction, accent in zip(eval_data["preds"], eval_data["accents"]):
        predictions_by_accent[accent][prediction] += 1

    return predictions_by_accent


    # results_by_accent = defaultdict(lambda: defaultdict(int))
    # print(list(zip(eval_data["preds"], eval_data["accents"], eval_data["labels"])))
    # for prediction, accent, label in zip(eval_data["preds"], eval_data["accents"], eval_data["labels"]):
    #     if prediction == label:
    #         results_by_accent[accent]["correct"] += 1
    #     results_by_accent[accent]["total"] += 1
    # # %%

    # # Merge "us" and "american"  accents
    # results_by_accent["us"] = {k: results_by_accent["us"].get(k, 0) + results_by_accent["american"].get(k, 0) for k in set(results_by_accent["us"]) | set(results_by_accent["american"])}
    # # results_by_accent["us"]["total"] = sum([results_by_accent["us"]["total"], results_by_accent["american"]["total"]])
    # del results_by_accent["american"]


    # for accent, results in results_by_accent.items():
    #     results_by_accent[accent]["accuracy"] = round(results["correct"]/results["total"], 1)
    #     print(f"Accuracy for {accent}: {results['correct']/results['total']}")
    #     print(f"Total samples for {accent}: {results['total']}")
    #     print()

    return results_by_accent

In [21]:
def print_top_3_confusions(preds_by_accent):
    # Normalize the predictions
    for accent, predictions in preds_by_accent.items():
        total = sum(predictions.values()) - predictions["en"]
        for prediction, count in predictions.items():
            preds_by_accent[accent][prediction] = round((count/total)*100, 1) if total > 0 else 0

    # Print out the top 3 confusions for each accent
    for accent, predictions in preds_by_accent.items():
        print(f"Top 3 confusions for {accent}:")
        for prediction, count in sorted([x for x in predictions.items() if x[0] != "en"], key=lambda x: x[1], reverse=True)[:3]:
            print(f"{prediction}: {count}")
        print()

def get_top_3_confusions(preds_by_accent):
    # Normalize the predictions
    for accent, predictions in preds_by_accent.items():
        total = sum(predictions.values())
        for prediction, count in predictions.items():
            preds_by_accent[accent][prediction] = round((count/total)*100, 1) if total > 0 else 0

    top_confusions = defaultdict(lambda: defaultdict(float))
    for accent, predictions in preds_by_accent.items():
        for prediction, count in sorted([x for x in predictions.items()], key=lambda x: x[1], reverse=True)[:3]:
            top_confusions[accent][prediction] = count

    return top_confusions

In [27]:
etps_predictions_path = "/exp/nbafna/projects/mitigating-accent-bias-in-lid/reps-phoneseq_exps/vl107/ecapa-tdnn_wav2vec2-xlsr-53-espeak-cv-ft/attentions-linear-8/reps-phoneseq_lid_model_outputs/nistlre_predictions.pkl"
# predictions_path = "/exp/nbafna/projects/mitigating-accent-bias-in-lid/reps-phoneseq_exps/vl107/ecapa-tdnn_wav2vec2-xlsr-53-espeak-cv-ft/attentions-linear-4/reps-phoneseq_lid_model_outputs/nistlre_predictions.pkl"
ps_predictions_path = "/exp/nbafna/projects/mitigating-accent-bias-in-lid/phoneseq_exps/vl107/wav2vec2-xlsr-53-espeak-cv-ft/attentions-linear-8/phoneseq_lid_model_outputs/nistlre_new_predictions.pkl"
et_predictions_path = "/home/hltcoe/nbafna/projects/mitigating-accent-bias-in-lid/prelim_evals/preds/formatted/nistlre_predictions.pkl"
# predictions_path = "/exp/nbafna/projects/mitigating-accent-bias-in-lid/dists-phoneseq-systemcombo_exps/vl107/ecapa-tdnn_wav2vec2-xlsr-53-espeak-cv-ft/attentions-linear-8/lid_model_outputs/nistlre_predictions.pkl"


In [28]:
et_preds_by_accent = load_results(et_predictions_path)
etps_preds_by_accent = load_results(etps_predictions_path)
ps_preds_by_accent = load_results(ps_predictions_path)

In [29]:
et_top_confusions = get_top_3_confusions(et_preds_by_accent)
etps_top_confusions = get_top_3_confusions(etps_preds_by_accent)
ps_top_confusions = get_top_3_confusions(ps_preds_by_accent)

In [30]:
for accent in et_top_confusions:
    print(f"Top 3 confusions for {accent}:")
    et_list = [(k, v) for k, v in et_top_confusions[accent].items()]
    ps_list = [(k, v) for k, v in ps_top_confusions[accent].items()]
    etps_list = [(k, v) for k, v in etps_top_confusions[accent].items()]
    print(f"ET: {et_list}")
    print(f"PS: {ps_list}")
    print(f"ET+PS: {etps_list}")
    print("\n\n\n")

Top 3 confusions for brz:
ET: [('pt', 52.1), ('yi', 9.3), ('ro', 8.9)]
PS: [('pt', 26.4), ('ro', 8.5), ('ca', 6.1)]
ET+PS: [('pt', 37.9), ('ro', 15.0), ('gl', 5.8)]




Top 3 confusions for gbr:
ET: [('en', 43.6), ('mk', 9.5), ('cy', 9.2)]
PS: [('en', 37.8), ('cy', 18.5), ('la', 13.8)]
ET+PS: [('en', 41.0), ('cy', 18.2), ('ab', 8.8)]




Top 3 confusions for usg:
ET: [('en', 39.1), ('fo', 10.9), ('mk', 7.6)]
PS: [('en', 36.7), ('la', 17.4), ('ab', 6.2)]
ET+PS: [('en', 43.8), ('fo', 7.1), ('ab', 6.2)]




Top 3 confusions for car:
ET: [('es', 28.9), ('mk', 10.6), ('fo', 8.2)]
PS: [('la', 17.8), ('gl', 16.4), ('es', 8.9)]
ET+PS: [('es', 26.6), ('gl', 8.1), ('ab', 6.2)]




Top 3 confusions for eur:
ET: [('ca', 22.9), ('es', 14.0), ('gl', 13.9)]
PS: [('gl', 17.7), ('es', 14.0), ('la', 10.6)]
ET+PS: [('es', 23.0), ('gl', 18.3), ('el', 7.8)]




Top 3 confusions for lac:
ET: [('es', 40.3), ('gl', 6.2), ('fo', 4.6)]
PS: [('gl', 19.5), ('es', 15.1), ('la', 12.4)]
ET+PS: [('es', 38.0), ('gl', 