# Learning String Similarity

The paper [@bilenko2002learning] outlines a process by which we can learn the
edit distance between two strings using minimal training over pairs of
coreferent strings.
Coreferent strings are strings that belong to entity references that refer to
the same real world object.
The idea is that if we fine tune the algorithm that computes the string edit
distance to the business domain where we perform entity resolution, we will be
able to identify coreferent entity references more accurately than if we use
a generic edit distance.
The aforementioned paper references an algorithm [@ristad1998learning] which
learns to compute the Levenshtein distance between two strings from training
data.

This notebook showcases Ristad and Yianilos' algorithm using Milenko's approach
to constructing the training corpus. 

In [1]:
import os
import random

import numpy as np
import polars as pl
from matchescu.matching.similarity import LevenshteinLearner


DATADIR = os.path.abspath("../../data")
CSV_PATH = os.path.join(DATADIR, "cora", "cora.csv")

In [2]:
df = pl.read_csv(CSV_PATH, has_header=False, ignore_errors=True).rename({
    "column_1": "id",
    "column_3": "class",
    "column_4": "author",
    "column_5": "volume",
    "column_6": "title",
    "column_7": "institution",
    "column_8": "venue",
    "column_11": "year"
}).select(pl.col("id", "class", "author", "title", "venue", "year"))
display(df)

id,class,author,title,venue,year
i64,str,str,str,str,str
1,"""blum1993""","""avrim blum, merrick furst, mic…","""cryptographic primitives based…","""in pre-proceedings of crypto '…","""1993"""
2,"""blum1993""","""avrim blum, merrick furst, mic…","""cryptographic primitives based…","""proc. crypto 93,""","""1994"""
3,"""blum1993""","""a. blum, m. furst, m. kearns, …","""cryptographic primitives based…","""crypto,""","""1993"""
4,"""blum1994""","""blum, a., furst, m., jackson, …","""weakly learning dnf and charac…","""proceedings of the 26th annual…","""(1994)."""
5,"""blum1994""","""blum, a., furst, m., jackson, …","""weakly learning dnf and charac…","""in proceedings of the twenty-s…","""(1994)."""
…,…,…,…,…,…
1289,"""schapire1998""","""robert e. schapire and yoram s…","""improved boosting algorithms u…","""in proceedings of the eleventh…","""1998"""
1290,"""schapire""","""schapire, r. e., freund, y., b…","""boosting the margin: a new exp…",,"""(1998)."""
1291,"""schapire1998mm""","""robert e. schapire and yoram s…","""a system for multiclass multi-…","""unpublished manuscript,""","""1998"""
1292,"""singer""","""robert e. schapire yoram singe…","""improved boosting algorithms u…",,


In [3]:
records = list(df.iter_rows(named=True))
dedupe_data = []
y = []
for i, left_record in enumerate(records):
    for j, right_record in enumerate(records, i+1):
        lclass, rclass = left_record["class"], right_record["class"]
        row = {f"{k}_left": v for k,v in left_record.items()}
        row.update(
            {f"{k}_right": v for k,v in right_record.items()}
        )
        dedupe_data.append(row)
        y.append(int(lclass == rclass))
X = pl.DataFrame(dedupe_data).to_numpy()
y = np.array(y)
display(X, y)

array([[1, 'blum1993',
        'avrim blum, merrick furst, michael kearns, and richard j. lipton.',
        ..., 'cryptographic primitives based on hard learning problems.',
        "in pre-proceedings of crypto '93,", '1993'],
       [1, 'blum1993',
        'avrim blum, merrick furst, michael kearns, and richard j. lipton.',
        ..., 'cryptographic primitives based on hard learning problems.',
        'proc. crypto 93,', '1994'],
       [1, 'blum1993',
        'avrim blum, merrick furst, michael kearns, and richard j. lipton.',
        ..., 'cryptographic primitives based on hard learning problems.',
        'crypto,', '1993'],
       ...,
       [1293, 'singer', 'robert e. schapire yoram singer.', ...,
        'a system for multiclass multi-label text categorization.',
        'unpublished manuscript,', '1998'],
       [1293, 'singer', 'robert e. schapire yoram singer.', ...,
        'improved boosting algorithms using confidence-rated predictions.',
        None, None],
       [

array([1, 1, 1, ..., 0, 1, 1])

We're going to split the data into 100 folds. Using 10 folds is extremely slow
in Python because of how slow Python's for loops are.

In [45]:
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

n_splits = 100
X_shuffled, y_shuffled = shuffle(X, y, random_state=42)
X_splits = np.array_split(X, n_splits)
y_splits = np.array_split(y, n_splits)

folds = []
for i in range(n_splits):
    X_split_train, X_split_test, y_split_train, y_split_test = train_test_split(X_splits[i], y_splits[i], train_size=0.7) 
    folds.append((X_splits[i], y_splits[i], X_split_train, y_split_train, X_split_test, y_split_test))

In [46]:
import random
from sklearn.svm import LinearSVC


def train_distance_model(features, target, n_samples=100, epochs=10):
    idx_of_matches, = np.where(target == 1)
    matching_recods = features[idx_of_matches]
    coreferent_values = [(val[3], val[9]) for val in matching_recods]
    corpus = list(random.choices(coreferent_values, k=n_samples))
    return LevenshteinLearner().fit(corpus, epochs)


def train_svm(features, target):
    model = LinearSVC()
    model.fit(features, target)
    return model

Next, train the Levenshtein distance estimator using 100 random samples from
the data over 10 epochs. 

In [47]:
import itertools
from sklearn.metrics import precision_score, recall_score, f1_score
from jellyfish import levenshtein_distance
from ipywidgets import IntProgress
from IPython.display import display


def _title_levenshtein(values: tuple) -> tuple:
    return (levenshtein_distance(values[3], values[9]),)


def _title_learned_levenshtein(values: tuple) -> tuple:
    return (distance_model.compute_distance(values[3], values[9]),)


stats = []
limit = 3
max_count = limit * 8
f = IntProgress(min=0, max=max_count)
display(f)

for idx, (X, y, X_train, y_train, X_test, y_test) in itertools.islice(enumerate(folds), limit):
    distance_model = train_distance_model(X, y, 50)
    title_train = list(map(_title_levenshtein, X_train))
    f.value += 1
    model = train_svm(title_train, y_train)
    f.value += 1
    print("trained SVM levenshtein #", idx+1)
    title_test = list(map(_title_levenshtein, X_test))
    f.value += 1
    prediction = model.predict(title_test)
    f.value += 1
    print("evaluated levenshtein #", idx+1)

    learned_title_train = list(map(_title_learned_levenshtein, X_train))
    f.value += 1
    learned_model = train_svm(learned_title_train, y_train)
    f.value += 1
    print("trained SVM learned-levenshtein #", idx+1)
    learned_title_test = list(map(_title_learned_levenshtein, X_test))
    f.value += 1
    learned_prediction = model.predict(learned_title_test)
    f.value += 1
    print("evaluated learned-levenshtein #", idx+1)
    
    stats.append({
        "levenshtein precision": precision_score(y_test, prediction),
        "levenshtein recall": recall_score(y_test, prediction),
        "levenshtein f1": f1_score(y_test, prediction),
        "learned levenshtein precision": precision_score(y_test, learned_prediction),
        "learned levenshtein recall": recall_score(y_test, learned_prediction),
        "learned levenshtein f1": f1_score(y_test, learned_prediction),
    })

IntProgress(value=0, max=24)

converged after 9 epochs
trained SVM levenshtein # 1
evaluated levenshtein # 1
trained SVM learned-levenshtein # 1
evaluated learned-levenshtein # 1
converged after 7 epochs
trained SVM levenshtein # 2
evaluated levenshtein # 2
trained SVM learned-levenshtein # 2
evaluated learned-levenshtein # 2
converged after 7 epochs
trained SVM levenshtein # 3
evaluated levenshtein # 3
trained SVM learned-levenshtein # 3
evaluated learned-levenshtein # 3


In [48]:
display(pl.DataFrame(stats))

levenshtein precision,levenshtein recall,levenshtein f1,learned levenshtein precision,learned levenshtein recall,learned levenshtein f1
f64,f64,f64,f64,f64,f64
1.0,1.0,1.0,0.002001,0.357143,0.003979
1.0,0.787986,0.881423,0.0,0.0,0.0
1.0,0.763948,0.86618,0.0,0.0,0.0


In [8]:
stats = pl.DataFrame(stats)
display(stats)
display(stats.mean())

levenshtein precision,levenshtein recall,levenshtein f1,learned levenshtein precision,learned levenshtein recall,learned levenshtein f1
f64,f64,f64,f64,f64,f64
0.766539,0.841559,0.802299,0.766539,0.841559,0.802299
0.758516,0.836792,0.795733,0.758516,0.836792,0.795733
0.762889,0.838194,0.798771,0.762889,0.838194,0.798771
0.770097,0.843522,0.805139,0.770097,0.843522,0.805139
0.765259,0.843803,0.802614,0.765259,0.843803,0.802614
0.763351,0.846003,0.802555,0.763351,0.846003,0.802555
0.758167,0.839832,0.796912,0.758167,0.839832,0.796912
0.764512,0.846003,0.803196,0.764512,0.846003,0.803196
0.745201,0.838429,0.789071,0.745201,0.838429,0.789071
0.764602,0.851893,0.805891,0.764602,0.851893,0.805891


levenshtein precision,levenshtein recall,levenshtein f1,learned levenshtein precision,learned levenshtein recall,learned levenshtein f1
f64,f64,f64,f64,f64,f64
0.761913,0.842603,0.800218,0.761913,0.842603,0.800218
