In [72]:
from itertools import accumulate

import numpy as np

In [4]:
def generate_indexes(m, n, k):
    indexes = ((i, j) for i in range(m) for j in range(n))
    indexes = filter(lambda x: x[0] + x[1] == k, indexes)

    return indexes

def conv_n(x, y, m, f, g):
    indexes = generate_indexes(len(x), len(y), m)

    return g(f(x[i], y[j]) for i, j in indexes)

def conv(x, y, f, g):
    return [conv_n(x, y, m, f, g) for m in range(2 * len(x) - 1)]

In [93]:
def levenshtein(word_1, word_2):
    m = len(word_1)
    n = len(word_2)

    dist = np.zeros((m + 1, n + 1))

    dist[1:, 0] = list(accumulate(word_1, func=lambda x, y: x + np.linalg.norm(y), initial=0))[1:]
    dist[0, 1:] = list(accumulate(word_2, func=lambda x, y: x + np.linalg.norm(y), initial=0))[1:]


    for j in range(1, n+1):
        for i in range(1, m+1):
            deletion_cost = np.linalg.norm(word_1[i-1])  # Needs check
            insertion_cost = np.linalg.norm(word_2[j-1])  # Needs check
            substitution_cost = 0

            if all(word_1[i - 1] != word_2[j - 1]):
                substitution_cost = np.linalg.norm(word_1[i - 1] - word_2[j - 1])

            dist[i, j] = min(dist[i - 1, j] + deletion_cost, 
                             dist[i, j - 1] + insertion_cost, 
                             dist[i - 1, j - 1] + substitution_cost)

    return dist[len(word_1), len(word_2)]

In [94]:
levenshtein(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 5, 9], [7, 8, 9]]))

5.196152422706632

In [74]:
def distance(word1, word2):
    # return sum(conv(word1, word2, lambda x, y: abs(x-y), sum))
    return levenshtein(word1, word2)

In [4]:
w1 = [1, 2, 3]
w2 = [4, 5, 6]
w3 = [7, 8, 9, 10]

# Tests
print(len(w1) != len(w3) and isinstance(distance(w1, w3), float), "Property 1")
print(abs(distance(w1, w2)) == abs(distance(w2, w1)) , "Property 2")
print(distance(w1, w1) == 0 , "Property 3")
print(distance(w1, [5] + w1) < distance(w1, [5] + w1 + [7]), "Property 4")
print(distance(w1, [7] + w1) < distance(w1, w2) , "Property 5")
print(distance(w1, w1 + [8]) < distance(w1, w2) , "Property 6")
print(distance(w1, w2) + distance(w2, w3) >= distance(w1, w3) , "Property 7")
print(abs(distance(w1, [7] + w1) - distance(w1, w1 + [7])) < 0.05 , "Property 8")

False Property 1
True Property 2
True Property 3
True Property 4
True Property 5
True Property 6
True Property 7
True Property 8
