Skip to content

Commit

Permalink
Merge pull request #122 from moj-analytical-services/issue118
Browse files Browse the repository at this point in the history
first attempt to fix issue 118
  • Loading branch information
mamonu authored Aug 21, 2020
2 parents 3b39b75 + baaa6d3 commit a7003f6
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,6 @@ dmypy.json

# gitignore exception for docs files (otherwise docs/build does not get saved in repo
!docs/*

# some files created by certain tests need to not be uploaded to the repo
saved_model.json
11 changes: 9 additions & 2 deletions splink/term_frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# https://github.com/moj-analytical-services/splink/pull/107

import logging
import math
import warnings

try:
from pyspark.sql.dataframe import DataFrame
Expand Down Expand Up @@ -52,8 +54,13 @@ def sql_gen_generate_adjusted_lambda(column_name, params, table_name="df_e"):
max_level = params.params["π"][f"gamma_{column_name}"]["num_levels"] - 1
m = params.params["π"][f"gamma_{column_name}"]["prob_dist_match"][f"level_{max_level}"]["probability"]
u = params.params["π"][f"gamma_{column_name}"]["prob_dist_non_match"][f"level_{max_level}"]["probability"]
average_adjustment = m/(m+u)


# ensure average adj calculation doesnt divide by zero (see issue 118)
if ( math.isclose((m+u), 0.0, rel_tol=1e-9, abs_tol=0.0)):
average_adjustment = 0.5
warnings.warn( f" Is most of column {column_name} or all of it comprised of NULL values??? There are levels where no comparisons are found.")
else:
average_adjustment = m/(m+u)

sql = f"""
with temp_adj as
Expand Down
121 changes: 121 additions & 0 deletions tests/test_adj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
from splink import Splink
import pandas as pd
import pyspark.sql.functions as f
import pyspark
import warnings


@pytest.fixture(scope="module")
def spark():

try:

import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql import types

conf = SparkConf()

conf.set("spark.sql.shuffle.partitions", "1")
conf.set("spark.jars.ivy", "/home/jovyan/.ivy2/")
conf.set("spark.driver.extraClassPath", "jars/scala-udf-similarity-0.0.6.jar")
conf.set("spark.jars", "jars/scala-udf-similarity-0.0.6.jar")
conf.set("spark.driver.memory", "4g")
conf.set("spark.sql.shuffle.partitions", "24")

sc = SparkContext.getOrCreate(conf=conf)

spark = SparkSession(sc)

udfs = [
("jaro_winkler_sim", "JaroWinklerSimilarity", types.DoubleType()),
("jaccard_sim", "JaccardSimilarity", types.DoubleType()),
("cosine_distance", "CosineDistance", types.DoubleType()),
("Dmetaphone", "DoubleMetaphone", types.StringType()),
("QgramTokeniser", "QgramTokeniser", types.StringType()),
("Q3gramTokeniser", "Q3gramTokeniser", types.StringType()),
("Q4gramTokeniser", "Q4gramTokeniser", types.StringType()),
("Q5gramTokeniser", "Q5gramTokeniser", types.StringType()),
]

for a, b, c in udfs:
spark.udf.registerJavaFunction(a, "uk.gov.moj.dash.linkage." + b, c)

SPARK_EXISTS = True
except:
SPARK_EXISTS = False

if SPARK_EXISTS:
print("Spark exists, running spark tests")
yield spark
else:
spark = None
logger.error("Spark not available")
print("Spark not available")
yield spark



@pytest.fixture(scope="module")
def sparkdf(spark):

data = [
{"surname": "smith", "firstname": "john"},
{"surname": "smith", "firstname": "john"},
{"surname": "smithe","firstname": "john"}


]

dfpd = pd.DataFrame(data)
df = spark.createDataFrame(dfpd)
yield df


def test_freq_adj_divzero(spark, sparkdf):

# create settings object that requests term_freq_adjustments on column 'weird'

settings = {
"link_type": "dedupe_only",
"blocking_rules": [
"l.surname = r.surname",

],
"comparison_columns": [
{
"col_name": "firstname",
"num_levels": 3,
},
{
"col_name": "surname",
"num_levels": 3,
"term_frequency_adjustments": True
},
{
"col_name": "weird",
"num_levels": 3,
"term_frequency_adjustments": True
}

],
"additional_columns_to_retain": ["unique_id"],
"em_convergence": 0.01
}


sparkdf = sparkdf.withColumn("unique_id", f.monotonically_increasing_id())
# create column weird in a way that could trigger a div by zero on the average adj calculation before the fix
sparkdf = sparkdf.withColumn("weird",f.lit(None))


try:
linker = Splink(settings, spark, df=sparkdf)
notpassing = False
except ZeroDivisionError:
notpassing = True

assert ( notpassing == False )

0 comments on commit a7003f6

Please sign in to comment.