# Evaluation of speech transcriptions

In [None]:
import os
import re
import jiwer
import Levenshtein
from collections import Counter
from rich.jupyter import print
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
## Visualización
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.font_manager as fm
import seaborn as sns

In [None]:
# Regular expressions to match id and transcription
re_id_head=re.compile(r"^(?P<id>[^ ]+) (?P<trans>.*)") # id at the start
re_id_tail=re.compile(r"(?P<trans>.*) (?P<id>[^ ]+)$") # id at the final

In [None]:
# Set the files and directories

# Configure this
REF_FILE="../data/test/CIEMPIESS_TEST.trans"
HYP_DIR="../data/output/exp_november2022"

# ---
HYP_FILES={file:os.path.join(HYP_DIR,file) for file in os.listdir(HYP_DIR)}
print(f"[green bold]Files to analyse[/] in dir [{HYP_DIR}]:\n ","\n  ".join([file for file,_ in HYP_FILES.items()]))

In [None]:
# labels to filenames (inspect files from previous output)

# Configure this
labels=["Go.","QN","W2V","W tn.","W bs.", "W sm.", "W med.", "W lg.", "QN FT","W2V FT"] # We will follow this order
label2name={
  "W med.":"whisper_medium_ciempiess_test.trans",
  "W sm.":"whisper_small_ciempiess_test.trans",
  "W tn.":"whisper_tiny_ciempiess_test.trans",
  "Go.":"google_ciempiess_test.trans",
  "W2V FT":"wav2vec_carlos_ciempiess_test.trans",
  "W lg.":"whisper_large_ciempiess_test.trans",
  "W2V":"wav2vec_jonatasgrosman_ciempiess_test.trans",
  "QN FT":"nemo_carlos_ciempiess_test.trans",
  "QN":"nemo_nvidia_ciempiess_test.trans",
  "W bs.":"whisper_base_ciempiess_test.trans",
}

# ---
name2label={v:k for k,v in label2name.items()}

## Auxiliary functions

In [None]:
# Functions to openfiles
def open_head(filename):
    """Opens a reference file in which first token an id"""
    with open(filename) as f:
        lines=[(gs['id'],gs['trans']) for l in f.readlines() if (gs := re_id_head.match(l.strip()).groupdict())]
    return lines

def open_tail(filename):
    """Opens a hypothesis file in which last token an id"""
    with open(filename) as f:
        lines=[(gs['id'],gs['trans']) for l in f.readlines() if (gs := re_id_tail.match(l.strip()).groupdict())]
    return lines

def open_transcription_file(filename,format=None):
    """Opens a trascription file
    
    param filename: name of file
    param format: None, head or tail"""
    if format is None:
        with open(filename) as f:
            lines=[(None,l) for l in f.readlines()]
        return lines
    elif format == "head":
        return open_head(filename)
    elif format == "tail":
        return open_tail(filename) 

## Checks

In [None]:
# Checks number of trasncriptions
ref=open_transcription_file(REF_FILE,format="head")
print(f"[green]Reference number of lines:[/] [bold] {len(ref)} [/]")
ids=set( id for id,_ in ref)
print(f"[green]✓ [magenta]Reference [green]has different ids per transcription[/]")\
        if len(ref) == len(ids) else\
        print(f"[red]✗ {name} has wrong number of transcriptions[/]")

for idd in labels:
    name=label2name[idd]
    file=HYP_FILES[name]
    if idd in ["Go."]:
        hyp=open_transcription_file(file,format="tail")
        ids_=set( id[1:-1] for id,_ in hyp)
    else:
        hyp=open_transcription_file(file,format="head")
        ids_=set( id for id,_ in hyp)
    print(f"[green]✓ [magenta]{name} [green]has rigth number of transcriptions[/]")\
        if len(ref) == len(hyp) else\
        print(f"[red]✗ {name} has wrong number of transcriptions[/]")
    print(f"[green]✓ [magenta]{name} [green]has all ids[/]")\
        if len(ids.difference(ids_))==0 else\
        print(f"[red]✗ {name} does not have all ids[/]")
    

## jiwer evaluation

In [None]:
ref=open_head(REF_FILE)
for idd in labels:
    name=label2name[idd]
    file=HYP_FILES[name]
    if idd in ["Go."]:
        hyp=open_tail(file)
    else:
        hyp=open_head(file)
    measures = jiwer.compute_measures([trans for _,trans in ref], [trans for _,trans in hyp])
    wer = measures['wer']
    cer = jiwer.cer([trans for _,trans in ref], [trans for _,trans in hyp])
    print(measures.keys())
    print(f"[magenta]{idd:6s} [{name[:-6]}][/]:")
    print(f"[green]  wer:[/] {wer*100:3.2f}")
    print(f"[green]  cer:[/] {cer*100:3.2f}")
    print(f"[green]  hits:[/] {measures['hits']:3d}")
    print(f"[green]  deletions:[/] {measures['deletions']:3d}")
    print(f"[green]  insertions:[/] {measures['insertions']:3d}")
    print(f"[green]  substitutions:[/] {measures['substitutions']:3d}")
    

## Errors

In [None]:
ref=open_head(REF_FILE)
MIN_COUNT=0
OPS={idd:{} for idd in labels}

for idd in labels:
    ops={
        'delete':Counter(),
        'insert':Counter(),
        'replace':Counter(),
        'replace_R':Counter(),
        'replace_H':Counter(),
        }
    name=label2name[idd]
    file=HYP_FILES[name]
    if idd in ["Go."]:
        hyp=open_tail(file)
    else:
        hyp=open_head(file)
    for t1,t2 in zip([trans.split() for _,trans in ref], [trans.split() for _,trans in hyp]):
        editops = Levenshtein.opcodes(t1,t2)
        for op,ri,rf,hi,hf in editops:
            if op=="delete":
                ops[op].update(t1[ri:rf])
            if op=="insert":
                ops[op].update(t2[hi:hf])
            if op=="replace":
                ops[op].update(f"{x}->{y}" for x,y in zip(t1[ri:rf],t2[hi:hf]))
                ops['replace_R'][f"{' '.join(t1[ri:rf])}"]+=1
                ops['replace_H'][f"{' '.join(t2[hi:hf])}"]+=1
    print(f"[magenta]{idd:6s} [{name[:-6]}][/]:")
    print(f"[yellow]Delete[/]  ({sum(ops['delete'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{c: 3d}[/]" for k,c in ops['delete'].most_common(10) if c >= MIN_COUNT]))
    print(f"[yellow]Insert[/]  ({sum(ops['insert'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{c: 3d}[/]" for k,c in ops['insert'].most_common(10) if c >= MIN_COUNT]))
    print(f"[yellow]Replace[/] ({sum(ops['replace'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{c: 3d}[/]" for k,c in ops['replace'].most_common(10) if c >= MIN_COUNT]))
    OPS[idd]=ops
    #print(f"[yellow]Replace R[/] ({sum(ops['replace_R'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{c: 3d}[/]" for k,c in ops['replace_R'].most_common() if c >= MIN_COUNT]))
    #print(f"[yellow]Replace H[/] ({sum(ops['replace_H'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{c: 3d}[/]" for k,c in ops['replace_H'].most_common() if c >= MIN_COUNT]))

                

In [None]:
#@markdown Control size of figure

linestyle_tuple = [
     ('loosely dotted',        (0, (1, 10))),
     ('densely dotted',        (0, (1, 1))),
     #('long dash with offset', (5, (10, 3))),
     #('loosely dashed',        (0, (5, 10))),
     ('dashed',                (0, (5, 5))),
     ('densely dashed',        (0, (5, 1))),

     #('loosely dashdotted',    (0, (3, 10, 1, 10))),
     #('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     #('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     #('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]

def plot_error_frequencies(op,x_sizefigure=10,y_sizefigure=4):
    plt.figure(figsize=(int(x_sizefigure),int(y_sizefigure)))
    fig, ax = plt.subplots()
    plt.title(f"Error frequencies in '{op}' edit operation")
    plt.ylabel("Word error occurrences (log base 10)")
    plt.xlabel("Word error rank (log base 10)")

    c=0;
    labels_=[]
    for idd,ops in OPS.items():
        if idd in ["Go.","QN","W tn.","W bs.", "W sm.", "W med."]:
            continue
        index=range(len(ops[op]))
        freq=[f for k,f in ops[op].most_common()]

        ax.loglog(
            index,
            freq,
            base=10,
            color=sns.color_palette('tab10')[c],
            linewidth=1.5+(0.3)*c,
            linestyle=linestyle_tuple[c][1],
        )
        labels_.append(idd)

        c+=1
        ax.set_xscale("log", base=10); ax.set_yscale("log", base=10)
    plt.legend(labels=[f"{l} ({len(OPS[l][op])}/{sum(OPS[l][op].values())})" for l in labels_])
    plt.show()

In [None]:
interact(plot_error_frequencies, op=["delete", "insert", "replace","replace_R","replace_H"])