In [0]:
from PIL import Image
import numpy as np
import requests
from io import BytesIO
import time
from pyspark.sql.functions import *
from delta.tables import DeltaTable
from datetime import datetime
from pyspark.sql.types import StructType, StructField, TimestampType, IntegerType

In [0]:
mask_path = "/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/predicted_masks"
catalog_dev = "`land_topografisk-gdb_dev`"
schema_dev = "ai2025"
spark.sql(f"USE CATALOG {catalog_dev}")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {schema_dev}")
spark.sql(f"USE SCHEMA {schema_dev}")
log_table = "logs_predicted_masks"
table = "predicted_bronze"
endepunkt_silver_table = "endepunkt_silver"

In [0]:
q = f"""
CREATE TABLE IF NOT EXISTS {log_table} (
  processed_time TIMESTAMP,
  num_inserted INT,
  num_updated INT,
  num_deleted INT

) USING DELTA
"""
spark.sql(q)

In [0]:
q = f"""
CREATE TABLE IF NOT EXISTS {table} (
    white_pixels INT,
    turning_space BOOLEAN,
    bbox ARRAY<DOUBLE>,
    source_file STRING,
    row_hash STRING,
    ingest_time TIMESTAMP
) USING DELTA
"""
spark.sql(q)

In [0]:
# Sett kontekst med katalog og skjema 
spark.sql(f'USE CATALOG {catalog_dev}')
spark.sql(f'CREATE SCHEMA IF NOT EXISTS {schema_dev}')
spark.sql(f'USE SCHEMA {schema_dev}')

In [0]:
def log_predicted_mask(log_data: list):
    """
    Writes the processed predicted mask to the log table.
    """
    schema = StructType(
        [
            StructField("processed_time", TimestampType(), True),
            StructField("num_inserted", IntegerType(), True),
            StructField("num_updated", IntegerType(), True),
            StructField("num_deleted", IntegerType(), True),
        ]
    )
    spark.createDataFrame(log_data, schema=schema).write.format("delta").mode(
        "append"
    ).saveAsTable(log_table)

In [0]:
def check_for_new_predicted_masks() -> list:
    """
    Function that checks for new predicted mask. Returns a list of new predicted masks.
    """
    all_masks = [
        f.path.rstrip("/").split("/")[-1]
        for f in dbutils.fs.ls(mask_path)
        if f.path.endswith(".png")
    ]
    
    processed_masks_df = spark.read.table(table).select("source_file")
    processed_masks = [row["source_file"] for row in processed_masks_df.collect()]

    return [mask for mask in all_masks if mask not in processed_masks]

In [0]:
def get_buildings(bbox: str):
    """
    Function that returns a GeoJSON of buildings within a bbox.
    """
    bbox_str = ", ".join(map(str, bbox)) # Fjerner klammeparenteser
    url = f"https://openwms.statkart.no/skwms1/wms.fkb?VERSION=1.3.0&service=WMS&request=GetMap&Format=image/png&GetFeatureInfo=text/plain&CRS=EPSG:25833&Layers=bygning&BBox={bbox_str}&width=369&height=369"
    response = requests.get(url, timeout=10)
    time.sleep(2)
    img_gray = Image.open(BytesIO(response.content)).convert("L")
    arr = np.array(img_gray)
    inv_arr = np.where(arr == 255, 0, 255).astype(np.uint8)
    return inv_arr


In [0]:
def read_bbox_from_table(nodeid: str) -> DataFrame:
    df = spark.read.table(endepunkt_silver_table).filter(col("nodeid") == nodeid).select(col("bbox")).first().bbox
    return df

In [0]:
def write_to_sdf(predicted_masks: list) -> DataFrame:
    """
    Read GeoJSON and write one merged MultiPolygon row to SDF with centroid.
    """
    records = []
    for mask in predicted_masks:
        bbox = read_bbox_from_table(mask[11:-4])

        img = Image.open(f"{mask_path}/{mask}").convert("L")
        arr = np.array(img)
        inv_arr = get_buildings(bbox)
        result_arr = np.clip(arr.astype(int) - inv_arr.astype(int), 0, 255).astype(np.uint8)
        count_255 = int((result_arr == 255).sum())  

        records.append(
            {
                "white_pixels": count_255,
                "turning_space": bool(count_255 > 0),
                "bbox": bbox,
                "source_file": mask,
            }
        )

    sdf = spark.createDataFrame(records)

    # Add metadata
    sdf = sdf.withColumn(
        "row_hash", sha2(concat_ws("||", *sdf.columns), 256)
    ).withColumn("ingest_time", current_timestamp())

    return sdf

In [0]:
def write_delta_table(sdf: DataFrame):
    """
    Write delta table from spark dataframe.
    """
    if not spark.catalog.tableExists(table):
        sdf.write.format("delta").mode("overwrite").saveAsTable(table)
    else:
        delta_tbl = DeltaTable.forName(spark, table)
        delta_tbl.alias("target").merge(
            source=sdf.alias("source"), condition="target.row_hash = source.row_hash"
        ).whenMatchedUpdate(
            condition="target.row_hash != source.row_hash",
            set={col: f"source.{col}" for col in sdf.columns},
        ).whenNotMatchedInsert(
            values={col: f"source.{col}" for col in sdf.columns}
        ).execute()

In [0]:
def write_to_delta_table(predicted_masks: DataFrame):
    """
    Updates the delta table and logs the predicted mask.
    """
    table_exists = False
    if spark.catalog.tableExists(table):
        delta_tbl = DeltaTable.forName(spark, table)
        version_before = delta_tbl.history(1).select("version").collect()[0][0]
        table_exists = True

    if predicted_masks:
        sdf = write_to_sdf(predicted_masks)
        write_delta_table(sdf)

    if table_exists:
        version_after = delta_tbl.history(1).select("version").collect()[0][0]
        if version_after > version_before:
            metrics = delta_tbl.history(1).select("operationMetrics").collect()[0][0]
            updated = int(metrics.get("numTargetRowsUpdated", 0))
            inserted = int(metrics.get("numTargetRowsInserted", 0))
            deleted = int(metrics.get("numTargetRowsDeleted", 0))
            print(f"Updated: {updated}, Inserted: {inserted}, Deleted: {deleted}")
        else:
            inserted, updated, deleted = 0, 0, 0
            print("No new Delta version found after merge.")
    else:
        inserted, updated, deleted = sdf.count(), 0, 0
        print(f"Updated: {updated}, Inserted: {inserted}, Deleted: {deleted}")

    log_predicted_mask(log_data=[(datetime.now(), inserted, updated, deleted)])

In [0]:
def main():
    """
    Reads predicted masks and writes them to delta table.
    """
    predicted_masks = check_for_new_predicted_masks()
    write_to_delta_table(predicted_masks)

In [0]:
main()