Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 64 #66

Merged
merged 7 commits into from
Feb 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,406 changes: 1,961 additions & 445 deletions quickstart_demo.ipynb

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion sparklink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sparklink.blocking import block_using_rules
from sparklink.gammas import add_gammas
from sparklink.iterate import iterate
from sparklink.expectation_step import run_expectation_step
from sparklink.term_frequencies import make_adjustment_for_term_frequencies

try:
Expand Down Expand Up @@ -73,13 +74,20 @@ def _get_df_comparison(self):
if self.settings["link_type"] in ("link_only", "link_and_dedupe"):
return block_using_rules(self.settings, self.spark, df_l = self.df_l, df_r=self.df_r)

def manually_apply_fellegi_sunter_weights(self):
df_comparison = self._get_df_comparison()
df_gammas = add_gammas(df_comparison, self.settings, self.spark, include_orig_cols = True)
return run_expectation_step(df_gammas, self.params, self.settings, self.spark)




def get_scored_comparisons(self, persist_df_gammas=True):
df_comparison = self._get_df_comparison()
df_gammas = add_gammas(df_comparison, self.settings, self.spark, include_orig_cols = True)
df_gammas.persist()

df_e = iterate(df_gammas, self.spark, self.params, log_iteration=True, num_iterations=5, compute_ll=False)
df_e = iterate(df_gammas, self.params, self.settings, self.spark, log_iteration=True, num_iterations=5, compute_ll=False)
df_gammas.unpersist()
return df_e

Expand Down
44 changes: 36 additions & 8 deletions sparklink/blocking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from collections import OrderedDict

# For type hints. Try except to ensure the sql_gen functions even if spark doesn't exist.
try:
Expand All @@ -11,11 +12,27 @@

from .logging_utils import log_sql, format_sql
from .sql import comparison_columns_select_expr, sql_gen_comparison_columns
from .settings import _get_columns_to_retain
from .check_types import check_spark_types
from .check_types import check_spark_types, check_types

log = logging.getLogger(__name__)


def _get_columns_to_retain_blocking(settings):

# Use ordered dict as an ordered set - i.e. to make sure we don't have duplicate cols to retain

columns_to_retain = OrderedDict()
columns_to_retain[settings["unique_id_column_name"]] = None

for c in settings["comparison_columns"]:
if c["col_is_in_input_df"]:
columns_to_retain[c["col_name"]] = None

for c in settings["additional_columns_to_retain"]:
columns_to_retain[c] = None

return columns_to_retain.keys()

def sql_gen_and_not_previous_rules(previous_rules: list):
if previous_rules:
# Note the isnull function is important here - otherwise
Expand Down Expand Up @@ -114,24 +131,35 @@ def sql_gen_block_using_rules(

return sql

@check_spark_types
@check_types
def block_using_rules(
settings,
settings: dict,
spark: SparkSession,
df_l: DataFrame=None,
df_r: DataFrame=None,
df: DataFrame=None,
columns_to_retain: list=None,
unique_id_col="unique_id",
logger=log
):
"""Apply a series of blocking rules to create a dataframe of record comparisons.

Args:
settings (dict): A sparklink settings dictionary
spark (SparkSession): The pyspark.sql.session.SparkSession
df_l (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
df_r (DataFrame, optional): Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
df (DataFrame, optional): Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`.
logger ([type], optional): [description]. Defaults to log.

Returns:
pyspark.sql.dataframe.DataFrame: A dataframe of each record comparison
"""


link_type = settings["link_type"]

if columns_to_retain is None:
columns_to_retain = _get_columns_to_retain(settings)

columns_to_retain = _get_columns_to_retain_blocking(settings)
unique_id_col = settings["unique_id_column_name"]

if link_type == "dedupe_only":
df.createOrReplaceTempView("df")
Expand Down
47 changes: 43 additions & 4 deletions sparklink/check_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ def wrapper(*args, **kwargs):
this_arg = args_dict[key]
this_type_hint = type_hints[key]

if type(this_type_hint) == type(Union):
possible_types = this_type_hint.__args__
else:
# If it's of type union it will have the __args__ argument
try:
if this_type_hint.__origin__ == Union:
possible_types = this_type_hint.__args__
else:
possible_types = (this_type_hint,)
except AttributeError:
possible_types = (this_type_hint,)


if DataFrame in possible_types or SparkSession in possible_types:
if isinstance(this_arg, possible_types):
pass
Expand All @@ -41,4 +44,40 @@ def wrapper(*args, **kwargs):

return func(*args, **kwargs)

return wrapper

def check_types(func):
def wrapper(*args, **kwargs):
type_hints = get_type_hints(func)

args_names = func.__code__.co_varnames[:func.__code__.co_argcount]
args_dict = {**dict(zip(args_names, args)), **kwargs}


for key in type_hints:
if key in args_dict:
this_arg = args_dict[key]
this_type_hint = type_hints[key]

# If it's of type union it will have the __args__ argument
try:
if this_type_hint.__origin__ == Union:
possible_types = this_type_hint.__args__
else:
possible_types = (this_type_hint,)
except AttributeError:
possible_types = (this_type_hint,)

if isinstance(this_arg, possible_types):
pass
else:
poss_types_str = [str(t) for t in possible_types]
poss_types_str = ' or '.join(poss_types_str)
raise TypeError(f"You passed the wrong type for argument {key}. "
f"You passed the argument {args_dict[key]} of type {type(args_dict[key])}. "
f"The type for this argument should be {poss_types_str}. ")


return func(*args, **kwargs)

return wrapper
125 changes: 107 additions & 18 deletions sparklink/expectation_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,27 @@
"""

import logging
from collections import OrderedDict

# For type hints. Try except to ensure the sql_gen functions even if spark doesn't exist.
try:
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.session import SparkSession
except ImportError:
DataFrame = None
SparkSession = None

log = logging.getLogger(__name__)
from .logging_utils import log_sql, log_other

def run_expectation_step(df_with_gamma, spark, params, compute_ll=False, logger=log):
from .gammas import _add_left_right
from .params import Params

def run_expectation_step(df_with_gamma: DataFrame,
params: Params,
settings: dict,
spark: SparkSession,
compute_ll=False,
logger=log):
"""[summary]

Args:
Expand All @@ -22,7 +38,7 @@ def run_expectation_step(df_with_gamma, spark, params, compute_ll=False, logger=
[type]: [description]
"""

sql = sql_gen_gamma_prob_columns(params)
sql = sql_gen_gamma_prob_columns(params, settings)

df_with_gamma.createOrReplaceTempView("df_with_gamma")
log_sql(sql, logger)
Expand All @@ -37,7 +53,7 @@ def run_expectation_step(df_with_gamma, spark, params, compute_ll=False, logger=
log_other(message, logger, level='INFO')
params.params["log_likelihood"] = ll

sql = sql_gen_expected_match_prob(params)
sql = sql_gen_expected_match_prob(params, settings)

log_sql(sql, logger)
df_with_gamma_probs.createOrReplaceTempView("df_with_gamma_probs")
Expand All @@ -47,31 +63,76 @@ def run_expectation_step(df_with_gamma, spark, params, compute_ll=False, logger=
return df_e


def sql_gen_gamma_prob_columns(params, table_name="df_with_gamma"):
def sql_gen_gamma_prob_columns(params, settings, table_name="df_with_gamma"):
"""
For each row, look up the probability of observing the gamma value given the record
is a match and non_match respectively
"""

case_statements = []
# Get case statements
case_statements = {}
for gamma_str in params.gamma_cols:
for match in [0, 1]:
case_statements.append(
sql_gen_gamma_case_when(gamma_str, match, params))
alias = _case_when_col_alias(gamma_str, match)
case_statement = sql_gen_gamma_case_when(gamma_str, match, params)
case_statements[alias] = case_statement


# Column order for case statement. We want orig_col_l, orig_col_r, gamma_orig_col, prob_gamma_u, prob_gamma_m
select_cols = OrderedDict()
select_cols = _add_left_right(select_cols, settings["unique_id_column_name"])

for col in settings["comparison_columns"]:
col_name = col["col_name"]
if settings["retain_matching_columns"]:
select_cols = _add_left_right(select_cols, col_name)
if col["term_frequency_adjustments"]:
select_cols = _add_left_right(select_cols, col_name)
select_cols["gamma_" + col_name] = "gamma_" + col_name

select_cols[f"prob_gamma_{col_name}_non_match"] = case_statements[f"prob_gamma_{col_name}_non_match"]
select_cols[f"prob_gamma_{col_name}_match"] = case_statements[f"prob_gamma_{col_name}_match"]

for c in settings["additional_columns_to_retain"]:
select_cols[c] = c

select_expr = ", ".join(select_cols.values())

case_statements = ", \n\n".join(case_statements)

sql = f"""
-- We use case statements for these lookups rather than joins for performance and simplicity
select *,
{case_statements}
select {select_expr}
from {table_name}
"""

return sql


def sql_gen_expected_match_prob(params, table_name="df_with_gamma_probs"):
def _column_order_df_e_select_expr(settings, tf_adj_cols=False):
# Column order for case statement. We want orig_col_l, orig_col_r, gamma_orig_col, prob_gamma_u, prob_gamma_m
select_cols = OrderedDict()
select_cols = _add_left_right(select_cols, settings["unique_id_column_name"])

for col in settings["comparison_columns"]:
col_name = col["col_name"]
if settings["retain_matching_columns"]:
select_cols = _add_left_right(select_cols, col_name)
if col["term_frequency_adjustments"]:
select_cols = _add_left_right(select_cols, col_name)
select_cols["gamma_" + col_name] = "gamma_" + col_name

if settings["retain_intermediate_calculation_columns"]:
select_cols[f"prob_gamma_{col_name}_non_match"] = f"prob_gamma_{col_name}_non_match"
select_cols[f"prob_gamma_{col_name}_match"] = f"prob_gamma_{col_name}_match"
if tf_adj_cols:
if col["term_frequency_adjustments"]:
select_cols[col_name+"_adj"] = col_name+"_adj"

for c in settings["additional_columns_to_retain"]:
select_cols[c] = c
return ", ".join(select_cols.values())

def sql_gen_expected_match_prob(params, settings, table_name="df_with_gamma_probs"):
gamma_cols = params.gamma_cols

numerator = " * ".join([f"prob_{g}_match" for g in gamma_cols])
Expand All @@ -82,13 +143,44 @@ def sql_gen_expected_match_prob(params, table_name="df_with_gamma_probs"):
castoneminusλ = f"cast({1-λ} as double)"
match_prob_expression = f"({castλ} * {numerator})/(( {castλ} * {numerator}) + ({castoneminusλ} * {denom_part})) as match_probability"

# Get select expression for the other columns to select

# Column order for case statement. We want orig_col_l, orig_col_r, gamma_orig_col, prob_gamma_u, prob_gamma_m
# select_cols = OrderedDict()
# select_cols = _add_left_right(select_cols, settings["unique_id_column_name"])

# for col in settings["comparison_columns"]:
# col_name = col["col_name"]
# if settings["retain_matching_columns"]:
# select_cols = _add_left_right(select_cols, col_name)
# if col["term_frequency_adjustments"]:
# select_cols = _add_left_right(select_cols, col_name)
# select_cols["gamma_" + col_name] = "gamma_" + col_name

# if settings["retain_intermediate_calculation_columns"]:
# select_cols[f"prob_gamma_{col_name}_non_match"] = f"prob_gamma_{col_name}_non_match"
# select_cols[f"prob_gamma_{col_name}_match"] = f"prob_gamma_{col_name}_match"

# for c in settings["additional_columns_to_retain"]:
# select_cols[c] = c

select_expr = _column_order_df_e_select_expr(settings)

sql = f"""
select {match_prob_expression}, *
select {match_prob_expression}, {select_expr}
from {table_name}
"""

return sql

def _case_when_col_alias(gamma_str, match):

if match == 1:
name_suffix = "_match"
if match == 0:
name_suffix = "_non_match"

return f"prob_{gamma_str}{name_suffix}"

def sql_gen_gamma_case_when(gamma_str, match, params):
"""
Expand All @@ -111,12 +203,9 @@ def sql_gen_gamma_case_when(gamma_str, match, params):

case_statements = "\n".join(case_statements)

if match == 1:
name_suffix = "_match"
if match == 0:
name_suffix = "_non_match"
alias = _case_when_col_alias(gamma_str, match)

sql = f""" case \n{case_statements} \nend \nas prob_{gamma_str}{name_suffix}"""
sql = f""" case \n{case_statements} \nend \nas {alias}"""

return sql.strip()

Expand Down
Loading