In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import List, Tuple, Set, Union, Dict
import difflib
import Levenshtein
import re
import pprint
import copy

In [3]:
def match_sequences(a: str, b: str) -> List[Tuple[int, int]]:
    a_words = a.split()
    b_words = b.split()
    m = [
        [0 for _ in range(len(b_words) + 1)]
        for i in range(len(a_words) + 1)
    ]
    ops = [
        [" " for _ in range(len(b_words) + 1)]
        for i in range(len(a_words) + 1)
    ]
    for j in range(1, len(b_words) + 1):
        ops[0][j] = "i"
    for i in range(1, len(a_words) + 1):
        ops[i][0] = "d"
    
    for i in range(1, len(a_words) + 1):
        for j in range(1, len(b_words) + 1):
            i_w = i - 1
            j_w = j - 1
            
            values = [(m[i-1][j], "d"), (m[i][j-1], "i"), (m[i-1][j-1] + int(a_words[i_w] == b_words[j_w]), "m" if a_words[i_w] == b_words[j_w] else "nm")]

            max_value, max_op = max(values, key=lambda item: item[0])
            m[i][j] = max_value
            ops[i][j] = max_op
            
    # backtrace
    matches = []
    i = len(a_words)
    j = len(b_words)
    while i > 0 or j > 0:
        op = ops[i][j]
        if op == "d":
            i -= 1
        elif op == "i":
            j -= 1
        elif op == "m":
            i -= 1
            j -= 1
            matches.append((i, j))
        elif op == "nm":
            i -= 1
            j -= 1
        else:
            raise RuntimeError("should not happen")
            
    return list(reversed(matches))

In [4]:
%timeit [match_sequences("this test is a", "this is a test") for _ in range(10000)]

329 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
import edit_distance_rs

In [6]:
%timeit edit_distance_rs.batch_match_words(["this test is a" for _ in range(10000)], ["this is a test" for _ in range(10000)], 256)

11.3 ms ± 892 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
def edit_operations(
    a: str, 
    b: str, 
    with_swap: bool = False, 
    spaces_insert_delete_only: bool = False, 
    return_distance_only: bool = False
) -> Union[int, List[Tuple[str, int, int]]]:
    """
    Returns the edit operations transforming a into b, optionally allowing only insertion and deletion operations for spaces.
    Follows optimal string alignment distance at https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance.
    
    """
    d = [
        list(range(len(b) + 1)) if i == 0 else [i if j == 0 else -1 for j in range(len(b) + 1)]
        for i in range(len(a) + 1)
    ]
    # operations: ' ' --> not yet filled in, k --> keep, i --> insert, d --> delete, r --> replace, s --> swap
    ops = [
        ["i"] * (len(b) + 1) if i == 0 else ["d" if j == 0 else " " for j in range(len(b) + 1)]
        for i in range(len(a) + 1)
    ]
    ops[0][0] = "k"
    
    # fill in matrices
    for i in range(1, len(a) + 1):
        for j in range(1, len(b) + 1):
            # string indices are offset by -1
            i_str = i - 1
            j_str = j - 1
            
            # delete and insert
            costs = [(d[i-1][j] + 1, "d"), (d[i][j-1] + 1, "i")]
            if a[i_str] == b[j_str]:
                costs.append((d[i-1][j-1], "k"))
            else:
                # chars are not equal, only allow replacement if no space is involved or we are allowed to replace spaces
                if (
                    not spaces_insert_delete_only 
                    or (a[i_str] != " " and b[j_str] != " ")
                ):
                    costs.append((d[i-1][j-1] + 1, "r"))
            # check if we can swap chars, that is if we are allowed to swap and if the chars to swap match
            if with_swap and i > 1 and j > 1 and a[i_str] == b[j_str-1] and a[i_str-1] == b[j_str]:
                # we can swap the chars, but only allow swapping if no space is involved or we are allowed to swap spaces
                if (
                    not spaces_insert_delete_only 
                    or (a[i_str] != " " and a[i_str - 1] != " ")
                ):
                    costs.append((d[i-2][j-2] + 1, "s"))
                
            min_cost, min_op = min(costs, key=lambda item: item[0])
            d[i][j] = min_cost
            ops[i][j] = min_op
    
    # make sure that it worked
    assert all(v >= 0 for row in d for v in row)
    
    if return_distance_only:
        return d[-1][-1]
    
    # backtrace matrices
    edit_ops = []
    i = len(a)
    j = len(b)
    while i > 0 or j > 0:
        op = ops[i][j]
        if op == "k":
            # we do not add keep operation to edit_ops
            i -= 1
            j -= 1
            continue
            
        if op == "d":
            op_name = "delete"
            i -= 1
        elif op == "i":
            op_name = "insert"
            j -= 1
        elif op == "r":
            op_name = "replace"
            i -= 1
            j -= 1
        elif op == "s":
            op_name = "swap"
            i -= 2
            j -= 2
        else:
            raise RuntimeError("should not happen")
            
        edit_ops.append((op_name, i, j))
    
    return list(reversed(edit_ops))

def edit_distance(
    a: str, 
    b: str, 
    with_swap: bool = True, 
    spaces_insert_delete_only: bool = False
) -> int:
    return edit_operations(a, b, with_swap, spaces_insert_delete_only, return_distance_only=True)

In [27]:
def find_word_boundaries(s: str) -> List[Tuple[int, int]]:
    words = s.split()
    word_boundaries = []
    start_idx = 0
    for word in words:
        word_boundaries.append((start_idx, start_idx + len(word)))
        start_idx += len(word) + 1
    return word_boundaries

def get_edited_words2(ipt: str, tgt: str) -> Set[int]:
    edit_ops = edit_operations(ipt, tgt, spaces_insert_delete_only=True)
    ipt_wb = find_word_boundaries(ipt)
    tgt_wb = find_word_boundaries(tgt)
    edited_ipt_indices = set()
    edited_tgt_indices = set()
    for op_code, ipt_idx, tgt_idx in edit_ops:
        word_idx = 0
        for wb_s, wb_e in ipt_wb:
            if ipt_idx >= wb_s and ipt_idx < wb_e:
                edited_ipt_indices.add(word_idx)
                break
            elif ipt_idx == wb_e:
                edited_ipt_indices.add(word_idx)
                edited_ipt_indices.add(word_idx+1)
                break
            word_idx += 1
        assert word_idx < len(ipt_wb)
        tgt_word_idx = 0
        for wb_s, wb_e in tgt_wb:
            if wb_s <= tgt_idx < wb_e:
                edited_tgt_indices.add(tgt_word_idx)
                break
            elif tgt_idx == wb_e:
                assert op_code == "delete" or op_code == "insert"
                if op_code == "delete":
                    edited_tgt_indices.add(tgt_word_idx)
                else:
                    assert tgt[tgt_idx] == " "
                    edited_tgt_indices.add(tgt_word_idx)
                    edited_tgt_indices.add(tgt_word_idx + 1)
                break
            tgt_word_idx += 1
        assert tgt_word_idx < len(tgt_wb)
    return edited_ipt_indices, edited_tgt_indices
            

def get_edited_words(ipt: str, tgt: str) -> Set[int]:
    assert tgt.strip() == tgt and ipt.strip() == ipt, "the two strings must not contain leading or trailing whitespaces"
    tgt_word_boundaries = find_word_boundaries(tgt)
    edit_ops = edit_operations(ipt, tgt, spaces_insert_delete_only=True)
    edited_tgt_indices = set()
    for op_code, ipt_idx, tgt_idx in edit_ops:
        word_boundary_idx = 0
        while word_boundary_idx < len(tgt_word_boundaries):
            word_start, word_end = tgt_word_boundaries[word_boundary_idx]
            if tgt_idx <= word_end:
                break
            word_boundary_idx += 1
            
        if op_code == "insert" and tgt[tgt_idx] == " ":
            assert word_boundary_idx < len(tgt_word_boundaries) - 1
            edited_tgt_indices.add(word_boundary_idx)
            edited_tgt_indices.add(word_boundary_idx + 1)
        else:
            edited_tgt_indices.add(word_boundary_idx)
            
    return edited_tgt_indices

def match_words(pred: str, tgt: str) -> Tuple[Set[int], Set[int]]:
    sm = difflib.SequenceMatcher(a=pred.split(), b=tgt.split())
    matching_blocks = sm.get_matching_blocks()
    matching_pred_indices = set()
    matching_tgt_indices = set()
    for matching_block in matching_blocks:
        start_pred = matching_block.a
        for idx in range(start_pred, start_pred + matching_block.size):
            matching_pred_indices.add(idx)
        start_tgt = matching_block.b
        for idx in range(start_tgt, start_tgt + matching_block.size):
            matching_tgt_indices.add(idx)
    return matching_pred_indices, matching_tgt_indices

def group_words(
    ipt: str, 
    pred: str,
    matching_in_pred: Set[int]
) -> Set[int]:
    assert pred.strip() == pred and ipt.strip() == ipt, "the two strings must not contain leading or trailing whitespaces"
    edit_ops = edit_operations(ipt, pred, spaces_insert_delete_only=True)
    ipt_word_boundaries = find_word_boundaries(ipt)
    merged_with_next_indices = set()
    num_spaces_inserted = {}
    for op_code, ipt_idx, pred_idx in edit_ops:
        word_boundary_idx = 0
        while word_boundary_idx < len(ipt_word_boundaries):
            word_start, word_end = ipt_word_boundaries[word_boundary_idx]
            if ipt_idx <= word_end:
                break
            word_boundary_idx += 1
            
        if op_code == "delete" and ipt[ipt_idx] == " ":
            merged_with_next_indices.add(word_boundary_idx)
        
        if op_code == "insert" and pred[pred_idx] == " ":
            if word_boundary_idx not in num_spaces_inserted:
                num_spaces_inserted[word_boundary_idx] = 1
            else:
                num_spaces_inserted[word_boundary_idx] += 1
    
    correct = set()
    ipt_idx = 0
    pred_idx = 0
    while ipt_idx < len(ipt_word_boundaries):
        merged_word = {ipt_idx}
        total_spaces_inserted = num_spaces_inserted.get(ipt_idx, 0)
        while ipt_idx in merged_with_next_indices:
            ipt_idx += 1
            merged_word.add(ipt_idx)
            total_spaces_inserted += num_spaces_inserted.get(ipt_idx, 0)
            
        # find corresponding words for merged word in pred
        if all(idx in matching_in_pred for idx in range(pred_idx, pred_idx + total_spaces_inserted + 1)):
            correct = correct.union(merged_word)
            
        ipt_idx += 1
        pred_idx += total_spaces_inserted + 1
        
    assert ipt_idx == len(ipt_word_boundaries) and pred_idx == len(pred.split())
    return correct

In [10]:
test_pred = "The cute act eats delicate fi sh."
test_predicted = [test_pred]
test_tgt = "The cute cat eats delicious fish."
test_target = [test_tgt]
test_ipt = "Te cute cteats delicious fi sh."
test_inputs = [test_ipt]

misspelled = {0, 2, 3, 5}
restored = {0, 1, 3}
changed = {0, 2, 3}
correct = {0, 1}

edited_in_tgt = get_edited_words(test_ipt, test_tgt)
assert misspelled == edited_in_tgt, (misspelled, edited_in_tgt)
assert len(misspelled) <= len(test_tgt.split())

edited_in_ipt = get_edited_words(test_pred, test_ipt)
assert changed == edited_in_ipt, (changed, edited_in_ipt)
assert len(changed) <= len(test_ipt.split())

matching_in_pred, matching_in_tgt = match_words(test_pred, test_tgt)
assert restored == matching_in_tgt, (restored, matching_in_tgt)
assert len(restored) <= len(test_tgt.split())

correct_in_ipt = group_words(test_ipt, test_pred, matching_in_pred)
assert correct == correct_in_ipt
assert len(correct) <= len(test_ipt.split())

In [22]:
get_edited_words("this is a test", "thisis a test")

{0}

In [28]:
get_edited_words2("thisa is a test", "this is a test")

({0}, {0})

In [6]:
def correction_f1_prec_rec(predicted_sequences: List[str], target_sequences: List[str], input_sequences: List[str]) -> Tuple[float, float, float]:
    def _tp_fp_fn(pred: str, tgt: str, ipt: str) -> Tuple[int, int, int]:
        misspelled = get_edited_words(ipt, tgt)
        changed = get_edited_words(pred, ipt)
        matching_in_pred, restored = match_words(pred, tgt)
        correct = group_words(ipt, pred, matching_in_pred)
        tp = misspelled.intersection(restored)
        fn = misspelled.difference(restored)
        fp = changed.difference(correct)
        return len(tp), len(fp), len(fn)
        
    total_tp = total_fp = total_fn = 0
    for pred, tgt, ipt in zip(predicted_sequences, target_sequences, input_sequences):
        tp, fp, fn = _tp_fp_fn(pred, tgt, ipt)
        total_tp += tp
        total_fp += fp
        total_fn += fn
    
    return total_tp, total_fp, total_fn

In [7]:
import edit_distance_rs

In [8]:
print(Levenshtein.editops(test_ipt, test_tgt))
print(edit_distance_rs.edit_operations(test_ipt, test_tgt, False, False))

[('insert', 1, 1), ('insert', 9, 10), ('insert', 10, 12), ('delete', 27, 30)]
[('insert', 1, 1), ('insert', 9, 10), ('insert', 10, 12), ('delete', 27, 30)]


In [9]:
batch_size = 10000
batch_ipt = [test_ipt for _ in range(batch_size)]
batch_tgt = [test_tgt for _ in range(batch_size)]

### Unbatched

In [10]:
%timeit edit_operations(test_ipt, test_tgt)

1.02 ms ± 3.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [11]:
%timeit edit_distance(test_ipt, test_tgt)

1.12 ms ± 3.17 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [12]:
%timeit difflib.SequenceMatcher(a=test_ipt, b=test_tgt).get_matching_blocks()

40.8 µs ± 395 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [13]:
%timeit Levenshtein.distance(test_ipt, test_tgt)

1.14 µs ± 0.762 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [14]:
%timeit Levenshtein.editops(test_ipt, test_tgt)

1.82 µs ± 2.32 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [15]:
%timeit edit_distance_rs.edit_distance(test_ipt, test_tgt, False, False)

134 µs ± 126 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [16]:
%timeit edit_distance_rs.edit_operations(test_ipt, test_tgt, False, False)

124 µs ± 204 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### Batched

In [10]:
from tqdm.notebook import tqdm
import time
from nsc.api import utils
from spell_checking.utils import edit_distance as ed

In [11]:
benchmark_corrupt = utils.load_text_file("../benchmarks/test/sec/wikidump/artificial/corrupt.txt")
benchmark_correct = utils.load_text_file("../benchmarks/test/sec/wikidump/artificial/correct.txt")

len(benchmark_corrupt)

10000

In [12]:
_ = [edit_distance(i, t) for i, t in tqdm(zip(benchmark_corrupt, benchmark_correct), total=len(benchmark_correct))]

  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [13]:
_ = [Levenshtein.distance(i, t) for i, t in tqdm(zip(benchmark_corrupt, benchmark_correct), total=len(benchmark_corrupt))]

  0%|          | 0/10000 [00:00<?, ?it/s]

In [15]:
_ = [ed.edit_distance(i, t) for i, t in tqdm(zip(benchmark_corrupt, benchmark_correct), total=len(benchmark_corrupt))]

  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [19]:
start = time.perf_counter()
_ = ed.batch_edit_distance(benchmark_corrupt, benchmark_correct, batch_size=256)
end = time.perf_counter()
print(end - start)

22.007166981999944
