From 74d49e62698a8ce2961d72f5afb746991b406040 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 16 Mar 2020 14:33:17 +0000 Subject: [PATCH] load and save model --- splink/__init__.py | 41 +++++++++++++++++++++++++++++++++-------- tests/test_spark.py | 24 +++++++++++++++++++++++- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/splink/__init__.py b/splink/__init__.py index 4a1236303f..b1e7781292 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -9,7 +9,7 @@ from splink.settings import complete_settings_dict from splink.validate import validate_settings -from splink.params import Params +from splink.params import Params, load_params_from_json from splink.case_statements import _check_jaro_registered from splink.blocking import block_using_rules from splink.gammas import add_gammas @@ -118,21 +118,15 @@ def manually_apply_fellegi_sunter_weights(self): df_gammas = add_gammas(df_comparison, self.settings, self.spark) return run_expectation_step(df_gammas, self.params, self.settings, self.spark) - def get_scored_comparisons(self, num_iterations:int=None): + def get_scored_comparisons(self): """Use the EM algorithm to estimate model parameters and return match probabilities. Note: Does not compute term frequency adjustments. - Args: - num_iterations (int, optional): Override to allow user to specify max iterations. Defaults to None. - Returns: DataFrame: A spark dataframe including a match probability column """ - if not num_iterations: - num_iterations = self.settings["max_iterations"] - df_comparison = self._get_df_comparison() df_gammas = add_gammas(df_comparison, self.settings, self.spark) @@ -168,3 +162,34 @@ def make_term_frequency_adjustments(self, df_e: DataFrame): spark=self.spark, ) + def save_model_as_json(self, path:str, overwrite=False): + """Save model (settings, parameters and parameter history) as a json file so it can later be re-loaded using load_from_json + + Args: + path (str): Path to the json file. + overwrite (bool): Whether to overwrite the file if it exsits + """ + self.params.save_params_to_json_file(path, overwrite=overwrite) + + +def load_from_json(path: str, + spark: SparkSession, + df_l: DataFrame = None, + df_r: DataFrame = None, + df: DataFrame = None, + save_state_fn: Callable = None): + """Load a splink model from a json file which has previously been created using 'save_model_as_json' + + Args: + path (string): path to json file created using Splink.save_model_as_json + spark (SparkSession): SparkSession object + df_l (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`. + df_r (DataFrame, optional): A dataframe to link/dedupe. Where `link_type` is `link_only` or `link_and_dedupe`, one of the two dataframes to link. Should be ommitted `link_type` is `dedupe_only`. + df (DataFrame, optional): The dataframe to dedupe. Where `link_type` is `dedupe_only`, the dataframe to dedupe. Should be ommitted `link_type` is `link_only` or `link_and_dedupe`. + save_state_fn (function, optional): A function provided by the user that takes two arguments, params and settings, and is executed each iteration. This is a hook that allows the user to save the state between iterations, which is mostly useful for very large jobs which may need to be restarted from where they left off if they fail. + """ + params = load_params_from_json(path) + settings = params.settings + linker = Splink(settings, spark, df_l, df_r, df, save_state_fn) + linker.params = params + return linker \ No newline at end of file diff --git a/tests/test_spark.py b/tests/test_spark.py index 3157a48d03..2512d928e5 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -3,6 +3,7 @@ import copy import os +from splink import Splink, load_from_json from splink.blocking import block_using_rules from splink.params import Params from splink.gammas import add_gammas, complete_settings_dict @@ -569,4 +570,25 @@ def test_link_option_dedupe_only(spark, link_dedupe_data_repeat_ids): df = df.sort_values(["unique_id_l", "unique_id_r"]) assert list(df["unique_id_l"]) == [2] - assert list(df["unique_id_r"]) == [3] \ No newline at end of file + assert list(df["unique_id_r"]) == [3] + + +def test_main_api(spark, sqlite_con_1): + + settings = { + "link_type": "dedupe_only", + "comparison_columns": [{"col_name": "surname"}, + {"col_name": "mob"}], + "blocking_rules": ["l.mob = r.mob", "l.surname = r.surname"], + "max_iterations": 2 + } + settings = complete_settings_dict(settings, spark=None) + dfpd = pd.read_sql("select * from test1", sqlite_con_1) + + df = spark.createDataFrame(dfpd) + + linker = Splink(settings,spark, df=df) + df_e = linker.get_scored_comparisons() + linker.save_model_as_json("saved_model.json", overwrite=True) + linker_2 = load_from_json("saved_model.json", spark=spark, df=df) + df_e = linker_2.get_scored_comparisons() \ No newline at end of file