Skip to content

Commit

Permalink
load and save model
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Mar 16, 2020
1 parent 75b4bc3 commit 74d49e6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
41 changes: 33 additions & 8 deletions splink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from splink.settings import complete_settings_dict
from splink.validate import validate_settings
from splink.params import Params
from splink.params import Params, load_params_from_json
from splink.case_statements import _check_jaro_registered
from splink.blocking import block_using_rules
from splink.gammas import add_gammas
Expand Down Expand Up @@ -118,21 +118,15 @@ def manually_apply_fellegi_sunter_weights(self):
df_gammas = add_gammas(df_comparison, self.settings, self.spark)
return run_expectation_step(df_gammas, self.params, self.settings, self.spark)

def get_scored_comparisons(self, num_iterations:int=None):
def get_scored_comparisons(self):
"""Use the EM algorithm to estimate model parameters and return match probabilities.
Note: Does not compute term frequency adjustments.
Args:
num_iterations (int, optional): Override to allow user to specify max iterations. Defaults to None.
Returns:
DataFrame: A spark dataframe including a match probability column
"""

if not num_iterations:
num_iterations = self.settings["max_iterations"]

df_comparison = self._get_df_comparison()

df_gammas = add_gammas(df_comparison, self.settings, self.spark)
Expand Down Expand Up @@ -168,3 +162,34 @@ def make_term_frequency_adjustments(self, df_e: DataFrame):
spark=self.spark,
)

def save_model_as_json(self, path:str, overwrite=False):
"""Save model (settings, parameters and parameter history) as a json file so it can later be re-loaded using load_from_json
Args:
path (str): Path to the json file.
overwrite (bool): Whether to overwrite the file if it exsits
"""
self.params.save_params_to_json_file(path, overwrite=overwrite)


def load_from_json(path: str,
spark: SparkSession,
df_l: DataFrame = None,
df_r: DataFrame = None,
df: DataFrame = None,
save_state_fn: Callable = None):
"""Load a splink model from a json file which has previously been created using 'save_model_as_json'
Args:
path (string): path to json file created using Splink.save_model_as_json
spark (SparkSession): SparkSession object
df_l (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
df_r (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`.
df (DataFrame, optional): The dataframe to dedupe. Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`.
save_state_fn (function, optional): A function provided by the user that takes two arguments, params and settings, and is executed each iteration. This is a hook that allows the user to save the state between iterations, which is mostly useful for very large jobs which may need to be restarted from where they left off if they fail.
"""
params = load_params_from_json(path)
settings = params.settings
linker = Splink(settings, spark, df_l, df_r, df, save_state_fn)
linker.params = params
return linker
24 changes: 23 additions & 1 deletion tests/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import os

from splink import Splink, load_from_json
from splink.blocking import block_using_rules
from splink.params import Params
from splink.gammas import add_gammas, complete_settings_dict
Expand Down Expand Up @@ -569,4 +570,25 @@ def test_link_option_dedupe_only(spark, link_dedupe_data_repeat_ids):
df = df.sort_values(["unique_id_l", "unique_id_r"])

assert list(df["unique_id_l"]) == [2]
assert list(df["unique_id_r"]) == [3]
assert list(df["unique_id_r"]) == [3]


def test_main_api(spark, sqlite_con_1):

settings = {
"link_type": "dedupe_only",
"comparison_columns": [{"col_name": "surname"},
{"col_name": "mob"}],
"blocking_rules": ["l.mob = r.mob", "l.surname = r.surname"],
"max_iterations": 2
}
settings = complete_settings_dict(settings, spark=None)
dfpd = pd.read_sql("select * from test1", sqlite_con_1)

df = spark.createDataFrame(dfpd)

linker = Splink(settings,spark, df=df)
df_e = linker.get_scored_comparisons()
linker.save_model_as_json("saved_model.json", overwrite=True)
linker_2 = load_from_json("saved_model.json", spark=spark, df=df)
df_e = linker_2.get_scored_comparisons()

0 comments on commit 74d49e6

Please sign in to comment.