Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
373 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
def _sql_expr_move_left_to_right( | ||
col_name, | ||
unique_id_col: str = "unique_id", | ||
source_dataset_col: str = "source_dataset", | ||
): | ||
|
||
sds_l = f"{source_dataset_col}_l" | ||
uid_l = f"{unique_id_col}_l" | ||
sds_r = f"{source_dataset_col}_r" | ||
uid_r = f"{unique_id_col}_r" | ||
col_name_l = f"{col_name}_l" | ||
col_name_r = f"{col_name}_r" | ||
|
||
if source_dataset_col: | ||
uid_expr_l = f"concat({sds_l}, '-__-', {uid_l})" | ||
uid_expr_r = f"concat({sds_r}, '-__-', {uid_r})" | ||
else: | ||
uid_expr_l = uid_l | ||
uid_expr_r = uid_r | ||
|
||
move_to_left = f""" | ||
CASE | ||
WHEN {uid_expr_l} < {uid_expr_r} | ||
THEN {col_name_l} | ||
ELSE {col_name_r} | ||
END as {col_name_l} | ||
""" | ||
|
||
move_to_right = f""" | ||
CASE | ||
WHEN {uid_expr_l} < {uid_expr_r} | ||
THEN {col_name_r} | ||
ELSE {col_name_l} | ||
END as {col_name_r} | ||
""" | ||
|
||
exprs = f""" | ||
{move_to_left}, | ||
{move_to_right} | ||
""" | ||
|
||
return exprs | ||
|
||
|
||
def lower_id_to_left_hand_side( | ||
df, | ||
source_dataset_col: str = "source_dataset", | ||
unique_id_col: str = "unique_id", | ||
): | ||
"""Take a dataframe in the format of splink record comparisons (with _l and _r suffixes) | ||
and return a dataframe where the _l columns correspond to the record with the lower id. | ||
For example: | ||
| source_dataset_l | unique_id_l | source_dataset_r | unique_id_r | a_l | a_r | other_col | | ||
|:-------------------|--------------:|:-------------------|--------------:|------:|------:|:------------| | ||
| df | 0 | df | 1 | 0 | 1 | a | | ||
| df | 2 | df | 0 | 2 | 0 | b | | ||
| df | 0 | df | 3 | 0 | 3 | c | | ||
Becomes | ||
| source_dataset_l | unique_id_l | source_dataset_r | unique_id_r | a_l | a_r | other_col | | ||
|:-------------------|--------------:|:-------------------|--------------:|------:|------:|:------------| | ||
| df | 0 | df | 1 | 0 | 1 | a | | ||
| df | 0 | df | 2 | 0 | 2 | b | | ||
| df | 0 | df | 3 | 0 | 3 | c | | ||
Returns: | ||
df: a dataframe with the columns _l and _r swapped in the case where | ||
the unique_id_r < unique_id_l | ||
""" | ||
spark = df.sql_ctx.sparkSession | ||
cols = list(df.columns) | ||
|
||
l_cols = [c for c in cols if c.endswith("_l")] | ||
r_cols = [c for c in cols if c.endswith("_r")] | ||
other_cols = [c for c in cols if c not in (l_cols + r_cols)] | ||
|
||
case_exprs = [] | ||
for col in l_cols: | ||
this_col = col[:-2] | ||
expr = _sql_expr_move_left_to_right(this_col, unique_id_col, source_dataset_col) | ||
case_exprs.append(expr) | ||
case_exprs.extend(other_cols) | ||
select_expr = ", ".join(case_exprs) | ||
|
||
df.createOrReplaceTempView("df") | ||
sql = f""" | ||
select {select_expr} | ||
from df | ||
""" | ||
|
||
df = spark.sql(sql) | ||
return df.select(cols) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
from pyspark.sql.dataframe import DataFrame | ||
from pyspark.sql.functions import lit | ||
from splink.vertically_concat import vertically_concatenate_datasets | ||
from splink.lower_id_on_lhs import lower_id_to_left_hand_side | ||
from splink.blocking import _get_columns_to_retain_blocking, sql_gen_comparison_columns | ||
from splink.settings import Settings | ||
from splink.blocking import block_using_rules | ||
from splink.gammas import add_gammas | ||
from splink.maximisation_step import run_maximisation_step | ||
from splink.model import Model | ||
from splink.cluster import _check_graphframes_installation | ||
|
||
|
||
def estimate_m_from_labels( | ||
settings: dict, | ||
df_or_dfs: DataFrame, | ||
labels: DataFrame, | ||
use_connected_components, | ||
fix_m_probabilities=False, | ||
): | ||
"""Estimate m values from labels | ||
Args: | ||
settings (dict): splink settings dictionary | ||
df_or_dfs (DataFrame): (DataFrame or list of DataFrames, optional): | ||
labels (DataFrame): Labelled data. | ||
For link or link and dedupe, should have columns: | ||
'source_dataset_l', 'unique_id_l', 'source_dataset_r', and 'unique_id_r' | ||
For dedupe only, only needs 'unique_id_l' and 'unique_id_r' columns | ||
use_connected_components (bool, optional): Whether to use the connected components approach. | ||
Defaults to True. Described here: https://github.com/moj-analytical-services/splink/issues/245 | ||
fix_m_probabilities (bool, optional): If True, output comparison column settings will have | ||
fix_u_probabilities set to True. Defaults to False. | ||
""" | ||
|
||
# dfs is a list of dfs irrespective of whether input was a df or list of dfs | ||
if type(df_or_dfs) == DataFrame: | ||
dfs = [df_or_dfs] | ||
else: | ||
dfs = df_or_dfs | ||
|
||
spark = dfs[0].sql_ctx.sparkSession | ||
|
||
if use_connected_components: | ||
_check_graphframes_installation(spark) | ||
|
||
df_nodes = vertically_concatenate_datasets(dfs) | ||
|
||
from splink.settings import complete_settings_dict | ||
|
||
settings_complete = complete_settings_dict(settings, spark) | ||
if settings_complete["link_type"] == "dedupe_only": | ||
use_source_dataset = False | ||
else: | ||
use_source_dataset = True | ||
|
||
source_dataset_colname = settings_complete["source_dataset_column_name"] | ||
uid_colname = settings_complete["unique_id_column_name"] | ||
|
||
if use_connected_components: | ||
df_gammas = _get_comparisons_using_connected_components( | ||
df_nodes, | ||
labels, | ||
settings_complete, | ||
use_source_dataset, | ||
source_dataset_colname, | ||
uid_colname, | ||
) | ||
else: | ||
df_gammas = _get_comparisons_using_joins( | ||
df_nodes, | ||
labels, | ||
settings_complete, | ||
use_source_dataset, | ||
source_dataset_colname, | ||
uid_colname, | ||
) | ||
|
||
df_e = df_gammas.withColumn("match_probability", lit(1.0)) | ||
|
||
model = Model(settings_complete, spark) | ||
run_maximisation_step(df_e, model, spark) | ||
|
||
settings_with_m_dict = model.current_settings_obj.settings_dict | ||
|
||
# We want to add m probabilities from these estimates to the settings_with_u object | ||
settings_obj = Settings(settings) | ||
|
||
settings_obj.overwrite_m_u_probs_from_other_settings_dict( | ||
settings_with_m_dict, overwrite_u=False | ||
) | ||
|
||
for cc in settings_obj.comparison_columns_list: | ||
if fix_m_probabilities: | ||
cc.fix_m_probabilities = True | ||
|
||
return settings_obj.settings_dict | ||
|
||
|
||
def _get_comparisons_using_connected_components( | ||
df_nodes, | ||
df_labels, | ||
settings_complete, | ||
use_source_dataset, | ||
source_dataset_colname, | ||
uid_colname, | ||
): | ||
from graphframes import GraphFrame | ||
|
||
spark = df_nodes.sql_ctx.sparkSession | ||
|
||
if use_source_dataset: | ||
uid_node = f"concat({source_dataset_colname}, '-__-',{uid_colname}) as id" | ||
uid_r = f"concat({source_dataset_colname}_l, '-__-',{uid_colname}_l) as src" | ||
uid_l = f"concat({source_dataset_colname}_r, '-__-',{uid_colname}_r) as dst" | ||
else: | ||
uid_node = f"{uid_colname} as id" | ||
uid_r = f"{uid_colname}_l as src" | ||
uid_l = f"{uid_colname}_r as dst" | ||
|
||
cc_nodes = df_nodes.selectExpr(uid_node) | ||
edges = df_labels.selectExpr(uid_l, uid_r) | ||
g = GraphFrame(cc_nodes, edges) | ||
g = g.dropIsolatedVertices() | ||
cc = g.connectedComponents() | ||
|
||
df_nodes.createOrReplaceTempView("df_nodes") | ||
cc.createOrReplaceTempView("cc") | ||
|
||
if use_source_dataset: | ||
join_col_expr = ( | ||
f"concat(df_nodes.{source_dataset_colname}, '-__-',df_nodes.{uid_colname})" | ||
) | ||
else: | ||
join_col_expr = f"df_nodes.{uid_colname}" | ||
|
||
sql = f""" | ||
select df_nodes.*, cc.component as cluster | ||
from df_nodes | ||
inner join cc | ||
on cc.id = {join_col_expr} | ||
""" | ||
|
||
df_with_cluster = spark.sql(sql) | ||
|
||
settings_complete["blocking_rules"] = ["l.cluster = r.cluster"] | ||
|
||
df_comparison = block_using_rules(settings_complete, df_with_cluster, spark) | ||
df_gammas = add_gammas(df_comparison, settings_complete, spark) | ||
|
||
return df_gammas | ||
|
||
|
||
def _get_comparisons_using_joins( | ||
df_nodes, | ||
df_labels, | ||
settings_complete, | ||
use_source_dataset, | ||
source_dataset_colname, | ||
uid_colname, | ||
): | ||
spark = df_nodes.sql_ctx.sparkSession | ||
df_labels = lower_id_to_left_hand_side( | ||
df_labels, source_dataset_colname, uid_colname | ||
) | ||
|
||
df_nodes.createOrReplaceTempView("df_nodes") | ||
df_labels.createOrReplaceTempView("df_labels") | ||
|
||
columns_to_retain = _get_columns_to_retain_blocking(settings_complete, df_nodes) | ||
|
||
sql_select_expr = sql_gen_comparison_columns(columns_to_retain) | ||
|
||
if use_source_dataset: | ||
|
||
sql = f""" | ||
select {sql_select_expr}, '0' as match_key | ||
from df_nodes as l | ||
inner join df_labels | ||
on l.{source_dataset_colname} = df_labels.{source_dataset_colname}_l and l.{uid_colname} = df_labels.{uid_colname}_l | ||
inner join df_nodes as r | ||
on r.{source_dataset_colname} = df_labels.{source_dataset_colname}_r and r.{uid_colname} = df_labels.{uid_colname}_r | ||
""" | ||
else: | ||
sql = f""" | ||
select {sql_select_expr}, '0' as match_key | ||
from df_nodes as l | ||
inner join df_labels | ||
on l.{uid_colname} = df_labels.{uid_colname}_l | ||
inner join df_nodes as r | ||
on r.{uid_colname} = df_labels.{uid_colname}_r | ||
""" | ||
|
||
df_comparison = spark.sql(sql) | ||
df_gammas = add_gammas(df_comparison, settings_complete, spark) | ||
return df_gammas |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from pyspark.sql import Row | ||
from splink.case_statements import sql_gen_case_smnt_strict_equality_2 | ||
from splink.m_from_labels import estimate_m_from_labels | ||
|
||
import pytest | ||
|
||
|
||
def test_m_from_labels(spark): | ||
|
||
# fmt: off | ||
df_rows = [ | ||
{"uid": "0", "sds": "df1", "first_name": "Robin", "dob": "1909-10-11"}, | ||
{"uid": "1", "sds": "df1", "first_name": "Robin", "dob": "1909-10-11"}, | ||
{"uid": "2", "sds": "df1", "first_name": "Robim", "dob": "1909-10-11"}, | ||
{"uid": "3", "sds": "df1", "first_name": "James", "dob": "1909-10-10"}, | ||
] | ||
|
||
labels_rows = [ | ||
{"uid_l": "1", "sds_l": "df1", "uid_r": "0", "sds_r": "df1"}, | ||
{"uid_l": "2", "sds_l": "df1", "uid_r": "0", "sds_r": "df1"}, | ||
{"uid_l": "0", "sds_l": "df1", "uid_r": "3", "sds_r": "df1"}, | ||
] | ||
# fmt: on | ||
|
||
df = spark.createDataFrame(Row(**x) for x in df_rows) | ||
|
||
df_labels = spark.createDataFrame(Row(**x) for x in labels_rows) | ||
|
||
sql_name = """ | ||
case | ||
when first_name_l = first_name_r then 2 | ||
when substr(first_name_l, 1,3) = substr(first_name_r, 1,3) then 1 | ||
else 0 | ||
end | ||
""" | ||
|
||
settings = { | ||
"comparison_columns": [ | ||
{"col_name": "first_name", "case_expression": sql_name, "num_levels": 3}, | ||
{ | ||
"col_name": "dob", | ||
"case_expression": sql_gen_case_smnt_strict_equality_2("dob"), | ||
}, | ||
], | ||
"link_type": "dedupe_only", | ||
"unique_id_column_name": "uid", | ||
"source_dataset_column_name": "sds", | ||
} | ||
|
||
# This test requires graphframes and connected components, which aren't dev dependencies | ||
# I have checked and they pass | ||
# set_cc = estimate_m_from_labels( | ||
# settings, df, df_labels, use_connected_components=True | ||
# ) | ||
|
||
# m_first_name = set_cc["comparison_columns"][0]["m_probabilities"] | ||
|
||
# assert pytest.approx(m_first_name) == [3 / 6, 2 / 6, 1 / 6] | ||
|
||
# m_dob = set_cc["comparison_columns"][1]["m_probabilities"] | ||
# assert pytest.approx(m_dob) == [3 / 6, 3 / 6] | ||
|
||
set_nocc = estimate_m_from_labels( | ||
settings, df, df_labels, use_connected_components=False | ||
) | ||
m_first_name = set_nocc["comparison_columns"][0]["m_probabilities"] | ||
assert pytest.approx(m_first_name) == [1 / 3, 1 / 3, 1 / 3] | ||
|
||
m_dob = set_nocc["comparison_columns"][1]["m_probabilities"] | ||
assert pytest.approx(m_dob) == [1 / 3, 2 / 3] |