In [None]:
import os
import logging
import hashlib
import datetime
from dateutil.rrule import rrule, MONTHLY
from dateutil.relativedelta import relativedelta
from sqlalchemy import create_engine
from dotenv import load_dotenv

from airflow_submodule.hdb_resale import api, sql

In [None]:
# Retrieve environment variables
load_dotenv()

POSTGRESQL_DASH_USER = os.environ.get("POSTGRESQL_DASH_USER")
POSTGRESQL_DASH_PASSWORD = os.environ.get("POSTGRESQL_DASH_PASSWORD")
POSTGRESQL_DASH_DATABASE = os.environ.get("POSTGRESQL_DASH_DATABASE")
POSTGRESQL_HOST = os.environ.get("POSTGRESQL_HOST")
POSTGRESQL_PORT = os.environ.get("POSTGRESQL_PORT")

# Setup logger
logger = logging.getLogger(__name__)

In [None]:
# Set run parameters
# Whether to do full run or delta run
run_mode = "full" # "delta"
# For delta run, how many past N months to cover from this month
# If this month is Mar and past_n_mth = 3, then Jan-Mar will be covered
past_n_mth = 3

In [None]:
# Create the connection string
connection_string = f"postgresql+psycopg2://{POSTGRESQL_DASH_USER}:{POSTGRESQL_DASH_PASSWORD}@{POSTGRESQL_HOST}:{POSTGRESQL_PORT}/{POSTGRESQL_DASH_DATABASE}"

# Create a SQLAlchemy engine
engine = create_engine(connection_string)

# Define schema & tables in database using SQLAlchemy
metadata_obj = sql.define_all_schema_table()

# Create schema & tables if not exist
sql.create_all_table(engine=engine, metadata=metadata_obj)

In [None]:
# # FOR TESTING - drop all tables
# sql.drop_all_table(engine=engine, metadata=metadata_obj)

In [None]:
api_hdb_url = api.build_sggov_hdb_url()

table_name = "hdb_resale.source_data"
# NOTE it is unable to uniquely identify a flat without using _id provided, 
# as it is possible for many flats to share the exact same characteristics captured here
hash_key_columns = [
    "_id", "month", "town", 
    "flat_type", "block", "street_name", 
    "storey_range", "floor_area_sqm", 
    "flat_model", "lease_commence_date", 
    "resale_price"
]

end_month = datetime.date.today().replace(day=1) #datetime.date(2024, 3, 1)
if run_mode=="full":
    start_month = datetime.date(2017, 1, 1) #datetime.date(2024, 1, 1)
elif run_mode=="delta":
    start_month = end_month - relativedelta(months=past_n_mth-1)
else:
    raise Exception(f"run_mode {run_mode} not implemented.")
range_month = [dt.date() for dt in rrule(MONTHLY, dtstart=start_month, until=end_month)] # create a range of months

limit = 100 # following default of the API
start_offset = 0 # starting offset value

for cur_month in range_month: # loop through all months

    cur_offset = start_offset
    cur_row_retrieved = 0
    exp_row_retrieved = 0
    first_while_loop = True

    while first_while_loop or not res.empty: # loop through all rows
        # Get the next final formatted URL with base and query strings
        data_query = f'{{"month":"{cur_month.strftime("%Y-%m")}"}}'
        api_final_url = f"{api_hdb_url}&q={data_query}&limit={limit}&offset={cur_offset}"

        logger.info(f"Retrieving data from endpoint with query - {api_final_url}")

        res = api.get_sggov_hdb_data(api_url=api_final_url)

        # Break while loop when there is no longer any data
        if res.empty:
            assert int(cur_row_retrieved) == int(exp_row_retrieved) # Should have retrieved same number of rows as reported by API endpoint
            break

        # Rename original column names
        res = res.rename(columns={"rank month":"rank_month"})

        # Create row hash identifier using key columns
        def _create_hash(row):
            """Convert all selected columns into strings, combine them into one and calculate hash."""
            row_id = '_'.join(row.values.astype(str)).encode("utf-8")
            row_id = hashlib.sha1(row_id).hexdigest()

            return row_id

        res["_row_hash_id"] = res[hash_key_columns].apply(_create_hash, axis=1)

        # Remove data based on primary key
        # Needs to be done before data insertion to prevent database duplicated errors
        # No data will be removed if the primary key does not exist
        sql.delete_data_primary_key(
            engine=engine, metadata=metadata_obj, 
            table_name=table_name, 
            primary_key=res["_row_hash_id"].to_list()
        )

        # Insert data into database
        with engine.connect() as con:
            res.to_sql(name="source_data", schema="hdb_resale", con=con, if_exists="append", index=False, chunksize=10000)

        logger.info(f"Loaded {res.shape[0]} rows of data into table {table_name}.")

        cur_offset += limit
        cur_row_retrieved += res.shape[0]
        exp_row_retrieved = res["_full_count"][0] # This is the total row counts reported by API endpoint
        first_while_loop = False