<a href="https://colab.research.google.com/github/duckyngo/Word-Error-Rate-Visualization-with-Colab/blob/main/Manifest_STT_Calculate_WER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import json
from IPython.display import HTML, display
try:
  import jiwer
except:
  !pip install jiwer
  import jiwer


def wer(ref, hyp ,debug=False):
    r = ref.split()
    h = hyp.split()
    #costs will holds the costs, like in the Levenshtein distance algorithm
    costs = [[0 for inner in range(len(h)+1)] for outer in range(len(r)+1)]
    # backtrace will hold the operations we've done.
    # so we could later backtrace, like the WER algorithm requires us to.
    backtrace = [[0 for inner in range(len(h)+1)] for outer in range(len(r)+1)]

    OP_OK = 0
    OP_SUB = 1
    OP_INS = 2
    OP_DEL = 3

    DEL_PENALTY=1 # Tact
    INS_PENALTY=1 # Tact
    SUB_PENALTY=1 # Tact
    # First column represents the case where we achieve zero
    # hypothesis words by deleting all reference words.
    for i in range(1, len(r)+1):
        costs[i][0] = DEL_PENALTY*i
        backtrace[i][0] = OP_DEL

    # First row represents the case where we achieve the hypothesis
    # by inserting all hypothesis words into a zero-length reference.
    for j in range(1, len(h) + 1):
        costs[0][j] = INS_PENALTY * j
        backtrace[0][j] = OP_INS

    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                costs[i][j] = costs[i-1][j-1]
                backtrace[i][j] = OP_OK
            else:
                substitutionCost = costs[i-1][j-1] + SUB_PENALTY # penalty is always 1
                insertionCost    = costs[i][j-1] + INS_PENALTY   # penalty is always 1
                deletionCost     = costs[i-1][j] + DEL_PENALTY   # penalty is always 1

                costs[i][j] = min(substitutionCost, insertionCost, deletionCost)
                if costs[i][j] == substitutionCost:
                    backtrace[i][j] = OP_SUB
                elif costs[i][j] == insertionCost:
                    backtrace[i][j] = OP_INS
                else:
                    backtrace[i][j] = OP_DEL

    # back trace though the best route:
    i = len(r)
    j = len(h)
    numSub = 0
    numDel = 0
    numIns = 0
    numCor = 0
    if debug:
        lines = []
        compares = []
    while i > 0 or j > 0:
        if backtrace[i][j] == OP_OK:
            numCor += 1
            i-=1
            j-=1
            if debug:
                lines.append("OK\t" + r[i]+"\t"+h[j])
                compares.append(colored(0, 0, 0, h[j]))
        elif backtrace[i][j] == OP_SUB:
            numSub +=1
            i-=1
            j-=1
            if debug:
                lines.append("SUB\t" + r[i]+"\t"+h[j])
                compares.append(colored(0, 255, 0, h[j]) +  colored(0, 0, 0, f'({r[i]})'))
        elif backtrace[i][j] == OP_INS:
            numIns += 1
            j-=1
            if debug:
                lines.append("INS\t" + "****" + "\t" + h[j])
                compares.append(colored(0, 0, 255, h[j]))
        elif backtrace[i][j] == OP_DEL:
            numDel += 1
            i-=1
            if debug:
                lines.append("DEL\t" + r[i]+"\t"+"****")
                compares.append(colored(255, 0, 0, r[i]))
    if debug:
        # print("OP\tREF\tHYP")
        # lines = reversed(lines)
        # for line in lines:
        #     print(line)

        compares = reversed(compares)
        for line in compares:
          print(line, end=" ")
        # print("Ncor " + str(numCor))
        # print("Nsub " + str(numSub))
        # print("Ndel " + str(numDel))
        # print("Nins " + str(numIns))
    wer_result = round( (numSub + numDel + numIns) / (float) (len(r)), 3)
    return {'WER':wer_result, 'Cor':numCor, 'Sub':numSub, 'Ins':numIns, 'Del':numDel}, compares


def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))

# Turn on line wrapping on Colab Ref: https://github.com/jupyter/notebook/issues/6274
get_ipython().events.register('pre_run_cell', set_css)


def colored(r, g, b, text):
    return "\033[38;2;{};{};{}m{} \033[38;2;255;255;255m".format(r, g, b, text)

def strike(text, color=None):
    if color:
      return colored(0, 255, 0, ''.join([u'\u0336{}'.format(c) for c in text]))
      
    else:
      return  colored(0, 0, 0, ''.join([u'\u0332{}'.format(c) for c in text]))

#@title Calculate and visualize WER or CER { run: "auto" }

rm_punctuation = True #@param {type:"boolean"}
input_ref='\uB2EC\uBCF4\uB2E4 \uC774 \uC810 \uCE60 \uD37C\uC13C\uD2B8 \uD3EC\uC778\uD2B8 \uB0AE\uC544 \uC84C\uC2B5\uB2C8\uB2E4 \uD314\uC2ED \uD37C\uC13C\uD2B8\uB97C \uB118\uC5B4 \uD3EC\uD654 \uC0C1\uD0DC\uC600\uB358 \uC218\uB3C4\uAD8C \uC911\uC9C4 \uBCD1\uC0C1 \uAC00\uB3D9\uB960\uB3C4 \uC721\uC2ED\uC774 \uC810 \uC0BC \uD37C\uC13C\uD2B8\uB85C \uB5A8\uC5B4\uC84C\uC2B5\uB2C8\uB2E4 \uC9C0\uB09C\uB2EC \uC0BC\uC2ED\uC77C \uCE60\uC2ED \uD37C\uC13C\uD2B8 \uC544\uB798\uB85C \uB5A8\uC5B4\uC9C4 \uB4A4 \uB098\uD758 \uC5F0\uC18D \uD558\uB77D \uC149\uB2C8\uB2E4 \uC9C0\uB09C\uB2EC \uC774\uC2ED \uAD6C\uC77C \uD558\uB8E8\uC5D0\uB9CC \uC624\uBC31 \uC5EC\uB4E0 \uBA85\uC0C1\uC744 \uD655\uCDA9\uD558\uB294 \uB4F1 \uBCD1\uC0C1 \uBD80\uC871 \uC0C1\uD669\uC774 \uB098\uC544\uC9C0\uBA74\uC11C \uC804\uAD6D\uC5D0\uC11C \uD558\uB8E8 \uB118\uAC8C \uC785\uC6D0\uC744 \uAE30\uB2E4\uB9AC\uB294 \uC0AC\uB78C\uC740 \uB2F7\uC0C8\uC9F8 \uD55C \uBA85\uB3C4 \uC5C6\uC5C8\uC2B5\uB2C8\uB2E4 \uBC29\uC5ED \uB2F9\uAD6D\uC740 \uC774\uB2EC \uB9D0\uAE4C\uC9C0 \uC911\uC99D \uD658\uC790 \uBCD1\uC0C1 \uCC9C \uC624\uBC31 \uC77C\uD754 \uC5EC\uB35F\uAC1C\uB97C \uBE44\uB86F\uD574 \uC721\uCC9C \uAD6C\uBC31 \uB9C8\uD754 \uB124\uAC1C\uC758 \uC785\uC6D0 \uBCD1\uC0C1\uC744 \uD655\uCDA9\uD574 \uD558\uB8E8 \uD655\uC9C4 \uB9CC \uBA85\uC5D0\uB3C4 \uB300\uC751\uD558\uACA0\uB2E4\uB294 \uBAA9\uD45C\uC785\uB2C8\uB2E4' #@param {type:"string"}
input_hyp='\uB2EC\uBCF4\uB2E4 \uC774 \uC810 \uCE60 \uD37C\uC13C\uD2B8 \uD3EC\uC778\uD2B8 \uB0AE\uC544\uC84C\uC2B5\uB2C8\uB2E4 \uD314\uC2ED \uD37C\uC13C\uD2B8\uB97C \uB118\uC5B4 \uD3EC\uD55C \uC0C1\uD0DC\uC600\uB358 \uC218\uB3C4\uAD8C \uC911\uC99D\uBCD1\uC0C1 \uAC00\uB3C5\uB960\uB3C4 \uC721\uC2ED \uC774 \uC810 \uC0BC \uD37C\uC13C\uD2B8\uB85C \uB5A8\uC5B4\uC84C\uC2B5\uB2C8\uB2E4 \uC9C0\uB09C \uB2EC \uC0BC\uC2ED \uC77C \uCE60\uC2ED \uD37C\uC13C\uD2B8 \uC544\uB798\uB85C \uB5A8\uC5B4\uC9C4 \uB4A4 \uB098\uD750\uC5F0 \uC18D \uD558\uB77D\uD230\uB2C8\uB2E4 \uC9C0\uB09C \uB2EC \uC774\uC2ED \uAD6C \uC77C \uD558\uB8E8\uC5D0\uB9CC \uC624\uBC31 \uC5EC\uB4E0 \uBCD1\uC0C1\uC744 \uD655\uCDA9\uD558\uB294 \uB4F1 \uBCD1\uC0C1 \uBD80\uC871 \uC0C1\uD669\uC774 \uB098\uC544\uC9C0\uBA74\uC11C \uC804\uAD6D\uC5D0\uC11C \uD558\uB8E8 \uB118\uAC8C \uC774 \uBC88\uC744 \uAE30\uB2E4\uB9AC\uB294 \uC0AC\uB78C\uC740 \uB2E4 \uC14B\uC9F8 \uD55C \uBA85\uB3C4 \uC5C6\uC5C8\uC2B5\uB2C8\uB2E4 \uBC29\uC5ED \uB2E8\uAD6D\uC740 \uC774 \uB2EC \uB9D0\uAE4C\uC9C0 \uC911\uC99D \uD658\uC790 \uBCD1\uC0C1 \uCC9C \uC624\uBC31 \uC77C\uD754 \uC5EC\uB35F \uAC1C\uB97C \uBE44\uC211\uD574 \uC721\uCC9C \uAD6C\uBC31 \uB9C8\uD754 \uB124 \uAC1C\uC5D0 \uC774\uBC88 \uBCD1\uC0C1\uC744 \uD655\uC911\uD574 \uD558\uB8E8 \uD655\uC9C4 \uB9CC\uBA85\uD574\uB3C4 \uB300\uC751\uD558\uACA0\uB2E4\uB294 \uBAA9\uD45C\uB2C8\uB2E4' #@param {type:"string"}
input_json = "" #@param {type:"string"}

if input_json and input_json != "":
  json_data = json.loads(input_json)
  input_ref = json_data['text']
  input_hyp = json_data['pred_text']


if rm_punctuation == True:
    ref = jiwer.RemovePunctuation()(input_ref)
    hyp = jiwer.RemovePunctuation()(input_hyp)
else:
    ref = input_ref
    hyp = input_hyp

print(f"REF: {ref}\n")
print(f"HYP: {hyp}")
print('-'* 30)

output, compares = wer(ref, hyp ,debug=True)

print()
print(colored(0, 0, 0,   f"N CORRECT   : {output['Cor']}"))
print(colored(255, 0, 0, f"N DELETE    : {output['Del']}"))
print(colored(0, 255, 0, f"N SUBSTITUTE: {output['Sub']}"))
print(colored(0, 0, 255, f"N INSERT    : {output['Ins']}"))
print(colored(0, 0, 0, f"WER: {output['WER']}"))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jiwer
  Downloading jiwer-2.3.0-py3-none-any.whl (15 kB)
Collecting python-Levenshtein==0.12.2
  Downloading python-Levenshtein-0.12.2.tar.gz (50 kB)
[K     |████████████████████████████████| 50 kB 2.9 MB/s 
Building wheels for collected packages: python-Levenshtein
  Building wheel for python-Levenshtein (setup.py) ... [?25l[?25hdone
  Created wheel for python-Levenshtein: filename=python_Levenshtein-0.12.2-cp37-cp37m-linux_x86_64.whl size=149859 sha256=299e682680ddb5a13711436a0333496e8e9643e94a570a5794b69c4525742098
  Stored in directory: /root/.cache/pip/wheels/05/5f/ca/7c4367734892581bb5ff896f15027a932c551080b2abd3e00d
Successfully built python-Levenshtein
Installing collected packages: python-Levenshtein, jiwer
Successfully installed jiwer-2.3.0 python-Levenshtein-0.12.2
REF: 달보다 이 점 칠 퍼센트 포인트 낮아 졌습니다 팔십 퍼센트를 넘어 포화 상태였던 수도권 중진 병상 가동률도 육십이 점 삼 퍼센트로 떨어졌습니다 지난달 삼십일 칠십 퍼센트 

In [None]:
import json
try:
  import editdistance
except:
  !pip install editdistance
  import editdistance
from typing import List

def word_error_rate(hypotheses: List[str], references: List[str], use_cer=False) -> float:
    """
    Computes Average Word Error rate between two texts represented as
    corresponding lists of string. Hypotheses and references must have same
    length.
    Args:
      hypotheses: list of hypotheses
      references: list of references
      use_cer: bool, set True to enable cer
    Returns:
      (float) average word error rate
    """
    scores = 0
    words = 0
    if len(hypotheses) != len(references):
        raise ValueError(
            "In word error rate calculation, hypotheses and reference"
            " lists must have the same number of elements. But I got:"
            "{0} and {1} correspondingly".format(len(hypotheses), len(references))
        )
    for h, r in zip(hypotheses, references):
        if use_cer:
            h_list = list(h)
            r_list = list(r)
        else:
            h_list = h.split()
            r_list = r.split()
        words += len(r_list)
        scores += editdistance.eval(h_list, r_list)
    if words != 0:
        wer = 1.0 * scores / words
    else:
        wer = float('inf')
    return wer


def move_dimension_to_the_front(tensor, dim_index):
    all_dims = list(range(tensor.ndim))
    return tensor.permute(*([dim_index] + all_dims[:dim_index] + all_dims[dim_index + 1 :]))



#@title Calculate WER or CER for output transcript file { run: "auto" }
use_cer = False #@param {type:"boolean"}
manifest_path = "/content/1_kspon_eval_clean_output_character_20_202_ksponspeech.json" #@param {type: "string"}


ground_truth_text = []
predicted_text = []
invalid_manifest = False

ground_truth_text = []
predicted_text = []
invalid_manifest = False
with open(manifest_path, 'r', encoding='utf-8') as f:
  for line in f:
      data = json.loads(line)

      if 'pred_text' not in data:
          invalid_manifest = True
          break
      ground_truth_text.append(data['text'])
      predicted_text.append(data['pred_text'])

# Test for invalid manifest supplied
if invalid_manifest:
    raise ValueError(
        f"Invalid manifest provided: {manifest_path} does not "
        f"contain value for `pred_text`."
    )



metric_name = 'CER' if use_cer else 'WER'
metric_value = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=use_cer)
print(metric_value)

0.1945144407830975


In [None]:
use_cer = True
metric_name = 'CER' if use_cer else 'WER'
metric_value = word_error_rate(hypotheses=predicted_text, references=ground_truth_text, use_cer=use_cer)
print("CER: ",metric_value)

0.06909035362407394


In [None]:
del_count = 0
sub_count = 0
insert_count = 0
total_wer = 0

for idx, (ref, hyp) in enumerate(zip(ground_truth_text[500: 900], predicted_text[500: 900])):
  print(colored(0, 0, 0, f"{idx} ---------------"))
  print(colored(0, 0, 0, f"REF: {ref}"))
  print(colored(0, 0, 0, f"HYP: {hyp}"))
  output, compares = wer(ref, hyp ,debug=True)
  print("")
  metric_value = word_error_rate(hypotheses=[hyp], references=[ref], use_cer=True)
  print(colored(0, 0, 0, f"WER: {output['WER']}  | CER: {metric_value}"))
  print()
  
  del_count += output['Cor']
  sub_count += output['Sub']
  insert_count += output['Ins']
  total_wer += output['WER']

print(colored(255, 0, 0, f"N DELETE    : {del_count}"))
print(colored(0, 255, 0, f"N SUBSTITUTE: {sub_count}"))
print(colored(0, 0, 255, f"N INSERT    : {insert_count}"))
print(colored(0, 0, 0, f"WER: {total_wer / 500}"))

[38;2;0;0;0m0 --------------- [38;2;255;255;255m
[38;2;0;0;0mREF: 얼 얼마 주는데 [38;2;255;255;255m
[38;2;0;0;0mHYP: 얼 얼마 줬는데 [38;2;255;255;255m
[38;2;0;0;0m얼 [38;2;255;255;255m [38;2;0;0;0m얼마 [38;2;255;255;255m [38;2;0;255;0m줬는데 [38;2;255;255;255m[38;2;0;0;0m(주는데) [38;2;255;255;255m 
[38;2;0;0;0mWER: 0.333  | CER: 0.125 [38;2;255;255;255m

[38;2;0;0;0m1 --------------- [38;2;255;255;255m
[38;2;0;0;0mREF: 이제 그냥 졸업 [38;2;255;255;255m
[38;2;0;0;0mHYP: 이제 그냥 졸업 [38;2;255;255;255m
[38;2;0;0;0m이제 [38;2;255;255;255m [38;2;0;0;0m그냥 [38;2;255;255;255m [38;2;0;0;0m졸업 [38;2;255;255;255m 
[38;2;0;0;0mWER: 0.0  | CER: 0.0 [38;2;255;255;255m

[38;2;0;0;0m2 --------------- [38;2;255;255;255m
[38;2;0;0;0mREF: 글썽거리면서 봤었지 [38;2;255;255;255m
[38;2;0;0;0mHYP: 글썽 거리면서 봤었지 [38;2;255;255;255m
[38;2;0;0;255m글썽 [38;2;255;255;255m [38;2;0;255;0m거리면서 [38;2;255;255;255m[38;2;0;0;0m(글썽거리면서) [38;2;255;255;255m [38;2;0;0;0m봤었지 [38;2;255;255;255m 
[38;2;0;0;0mWER: 1.0  | CER: 0