The aim of this notebook is to implement the basic EM approach used by the R fastLink package in Apache Spark

In [None]:
import logging 

logging.basicConfig()

log = logging.getLogger("sql_logs").setLevel("ERROR")

In [None]:
from pyspark.context import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import StructType
import pyspark.sql.functions as f


conf=SparkConf()
conf.set('spark.driver.memory', '8g')
conf.set("spark.sql.shuffle.partitions", "4") 

sc = SparkContext.getOrCreate(conf=conf)

sc = SparkContext.getOrCreate()
spark = SparkSession(sc)

In [None]:
import time
start_time = time.time()

In [None]:
df = spark.read.parquet("parquet/fake_1000.parquet")
df.dropDuplicates()

In [None]:
import sys 
sys.path.append("gluejoblib")

In [None]:
from utility_functions import *
from sql_steps import *
from pipelines import get_features_df
from accuracy import *
from rules import *

In [None]:
df = sql_add_unique_row_id_to_original_table(df, spark)

In [None]:
rules = get_test_data_rules()

In [None]:
sql_select_expr = sql_gen_col_selection_compare_cols(df)
df.registerTempTable("df")

sqls = []
for rule in rules:
    sql = f"""
    select {sql_select_expr}  
     from df as l
        left join df as r
        on
        {rule}
        where l.row_id < r.row_id

    """
    sqls.append(sql)
    
df_comparison = spark.sql(" union ".join(sqls))

df_comparison = df_comparison.dropDuplicates(["row_id_l", "row_id_r"])

df_comparison.show()

In [None]:
# Generate gammas dataset 
def gammas_case_statement(col_name, i):
    return f"""case 
    when {col_name}_l = {col_name}_r then 1
    else 0 end as gamma_{i}"""
    
    

def add_gammas(df, spark, binary_comparison_cols=None, approximate_string_cols=None):
    gamma_select_expressions = []
    for i, col_name in enumerate(binary_comparison_cols):
        gamma_select_expressions.append(gammas_case_statement(col_name, i))
    
    gammas_select_expr = ",\n".join(gamma_select_expressions)
    
    df.registerTempTable("df")
    sql = f"""
    select *, {gammas_select_expr}
    from df
    """

    df = spark.sql(sql)
    return df

    
cols = ["first_name", "surname", "dob", "city", "email", "group"]
df_with_gamma = add_gammas(df_comparison, spark, cols)


In [None]:
df_with_gamma.persist()
df_with_gamma.count()

In [None]:
df_with_gamma.show()

In [None]:
df_with_gamma.registerTempTable("df_with_gamma")
sql = """
select avg(gamma_5)
from df_with_gamma
"""
spark.sql(sql).collect()[0][0]

In [None]:
for i in range(5):
    col = cols[i]
    print(col)
    field = f"gamma_{i}"
    sql = f"""
    select gamma_5 as m, avg({field}) as prop_match_{field}, count(*) as num_records
    from df_with_gamma
    group by gamma_5
    """
    spark.sql(sql).show()

In [None]:
df_with_gamma.sample(0.001).show(20)

In [None]:
df_with_gamma.filter(df_with_gamma["city_l"]== df_with_gamma["city_r"]).filter(df_with_gamma["gamma_5"] == 0).sample(0.001).show()

In [None]:
cols = ["first_name", "surname", "dob", "city", "email"]
def generate_params(binary_comparison_cols=cols):
    params = {}
    params["λ"] = 0.8
    params["π"] = {}
    
    
    for i, col in enumerate(binary_comparison_cols):
        params["π"][f"gamma_{i}"] = {}
        this_gamma = params["π"][f"gamma_{i}"]
        this_gamma["desc"] = f"Exact match on {col}"
        this_gamma["type"] = "exact_match_only"
        
        pdm = {
                "level_0": {
                    "value": 0,
                    "probability": 0.1
                },
                "level_1": {
                    "value": 1,
                    "probability": 0.9
                }
            }
              
               
        pdnm =  {
                "level_0": {
                    "value": 0,
                    "probability": 0.8
                },
                "level_1": {
                    "value": 1,
                    "probability": 0.2
                }
        }
        this_gamma["prob_dist_match"] = pdm
        this_gamma["prob_dist_non_match"] = pdnm
                 
    return params 
        

params = generate_params(cols)

## Expectation step

In [None]:
%autoreload True

In [None]:
from em_in_spark.fns import *


print(params["λ"])
print(params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, params)

new_params = update_params(df_e, spark, params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])


df_e = run_expectation_step(df_with_gamma, spark, new_params)
new_params = update_params(df_e, spark, new_params)
print(new_params["λ"])
print(new_params["π"]["gamma_0"])

df_e.show()

In [None]:
new_params

In [None]:
print("--- %s seconds ---" % (time.time() - start_time))