In [0]:
%pip install pyproj rasterio

In [0]:
import numpy as np
import pandas as pd
from PIL import Image
from rasterio.features import rasterize
from rasterio.transform import from_bounds
from datetime import datetime
from typing import List
from shapely import wkt
from shapely.ops import unary_union
from shapely.geometry import GeometryCollection, Polygon, MultiPolygon
from delta.tables import DeltaTable
from pyspark.sql import DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import (
    StructType,
    StructField,
    IntegerType,
    ArrayType,
    BooleanType,
    StringType,
)

from src.data.data_utils import write_delta_table
from src.data.log_utils import check_for_new_predicted_masks, log_predicted_masks
from src.data.transformation_utils import get_srid, transform_to_epsg

In [0]:
mask_path = "/Volumes/land_auto-gen-kart_dev/external_dev/static_data/DL_bildesegmentering/predicted_snuplasser"
catalog_dev = "`land_auto-gen-kart_dev`"
schema_dev = "dl_bildesegmentering"
spark.sql(f"USE CATALOG {catalog_dev}")
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {schema_dev}")
spark.sql(f"USE SCHEMA {schema_dev}")
log_table = "logs_predicted_snuplasser"
table = "predicted_snuplasser_bronze"
endepunkt_silver_table = "endepunkt_silver"

land_catalog = "land_ngis_dev"
bygning_schema = "silver_fkbbygning"
bygning_table = "bygning"
dataset = f"{land_catalog}.{bygning_schema}.{bygning_table}"

In [0]:
q = f"""
CREATE TABLE IF NOT EXISTS {log_table} (
  processed_time TIMESTAMP,
  num_inserted INT,
  num_updated INT,
  num_deleted INT
) USING DELTA
"""
spark.sql(q)

In [0]:
q = f"""
CREATE TABLE IF NOT EXISTS {table} (
    row_hash STRING,
    white_pixels INT,
    turning_space BOOLEAN,
    bbox ARRAY<DOUBLE>,
    source_file STRING,
    ingest_time TIMESTAMP
) USING DELTA
"""
spark.sql(q)

In [0]:
def get_buildings(bbox: str, polygons: List) -> np.ndarray:
    """
    Returnerer et array med bygninger i bboxen.
    """
    minx, miny, maxx, maxy = bbox
    out_shape = (369, 369)
    transform = from_bounds(minx, miny, maxx, maxy, out_shape[1], out_shape[0])

    # Rasteriserer polygonene til et raster med samme størrelse som masken.
    geoms = [poly for poly in polygons]
    mask = rasterize(
        [(geom, 1) for geom in geoms],
        out_shape=out_shape,
        transform=transform,
        fill=0,
        dtype="uint8",
    )
    arr = np.array(mask)

    return arr

In [0]:
def flatten_geometries(geom):
    """
    Returnerer en liste med geometrier som kan brukes til å rasterize masken.
    """
    if geom is None:
        return []
    elif isinstance(geom, (Polygon, MultiPolygon)):
        return [geom]
    elif isinstance(geom, GeometryCollection):
        # recursively extract polygons / multipolygons
        geoms = []
        for g in geom.geoms:
            geoms.extend(flatten_geometries(g))
        return geoms
    else:
        return []

In [0]:
def write_to_sdf(predicted_masks: list) -> DataFrame:
    """
    Returnerer en spark dataframe med data fra deltatabellen.
    """
    # Lager en spark dataframe fra maskene
    df = spark.createDataFrame([(m,) for m in predicted_masks], ["mask"])
    df = df.withColumn("nodeid", expr("substring(mask, 12, length(mask)-15)"))

    silver = (
        spark.read.table(endepunkt_silver_table)
        .withColumn(
            "bbox_geom",
            expr(
                """
                ST_PolygonFromEnvelope(
                    bbox[0], bbox[1], bbox[2], bbox[3]
                )
            """
            ),
        )
        .select("nodeid", "geometry", "kommune_id", "bbox", "bbox_geom")
    )

    buildings = (
        spark.read.table(dataset)
        .select("kommunenummer", "geometry")
        .withColumnRenamed("geometry", "building_geometry")
    )
    buildings = transform_to_epsg(
        buildings,
        col="building_geometry",
        source_srid=get_srid(land_catalog, bygning_schema, bygning_table),
        target_srid="EPSG:25833",
    )

    joined_all = (
        df.join(silver, "nodeid")
        .join(buildings, col("kommune_id") == col("kommunenummer"), "left")
        .withColumn("intersects", expr("ST_Intersects(bbox_geom, building_geometry)"))
    )

    grouped_df = joined_all.groupBy("nodeid", "bbox", "mask").agg(
        collect_list(
            when(col("intersects"), expr("ST_AsText(building_geometry)"))
        ).alias("building_wkts")
    )

    schema = StructType(
        [
            StructField("white_pixels", IntegerType(), False),
            StructField("turning_space", BooleanType(), False),
            StructField("bbox", ArrayType(IntegerType()), False),
            StructField("source_file", StringType(), False),
        ]
    )

    def process_partition(pdf_iter):
        for pdf in pdf_iter:
            out_rows = []
            for _, row in pdf.iterrows():
                polygons = [wkt.loads(g) for g in row["building_wkts"] if g]
                merged = unary_union(polygons) if polygons else None
                geoms = flatten_geometries(merged)

                arr = np.array(Image.open(f"{mask_path}/{row['mask']}").convert("L"))

                inv_arr = get_buildings(row["bbox"], geoms)
                result_arr = np.clip(
                    arr.astype(int) - inv_arr.astype(int), 0, 255
                ).astype(np.uint8)

                count_255 = int((result_arr == 255).sum())

                out_rows.append(
                    {
                        "white_pixels": count_255,
                        "turning_space": count_255 > 0,
                        "bbox": row["bbox"],
                        "source_file": row["mask"],
                    }
                )

            yield pd.DataFrame(out_rows)

    records_df = grouped_df.mapInPandas(process_partition, schema)
    sdf = records_df.withColumn(
        "row_hash", sha2(concat_ws("||", *records_df.columns), 256)
    ).withColumn("ingest_time", current_timestamp())

    return sdf

In [0]:
def write_to_delta_table(predicted_masks: DataFrame):
    """
    Skriver logg med antall insert, update og deleter i deltatabellen og lagrer denne.
    """
    table_exists = False
    if spark.catalog.tableExists(table):
        delta_tbl = DeltaTable.forName(spark, table)
        version_before = delta_tbl.history(1).select("version").collect()[0][0]
        table_exists = True

    if predicted_masks:
        sdf = write_to_sdf(predicted_masks)
        write_delta_table(sdf, table, id_col="row_hash")

    if table_exists:
        version_after = delta_tbl.history(1).select("version").collect()[0][0]
        if version_after > version_before:
            metrics = delta_tbl.history(1).select("operationMetrics").collect()[0][0]
            updated = int(metrics.get("numTargetRowsUpdated", 0))
            inserted = int(metrics.get("numTargetRowsInserted", 0))
            deleted = int(metrics.get("numTargetRowsDeleted", 0))
            print(f"Updated: {updated}, Inserted: {inserted}, Deleted: {deleted}")
        else:
            inserted, updated, deleted = 0, 0, 0
            print("No new Delta version found after merge.")
    else:
        inserted, updated, deleted = sdf.count(), 0, 0
        print(f"Updated: {updated}, Inserted: {inserted}, Deleted: {deleted}")

    log_predicted_masks([(datetime.now(), inserted, updated, deleted)], log_table)

In [0]:
predicted_masks = check_for_new_predicted_masks(mask_path, table)
write_to_delta_table(predicted_masks)