In [9]:
import torch as th


mes = th.randint(0, 4, (3, 10))
mask = (mes == 0).cumsum(dim=1) == 0
masked = mes * mask
print(mes)
print(masked)

tensor([[3, 1, 2, 1, 3, 1, 1, 2, 3, 2],
        [1, 1, 3, 1, 0, 3, 2, 2, 1, 0],
        [3, 1, 0, 1, 1, 2, 1, 0, 0, 3]])
tensor([[3, 1, 2, 1, 3, 1, 1, 2, 3, 2],
        [1, 1, 3, 1, 0, 0, 0, 0, 0, 0],
        [3, 1, 0, 0, 0, 0, 0, 0, 0, 0]])


In [19]:
import torch as th
from pprint import pprint
import editdistance
from numba import njit
import numpy as np
import time
import matplotlib.pyplot as plt


@njit
def edit_distance(s1: np.ndarray, s2: np.ndarray):
    if len(s1) < len(s2):
        return edit_distance(s2, s1)

    if len(s2) == 0:
        return len(s1)

    previous_row = np.arange(len(s2) + 1, dtype=np.int64)
    for i, c1 in enumerate(s1):
        current_row = np.zeros(len(s2) + 1, dtype=np.int64)
        current_row[0] = i + 1
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row[j + 1] = min(insertions, deletions, substitutions)
        previous_row = current_row

    return previous_row[-1]


def lansim(
    message1: np.ndarray,
    message2: np.ndarray,
    length1: np.ndarray | None = None,
    length2: np.ndarray | None = None,
):
    if length1 is None:
        length1 = np.argmin(message1, axis=1) + 1
    if length2 is None:
        length2 = np.argmin(message2, axis=1) + 1

    edit_distances = np.arange(message1.shape[0], dtype=np.int64)
    for i in range(message1.shape[0]):
        edit_distances[i] = editdistance.eval(
            message1[i, : length1[i]], message2[i, : length2[i]]
        )

    return 1 - edit_distances / np.maximum(length1, length2)


@njit
def lansim_numba(
    message1: np.ndarray,
    message2: np.ndarray,
    length1: np.ndarray | None = None,
    length2: np.ndarray | None = None,
    distance:str="edit_distance"
):
    if length1 is None:
        length1 = np.argmin(message1, axis=1) + 1
    if length2 is None:
        length2 = np.argmin(message2, axis=1) + 1

    if distance == "edit_distance":
        distance = edit_distance
    else: 
        raise ValueError("distance must be edit_distance")

    edit_distances = np.arange(message1.shape[0], dtype=np.int64)
    for i in range(message1.shape[0]):
        edit_distances[i] = edit_distance(
            message1[i, : length1[i]], message2[i, : length2[i]]
        )

    return 1 - edit_distances / np.maximum(length1, length2)


batch_size = 100000
lengths = list(range(1, 20, 1))
times = []
times_numba = []
for l in lengths:
    messages1 = np.random.randint(0, 20, (batch_size, l))
    messages2 = np.random.randint(0, 20, (batch_size, l))
    eos = np.zeros((batch_size, 1))
    messages1 = np.concatenate([messages1, eos], axis=1)
    messages2 = np.concatenate([messages2, eos], axis=1)

    start = time.time()
    lan1 = lansim(messages1, messages2)
    elapsed = time.time() - start
    times.append(elapsed)

    start = time.time()
    lan2 = lansim_numba(messages1, messages2)
    elapsed = time.time() - start
    times_numba.append(elapsed)

    print(l)

    assert np.allclose(lan1, lan2)

plt.plot(lengths, times, label="torch")
plt.plot(lengths, times_numba, label="torch_numba")
plt.legend()
plt.show()

ValueError: distance must be edit_distance