diff --git a/tests/test_spark.py b/tests/test_spark.py index 8f9a16f06f..96b49794d9 100644 --- a/tests/test_spark.py +++ b/tests/test_spark.py @@ -384,6 +384,43 @@ def test_case_statements(spark, sqlite_con_3): assert df.loc[3, "gamma_str_col"] == -1 assert df.loc[4, "gamma_str_col"] == -1 + + data = [ + {"surname_l": "smith", "forename1_l": "john", "forename2_l": "david", + "surname_r": "smith", "forename1_r": "john", "forename2_r": "david"}, + + {"surname_l": "smith", "forename1_l": "john", "forename2_l": "david", + "surname_r": "smithe", "forename1_r": "john", "forename2_r": "david"}, + + {"surname_l": "smith", "forename1_l": "john", "forename2_l": "david", + "surname_r": "john", "forename1_r": "smith", "forename2_r": "david"}, + + {"surname_l": "smith", "forename1_l": "john", "forename2_l": "david", + "surname_r": "john", "forename1_r": "david", "forename2_r": "smithe"}, + + {"surname_l": "linacre", "forename1_l": "john", "forename2_l": "david", + "surname_r": "linaker", "forename1_r": "john", "forename2_r": "david"}, + + {"surname_l": "smith", "forename1_l": "john", "forename2_l": "david", + "surname_r": "john", "forename1_r": "david", "forename2_r": "smarty"} + ] + dfpd = pd.DataFrame(data) + df = spark.createDataFrame(dfpd) + df.createOrReplaceTempView("df_names") + + sql = sql_gen_gammas_name_inversion_3("surname", ["forename1", "forename2"], "surname") + + df_results = spark.sql(f"select {sql} from df_names").toPandas() + assert df_results.loc[0, "gamma_surname"] == 3 + assert df_results.loc[1, "gamma_surname"] == 3 + assert df_results.loc[2, "gamma_surname"] == 2 + assert df_results.loc[3, "gamma_surname"] == 2 + assert df_results.loc[4, "gamma_surname"] == 1 + assert df_results.loc[5, "gamma_surname"] == 0 + + + + from splink.gammas import add_gammas