Skip to content

Commit

Permalink
Merge 5c9aed1 into d13a84d
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Jan 8, 2022
2 parents d13a84d + 5c9aed1 commit bbc64e3
Show file tree
Hide file tree
Showing 20 changed files with 668 additions and 379 deletions.
4 changes: 2 additions & 2 deletions Dockerfile_testrunner
@@ -1,7 +1,7 @@
FROM mamonu/moj-spark-jovyan:baseenv

RUN pip install pytest pytest-cov poetry coveralls typeguard
RUN pip install --no-dependencies splink-data-generation==1.0.0
RUN pip install pytest pytest-cov poetry coveralls typeguard sqlglot
RUN pip install --no-dependencies splink-data-generation==1.0.1

ADD . /myfiles
WORKDIR /myfiles
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Expand Up @@ -12,11 +12,12 @@ readme = "README.md"
python = "^3.6"
jsonschema = "^3.2"
typeguard = "^2.10.0"
sqlglot = "^1.16.1"

[tool.poetry.dev-dependencies]
pytest = "^5.3"
pandas = "^1.0.0"
splink-data-generation = "^0.2.1"
splink-data-generation = "^1.0.1"

[build-system]
requires = ["poetry>=0.12"]
Expand Down
2 changes: 1 addition & 1 deletion splink/__init__.py
Expand Up @@ -36,7 +36,7 @@ def __init__(
):
"""Splink data linker
Provides easy access to the core user-facing functinoality of splink
Provides easy access to the core user-facing functionality of splink
Args:
settings (dict): splink settings dictionary
Expand Down
98 changes: 93 additions & 5 deletions splink/default_settings.py
@@ -1,5 +1,5 @@
import warnings

from splink.settings import ComparisonColumn
from pyspark.sql.session import SparkSession

from copy import deepcopy
Expand All @@ -20,6 +20,12 @@
_add_as_gamma_to_case_statement,
)

from .parse_case_statement import (
parse_case_statement,
generate_sql_from_parsed_case_expr,
get_columns_used_from_sql_without_l_r_suffix,
)


def _normalise_prob_list(prob_array: list):
sum_list = sum(prob_array)
Expand Down Expand Up @@ -101,8 +107,12 @@ def _get_default_probabilities(m_or_u, levels):

def _complete_case_expression(col_settings, spark):

cc = ComparisonColumn(col_settings)
if cc.has_case_expression_or_comparison_levels:
return col_settings

default_case_statements = _get_default_case_statements_functions(spark)
levels = col_settings["num_levels"]
levels = cc.num_levels

if "custom_name" in col_settings:
col_name_for_case_fn = col_settings["custom_name"]
Expand Down Expand Up @@ -136,7 +146,8 @@ def _complete_probabilities(col_settings: dict, mu_probabilities: str):
"""

if mu_probabilities not in col_settings:
levels = col_settings["num_levels"]
cc = ComparisonColumn(col_settings)
levels = cc.num_levels
probs = _get_default_probabilities(mu_probabilities, levels)
col_settings[mu_probabilities] = probs

Expand All @@ -149,11 +160,72 @@ def _complete_tf_adjustment_weights(col_settings: dict):
f"All values of 'tf_adjustment_weights' must be between 0 and 1"
)
else:
weights = [0.0] * col_settings["num_levels"]
cc = ComparisonColumn(col_settings)

weights = [0.0] * cc.num_levels
weights[-1] = 1.0
col_settings["tf_adjustment_weights"] = weights


def _complete_comparison_levels(col_settings):
if "comparison_levels" not in col_settings:
case_expression = col_settings["case_expression"]
col_settings["comparison_levels"] = parse_case_statement(case_expression)

if "case_expression" not in col_settings:
cl = col_settings["comparison_levels"]
col_settings["case_expression"] = generate_sql_from_parsed_case_expr(cl)

from splink.settings import ComparisonColumn

cc = ComparisonColumn(col_settings)
keys = cc.comparison_levels_dict.keys()
if "-1" not in keys:

warnings.warn(
"No -1 level found in case statement."
" You usually want to use -1 as the level for the null value."
" e.g. WHEN col_l is null or col_r is null then -1"
f" Case statement is:\n {col_settings['case_expression']}."
)


def _complete_col_name(col_settings):

if "custom_name" in col_settings:
return

if "col_name" in col_settings:
return

sql = generate_sql_from_parsed_case_expr(col_settings["comparison_levels"])
sql_cols = get_columns_used_from_sql_without_l_r_suffix(sql)
if len(sql_cols) == 1:
col_settings["col_name"] = sql_cols[0]
else:
col_settings["custom_name"] = "_".join(sql_cols)
return col_settings


def _complete_custom_columns(col_settings):

if "col_name" in col_settings:
return col_settings

if "custom_name" in col_settings:
sql = generate_sql_from_parsed_case_expr(col_settings["comparison_levels"])
sql_cols = get_columns_used_from_sql_without_l_r_suffix(sql)
if "columns_used" in col_settings:
if set(sql_cols) != set(col_settings["columns_used"]):
warnings.warn(
f"The columns used in the case statement are {sql_cols} but the columns "
f"specified in the settings dictionary are {col_settings['columns_used']}"
)
else:
col_settings["custom_columns_used"] = sql_cols
return col_settings


def complete_settings_dict(settings_dict: dict, spark: SparkSession):
"""Auto-populate any missing settings from the settings dictionary using the 'sensible defaults' that
are specified in the json schema (./splink/files/settings_jsonschema.json)
Expand Down Expand Up @@ -203,7 +275,6 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):

# Populate non-existing keys from defaults
keys_for_defaults = [
"num_levels",
"data_type",
"term_frequency_adjustments",
"fix_u_probabilities",
Expand All @@ -215,10 +286,27 @@ def complete_settings_dict(settings_dict: dict, spark: SparkSession):
default = get_default_value_from_schema(key, is_column_setting=True)
col_settings[key] = default

# Populate default value for num levels only if case_expression or comparison_levels is not specified
skip_if_present = set(["case_expression", "comparison_levels", "num_levels"])
keys = set(col_settings.keys())
intersect = keys.intersection(skip_if_present)
if len(intersect) == 0:
default = get_default_value_from_schema(
"num_levels", is_column_setting=True
)
col_settings["num_levels"] = default

# Doesn't need assignment because we're modify the col_settings dictionary

_complete_case_expression(col_settings, spark)

_complete_comparison_levels(col_settings)
_complete_col_name(col_settings)
_complete_custom_columns(col_settings)

_complete_probabilities(col_settings, "m_probabilities")
_complete_probabilities(col_settings, "u_probabilities")

_complete_tf_adjustment_weights(col_settings)

return settings_dict
Expand Down
6 changes: 1 addition & 5 deletions splink/diagnostics.py
Expand Up @@ -8,11 +8,7 @@
from typeguard import typechecked

from .charts import load_chart_definition, altair_if_installed_else_json
from .settings import complete_settings_dict, Settings
from .vertically_concat import vertically_concatenate_datasets
from .blocking import block_using_rules
from .gammas import add_gammas
from .estimate import _num_target_rows_to_rows_to_sample
from .settings import Settings


def _equal_spaced_buckets(num_buckets, extent):
Expand Down
2 changes: 1 addition & 1 deletion splink/estimate.py
Expand Up @@ -5,7 +5,7 @@
from .gammas import add_gammas
from .maximisation_step import run_maximisation_step
from .model import Model
from .settings import complete_settings_dict
from .default_settings import complete_settings_dict
from .vertically_concat import vertically_concatenate_datasets

import warnings
Expand Down

0 comments on commit bbc64e3

Please sign in to comment.