In [1]:
%load_ext autoreload
%autoreload 2

In [32]:
from typing import List, Tuple, Set
import difflib
import Levenshtein
import re

In [168]:
test_pred = "The cute act eats delicate fi sh."
test_predicted = [test_pred]
test_tgt = "The cute cat eats delicious fish"
test_target = [test_tgt]
test_ipt = "Te cute cteats delicious fi sh."
test_inputs = [test_ipt]

In [169]:
misspelled = {0, 2, 3, 5}
restored = {0, 1, 3}
changed = {0, 2, 3}
correct = {0, 1}

In [176]:
def find_word_boundaries(s: str) -> List[Tuple[int, int]]:
    word_boundary_pattern = re.compile("\S+")
    matches = [(match.start(), match.end()) for match in word_boundary_pattern.finditer(s)]
    assert len(matches) == len(s.split())
    return matches

def is_transposition(edit_op, next_edit_op) -> bool:
    if next_edit_op is None:
        return False
    op_code, src_idx, tgt_idx = edit_op
    next_op_code, next_src_idx, next_tgt_idx = next_edit_op
    return op_code == "replace" and next_op_code == "replace" and next_src_idx - src_idx == 1 and next_tgt_idx - tgt_idx == 1

def get_edited_indices(ipt: str, tgt: str) -> Set[int]:
    assert tgt.strip() == tgt and ipt.strip() == ipt, "the two strings must not contain leading or trailing whitespaces"
    tgt_word_boundaries = find_word_boundaries(tgt)
    edit_ops = Levenshtein.editops(ipt, tgt)
    edited_tgt_indices = set()
    i = 0
    while i < len(edit_ops):
        edit_op = edit_ops[i]
        op_code, ipt_idx, tgt_idx = edit_op
        word_boundary_idx = 0
        while word_boundary_idx < len(tgt_word_boundaries):
            word_start, word_end = tgt_word_boundaries[word_boundary_idx]
            if tgt_idx <= word_end:
                break
            word_boundary_idx += 1
        next_edit_op = edit_ops[i + 1] if i < len(edit_ops) - 1 else None
        # check for transposition ourselves, since edit_ops only return replace, insert, delete
        if is_transposition(edit_op, next_edit_op):
            _, next_ipt_idx, next_tgt_idx = next_edit_op
            # transposition of a whitespace is same as deleting and inserting it elsewhere, two tokens will be affected
            if tgt[tgt_idx] == " " or tgt[next_tgt_idx] == " ":
                edited_tgt_indices.add(word_boundary_idx)
                edited_tgt_indices.add(word_boundary_idx + 1)
            else:
                edited_tgt_indices.add(word_boundary_idx)
            i += 2
        elif op_code == "insert" and tgt[tgt_idx] == " ":
            assert word_boundary_idx < len(tgt_word_boundaries) - 1
            edited_tgt_indices.add(word_boundary_idx)
            edited_tgt_indices.add(word_boundary_idx + 1)
            i += 1
        else:
            edited_tgt_indices.add(word_boundary_idx)
            i += 1
    return edited_tgt_indices

def match_tokens(ipt: str, tgt: str) -> Set[int]:
    sm = difflib.SequenceMatcher(a=ipt.split(), b=tgt.split())
    matching_blocks = sm.get_matching_blocks()
    matching_tgt_indices = set()
    for matching_block in matching_blocks:
        start = matching_block.b
        for idx in range(start, start + matching_block.size):
            matching_tgt_indices.add(idx)
    return matching_tgt_indices

edited_in_tgt = get_edited_indices(test_ipt, test_tgt)
assert misspelled == edited_in_tgt, (misspelled, edited_in_tgt)

edited_in_ipt = get_edited_indices(test_pred, test_ipt)
assert changed == edited_in_ipt, (changed, edited_in_ipt)

matching_in_tgt = match_tokens(test_pred, test_tgt)
assert restored == matching_in_tgt, (restored, matching_in_tgt)

In [173]:
sm = difflib.SequenceMatcher(a=test_pred.split(), b=test_tgt.split())
sm.get_matching_blocks()

[Match(a=0, b=0, size=2), Match(a=3, b=3, size=1), Match(a=7, b=6, size=0)]

In [9]:
def f1_prec_rec(predicted_sequences: List[str], target_sequences: List[str], input_sequences: List[str]) -> Tuple[float, float, float]:
    def _tp_fp_fn(pred: str, tgt: str, ipt: str) -> Tuple[int, int, int]:
        tp = fp = fn = 0
        return tp, fp, fn
        
    total_tp = total_fp = total_fn = 0
    for pred, tgt, ipt in zip(predicted_sequences, target_sequences, input_sequences):
        tp, fp, fn = _tp_fp_fn(pred, tgt, ipt)
        total_tp += tp
        total_fp += fp
        total_fn += fn
    
    return total_tp, total_fp, total_fn

In [10]:
f1_prec_rec(test_predicted, test_target, test_inputs)

(0, 0, 0)