#Overview
This notebook will read from a Postgres database, and keep an updated replication of a selection of tables and views into the Unity Catalog.

This notebook can be run manually or through a scheduled job.

It could also be converted to a python script, and/or loaded in github, and have the scheduled jub run it from there.

## Notes
* If the table doesn't exist at the destination in Databricks, it will create it
* If a refresh time field is specified, it will use it to check if a new version is available from the current replicated one.
* Currently, the script assumes that the refresh time value is the same in the whole source table (as it currently is). In case this is changed, this script would have to change to reflect that
* If there needs to be a refresh, it will load the source table(s) completely, create a hash value for uniquely identifying records, and comparing this with the records in the destination
* The tables at the destination have an additional `dbr_hash` field to identify which records need to be inserted or deleted from the destination table based on changes from the source table
* The destination tables will have a version history recording the inserts and deletes (based on insert, deletes or updates of the source table)
* The `last_refresh_time` field is dropped from the destination table, and saved as a table property called `source_timestamp`. As currently this value is the same for all fields in the source table, and it would take storage innecesarily upon every version update (specially on big tables), it can be practical to drop it.

## Notes (2)
* To reference a table with special characters in its name (e.g. "-") in Postgres, use double quotes (for example: `*****."*****-dashboard"`)
* To reference it within Spark / Databricks Delta Table, use back-tick (for example: *****.`` `*****-dashboard` ``)
## To-do
* Allow for passing parameters through a Databricks task: for example for specifying the tables to process

# Supporting definitions and functions

In [0]:
from pyspark.sql.functions import col, concat_ws, hash, when
from datetime import datetime

# Set partitions (parallel postgres processes) to the number of available cores in the Cluster
# If going beyond ~50, there needs to be special considerations
num_partitions = 8

pg_host = dbutils.secrets.get(scope="****", key="****")
pg_port = "****"
pg_database = "****"
pg_user = dbutils.secrets.get(scope="****", key="****")
pg_pass = dbutils.secrets.get(scope="****", key="****")

In [0]:
def get_table(table, schema, query):
    """Return a DataFrame with the table name. Takes connection details from global variables"""
    dbtable = f"({query}) AS result_table"

    table_df = (spark.read
        .format("postgresql")
        .option("host", pg_host)
        .option("port", pg_port)
        .option("database", pg_database)
        .option("user",  pg_user)
        .option("password", pg_pass)
        .option("numPartitions", num_partitions)
        .option("dbtable", dbtable)
        .load()
    )
    return table_df
  
def get_big_table(table, schema, query, partitions_config):
    """Return a DataFrame with the table name. Takes connection details from global variables.
    Requires partition information to process parts of the tables in parallel"""
    
    num_partitions = partitions_config['num_partitions']
    column_partition = partitions_config['column_partition']
    lower_bound = partitions_config['lower_bound']
    upper_bound = partitions_config['upper_bound']
    
    dbtable = f"({query}) AS result_table"

    table_df = (spark.read
        .format("postgresql")
        .option("host", pg_host)
        .option("port", pg_port)
        .option("database", pg_database)
        .option("user",  pg_user)
        .option("password", pg_pass)
        .option("dbtable", dbtable)
        .option("fetchsize", "10000")
        .option("partitionColumn", column_partition)
        .option("numPartitions", num_partitions)
        .option("lowerBound", lower_bound)
        .option("upperBound", upper_bound)
        .option("stringtype", "unspecified")
        .load()
    )
    return table_df

def dest_table_exists(table, schema, catalog):
    schema_tables = spark.sql(f"SHOW TABLES IN {catalog}.{schema}")
    return bool(schema_tables.filter(schema_tables.tableName == f"{table}").collect())

def get_column_bounds(table, schema, column_partition):
    q_bounds = (f'SELECT MIN("{column_partition}") AS lower_bound, '
                f'MAX("{column_partition}") AS upper_bound '
                f'FROM {schema}."{table}"')
    df_bounds = get_table(table, schema, q_bounds)
    lower_bound = df_bounds.head()['lower_bound']
    upper_bound = df_bounds.head()['upper_bound']
    return lower_bound, upper_bound

def get_schema_info(dataframe, exclude_columns):
    '''Returns the schema of the dataframe, using python .dtypes information
    (field name and type), excluding the column names in 'exclude_columns'  '''
    schema_info = []
    for column in dataframe.dtypes:
        if column[0] not in exclude_columns:
            schema_info.append(column)
    return schema_info 

def register_table_properties(table, schema, catalog):
    table_description = (f"Table `{table}` replicated from Postgres. "
                    f"Version history can be seen using `DESCRIBE HISTORY {catalog}.{schema}.{table}`. "
                    "Use the @ syntax to specify the timestamp or version as part of the table name to query "
                    "a specific version. Specify a version after `@` by prepending a `v` to the version number. "
                    "Timestamps must be in `yyyyMMddHHmmssSSS` format. "
                    f"Examples: `SELECT * FROM {catalog}.{schema}.{table}@20190101000000000`.... "
                    f"`SELECT * FROM {catalog}.{schema}.{table}@v123` ")
    q_register_description = (f"COMMENT ON TABLE {catalog}.{schema}.`{table}` IS \"{table_description}\"")
    spark.sql(q_register_description)
    # Set retention policy to keep history
    q_retention_policy = (f"ALTER TABLE {catalog}.{schema}.`{table}` "
                        f"SET TBLPROPERTIES ('delta.logRetentionDuration'='interval 50000 weeks')")
    spark.sql(q_retention_policy)

def register_source_timestamp_property(table, schema, catalog, source_time):
        if (source_time != ""):
            # Register the source refresh field timestamp as a Table property
            q_register_source_timestamp = (f"ALTER TABLE {catalog}.{schema}.`{table}` "
                                    f"SET TBLPROPERTIES ('source_timestamp'='{source_time}')")
            spark.sql(q_register_source_timestamp)
            print(f"Registered the table property of '{table}': 'source_timestamp'= {source_time}")

def process_table (table, schema, catalog, refresh_field="", partitions_config=""):
    '''Loads table from Postgres into a Spark Dataframe, and merges it into the corresponding
    Databricks destination table. Assumes the table name and the schema have the same name
    in Postgres as in Databricks selected catalog'''
    refresh = False # Variable to indicate if there needs to be a full refresh
    source_time = ""
    print(f"\n*** Processing table {schema}.{table} ***")

    if (refresh_field != ""):
        # Check the source time
        q_source_time = f'SELECT {refresh_field} FROM {schema}."{table}" WHERE {refresh_field} IS NOT NULL LIMIT 1'
        df_source_time = get_table(table, schema, q_source_time)
        source_time = df_source_time.head()[f'{refresh_field}']
    
    # Check if the destination table exists. If not, will create it later on with the full refresh
    destination_table_exists = dest_table_exists(table, schema, catalog)
    if (destination_table_exists == False):
        refresh = True
        print(f"The destination table '{schema}.{table}' doesn't exist. Creating it and doing a full refresh")
    
    if (refresh == False and refresh_field != ""):
        df_dest_time = spark.sql(f"SHOW TBLPROPERTIES {catalog}.{schema}.`{table}` ('source_timestamp')")
        # Check if there is need to refresh the destination table based on 'refresh_field'
        str_destination_time = df_dest_time.head()['value']
        destination_time = datetime.strptime(str_destination_time, '%Y-%m-%d %H:%M:%S.%f')
        
        if (source_time > destination_time):
            print(f"The destination table '{schema}.{table}' is outdated, will review source table for updates")
        elif (source_time == destination_time):
            print(f"The destination table '{schema}.{table}' seems up to date, with '{refresh_field}': {destination_time}. No need to update it")
            return
        else:
            print(f"Inconsistency comparing '{refresh_field}' of '{schema}.{table}'. Please check. Source Table value: {source_time} . Destination table value: {destination_time}")

    # Get the source table schema
    q_load_schema = f'SELECT * FROM {schema}."{table}" WHERE 1=2 '
    df_source_schema = get_table(table, schema, q_load_schema)
    exclude_columns = [f'{refresh_field}', 'dbr_hash']
    source_schema = get_schema_info(df_source_schema, exclude_columns)
    # Take the field names of the schema for doing the source full SELECT
    source_col_names = [column[0] for column in source_schema]
    str_source_cols = ", ".join(source_col_names)
    
    # Load all the source table (excluding refresh field if existing)
    # Bigger tables are read in parallel through partitions, specified in 'partitions_config'
    q_load_source = f'SELECT {str_source_cols} FROM {schema}."{table}"'
    if (partitions_config == ""):
        df_source = get_table(table, schema, q_load_source)
    else:  
        print(f"Processing table '{table}', reading it through custom partitions for faster load and processing")
        df_source = get_big_table(table, schema, q_load_source, partitions_config)

    print("Creating a hash column to uniquely identify records. Might take some time depending on the table size")
    df_source = df_source.withColumn("dbr_hash", hash(concat_ws(",", *df_source.columns)))
    df_source.createOrReplaceTempView("temp_source_table") # Register it in sparksql

    # Compare the source and destination schema. If different, do a full refresh
    if (destination_table_exists == True):
        df_destination = spark.read.table(f"{catalog}.{schema}.`{table}`")
        destination_schema = get_schema_info(df_destination, exclude_columns=['dbr_hash'])
        if (source_schema != destination_schema):
            refresh = True
            print("The schema of the source table seems to have changed. Doing a full refresh with the new structure")
            print(f"Source schema: {source_schema}\n Destination schema: {destination_schema}")

    # Do a full refresh of the destination table if specified
    if (refresh == True):
        df_source.write.mode("overwrite") \
                            .option("overwriteSchema", "True") \
                            .saveAsTable(f"{catalog}.{schema}.`{table}`")
        print(f"Made a full refresh of '{schema}.{table}' in Databricks, and added a 'dbr_hash' column for uniquely identifying records")
        register_table_properties(table, schema, catalog)
        register_source_timestamp_property(table, schema, catalog, source_time)
        return

    # If not doing a full refresh, continue with comparisons
    
    # Load hashes from destination table to do comparisons
    df_destination_hashes = spark.sql(f"SELECT dbr_hash FROM {catalog}.{schema}.`{table}`")
    df_destination_hashes.createOrReplaceTempView("temp_destination_hashes") # Register it in sparksql

    # Identify the new or updated records from the source table and insert them in the actual destination table
    q_updated_hashes = f"SELECT dbr_hash FROM temp_source_table EXCEPT SELECT dbr_hash FROM temp_destination_hashes"
    df_updated_hashes = spark.sql(q_updated_hashes)
    df_updated_hashes.createOrReplaceTempView("updated_hashes")
    q_updated_records = (f"SELECT * FROM temp_source_table WHERE dbr_hash IN "
                    "(SELECT dbr_hash FROM updated_hashes)")
    df_updated_records = spark.sql(q_updated_records)

    if (df_updated_records.count() > 0):
        df_updated_records.createOrReplaceTempView("updated_records")
        # Make the changes to the destination table
        insert_result = spark.sql(f"INSERT INTO {catalog}.{schema}.`{table}` SELECT * FROM updated_records")
        print(f"Records have been inserted or updated at '{table}' destination table")
    else: 
        print(f"No records needed to be inserted or updated at '{table}' destination table")

    # If the source and destination tables are now the same size, there is no need to delete records
    df_destination_size = spark.sql(f"SELECT COUNT(*) AS records FROM {catalog}.{schema}.`{table}`")
    destination_size = df_destination_size.head()['records']
    if (df_source.count() == destination_size):
        print(f"No records have needed to be deleted from the destination '{table}' table")
        register_source_timestamp_property(table, schema, catalog, source_time)
        return

    # Otherwise, identify and delete records from the destination table that are not in the source table
    q_hashes_todel = (f"SELECT dbr_hash FROM {catalog}.{schema}.`{table}` "
                        "EXCEPT SELECT dbr_hash FROM temp_source_table")
    df_hashes_todel = spark.sql(q_hashes_todel)
    df_hashes_todel.createOrReplaceTempView(f"hashes_to_del")
    q_delete = (f"DELETE FROM {catalog}.{schema}.`{table}` WHERE dbr_hash IN "
                "(SELECT dbr_hash FROM hashes_to_del)")
    delete_result = spark.sql(q_delete)
    print(f"Records have been deleted from the destination '{table}' table as they were outdated")

    register_source_timestamp_property(table, schema, catalog, source_time)
    return

def process_table_list(table_list, schema, databricks_catalog, refresh_field=""):
    for table in table_list:
        process_table(table, schema, databricks_catalog, refresh_field=refresh_field)


# Actual update of tables

In [0]:
pg_host = dbutils.secrets.get(scope="****", key="****")

## Small tables

In [0]:
databricks_catalog = "****"

schema = "*****"
*****tables = ["commitments", "commodities", "*****-dashboard", "logistics"]
process_table_list(*****tables, schema, databricks_catalog, refresh_field="last_refresh_time")


*** Processing table *****.commitments ***
The destination table '*****.commitments' seems up to date, with 'last_refresh_time': 2023-04-09 04:03:29.342338. No need to update it

*** Processing table *****.commodities ***
The destination table '*****.commodities' seems up to date, with 'last_refresh_time': 2023-04-09 04:03:29.335084. No need to update it

*** Processing table *****.*****-dashboard ***
The destination table '*****.*****-dashboard' seems up to date, with 'last_refresh_time': 2023-04-09 04:03:29.337951. No need to update it

*** Processing table *****.logistics ***
The destination table '*****.logistics' seems up to date, with 'last_refresh_time': 2023-04-09 04:03:29.331664. No need to update it


## Small views
Note they don't have a refresh field

In [0]:
schema = "views"
view_tables = ["datasets", "f500", "indicators_type", "rejected_name_matches", "supply_chains_pipeline"]
process_table_list(view_tables, schema, databricks_catalog)


*** Processing table views.datasets ***
Creating a hash column to uniquely identify records. Might take some time depending on the table size
No records needed to be inserted or updated at 'datasets' destination table
No records have needed to be deleted from the destination 'datasets' table

*** Processing table views.f500 ***
Creating a hash column to uniquely identify records. Might take some time depending on the table size
No records needed to be inserted or updated at 'f500' destination table
No records have needed to be deleted from the destination 'f500' table

*** Processing table views.indicators_type ***
Creating a hash column to uniquely identify records. Might take some time depending on the table size
No records needed to be inserted or updated at 'indicators_type' destination table
No records have needed to be deleted from the destination 'indicators_type' table

*** Processing table views.rejected_name_matches ***
Creating a hash column to uniquely identify records. Mi

## Big tables (>100MB)
For the bigger tables, specifying partitioning information including `column_partition`, with its corresponding `lower_bound` and `upper_bound` so they can be read and processed in `num_partitions` chunks processed in parallel.

It is recommended that the `column_partition` is an integer or date, and that its evenly distributed so that each partition processes roughly the same amount of records.

### supply-chains-latest
Notes:
* Full refresh takes 1.4 minutes in single node, 1 minute in a two node cluster
* Using `last_refresh_time` for partitioning
* Source table size: 3.11 GB
* Destination table size: 250MB

In [0]:
catalog = "*****"
schema = "*****"
table = "supply-chains-latest"

# Will use the row number for dividing partitions
num_partitions = 18 # rule of thumb: 2 - 3 times number of worker cpus
column_partition = "row_number"
lower_bound, upper_bound = get_column_bounds(table, schema, column_partition)

partitions_config = {
    "num_partitions": num_partitions, 
    "column_partition": column_partition,
    "lower_bound": lower_bound,
    "upper_bound": upper_bound
}

process_table(table, schema, catalog, refresh_field="last_refresh_time", partitions_config=partitions_config)


*** Processing table *****.supply-chains-latest ***
The destination table '*****.supply-chains-latest' seems up to date, with 'last_refresh_time': 2023-04-06 09:14:22.801232. No need to update it


### supply-chains-concatenated
Notes:
* Currently using `commodity_id` as partitioning column to split the table for read, though as its not evenly distributed, it puts all the burden of processing in two of the table partitions
  * For improving this, find another more evenly distributed column. Or split in an array of reads based on a hashed value [(example link - using scala)](https://dzlab.github.io/spark/2022/02/10/spark-jdbc-partitioning/)
* Takes 2.86 minutes for a full refresh in a single node, 2 minutes in a 2 node cluster
* Destination table size: 380MB

In [0]:
catalog = "****"
schema = "*****"
table = "supply-chains-concatenated"

num_partitions = 18 # rule of thumb: 2 - 3 times number of worker cpus (or number of distinct partitions)
column_partition = "commodity_id"
lower_bound, upper_bound = get_column_bounds(table, schema, column_partition)

partitions_config = {
    "num_partitions": num_partitions, 
    "column_partition": column_partition,
    "lower_bound": lower_bound,
    "upper_bound": upper_bound
}

process_table(table, schema, catalog, refresh_field="last_refresh_time", partitions_config=partitions_config)


*** Processing table *****.supply-chains-concatenated ***
The destination table '*****.supply-chains-concatenated' seems up to date, with 'last_refresh_time': 2023-04-06 09:14:22.801232. No need to update it


### regional-indicators

In [0]:
catalog = "****"
schema = "*****"
table = "regional-indicators"

num_partitions = 18 # rule of thumb: 2 - 3 times number of worker cpus (or number of distinct partitions)
column_partition = "level"
lower_bound, upper_bound = get_column_bounds(table, schema, column_partition)

partitions_config = {
    "num_partitions": num_partitions, 
    "column_partition": column_partition,
    "lower_bound": lower_bound,
    "upper_bound": upper_bound
}

process_table(table, schema, catalog, refresh_field="last_refresh_time", partitions_config=partitions_config)


*** Processing table *****.regional-indicators ***
The destination table '*****.regional-indicators' doesn't exist. Creating it and doing a full refresh
Processing table 'regional-indicators', reading it through custom partitions for faster load and processing
Creating a hash column to uniquely identify records. Might take some time depending on the table size
Made a full refresh of '*****.regional-indicators' in Databricks, and added a 'dbr_hash' column for uniquely identifying records
Registered the table property of 'regional-indicators': 'source_timestamp'= 2023-04-09 04:03:29.320122


### regions

In [0]:
catalog = "*****"
schema = "*****"
table = "regions"

num_partitions = 18 # note however that the 'level' column has fewer values
column_partition = "level"
lower_bound, upper_bound = get_column_bounds(table, schema, column_partition)

partitions_config = {
    "num_partitions": num_partitions, 
    "column_partition": column_partition,
    "lower_bound": lower_bound,
    "upper_bound": upper_bound
}

process_table(table, schema, catalog, refresh_field="last_refresh_time", partitions_config=partitions_config)


*** Processing table *****.regions ***
The destination table '*****.regions' doesn't exist. Creating it and doing a full refresh
Processing table 'regions', reading it through custom partitions for faster load and processing
Creating a hash column to uniquely identify records. Might take some time depending on the table size
Made a full refresh of '*****.regions' in Databricks, and added a 'dbr_hash' column for uniquely identifying records
Registered the table property of 'regions': 'source_timestamp'= 2023-04-09 04:04:55.366554


### traders
Note that as it doesn't have an integer column to use for partitioning it while reading, it will be read all in a single process.

Another option is to do multiple reads based on a hash value of one of the string columns.

In [0]:
catalog = "****"
schema = "*****"
table = "traders"

process_table(table, schema, catalog, refresh_field="last_refresh_time")


*** Processing table *****.traders ***
The destination table '*****.traders' doesn't exist. Creating it and doing a full refresh
Creating a hash column to uniquely identify records. Might take some time depending on the table size
Made a full refresh of '*****.traders' in Databricks, and added a 'dbr_hash' column for uniquely identifying records
Registered the table property of 'traders': 'source_timestamp'= 2023-04-06 09:14:21.236168


## Big views
Review if including `country_node_attributes_view` and `flows_long`, as I couldn't do a full count on them to see the size (stopped it after 10 minutes running)

# From here on, test and debugging cells

## Commitments updates tests

In [0]:
# host is actual postgres server
pg_host = dbutils.secrets.get(scope="****", key="****)
catalog = "*****"

In [0]:
schema = "*****"
table = "commitments"

process_table(table, schema, catalog, refresh_field="last_refresh_time")


*** Processing table *****.commitments ***
The destination table '*****.commitments' is outdated, will review source table for updates
Creating a hash column to uniquely identify records. Might take some time depending on the table size
Records have been inserted or updated at 'commitments' destination table
Records have been deleted from the destination 'commitments' table as they were outdated
Registered the table property of 'commitments': 'source_timestamp'= 2023-04-13 09:00:38.682497


## supply-chains-latest

In [0]:
catalog = "****"
schema = "*****"
table = "supply-chains-latest"

results = process_table(table, schema, catalog, refresh_field="last_refresh_time")


*** Processing table *****.supply-chains-latest ***
The destination table '*****.supply-chains-latest' doesn't exist. Creating it and doing a full refresh
Creating a hash column to uniquely identify records. Might take some time depending on the table size
Made a full refresh of '*****.supply-chains-latest' in Databricks, and added a 'dbr_hash' column for uniquely identifying records
Registered the table property of 'supply-chains-latest': 'source_timestamp'= 2023-04-06 09:14:22.801232


[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
[0;32m<command-338952632807930>[0m in [0;36m<cell line: 5>[0;34m()[0m
[1;32m      3[0m [0mtable[0m [0;34m=[0m [0;34m"supply-chains-latest"[0m[0;34m[0m[0;34m[0m[0m
[1;32m      4[0m [0;34m[0m[0m
[0;32m----> 5[0;31m [0mresults[0m [0;34m=[0m [0mprocess_table[0m[0;34m([0m[0mtable[0m[0;34m,[0m [0mschema[0m[0;34m,[0m [0mcatalog[0m[0;34m,[0m [0mrefresh_field[0m[0;34m=[0m[0;34m"last_refresh_time"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
[0;32m<command-338952632807917>[0m in [0;36mprocess_table[0;34m(table, schema, catalog, refresh_field, big_table)[0m
[1;32m    138[0m         [0mregister_table_properties[0m[0;34m([0m[0mtable[0m[0;34m,[0m [0mschema[0m[0;34m,[0m [0mcatalog[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    139[0m         [0mregist

In [0]:
catalog = "****"
schema = "*****"
table = "supply-chains-latest"

supply-chains-latest_dict = {
    "num_partitions": 18, 
    "column_partition": "row_number",
    "lower_bound": 1,
    "upper_bound": 5000000
}

results = process_table(table, schema, catalog, refresh_field="last_refresh_time", supply-chains-latest_dict)


*** Processing table *****.supply-chains-latest ***
The destination table '*****.supply-chains-latest' doesn't exist. Creating it and doing a full refresh
Processing big table: supply-chains-latest
Creating a hash column to uniquely identify records. Might take some time depending on the table size
Made a full refresh of '*****.supply-chains-latest' in Databricks, and added a 'dbr_hash' column for uniquely identifying records
Registered the table property of 'supply-chains-latest': 'source_timestamp'= 2023-04-06 09:14:22.801232


In [0]:
display(results[0])

In [0]:
col_list = []
col_result = spark.sql(f"SHOW COLUMNS IN *****postgres.*****.commitments").collect()
for value in col_result:
    col_list.append(value['col_name'])
col_string = ", ".join(col_list)
print(col_string)

links_reference_id, last_refresh_time, version, link_id, country_of_production, commodity, year, biome, trader_group, trader_group_id, commitment, tbl_bighash, dbr_hash


# Exploration queries

In [0]:
%sql
INSERT INTO *****.commitments
SELECT links_reference_id+2, last_refresh_time, version, link_id, country_of_production, commodity, year+2, biome, trader_group, trader_group_id, commitment
FROM *****.commitments 
WHERE links_reference_id = 1744 AND year = 2014

In [0]:
%sql
SELECT COUNT(*)
FROM *****postgres.*****.commitments 
WHERE links_reference_id = 1746 AND year = 2016

count(1)
426


In [0]:
%sql
SELECT * FROM ****.*****.commitments 
WHERE links_reference_id = 1744 and year = 2014
LIMIT 100

In [0]:
spark.sql("DELETE FROM *****postgres.*****.commitments")

Out[12]: DataFrame[num_affected_rows: bigint]

In [0]:
%sql
SELECT COUNT(*)
FROM *****postgres.*****.regions

count(1)
16848
