In [142]:
import os
from dataclasses import dataclass
from typing import Any

import polars as pl

from matchescu.entity_matchers.common import FSComparison
from matchescu.typing import EntityReference, EntityReferenceIdentifier

In [143]:
DATADIR = os.path.abspath("../../data")

In [144]:
@dataclass
class DataSource:
    name: str
    id_factory: EntityReferenceIdentifier
    ground_truth_key: str
    dataframe: pl.DataFrame

    def head(self, limit: int | None = None) -> pl.DataFrame:
        if limit is not None:
            return self.dataframe.head(limit)
        return self.dataframe


class ColumnBasedIdRetriever:
    def __init__(self, col: int | str):
        self._col = col

    def __call__(self, row: EntityReference) -> Any:
        if not isinstance(row, (tuple, list, dict)):
            raise ValueError("only tuple, list and dict row types supported")
        return row[self._col]

In [145]:
from decimal import Decimal


def _parse_float(value: str | None) -> Decimal | None:
    if value is None:
        return None
    value = value.strip("$").replace(",", "")
    return Decimal(value)


def read_csv(path: str) -> pl.DataFrame:
    df = pl.read_csv(path, ignore_errors=True)
    return df.with_columns(
        [pl.col("price").map_elements(_parse_float, pl.Decimal).alias("real_price")]
    ).drop("price").rename({"real_price": "price"})


abt = DataSource(
    name="abt",
    id_factory=ColumnBasedIdRetriever(0),
    ground_truth_key="idAbt",
    dataframe=read_csv(os.path.join(DATADIR, "abt-buy", "Abt.csv"))
)
buy = DataSource(
    name="buy",
    id_factory=ColumnBasedIdRetriever(0),
    ground_truth_key="idBuy",
    dataframe=read_csv(os.path.join(DATADIR, "abt-buy", "Buy.csv"))
)
gt = set(pl.read_csv(os.path.join(DATADIR, "abt-buy", "abt_buy_perfectMapping.csv"), ignore_errors=True).iter_rows())

comparison = (
    FSComparison()
    .levenshtein("name", 1, 1, 0.5, True)
    .jaro_winkler("description", 2, 2, 0.45, True)
)

In [146]:
from functools import partial
from itertools import product
from typing import Generator


def _prep_dataframe(ds: DataSource) -> pl.DataFrame:
    renames = {
        col: f"{col}_{ds.name}"
        for col in ds.dataframe.columns
        if not col.endswith(ds.name)
    }
    return ds.dataframe.rename(renames)


def _compute_row_values(row: tuple, left: DataSource, right: DataSource, ground_truth: set, cmp: FSComparison):
    divider = len(left.dataframe.columns)
    left_row = row[:divider]
    right_row = row[divider:]
    left_id = left.id_factory(left_row)
    right_id = right.id_factory(right_row)
    result = {
        "left_id": left_id,
        "right_id": right_id,
        "is_same_entity": (left_id, right_id) in ground_truth
    }
    match_pattern = []
    for config in cmp.specs:
        a = left_row[config.left_ref_key]
        b = right_row[config.right_ref_key]
        match_result = config.match_strategy(a, b).value
        result[config.label] = match_result
        match_pattern.append(str(match_result))
    result["match_pattern"] = "".join(match_pattern)
    return result,


def create_comparison_df(
        left: DataSource,
        right: DataSource,
        ground_truth: set[tuple],
        cmp: FSComparison,
        limit: int | None = None
) -> pl.DataFrame:
    left_tmp, right_tmp = left.dataframe, right.dataframe
    try:
        left.dataframe = _prep_dataframe(left)
        right.dataframe = _prep_dataframe(right)

        cross_df = left.head(limit).join(right.head(limit), how="cross")
        compute_row_values = partial(
            _compute_row_values, left=left, right=right, ground_truth=ground_truth, cmp=cmp
        )
        dtype = {
            "left_id": pl.Int64,
            "right_id": pl.Int64,
            "is_same_entity": pl.Boolean,
            "match_pattern": pl.String,
        }
        for config in cmp.specs:
            dtype[config.label] = pl.UInt8
        other_df = cross_df.map_rows(compute_row_values, return_dtype=pl.Struct(dtype)).unnest("column_0")
        return other_df
    finally:
        left.dataframe = left_tmp
        right.dataframe = right_tmp

In [147]:
comparison_df = create_comparison_df(abt, buy, gt, comparison, limit=300)
print(comparison_df)

shape: (90_000, 6)
┌─────────┬───────────┬────────────────┬──────┬─────────────┬───────────────┐
│ left_id ┆ right_id  ┆ is_same_entity ┆ name ┆ description ┆ match_pattern │
│ ---     ┆ ---       ┆ ---            ┆ ---  ┆ ---         ┆ ---           │
│ i64     ┆ i64       ┆ bool           ┆ i64  ┆ i64         ┆ str           │
╞═════════╪═══════════╪════════════════╪══════╪═════════════╪═══════════════╡
│ 552     ┆ 10011646  ┆ false          ┆ 0    ┆ 1           ┆ 01            │
│ 552     ┆ 10140760  ┆ false          ┆ 0    ┆ 1           ┆ 01            │
│ 552     ┆ 10221960  ┆ false          ┆ 0    ┆ 1           ┆ 01            │
│ 552     ┆ 10246269  ┆ false          ┆ 0    ┆ 1           ┆ 01            │
│ 552     ┆ 10315184  ┆ false          ┆ 0    ┆ 1           ┆ 01            │
│ …       ┆ …         ┆ …              ┆ …    ┆ …           ┆ …             │
│ 29829   ┆ 204053338 ┆ false          ┆ 0    ┆ 1           ┆ 01            │
│ 29829   ┆ 204071217 ┆ false          ┆ 0   

In [148]:
from numpy import log2

from matchescu.entity_matchers.attribute_matching import FSMatchResult


def _generate_fs_patterns(n_comparisons: int) -> Generator[list[FSMatchResult], None, None]:
    yield from product(FSMatchResult, repeat=n_comparisons)


def _logarithmic_weight(prob_tp: float, prob_fp: float) -> float:
    if prob_fp and prob_tp:
        return log2(prob_tp / prob_fp)
    if prob_tp and not prob_fp:
        return 1000000
    return -1000000


def _rows_matching_pattern(pattern: tuple, cmp: FSComparison):
    match_pattern_expr = pl.lit(True)
    for pattern_value, attr_comparison_spec in zip(pattern, cmp.specs):
        match_pattern_expr = match_pattern_expr & (pl.col(attr_comparison_spec.label) == pattern_value)
    return match_pattern_expr


def create_pattern_weights(comp_df: pl.DataFrame, cmp: FSComparison) -> pl.DataFrame:
    data = []
    n_comparisons = len(cmp.specs)
    for pattern in _generate_fs_patterns(n_comparisons):
        pattern_df = comp_df.filter(_rows_matching_pattern(pattern, cmp))
        total_pattern_occurences = pattern_df.shape[0]

        entity_pattern_matches = pattern_df.filter(pl.col("is_same_entity")).shape[0]
        entity_pattern_mismatches = pattern_df.filter(pl.col("is_same_entity").not_()).shape[0]
        tp_prob = entity_pattern_matches / total_pattern_occurences if total_pattern_occurences > 0 else 0
        fp_prob = entity_pattern_mismatches / total_pattern_occurences if total_pattern_occurences > 0 else 0
        pattern_weight = _logarithmic_weight(tp_prob, fp_prob)

        data.append({
            "pattern": "".join(map(str, pattern)),
            "tp_count": entity_pattern_matches,
            "fp_count": entity_pattern_mismatches,
            "total": total_pattern_occurences,
            "tp_prob": tp_prob,
            "fp_prob": fp_prob,
            "pattern_weight": pattern_weight
        })

    tp_df = pl.DataFrame(data)
    return tp_df.sort(by="pattern_weight", descending=True)

In [149]:
pattern_weight_df = create_pattern_weights(comparison_df, comparison)
print(pattern_weight_df)

shape: (9, 7)
┌─────────┬──────────┬──────────┬───────┬──────────┬──────────┬────────────────┐
│ pattern ┆ tp_count ┆ fp_count ┆ total ┆ tp_prob  ┆ fp_prob  ┆ pattern_weight │
│ ---     ┆ ---      ┆ ---      ┆ ---   ┆ ---      ┆ ---      ┆ ---            │
│ str     ┆ i64      ┆ i64      ┆ i64   ┆ f64      ┆ f64      ┆ f64            │
╞═════════╪══════════╪══════════╪═══════╪══════════╪══════════╪════════════════╡
│ 10      ┆ 35       ┆ 25       ┆ 60    ┆ 0.583333 ┆ 0.416667 ┆ 0.485427       │
│ 11      ┆ 76       ┆ 198      ┆ 274   ┆ 0.277372 ┆ 0.722628 ┆ -1.381429      │
│ 00      ┆ 51       ┆ 33775    ┆ 33826 ┆ 0.001508 ┆ 0.998492 ┆ -9.371243      │
│ 01      ┆ 53       ┆ 55689    ┆ 55742 ┆ 0.000951 ┆ 0.999049 ┆ -10.037184     │
│ 02      ┆ 0        ┆ 97       ┆ 97    ┆ 0.0      ┆ 1.0      ┆ -1e6           │
│ 12      ┆ 0        ┆ 1        ┆ 1     ┆ 0.0      ┆ 1.0      ┆ -1e6           │
│ 20      ┆ 0        ┆ 0        ┆ 0     ┆ 0.0      ┆ 0.0      ┆ -1e6           │
│ 21      ┆ 0 

In [150]:
total_count = comparison_df.shape[0]
entity_count = comparison_df.filter(pl.col("is_same_entity") == True).shape[0]
mismatch_count = total_count - entity_count
print(f"total={total_count}; entities={entity_count}; non-entities={mismatch_count}")
print(pattern_weight_df.select("pattern", "tp_prob", "fp_prob", "pattern_weight"))

total=90000; entities=215; non-entities=89785
shape: (9, 4)
┌─────────┬──────────┬──────────┬────────────────┐
│ pattern ┆ tp_prob  ┆ fp_prob  ┆ pattern_weight │
│ ---     ┆ ---      ┆ ---      ┆ ---            │
│ str     ┆ f64      ┆ f64      ┆ f64            │
╞═════════╪══════════╪══════════╪════════════════╡
│ 10      ┆ 0.583333 ┆ 0.416667 ┆ 0.485427       │
│ 11      ┆ 0.277372 ┆ 0.722628 ┆ -1.381429      │
│ 00      ┆ 0.001508 ┆ 0.998492 ┆ -9.371243      │
│ 01      ┆ 0.000951 ┆ 0.999049 ┆ -10.037184     │
│ 02      ┆ 0.0      ┆ 1.0      ┆ -1e6           │
│ 12      ┆ 0.0      ┆ 1.0      ┆ -1e6           │
│ 20      ┆ 0.0      ┆ 0.0      ┆ -1e6           │
│ 21      ┆ 0.0      ┆ 0.0      ┆ -1e6           │
│ 22      ┆ 0.0      ┆ 0.0      ┆ -1e6           │
└─────────┴──────────┴──────────┴────────────────┘


In [151]:
import plotly.graph_objs as go

fig = go.Figure(data=[
    go.Bar(x=pattern_weight_df["pattern"], y=pattern_weight_df["tp_prob"], name="true positives"),
    go.Bar(x=pattern_weight_df["pattern"], y=pattern_weight_df["fp_prob"], name="false positives"),
])
fig.update_layout(barmode="group")
fig.show()

In [152]:
T_mu = 0  # the pattern represents a true link above this value 
T_lambda = -4  # the pattern represents a true non-link below this value

positive_patterns = set(pattern_weight_df.filter(pl.col("pattern_weight") >= T_mu)["pattern"].to_list())
negative_patterns = set(pattern_weight_df.filter(pl.col("pattern_weight") <= T_lambda)["pattern"].to_list())
undecided = set(pattern_weight_df.filter(pl.col("pattern_weight").is_between(T_lambda, T_mu))["pattern"].to_list())

print(positive_patterns, negative_patterns, undecided)

{'10'} {'01', '22', '12', '02', '20', '00', '21'} {'11'}


In [153]:
tp, tn, fp, fn = 0, 0, 0, 0
total_comparisons = 0
ambiguous = 0

for left_row in abt.dataframe.iter_rows():
    for right_row in buy.dataframe.iter_rows():
        total_comparisons += 1
        match_pattern = "".join(
            map(
                str,
                [
                    config.match_strategy(
                        left_row[config.left_ref_key],
                        right_row[config.right_ref_key]
                    )
                    for config in comparison.specs
                ]
            )
        )
        is_link = match_pattern in positive_patterns
        is_non_link = match_pattern in negative_patterns
        is_possible_link = match_pattern in undecided
        if is_possible_link:
            ambiguous += 1
            continue

        left_id = abt.id_factory(left_row)
        right_id = buy.id_factory(right_row)
        is_same_entity = (left_id, right_id) in gt

        if is_link:  #
            if is_same_entity:
                tp += 1
            else:
                fp += 1
        else:  # is not link
            if is_same_entity:
                fn += 1
            else:
                tn += 1

unambiguous = total_comparisons - ambiguous
print(
    f"total comparisons: {total_comparisons}; manual inspections: {ambiguous}; unambiguous comparisons: {unambiguous}")
print(f"tp={tp};fp={fp};tn={tn};fn={fn}")
print(f"tp+fp+fn+tn={tp + fp + fn + tn}; unambiguous={unambiguous}")
print(f"manual inspections required: {ambiguous}")

total comparisons: 1180452; manual inspections: 2090; unambiguous comparisons: 1178362
tp=182;fp=330;tn=1177306;fn=544
tp+fp+fn+tn=1178362; unambiguous=1178362
manual inspections required: 2090


In [154]:
p = tp / (tp + fp) if tp + fp > 0 else 0
r = tp / (tp + fn) if tp + fn > 0 else 0
f1 = 2 * p * r / (p + r) if p + r > 0 else 0
print(p, r, f1)

0.35546875 0.25068870523415976 0.2940226171243942
