In [None]:
import time
import itertools
import json
import h3
import pandas as pd
import psycopg2
import psycopg2.extras
from inspect import cleandoc
from dotenv import load_dotenv
from h3_transformation import H3Transformation
from pg_client import PgClient

## Setup

In [None]:
pg_client = PgClient(database='spatial_dwh')

## Definitions

### administrative border

In [None]:
def get_administrative_border_df(pg: PgClient, table: str, as_geojson: bool = True) -> pd.DataFrame | None:
    sql = """
    SELECT 
        *,
        {st}(geometry) AS {geometry_alias},
        ST_X(centroid) AS centroid_lng,
        ST_Y(centroid) AS centroid_lat
    FROM {table}
    """

    # Check as_geojson parameter
    if as_geojson:
        query = sql.format(st="ST_AsGeoJSON", geometry_alias="geojson", table=table)
    else:
        query = sql.format(st="ST_AsText", geometry_alias="geometry", table=table)

    # Fetch database
    adm_border_df = None
    try:
        records: list = pg.fetchall(query)
        adm_border_df = pd.DataFrame(
            records, columns=[desc[0] for desc in pg.cursor().description]
        )
    except (Exception, psycopg2.DatabaseError) as error:
        print(error)

    return adm_border_df

# Transform MultiPolygon GeoJSON into several Polygon GeoJSONs
def multipolygon_to_polygons(multipolygon_geojson: dict) -> list:
    polygon_geojsons = []
    for polygon in multipolygon_geojson["coordinates"]:
        vn_polygon_geojson = {
            "type": "Polygon",
            "coordinates": polygon
        }
        polygon_geojsons.append(vn_polygon_geojson)
    
    return polygon_geojsons

### GeoJSON

In [None]:
def polygon_geojsons_to_h3(polygon_geojsons: list[str] | None, res: int) -> set[str]:
    # Exceptions
    if polygon_geojsons is None:
        # throw exception
        raise Exception("Polygon GeoJSONs must not be None")
    
    # resolution must be within [0, 15]
    if res < 0 or res > 15:
        raise Exception("Resolution must be in range [0, 15]")
    
    # Transform Polygons into sets of H3 cells
    h_cell_sets: list[set] = []
    for polygon_geojson in polygon_geojsons:
        h_cell_sets.append(h3.polyfill(geojson=polygon_geojson, res=res, geo_json_conformant=True))
        
    # Union sets into a single set of H3 cells
    h_cells: set[str] = set().union(*h_cell_sets)
        
    return h_cells

### H3

In [None]:
def get_h3_table_name(adm: str, res: int) -> str:
    return f"h3_{adm}_r{res}"

def get_create_table_sql(table: str, schema: dict[str, str]) -> str:
    # Parse schema and build column parameters
    COLUMN_MAX_LENGTH = 18
    SEP = ",\n    "
    column_parameters = ""
    for key, value in schema.items():
        column_parameters += f"{key:{COLUMN_MAX_LENGTH}}{value}{SEP}"

    # Format column_parameters into CREATE TABLE SQL
    sql = f"""
    CREATE TABLE {table} (
        {column_parameters}
        CONSTRAINT ck_resolution CHECK (resolution >= 0 AND resolution <= 15)
    );
    """
    return sql

# Batch INSERT operation for PostgreSQL
def save_h3_to_postgresql(
    pg: PgClient, h3_cells: set[str], table: str, schema: dict[str, str]
) -> None:
    POSTGRES_STATEMENT_MAX_RECORDS = 1000
    TOTAL_HEXAGONS = len(h3_cells)

    # Parse schema to build INSERT sql
    def _build_insert_statement(table: str, schema: dict[str, str]) -> str:
        columns = ""
        last_column = list(schema.keys())[-1]
        for column in schema.keys():
            columns += column
            if column != last_column:
                columns += ", "
                
        _stmt = f"""
        INSERT INTO {table} ({columns})
        VALUES
        """
        return _stmt

    def _multi_values_execute(pg: PgClient, table: str, h3_cells: set[str]) -> None:
        # Iterate over H3 cells & batch INSERT
        record_count = 0
        values_arguments_count = 0
        values_arguments_sql = ""
        insert_statement_sql = _build_insert_statement(table=table, schema=schema)
        
        for cell in h3_cells:
            record_count += 1

            # Extract H3 data for each hexagon
            idx = h3.string_to_h3(cell)
            resolution = h3.h3_get_resolution(cell)
            circumradius_m = h3.edge_length(resolution, unit="m")
            area_m2 = h3.cell_area(cell, unit="m^2")
            centroid_lat, centroid_lng = h3.h3_to_geo(h=cell)
            geometry_geojson = H3Transformation.cell_to_geojson(
                h3_cell=cell, include_default_properties=False, geometry_only=True
            )

            # Prepare values_arguments
            values_arguments_sql += f"""
            (
                {idx},
                {resolution},
                {circumradius_m},
                {area_m2},
                ST_SetSRID(ST_Point({centroid_lng}, {centroid_lat}), 4326),
                ST_GeomFromGeoJSON('{geometry_geojson}')
            ),"""
            values_arguments_count += 1

            # If reached POSTGRES_STATEMENT_MAX_RECORDS or record_count == TOTAL_HEXAGONS
            if (record_count % POSTGRES_STATEMENT_MAX_RECORDS == 0) or (
                record_count == TOTAL_HEXAGONS
            ):
                print("--- Executing record: ", record_count)

                # Format values_arguments to execute query
                values_arguments_sql = values_arguments_sql.rstrip(",")
                query = insert_statement_sql + values_arguments_sql
                pg.execute(query=query)
                print(f"--- Executed {values_arguments_count} values_arguments.")

                # Reset
                values_arguments_sql = ""
                values_arguments_count = 0
                
        return None
    
    def _execute_values(pg: PgClient, table: str, h3_cells: set[str]) -> None:
        insert_statement_sql = _build_insert_statement(table=table, schema=schema)
        
        psycopg2.extras.execute_values(
            pg.cursor(),
            insert_statement_sql + "%s;", 
            ((
                h3.string_to_h3(cell),
                h3.h3_get_resolution(cell),
                h3.edge_length(h3.h3_get_resolution(cell), unit="m"),
                h3.cell_area(cell, unit="m^2"),
                h3.h3_to_geo(h=cell)[1], h3.h3_to_geo(h=cell)[0],
                H3Transformation.cell_to_geojson(h3_cell=cell, include_default_properties=False, geometry_only=True),
            ) for cell in h3_cells),
            template="(%s, %s, %s, %s,  ST_SetSRID(ST_Point(%s, %s), 4326), ST_GeomFromGeoJSON(%s))",
            page_size=1000
        )
            
        # Commit changes
        pg.connection().commit()
        print("Successfully inserted H3 cells into table", table)
        
        return None
    
    # Actual execution
    _execute_values(pg=pg, table=table, h3_cells=h3_cells)
    # _multi_values_execute(pg=pg, table=table, h3_cells=h3_cells)

def get_alter_table_pk_sql(table: str) -> str:
    sql = f"ALTER TABLE {table} ADD PRIMARY KEY (idx);"
    return sql

def get_create_gist_index_sql(table: str) -> str:
    sql = f"""
    CREATE INDEX gidx_{table}
    ON {table} USING GIST (geometry);
    """
    return sql

## H3 Vietnam Pipeline

### Configurations

In [None]:
administrative = "vietnam"
resolution = 6

table = get_h3_table_name(adm=administrative, res=resolution)

h3_table_schema = {
    "idx": "INT8 NOT NULL",
    "resolution": "INT2 NOT NULL",
    "circumradius_m": "FLOAT8 NOT NULL",
    "area_m2": "FLOAT8 NOT NULL",
    "centroid": "GEOMETRY (POINT, 4326) NOT NULL",
    "geometry": "GEOMETRY (POLYGON, 4326) NOT NULL",
}

### vietnam_border (GeoJSON)

In [None]:
# Get Vietnam border
vietnam_border_df = get_administrative_border_df(pg=pg_client, table='vietnam_border', as_geojson=True)
display(vietnam_border_df)

# Extract data from PostgreSQL
if vietnam_border_df is not None:
    vn_centroid_lng = vietnam_border_df["centroid_lng"].values[0]
    vn_centroid_lat = vietnam_border_df["centroid_lat"].values[0] 
    vn_multipolygon_geojson = json.loads(vietnam_border_df["geojson"].values[0])
    print(vn_centroid_lng, vn_centroid_lat)
    print(type(vn_multipolygon_geojson))
else:
    vn_centroid_lng = None
    vn_centroid_lat = None
    vn_multipolygon_geojson = None
    print("No data")

# Convert Vietnam border MultiPolygon to Polygons
if vn_multipolygon_geojson is not None:
    vn_polygon_geojsons = multipolygon_to_polygons(vn_multipolygon_geojson)
    print(vn_polygon_geojsons[0])
else:
    vn_polygon_geojsons = None
    print("No data")

### Generate H3 & save to PostgreSQL

In [None]:
# Generate H3 cells
vietnam_border_cells = polygon_geojsons_to_h3(polygon_geojsons=vn_polygon_geojsons, res=resolution)
print(len(vietnam_border_cells))

# H3 subset in case of Testing (h3_cells could grow very large)
# vietnam_border_cells_subset = set(itertools.islice(vietnam_border_cells, 10))

In [None]:
# PostgreSQL operations
h3_cells = vietnam_border_cells

try:
    create_table_sql = get_create_table_sql(table=table, schema=h3_table_schema)
    pg_client.execute(create_table_sql)
    print("Successfully created table", table)

    # batch INSERT H3 cells
    save_h3_to_postgresql(pg=pg_client, h3_cells=h3_cells, table=table, schema=h3_table_schema)
    
    # add PRIMARY KEY to table
    alter_table_pk_sql = get_alter_table_pk_sql(table)
    pg_client.execute(alter_table_pk_sql)
    print("Successfully added PRIMARY KEY to table", table)
    
    # create GIST index for table
    create_gist_index_sql = get_create_gist_index_sql(table)
    pg_client.execute(create_gist_index_sql)
    print("Successfully created GIST index for table", table)
except (Exception, psycopg2.DatabaseError, psycopg2.ProgrammingError) as error:
    print(error)
finally:
    pg_client.close()