In [None]:
from try_settings import settings_dict
import pandas as pd
import boto3
from dataengineeringutils3.s3 import delete_s3_folder_contents
import awswrangler as wr
# creates a session at least on the platform...
my_session = boto3.Session(region_name="eu-west-1")

small_link_job = False

In [None]:
import os
os.chdir('/home/jovyan/splink/')

In [None]:
# reset our db for another test run...
if "splink_awswrangler_test" in wr.catalog.databases(limit=10000).values:
    wr.catalog.delete_database(
        name='splink_awswrangler_test',
        boto3_session=my_session
    )
    print("Cleaning up folder contents...")
    # clean up folder contents from s3...
    # can potentially add this as a module to our awslinker
    delete_s3_folder_contents("s3://alpha-splink-db-testing/splink_warehouse/")
    delete_s3_folder_contents("s3://alpha-splink-db-testing/data/")

if "splink_awswrangler_test" not in wr.catalog.databases(limit=10000).values:
    import time
    time.sleep(3)
    wr.catalog.create_database("splink_awswrangler_test", exist_ok=True)
print("Deleted existing db")

## Write all synthetic data

Only runs if we want a larger link job    table_name = "synthetic_data_all"
    data_path = f"s3://alpha-splink-db-testing/perm_data/{table_name}.parquet"
    athena_s3_path = "s3://alpha-splink-db-testing/data/"
    df = wr.s3.read_parquet(
        path=data_path,
        boto3_session=my_session
    )

In [None]:
if not small_link_job:
    table_name = "synthetic_data_all"
    data_path = f"s3://alpha-splink-db-testing/perm_data/{table_name}.parquet"
    athena_s3_path = "s3://alpha-splink-db-testing/data/"
    df = wr.s3.read_parquet(
        path=data_path,
        boto3_session=my_session
    )
    for i in range(4):
        df = pd.concat([df, df])
        
    df['unique_id'] = range(len(df))
    bucket = "alpha-splink-db-testing"
    path = f"s3://{bucket}/data/"
    df.drop(["source_dataset"], axis=1, inplace=True)
    wr.s3.to_parquet(
        df=df,
        path=path,
        dataset=True,
        mode="overwrite",
        database="splink_awswrangler_test",
        table=table_name
    );
    print(len(df))
    display(df.head(10))
    del(df)

## Write df_left and df_right

Only runs if we are running a small link job

In [None]:
if small_link_job:
    df = pd.read_csv("./tests/datasets/fake_1000_from_splink_demos.csv")
    bucket = "alpha-splink-db-testing"
    path = f"s3://{bucket}/data/"
    wr.s3.to_parquet(
        df=df,
        path=path,
        dataset=True,
        mode="overwrite",
        database="splink_awswrangler_test",
        table="df_left"
    );
    wr.s3.to_parquet(
        df=df,
        path=path,
        dataset=True,
        mode="overwrite",
        database="splink_awswrangler_test",
        table="df_right"
    );

## Run linking job

In [None]:
from splink.aws.aws_linker import AWSLinker
from try_settings import settings_dict, settings_dict_large
import time

if small_link_job:
    inputs = {"df_left": "df_left", "df_right": "df_right"}
    settings = settings_dict
    txt = "smaller link job"
else:
    inputs = {"synthetic_data_all": "synthetic_data_all"}
    settings = settings_dict_large
    txt = "large link job"

print(f"Running {txt}")
print(f"====================")
    
t = time.time()
linker = AWSLinker(
    settings,
    input_tables=inputs,
    boto3_session=my_session,
    output_bucket="alpha-splink-db-testing",
    database_name="splink_awswrangler_test"
)
linker.train_u_using_random_sampling(target_rows=1e6)

if small_link_job:
    blocking_rule = "l.first_name = r.first_name and l.surname = r.surname"
    linker.train_m_using_expectation_maximisation(blocking_rule)

    blocking_rule = "l.dob = r.dob"
    linker.train_m_using_expectation_maximisation(blocking_rule)
else:
    linker.train_m_using_expectation_maximisation("l.full_name = r.full_name")

    linker.train_m_using_expectation_maximisation(
        "l.dob = r.dob and substr(l.postcode,1,2) = substr(r.postcode,1,2)"
    )

p = time.time()
df = linker.predict()
print(f"Predict step took {time.time()-p} seconds")

print(f"Total time taken {time.time()-t} seconds")