# Evaluation of speech transcriptions

This code reproduce the analysis for paper __Anonymized__ 

In [None]:
import os
import re
import jiwer
import Levenshtein
from collections import Counter
from rich.jupyter import print
from rich.console import Console
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import numpy as np
## Visualización
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.font_manager as fm
import seaborn as sns

# ❶ Prepare _Reference_ transcriptions

In [None]:
# Prepare directory

DIR_TEST_FILES=os.path.join('..','data','test')

print(f"About to create directory: [green]{DIR_TEST_FILES}[/]")

try:
    os.makedirs('../data/test')
except FileExistsError:
    pass

## Dowload file with transcriptions

Open the following link: https://mega.nz/folder/0shhAaQT#KaoWJ7XOVjDu_2k_JYzyoA/file/FpZgDQqR

Download the file: _CIEMPIESS_TEST.trascription_ into the folder: __data/test__

In [None]:
# Rename file

if  os.path.exists(os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.transcriptions')):
    os.rename(os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.transcriptions'), 
              os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.trans'))
    print(f"{os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.transcriptions')} -> {os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.trans')}")
else:
    print(f"[red]File {os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.transcriptions')} not found[/]")
    
if not os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.trans'):
    print(f"[bold bright_red]✗ Transcription file not found: {os.path.join(DIR_TEST_FILES,'CIEMPIESS_TEST.trans')} it is not possible to continue [/]")
else:
    print(f"[bold bright_green]✓ Reference transcription file in place[/]")

# ❷ Checking Hypothesis transcriptions

In [None]:
# Regular expressions to match id and transcription
re_id_head=re.compile(r"^(?P<id>[^ ]+) (?P<trans>.*)") # id at the start
re_id_tail=re.compile(r"(?P<trans>.*) (?P<id>[^ ]+)$") # id at the final

In [None]:
# Set the files and directories

# Configure this
REF_FILE="../data/test/CIEMPIESS_TEST.trans"
HYP_DIR="../data/output/exp_november2022"

# ---
HYP_FILES={file:os.path.join(HYP_DIR,file) for file in os.listdir(HYP_DIR)}
print(f"[green bold]Files to analyse[/] in dir [{HYP_DIR}]:\n ","\n  ".join([file for file,_ in HYP_FILES.items()]))

In [None]:
# labels to filenames (inspect files from previous output)

# Configure this
labels=["Go.","QN","W2V","W tn.","W bs.", "W sm.", "W med.", "W lg.", "QN FT","W2V FT"] # We will follow this order
label2name={
  "W med.":"whisper_medium_ciempiess_test.trans",
  "W sm.":"whisper_small_ciempiess_test.trans",
  "W tn.":"whisper_tiny_ciempiess_test.trans",
  "Go.":"google_ciempiess_test.trans",
  "W2V FT":"wav2vec_carlos_ciempiess_test.trans",
  "W lg.":"whisper_large_ciempiess_test.trans",
  "W2V":"wav2vec_jonatasgrosman_ciempiess_test.trans",
  "QN FT":"nemo_carlos_ciempiess_test.trans",
  "QN":"nemo_nvidia_ciempiess_test.trans",
  "W bs.":"whisper_base_ciempiess_test.trans",
}

# ---
name2label={v:k for k,v in label2name.items()}

## Auxiliary functions

In [122]:
# Functions to openfiles
def open_head(filename):
    """Opens a reference file in which first token an id"""
    with open(filename) as f:
        lines=[(gs['id'],gs['trans']) for l in f.readlines() if (gs := re_id_head.match(l.strip()).groupdict())]
    return lines

def open_tail(filename):
    """Opens a hypothesis file in which last token an id"""
    with open(filename) as f:
        lines=[(gs['id'],gs['trans']) for l in f.readlines() if (gs := re_id_tail.match(l.strip()).groupdict())]
    return lines

def open_transcription_file(filename,format=None):
    """Opens a trascription file
    
    param filename: name of file
    param format: None, head or tail"""
    if format is None:
        with open(filename) as f:
            lines=[(None,l) for l in f.readlines()]
        return lines
    elif format == "head":
        return open_head(filename)
    elif format == "tail":
        return open_tail(filename)
    
def normalize_counts(most_common,voca,min_count=3):
    if  isinstance(most_common.most_common(1)[0][0],tuple):
        vals=[((k1,k2),c,c/min(voca.get(k1,1),voca.get(k2,1))) for (k1,k2),c in most_common.most_common() if c >= min_count]
    else:
        vals=[(k,c,c/voca.get(k,1)) for k,c in most_common.most_common() if c >= min_count]
    vals.sort(key=lambda k: k[2],reverse=True)
    return vals

# ❸ Check against reference transcriptions

In [123]:
# Checks number of trasncriptions
ref=open_transcription_file(REF_FILE,format="head")
print(f"[green]Reference number of lines:[/] [bold] {len(ref)} [/]")
ids=set( id for id,_ in ref)
print(f"[green]✓ [magenta]Reference [green]has different ids per transcription[/]")\
        if len(ref) == len(ids) else\
        print(f"[red]✗ {name} has wrong number of transcriptions[/]")

VOCA=Counter()
for _,l in ref:
    VOCA.update(l.split())

for idd in labels:
    name=label2name[idd]
    file=HYP_FILES[name]
    if idd in ["Go."]:
        hyp=open_transcription_file(file,format="tail")
        ids_=set( id[1:-1] for id,_ in hyp)
    else:
        hyp=open_transcription_file(file,format="head")
        ids_=set( id for id,_ in hyp)
    print(f"[green]✓ [magenta]{name} [green]has rigth number of transcriptions[/]")\
        if len(ref) == len(hyp) else\
        print(f"[red]✗ {name} has wrong number of transcriptions[/]")
    print(f"[green]✓ [magenta]{name} [green]has all ids[/]")\
        if len(ids.difference(ids_))==0 else\
        print(f"[red]✗ {name} does not have all ids[/]")

# ❹ evaluation with _jiwer_

In [124]:
# Copare reference with hypothesis files

ref=open_head(REF_FILE)
for idd in labels:
    name=label2name[idd]
    file=HYP_FILES[name]
    if idd in ["Go."]:
        hyp=open_tail(file)
    else:
        hyp=open_head(file)
    measures = jiwer.compute_measures([trans for _,trans in ref], [trans for _,trans in hyp])
    wer = measures['wer']
    cer = jiwer.cer([trans for _,trans in ref], [trans for _,trans in hyp])
    print(measures.keys())
    print(f"[magenta]{idd:6s} [{name[:-6]}][/]:")
    print(f"[green]  wer:[/] {wer*100:3.2f}")
    print(f"[green]  cer:[/] {cer*100:3.2f}")
    print(f"[green]  hits:[/] {measures['hits']:3d}")
    print(f"[green]  deletions:[/] {measures['deletions']:3d}")
    print(f"[green]  insertions:[/] {measures['insertions']:3d}")
    print(f"[green]  substitutions:[/] {measures['substitutions']:3d}")
    

## Errors analysis

In [125]:
ref=open_head(REF_FILE)
MIN_COUNT=5
PRINT_COUNT=30
OPS={idd:{} for idd in labels}

for idd in labels:
    ops={
        'delete':Counter(),
        'insert':Counter(),
        'replace':Counter(),
        'replace_R':Counter(),
        'replace_H':Counter(),
        }
    name=label2name[idd]
    file=HYP_FILES[name]
    if idd in ["Go."]:
        hyp=open_tail(file)
    else:
        hyp=open_head(file)
    for t1,t2 in zip([trans.split() for _,trans in hyp], [trans.split() for _,trans in ref]):
        editops = Levenshtein.opcodes(t1,t2)
        for op,hi,hf,ri,rf in editops:
            if op=="delete":
                ops[op].update(t1[hi:hf])
            if op=="insert":
                ops[op].update(t2[ri:rf])
            if op=="replace":
                ops[op].update((x,y) for x,y in zip(t1[hi:hf],t2[ri:rf]))
                ops['replace_H'].update(x for x,y in zip(t1[hi:hf],t2[ri:rf]))
                ops['replace_R'].update(y for x,y in zip(t1[hi:hf],t2[ri:rf]))
    print(f"[magenta]{idd:6s} [{name[:-6]}][/]:")
    print(f"[yellow]Delete[/]  ({sum(ops['delete'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{n:6.4f},{c}[/]" for k,c,n in normalize_counts(ops['delete'],VOCA,min_count=MIN_COUNT)[:PRINT_COUNT] ]))
    print(f"[yellow]Insert[/]  ({sum(ops['insert'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{n:6.4f},{c}[/]" for k,c,n in normalize_counts(ops['insert'],VOCA,min_count=MIN_COUNT)[:PRINT_COUNT]]))
    print(f"[yellow]Replace[/] ({sum(ops['replace'].values())}): "," ".join([ f"[white bold]{k[0]}->{k[1]}[/]:[cyan]{n:6.4f},{c}[/]" for k,c,n in normalize_counts(ops['replace'],VOCA,min_count=MIN_COUNT)[:PRINT_COUNT]]))
    OPS[idd]=ops
    print(f"[yellow]Replace_H[/]  ({sum(ops['replace_H'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{n:6.4f},{c}[/]" for k,c,n in normalize_counts(ops['replace_R'],VOCA,min_count=MIN_COUNT)[:PRINT_COUNT] ]))
    print(f"[yellow]Replace_I[/]  ({sum(ops['replace_R'].values())}): "," ".join([ f"[white bold]{k}[/]:[cyan]{n:6.4f},{c}[/]" for k,c,n in normalize_counts(ops['replace_H'],VOCA,min_count=MIN_COUNT)[:PRINT_COUNT]]))

## Plot requency of error

In [126]:
#@markdown Control size of figure
linestyle_tuple = [
     ('loosely dotted',        (0, (1, 3))),
     ('densely dotted',        (0, (1, 1))),
     ('long dash with offset', (5, (10, 3))),
     ('loosely dashed',        (0, (5, 10))),
     ('dashed',                (0, (5, 5))),
     ('densely dashed',        (0, (5, 1))),

     ('loosely dashdotted',    (0, (3, 10, 1, 10))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     #('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     #('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]

def plot_error_frequencies(op,x_sizefigure=10,y_sizefigure=4):
    plt.figure(figsize=(int(x_sizefigure),int(y_sizefigure)))
    fig, ax = plt.subplots()
    plt.title(f"Error frequencies in '{op}' edit operation")
    plt.ylabel("Word error occurrences (log base 10)")
    plt.xlabel("Word error rank (log base 10)")

    c=0
    col=0
    labels_=[]
    for idd,ops in OPS.items():
        
        if idd in []:#"Go.","QN","W tn.","W bs.", "W sm.", "W med."]:
            col+=1
            continue
        index=range(len(ops[op]))
        freq=[f for k,f in ops[op].most_common()]

        ax.loglog(
            index,
            freq,
            base=10,
            color=sns.color_palette('colorblind')[col],
            linewidth=1.5+(0.3)*c,
            linestyle=linestyle_tuple[c][1],
        )
        labels_.append(idd)

        c+=1
        col+=1
        ax.set_xscale("log", base=10); ax.set_yscale("log", base=10)
    plt.legend(labels=[f"{l} ({len(OPS[l][op])}/{sum(OPS[l][op].values())})" for l in labels_])
    plt.show()

In [127]:
interact(plot_error_frequencies, op=["delete", "insert", "replace","replace_R","replace_H"])

In [132]:
def plot_normalized_frequencies(op,x_sizefigure=10,y_sizefigure=4):
    plt.figure(figsize=(int(x_sizefigure),int(y_sizefigure)))
    fig, ax = plt.subplots()
    plt.title(f"Error normalize frequencies in '{op}' edit operation")
    plt.ylabel("Word normalized error (log base 10)")
    plt.xlabel("Word error rank (log base 10)")

    c=0
    col=0
    labels_=[]
    for idd,ops in OPS.items():
        
        if idd in []:#"Go.","QN","W tn.","W bs.", "W sm.", "W med."]:
            col+=1
            continue
       
        freq=[n for k,c,n in normalize_counts(ops[op],VOCA,min_count=0)]
        index=range(len(freq))
        ax.loglog(
            index,
            freq,
            base=10,
            color=sns.color_palette('colorblind')[col],
            linewidth=1.5+(0.3)*c,
            linestyle=linestyle_tuple[c][1],
        )
        labels_.append(idd)

        c+=1
        col+=1
        ax.set_xscale("log", base=10) 
        ax.set_yscale("log", base=10)
    plt.legend(labels=[f"{l} ({len(normalize_counts(OPS[l][op],VOCA,min_count=0))}/{sum(OPS[l][op].values())})" for l in labels_])
    plt.show()

In [133]:
interact(plot_normalized_frequencies, op=["delete", "insert", "replace","replace_R","replace_H"])

## Replace analysis

In [134]:
fig, ax = plt.subplots(1,10,figsize=(15,2))
fig.suptitle("Number of edits in replacement operations")
#fig.supxlabel("Number of edits")
labels_=[]
c=0
for idd,ops in OPS.items():
    if idd in []:
            continue
    hist=[]
    for (r,h),f in ops['replace'].most_common():
        editops = Levenshtein.editops(r,h)
        hist.append(len(editops),)
    
    labels_.append(idd)
    
    ax[c].hist(hist,color=sns.color_palette('colorblind')[c])
    ax[c].set_xlim(right=20,left=1)
    ax[c].set_ylim(top=8000)
    ax[c].set_xlabel(idd)
    ax[c].text(10,7000,f"{sum(hist):,d}")

    if c!=0:
        ax[c].set_yticks([])
    
    c+=1

#fig.legend(labels=labels_,loc=7)
plt.show()   

## lenght of words in errors analysis

In [135]:
def plot_histograms_lenghts(op,top=7000,right=20,x_sizefigure=15,y_sizefigure=2):
    fig, ax = plt.subplots(1,10,figsize=(x_sizefigure,y_sizefigure))
    fig.suptitle(f"Lenght of words in '{op}' operation")
    #fig.supxlabel("Number of edits")
    labels_=[]
    c=0
    max_=0
    max_length=0
    for idd,ops in OPS.items():
        hist=[]
        for w,f in ops[op].most_common():
            hist.append(len(w))

        counts, bins = np.histogram(hist,bins=max(hist))
        max_=max(max(counts),max_)
        max_length=max(max(hist),max_length)
      
    
    for idd,ops in OPS.items():
        if idd in []:
                continue
        hist=[]
        for w,f in ops[op].most_common():
            hist.append(len(w))

        labels_.append(idd)

        counts, edges, bars = ax[c].hist(hist,color=sns.color_palette('colorblind')[c],bins=max_length)
        ax[c].set_xlim(left=1,right=max_length)
        ax[c].set_ylim(top=int(max_*1.1))
        ax[c].set_xlabel(idd)
        
        #ax[c].bar_label(bars)
        ax[c].text(11,max_,f"Md: {int(np.median(hist)):d}")

        if c!=0:
            ax[c].set_yticks([])

        c+=1

    #fig.legend(labels=labels_,loc=7)
    plt.show()       
    
    

In [136]:
interact(plot_histograms_lenghts, op=["delete", "insert","replace_R","replace_H"])