Skip to content

Commit

Permalink
address unique id issues and test
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Mar 3, 2020
1 parent d88153f commit a702cdb
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 13 deletions.
34 changes: 24 additions & 10 deletions splink/blocking.py
Expand Up @@ -21,6 +21,8 @@ def _get_columns_to_retain_blocking(settings):

# Use ordered dict as an ordered set - i.e. to make sure we don't have duplicate cols to retain

# That means we're only interested in the keys so we set values to None

columns_to_retain = OrderedDict()
columns_to_retain[settings["unique_id_column_name"]] = None

Expand All @@ -31,7 +33,7 @@ def _get_columns_to_retain_blocking(settings):
for c in settings["additional_columns_to_retain"]:
columns_to_retain[c] = None

return columns_to_retain.keys()
return list(columns_to_retain.keys())

def sql_gen_and_not_previous_rules(previous_rules: list):
if previous_rules:
Expand All @@ -49,10 +51,10 @@ def sql_gen_vertically_concatenate(columns_to_retain: list, table_name_l = "df_l
retain = ", ".join(columns_to_retain)

sql = f"""
select {retain}, 'left' as source_table
select {retain}, 'left' as _source_table
from {table_name_l}
union all
select {retain}, 'right' as source_table
select {retain}, 'right' as _source_table
from {table_name_r}
"""

Expand Down Expand Up @@ -109,8 +111,12 @@ def sql_gen_block_using_rules(

if link_type == "link_only":
where_condition = ""
else:
elif link_type == "dedupe_only":
where_condition = f"where l.{unique_id_col} < r.{unique_id_col}"
elif link_type == "link_and_dedupe":
# Where a record from left and right are being compared, you want the left record to end up in the _l fields, and the right record to end up in _r fields.
where_condition = f"where (l._source_table < r._source_table) or (l.{unique_id_col} < r.{unique_id_col} and l._source_table = r._source_table)"


sqls = []
previous_rules =[]
Expand Down Expand Up @@ -153,13 +159,12 @@ def block_using_rules(
Returns:
pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison
"""

if len(settings["blocking_rules"])==0:
return cartesian_block(settings, spark, df_l, df_r, df)

link_type = settings["link_type"]


columns_to_retain = _get_columns_to_retain_blocking(settings)
unique_id_col = settings["unique_id_column_name"]

Expand All @@ -172,6 +177,7 @@ def block_using_rules(

if link_type == "link_and_dedupe":
df_concat = vertically_concatenate_datasets(df_l, df_r, settings, spark=spark)
columns_to_retain.append("_source_table")
df_concat.createOrReplaceTempView("df")
df_concat.persist()

Expand Down Expand Up @@ -211,7 +217,7 @@ def sql_gen_cartesian_block(
Returns:
str: A SQL statement that implements the join
"""

# In both these cases the data is in a single table
# (In the link_and_dedupe case the two tables have already been vertically concatenated)
if link_type in ['dedupe_only', 'link_and_dedupe']:
Expand All @@ -223,18 +229,26 @@ def sql_gen_cartesian_block(

sql_select_expr = sql_gen_comparison_columns(columns_to_retain)

if link_type == "link_only":
where_condition = ""
elif link_type == "dedupe_only":
where_condition = f"where l.{unique_id_col} < r.{unique_id_col}"
elif link_type == "link_and_dedupe":
# Where a record from left and right are being compared, you want the left record to end up in the _l fields, and the right record to end up in _r fields.
where_condition = f"where (l._source_table < r._source_table) or (l.{unique_id_col} < r.{unique_id_col} and l._source_table = r._source_table)"

sql = f"""
select
{sql_select_expr}
from {table_name_l} as l
cross join {table_name_r} as r
where l.{unique_id_col} < r.{unique_id_col}
{where_condition}
"""

return sql


def cartesian_block(
def cartesian_block(
settings: dict,
spark: SparkSession,
df_l: DataFrame=None,
Expand Down Expand Up @@ -274,7 +288,7 @@ def cartesian_block(
df_concat.persist()

sql = sql_gen_cartesian_block(link_type, columns_to_retain, unique_id_col)

logger.debug(format_sql(sql))

df_comparison = spark.sql(sql)
Expand Down
6 changes: 6 additions & 0 deletions splink/expectation_step.py
Expand Up @@ -98,6 +98,9 @@ def _sql_gen_gamma_prob_columns(params, settings, table_name="df_with_gamma"):
select_cols[f"prob_gamma_{col_name}_non_match"] = case_statements[f"prob_gamma_{col_name}_non_match"]
select_cols[f"prob_gamma_{col_name}_match"] = case_statements[f"prob_gamma_{col_name}_match"]

if settings["link_type"] == 'link_and_dedupe':
select_cols = _add_left_right(select_cols, "_source_table")

for c in settings["additional_columns_to_retain"]:
select_cols = _add_left_right(select_cols, c)

Expand Down Expand Up @@ -133,6 +136,9 @@ def _column_order_df_e_select_expr(settings, tf_adj_cols=False):
if col["term_frequency_adjustments"]:
select_cols[col_name+"_adj"] = col_name+"_adj"

if settings["link_type"] == 'link_and_dedupe':
select_cols = _add_left_right(select_cols, "_source_table")

for c in settings["additional_columns_to_retain"]:
select_cols = _add_left_right(select_cols, c)
return ", ".join(select_cols.values())
Expand Down
3 changes: 3 additions & 0 deletions splink/gammas.py
Expand Up @@ -46,6 +46,9 @@ def _get_select_expression_gammas(settings: dict):
cols_to_retain = _add_left_right(cols_to_retain, col_name)
cols_to_retain["gamma_" + col_name] = col["case_expression"]

if settings["link_type"] == 'link_and_dedupe':
cols_to_retain = _add_left_right(cols_to_retain, "_source_table")

for c in settings["additional_columns_to_retain"]:
cols_to_retain = _add_left_right(cols_to_retain, c)

Expand Down
31 changes: 31 additions & 0 deletions tests/conftest.py
Expand Up @@ -62,6 +62,37 @@ def link_dedupe_data():

yield con


@pytest.fixture(scope='module')
def link_dedupe_data_repeat_ids():

# Create the database and the database table
con = sqlite3.connect(":memory:")
con.row_factory = sqlite3.Row

data_l = [
{"unique_id": 1, "surname": "Linacre", "first_name": "Robin"},
{"unique_id": 2, "surname": "Smith", "first_name": "John"},
{"unique_id": 3, "surname": "Smith", "first_name": "John"}
]

data_into_table(data_l, "df_l", con)

data_r = [
{"unique_id": 1, "surname": "Linacre", "first_name": "Robin"},
{"unique_id": 2, "surname": "Smith", "first_name": "John"},
{"unique_id": 3, "surname": "Smith", "first_name": "Robin"}
]

data_into_table(data_r, "df_r", con)

cols_to_retain = ["unique_id", "surname", "first_name"]
sql = sql_gen_vertically_concatenate(cols_to_retain)
df = pd.read_sql(sql, con)
df.to_sql("df", con, index=False)

yield con

@pytest.mark.filterwarnings("ignore:*")
@pytest.fixture(scope="function")
def gamma_settings_1():
Expand Down
78 changes: 75 additions & 3 deletions tests/test_spark.py
Expand Up @@ -82,11 +82,11 @@ def test_no_blocking(spark, link_dedupe_data):
dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data)
df_l = spark.createDataFrame(dfpd_l)
df_r = spark.createDataFrame(dfpd_r)


df_comparison = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
df = df_comparison.toPandas()

df = df.sort_values(["unique_id_l", "unique_id_r"])

assert list(df["unique_id_l"]) == [1,1,1,2,2,2]
assert list(df["unique_id_r"]) == [7,8,9,7,8,9]
Expand Down Expand Up @@ -429,3 +429,75 @@ def test_iteration_known_data_generating_process(
"level_1"
]["probability"] == pytest.approx(0.5, abs=0.01)


def test_link_option_link_dedupe(spark, link_dedupe_data_repeat_ids):
settings = {
"link_type": "link_and_dedupe",
"comparison_columns": [{"col_name": "first_name"},
{"col_name": "surname"}],
"blocking_rules": [
"l.first_name = r.first_name",
"l.surname = r.surname"
]
}
settings = complete_settings_dict(settings, spark=None)
dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
df_l = spark.createDataFrame(dfpd_l)
dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids)
df_r = spark.createDataFrame(dfpd_r)
df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
df = df.toPandas()
df["u_l"] = df["unique_id_l"].astype(str) + df["_source_table_l"].str.slice(0,1)
df["u_r"] = df["unique_id_r"].astype(str) + df["_source_table_r"].str.slice(0,1)
df = df.sort_values(["_source_table_l", "_source_table_r", "unique_id_l", "unique_id_r"])

assert list(df["u_l"]) == ['2l', '1l', '1l', '2l', '2l', '3l', '3l', '1r', '2r']
assert list(df["u_r"]) == ['3l', '1r', '3r', '2r', '3r', '2r', '3r', '3r', '3r']


def test_link_option_link(spark, link_dedupe_data_repeat_ids):
settings = {
"link_type": "link_only",
"comparison_columns": [{"col_name": "first_name"},
{"col_name": "surname"}],
"blocking_rules": [
"l.first_name = r.first_name",
"l.surname = r.surname"
]
}
settings = complete_settings_dict(settings, spark=None)
dfpd_l = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
df_l = spark.createDataFrame(dfpd_l)
dfpd_r = pd.read_sql("select * from df_r", link_dedupe_data_repeat_ids)
df_r = spark.createDataFrame(dfpd_r)
df = block_using_rules(settings, spark, df_l=df_l, df_r=df_r)
df = df.toPandas()

df = df.sort_values(["unique_id_l", "unique_id_r"])

assert list(df["unique_id_l"]) == [1, 1, 2, 2, 3, 3]
assert list(df["unique_id_r"]) == [1, 3, 2, 3, 2, 3]



def test_link_option_dedupe_only(spark, link_dedupe_data_repeat_ids):
settings = {
"link_type": "dedupe_only",
"comparison_columns": [{"col_name": "first_name"},
{"col_name": "surname"}],
"blocking_rules": [
"l.first_name = r.first_name",
"l.surname = r.surname"
]
}
settings = complete_settings_dict(settings, spark=None)
dfpd = pd.read_sql("select * from df_l", link_dedupe_data_repeat_ids)
df = spark.createDataFrame(dfpd)

df = block_using_rules(settings, spark, df=df)
df = df.toPandas()

df = df.sort_values(["unique_id_l", "unique_id_r"])

assert list(df["unique_id_l"]) == [2]
assert list(df["unique_id_r"]) == [3]

0 comments on commit a702cdb

Please sign in to comment.