# Levenshtein distance and WER/CER

(References)

- https://www.rev.ai/blog/how-to-calculate-word-error-rate/
- https://en.wikipedia.org/wiki/Levenshtein_distance

In [1]:
import typing as t
from collections import deque

import numpy as np


def levenshtein_distance(s: t.Sequence, t: t.Sequence):
    # for all i and j, d[i,j] will hold the Levenshtein distance between
    # the first i characters of s and the first j characters of t

    m = len(s) + 1
    n = len(t) + 1

    d = np.zeros((m, n))

    # source prefixes can be transformed into empty string by
    # dropping all characters
    for i in range(1, m):
        d[i, 0] = i

    # target prefixes can be reached from empty source prefix
    # by inserting every character
    for j in range(1, n):
        d[0, j] = j

    for j in range(1, n):
        for i in range(1, m):
            if s[i - 1] == t[j - 1]:
                substitution_cost = 0
            else:
                substitution_cost = 1

            d[i, j] = min(
                d[i - 1, j] + 1,  # deletion
                d[i, j - 1] + 1,  # insertion
                d[i - 1, j - 1] + substitution_cost,  # substitution
            )

    return d[m - 1, n - 1], d


def levenshtein_distance_details(s1, s2, d, separator=" "):
    i = d.shape[0] - 1
    j = d.shape[1] - 1

    result = deque()
    prev_label = None
    prev_tokens_same = deque()
    prev_tokens_deleted = deque()
    prev_tokens_inserted = deque()

    SAME = "SAME"
    DELETED = "DEL"
    INSERTED = "INS"
    SUBSTITUTED = "SUBS"

    def prepend_prev_label():
        if prev_label == SAME:
            result.appendleft((prev_label, separator.join(prev_tokens_same)))
            prev_tokens_same.clear()
        elif prev_label == SUBSTITUTED:
            result.appendleft(
                (
                    prev_label,
                    separator.join(prev_tokens_deleted),
                    separator.join(prev_tokens_inserted),
                )
            )
            prev_tokens_deleted.clear()
            prev_tokens_inserted.clear()
        elif prev_label == DELETED:
            result.appendleft((prev_label, separator.join(prev_tokens_deleted)))
            prev_tokens_deleted.clear()
        elif prev_label == INSERTED:
            result.appendleft((prev_label, separator.join(prev_tokens_inserted)))
            prev_tokens_inserted.clear()

    def check_and_process_label(current_label):
        if prev_label is None:
            return

        if prev_label != current_label:
            prepend_prev_label()

    while i > 0 or j > 0:
        v_current = d[i, j]
        v_deletion = d[i - 1, j]
        v_insertion = d[i, j - 1]
        v_diagonal = d[i - 1, j - 1]

        v_min = min(v_deletion, v_insertion, v_diagonal)

        if v_min == v_current:
            check_and_process_label(SAME)
            prev_tokens_same.appendleft(s1[i - 1])
            prev_label = SAME
            i -= 1
            j -= 1
        elif v_min == v_diagonal:
            check_and_process_label(SUBSTITUTED)
            prev_tokens_deleted.appendleft(s1[i - 1])
            prev_tokens_inserted.appendleft(s2[j - 1])
            prev_label = SUBSTITUTED
            i -= 1
            j -= 1
        elif v_min == v_deletion:
            check_and_process_label(DELETED)
            prev_tokens_deleted.appendleft(s1[i - 1])
            prev_label = DELETED
            i -= 1
        elif v_min == v_insertion:
            check_and_process_label(INSERTED)
            prev_tokens_inserted.appendleft(s2[j - 1])
            prev_label = INSERTED
            j -= 1
        else:
            raise ValueError("Found a case not expected within d")

    prepend_prev_label()

    return result

def WER(ref_sentence: str, hyp_sentence: str, requires_details: bool = False):
    sep = " "
    tokenize = lambda s: s.split(sep)
    ref_tokens = tokenize(ref_sentence)
    hyp_tokens = tokenize(hyp_sentence)
    val, d = levenshtein_distance(ref_tokens, hyp_tokens)
    return (
        val / len(ref_tokens),
        levenshtein_distance_details(ref_tokens, hyp_tokens, d, separator=sep)
        if requires_details
        else None,
    )

def CER(ref_word: str, hyp_word: str, requires_details: bool = False):
    tokenize = lambda s: s.replace(" ", "")
    ref_tokens = tokenize(ref_word)
    hyp_tokens = tokenize(hyp_word)
    val, d = levenshtein_distance(ref_tokens, hyp_tokens)
    return (
        val / len(ref_tokens),
        levenshtein_distance_details(ref_tokens, hyp_tokens, d, separator="")
        if requires_details
        else None,
    )

In [2]:
s1 = "We wanted people to know that we've got something brand new and essentially this product is uh what we call disruptive changes the way that people interact with technology"
s2 = "We wanted people to know that how to me where i know and essentially this product is what we call scripted changes the way people are rapid technology"

print(f"{WER(s1, s2, requires_details=True)=}\n")
print(f"{CER(s1, s2, requires_details=True)=}")

WER(s1, s2, requires_details=True)=(0.3793103448275862, deque([('SAME', 'We wanted people to know that'), ('INS', 'how'), ('SUBS', "we've got something brand new", 'to me where i know'), ('SAME', 'and essentially this product is'), ('DEL', 'uh'), ('SAME', 'what we call'), ('SUBS', 'disruptive', 'scripted'), ('SAME', 'changes the way'), ('DEL', 'that'), ('SAME', 'people'), ('SUBS', 'interact with', 'are rapid'), ('SAME', 'technology')]))

CER(s1, s2, requires_details=True)=(0.26573426573426573, deque([('SAME', 'Wewantedpeopletoknowthat'), ('DEL', "we've"), ('SUBS', 'g', 'h'), ('SAME', 'o'), ('SUBS', 'ts', 'wt'), ('SAME', 'ome'), ('SUBS', 't', 'w'), ('SAME', 'h'), ('DEL', 'ing'), ('SUBS', 'b', 'e'), ('SAME', 'r'), ('SUBS', 'and', 'eik'), ('SAME', 'n'), ('SUBS', 'e', 'o'), ('SAME', 'wandessentiallythisproductis'), ('DEL', 'uh'), ('SAME', 'whatwecall'), ('DEL', 'd'), ('SUBS', 'is', 'sc'), ('SAME', 'r'), ('SUBS', 'u', 'i'), ('SAME', 'pt'), ('DEL', 'i'), ('SUBS', 've', 'ed'), ('SAME', 'chang