Skip to content

Commit

Permalink
Merge fa1a6d8 into 4c3acd6
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Nov 16, 2021
2 parents 4c3acd6 + fa1a6d8 commit e755955
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 4 deletions.
11 changes: 9 additions & 2 deletions CHANGELOG.md
Expand Up @@ -5,14 +5,21 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [2.0.2]

### Added

- Add function to compute m values from labelled data by @RobinL in https://github.com/moj-analytical-services/splink/pull/248

## [2.0.1]

### Added

* Add function that outputs the full path to the similarity jar by @RobinL in https://github.com/moj-analytical-services/splink/pull/237
- Add function that outputs the full path to the similarity jar by @RobinL in https://github.com/moj-analytical-services/splink/pull/237

### Changed
* Allow match weight to be used in the diagnostic histogram by @RobinL in https://github.com/moj-analytical-services/splink/pull/239

- Allow match weight to be used in the diagnostic histogram by @RobinL in https://github.com/moj-analytical-services/splink/pull/239

## [2.0.0]

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
@@ -1,7 +1,7 @@
[tool.poetry]
name = "splink"
version = "2.0.1"
description = "Implementation in Apache Spark of the EM algorithm to estimate parameters of Fellegi-Sunter's canonical model of record linkage."
version = "2.0.2"
description = "Implementation of Fellegi-Sunter's canonical model of record linkage in Apache Spark, including EM algorithm to estimate parameters"
authors = ["Robin Linacre <robinlinacre@hotmail.com>", "Sam Lindsay", "Theodore Manassis"]
license = "MIT"
homepage = "https://github.com/moj-analytical-services/splink"
Expand Down
95 changes: 95 additions & 0 deletions splink/lower_id_on_lhs.py
@@ -0,0 +1,95 @@
def _sql_expr_move_left_to_right(
col_name,
unique_id_col: str = "unique_id",
source_dataset_col: str = "source_dataset",
):

sds_l = f"{source_dataset_col}_l"
uid_l = f"{unique_id_col}_l"
sds_r = f"{source_dataset_col}_r"
uid_r = f"{unique_id_col}_r"
col_name_l = f"{col_name}_l"
col_name_r = f"{col_name}_r"

if source_dataset_col:
uid_expr_l = f"concat({sds_l}, '-__-', {uid_l})"
uid_expr_r = f"concat({sds_r}, '-__-', {uid_r})"
else:
uid_expr_l = uid_l
uid_expr_r = uid_r

move_to_left = f"""
CASE
WHEN {uid_expr_l} < {uid_expr_r}
THEN {col_name_l}
ELSE {col_name_r}
END as {col_name_l}
"""

move_to_right = f"""
CASE
WHEN {uid_expr_l} < {uid_expr_r}
THEN {col_name_r}
ELSE {col_name_l}
END as {col_name_r}
"""

exprs = f"""
{move_to_left},
{move_to_right}
"""

return exprs


def lower_id_to_left_hand_side(
df,
source_dataset_col: str = "source_dataset",
unique_id_col: str = "unique_id",
):
"""Take a dataframe in the format of splink record comparisons (with _l and _r suffixes)
and return a dataframe where the _l columns correspond to the record with the lower id.
For example:
| source_dataset_l | unique_id_l | source_dataset_r | unique_id_r | a_l | a_r | other_col |
|:-------------------|--------------:|:-------------------|--------------:|------:|------:|:------------|
| df | 0 | df | 1 | 0 | 1 | a |
| df | 2 | df | 0 | 2 | 0 | b |
| df | 0 | df | 3 | 0 | 3 | c |
Becomes
| source_dataset_l | unique_id_l | source_dataset_r | unique_id_r | a_l | a_r | other_col |
|:-------------------|--------------:|:-------------------|--------------:|------:|------:|:------------|
| df | 0 | df | 1 | 0 | 1 | a |
| df | 0 | df | 2 | 0 | 2 | b |
| df | 0 | df | 3 | 0 | 3 | c |
Returns:
df: a dataframe with the columns _l and _r swapped in the case where
the unique_id_r < unique_id_l
"""
spark = df.sql_ctx.sparkSession
cols = list(df.columns)

l_cols = [c for c in cols if c.endswith("_l")]
r_cols = [c for c in cols if c.endswith("_r")]
other_cols = [c for c in cols if c not in (l_cols + r_cols)]

case_exprs = []
for col in l_cols:
this_col = col[:-2]
expr = _sql_expr_move_left_to_right(this_col, unique_id_col, source_dataset_col)
case_exprs.append(expr)
case_exprs.extend(other_cols)
select_expr = ", ".join(case_exprs)

df.createOrReplaceTempView("df")
sql = f"""
select {select_expr}
from df
"""

df = spark.sql(sql)
return df.select(cols)
197 changes: 197 additions & 0 deletions splink/m_from_labels.py
@@ -0,0 +1,197 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import lit
from splink.vertically_concat import vertically_concatenate_datasets
from splink.lower_id_on_lhs import lower_id_to_left_hand_side
from splink.blocking import _get_columns_to_retain_blocking, sql_gen_comparison_columns
from splink.settings import Settings
from splink.blocking import block_using_rules
from splink.gammas import add_gammas
from splink.maximisation_step import run_maximisation_step
from splink.model import Model
from splink.cluster import _check_graphframes_installation


def estimate_m_from_labels(
settings: dict,
df_or_dfs: DataFrame,
labels: DataFrame,
use_connected_components,
fix_m_probabilities=False,
):
"""Estimate m values from labels
Args:
settings (dict): splink settings dictionary
df_or_dfs (DataFrame): (DataFrame or list of DataFrames, optional):
labels (DataFrame): Labelled data.
For link or link and dedupe, should have columns:
'source_dataset_l', 'unique_id_l', 'source_dataset_r', and 'unique_id_r'
For dedupe only, only needs 'unique_id_l' and 'unique_id_r' columns
use_connected_components (bool, optional): Whether to use the connected components approach.
Defaults to True. Described here: https://github.com/moj-analytical-services/splink/issues/245
fix_m_probabilities (bool, optional): If True, output comparison column settings will have
fix_u_probabilities set to True. Defaults to False.
"""

# dfs is a list of dfs irrespective of whether input was a df or list of dfs
if type(df_or_dfs) == DataFrame:
dfs = [df_or_dfs]
else:
dfs = df_or_dfs

spark = dfs[0].sql_ctx.sparkSession

if use_connected_components:
_check_graphframes_installation(spark)

df_nodes = vertically_concatenate_datasets(dfs)

from splink.settings import complete_settings_dict

settings_complete = complete_settings_dict(settings, spark)
if settings_complete["link_type"] == "dedupe_only":
use_source_dataset = False
else:
use_source_dataset = True

source_dataset_colname = settings_complete["source_dataset_column_name"]
uid_colname = settings_complete["unique_id_column_name"]

if use_connected_components:
df_gammas = _get_comparisons_using_connected_components(
df_nodes,
labels,
settings_complete,
use_source_dataset,
source_dataset_colname,
uid_colname,
)
else:
df_gammas = _get_comparisons_using_joins(
df_nodes,
labels,
settings_complete,
use_source_dataset,
source_dataset_colname,
uid_colname,
)

df_e = df_gammas.withColumn("match_probability", lit(1.0))

model = Model(settings_complete, spark)
run_maximisation_step(df_e, model, spark)

settings_with_m_dict = model.current_settings_obj.settings_dict

# We want to add m probabilities from these estimates to the settings_with_u object
settings_obj = Settings(settings)

settings_obj.overwrite_m_u_probs_from_other_settings_dict(
settings_with_m_dict, overwrite_u=False
)

for cc in settings_obj.comparison_columns_list:
if fix_m_probabilities:
cc.fix_m_probabilities = True

return settings_obj.settings_dict


def _get_comparisons_using_connected_components(
df_nodes,
df_labels,
settings_complete,
use_source_dataset,
source_dataset_colname,
uid_colname,
):
from graphframes import GraphFrame

spark = df_nodes.sql_ctx.sparkSession

if use_source_dataset:
uid_node = f"concat({source_dataset_colname}, '-__-',{uid_colname}) as id"
uid_r = f"concat({source_dataset_colname}_l, '-__-',{uid_colname}_l) as src"
uid_l = f"concat({source_dataset_colname}_r, '-__-',{uid_colname}_r) as dst"
else:
uid_node = f"{uid_colname} as id"
uid_r = f"{uid_colname}_l as src"
uid_l = f"{uid_colname}_r as dst"

cc_nodes = df_nodes.selectExpr(uid_node)
edges = df_labels.selectExpr(uid_l, uid_r)
g = GraphFrame(cc_nodes, edges)
g = g.dropIsolatedVertices()
cc = g.connectedComponents()

df_nodes.createOrReplaceTempView("df_nodes")
cc.createOrReplaceTempView("cc")

if use_source_dataset:
join_col_expr = (
f"concat(df_nodes.{source_dataset_colname}, '-__-',df_nodes.{uid_colname})"
)
else:
join_col_expr = f"df_nodes.{uid_colname}"

sql = f"""
select df_nodes.*, cc.component as cluster
from df_nodes
inner join cc
on cc.id = {join_col_expr}
"""

df_with_cluster = spark.sql(sql)

settings_complete["blocking_rules"] = ["l.cluster = r.cluster"]

df_comparison = block_using_rules(settings_complete, df_with_cluster, spark)
df_gammas = add_gammas(df_comparison, settings_complete, spark)

return df_gammas


def _get_comparisons_using_joins(
df_nodes,
df_labels,
settings_complete,
use_source_dataset,
source_dataset_colname,
uid_colname,
):
spark = df_nodes.sql_ctx.sparkSession
df_labels = lower_id_to_left_hand_side(
df_labels, source_dataset_colname, uid_colname
)

df_nodes.createOrReplaceTempView("df_nodes")
df_labels.createOrReplaceTempView("df_labels")

columns_to_retain = _get_columns_to_retain_blocking(settings_complete, df_nodes)

sql_select_expr = sql_gen_comparison_columns(columns_to_retain)

if use_source_dataset:

sql = f"""
select {sql_select_expr}, '0' as match_key
from df_nodes as l
inner join df_labels
on l.{source_dataset_colname} = df_labels.{source_dataset_colname}_l and l.{uid_colname} = df_labels.{uid_colname}_l
inner join df_nodes as r
on r.{source_dataset_colname} = df_labels.{source_dataset_colname}_r and r.{uid_colname} = df_labels.{uid_colname}_r
"""
else:
sql = f"""
select {sql_select_expr}, '0' as match_key
from df_nodes as l
inner join df_labels
on l.{uid_colname} = df_labels.{uid_colname}_l
inner join df_nodes as r
on r.{uid_colname} = df_labels.{uid_colname}_r
"""

df_comparison = spark.sql(sql)
df_gammas = add_gammas(df_comparison, settings_complete, spark)
return df_gammas
70 changes: 70 additions & 0 deletions tests/test_m_from_labels.py
@@ -0,0 +1,70 @@
from pyspark.sql import Row
from splink.case_statements import sql_gen_case_smnt_strict_equality_2
from splink.m_from_labels import estimate_m_from_labels

import pytest


def test_m_from_labels(spark):

# fmt: off
df_rows = [
{"uid": "0", "sds": "df1", "first_name": "Robin", "dob": "1909-10-11"},
{"uid": "1", "sds": "df1", "first_name": "Robin", "dob": "1909-10-11"},
{"uid": "2", "sds": "df1", "first_name": "Robim", "dob": "1909-10-11"},
{"uid": "3", "sds": "df1", "first_name": "James", "dob": "1909-10-10"},
]

labels_rows = [
{"uid_l": "1", "sds_l": "df1", "uid_r": "0", "sds_r": "df1"},
{"uid_l": "2", "sds_l": "df1", "uid_r": "0", "sds_r": "df1"},
{"uid_l": "0", "sds_l": "df1", "uid_r": "3", "sds_r": "df1"},
]
# fmt: on

df = spark.createDataFrame(Row(**x) for x in df_rows)

df_labels = spark.createDataFrame(Row(**x) for x in labels_rows)

sql_name = """
case
when first_name_l = first_name_r then 2
when substr(first_name_l, 1,3) = substr(first_name_r, 1,3) then 1
else 0
end
"""

settings = {
"comparison_columns": [
{"col_name": "first_name", "case_expression": sql_name, "num_levels": 3},
{
"col_name": "dob",
"case_expression": sql_gen_case_smnt_strict_equality_2("dob"),
},
],
"link_type": "dedupe_only",
"unique_id_column_name": "uid",
"source_dataset_column_name": "sds",
}

# This test requires graphframes and connected components, which aren't dev dependencies
# I have checked and they pass
# set_cc = estimate_m_from_labels(
# settings, df, df_labels, use_connected_components=True
# )

# m_first_name = set_cc["comparison_columns"][0]["m_probabilities"]

# assert pytest.approx(m_first_name) == [3 / 6, 2 / 6, 1 / 6]

# m_dob = set_cc["comparison_columns"][1]["m_probabilities"]
# assert pytest.approx(m_dob) == [3 / 6, 3 / 6]

set_nocc = estimate_m_from_labels(
settings, df, df_labels, use_connected_components=False
)
m_first_name = set_nocc["comparison_columns"][0]["m_probabilities"]
assert pytest.approx(m_first_name) == [1 / 3, 1 / 3, 1 / 3]

m_dob = set_nocc["comparison_columns"][1]["m_probabilities"]
assert pytest.approx(m_dob) == [1 / 3, 2 / 3]

0 comments on commit e755955

Please sign in to comment.