In [0]:
from pyspark.sql import functions as F, types as T, Window
import pandas as pd

enriched = spark.table("genealogy.bronze_gedcom_enriched")

# Keep only lines that belong to an event AND are not inside ignored subtrees (e.g. SOUR)
in_scope = enriched.where(
    F.col("event_line_no").isNotNull() & (~F.col("blocked"))
)

# ----------------------------
# 1) Build one row per event
# ----------------------------

# Event roots (one row per event occurrence)
events = in_scope.where(
    (F.col("line_no") == F.col("event_line_no")) &
    (F.col("record_xref").like("@I%") | F.col("record_xref").like("@F%"))
).select(
    "record_xref",
    F.col("event_line_no"),
    F.col("event_tag").alias("event_type"),
    F.col("value").alias("event_value")
)

# Helper: pick FIRST value by line_no for a given attribute tag per event
def first_attr(df, tag_name, out_col):
    w = Window.partitionBy("record_xref", "event_line_no").orderBy("line_no")
    return (
        df.where(F.col("tag") == tag_name)
          .withColumn("rn", F.row_number().over(w))
          .where(F.col("rn") == 1)
          .select(
              "record_xref",
              F.col("event_line_no"),
              F.col("value").alias(out_col)
          )
    )

event_date = first_attr(in_scope, "DATE", "event_date")
event_place = first_attr(in_scope, "PLAC", "event_place")
event_subtype = first_attr(in_scope, "TYPE", "event_subtype")

# ----------------------------
# 2) Assemble NOTE text (NOTE + CONC/CONT), ignoring anything under SOUR because it's already filtered out
# ----------------------------

note_pieces = in_scope.where(F.col("tag").isin(["NOTE", "CONC", "CONT"])) \
    .where(F.col("note_root_line_no").isNotNull()) \
    .select(
        "record_xref",
        F.col("event_line_no"),
        "note_root_line_no",
        "line_no",
        "tag",
        "value"
    )

# Collect and sort pieces per NOTE root, then fold into a string:
# NOTE and CONC: append directly; CONT: prepend newline then append
note_by_root = (
    note_pieces
    .groupBy("record_xref", "event_line_no", "note_root_line_no")
    .agg(F.sort_array(F.collect_list(F.struct("line_no", "tag", "value"))).alias("pieces"))
    .withColumn(
        "note_text",
        F.expr("""
            aggregate(
              pieces,
              '',
              (acc, x) ->
                case
                  when x.tag = 'NOTE' then concat(acc, coalesce(x.value, ''))
                  when x.tag = 'CONC' then concat(acc, coalesce(x.value, ''))
                  when x.tag = 'CONT' then concat(acc, '\n', coalesce(x.value, ''))
                  else acc
                end
            )
        """)
    )
    .select("record_xref", "event_line_no", "note_root_line_no", "note_text")
)

# If multiple NOTE roots exist under the same event, join them with blank lines (ordered by note_root_line_no)
event_note = (
    note_by_root
    .groupBy("record_xref", "event_line_no")
    .agg(F.sort_array(F.collect_list(F.struct("note_root_line_no", "note_text"))).alias("notes"))
    .withColumn("event_note", F.expr("array_join(transform(notes, x -> x.note_text), '\n\n')"))
    .select("record_xref", "event_line_no", "event_note")
)

# ----------------------------
# 3) Define windows for event identifier
# ----------------------------
event_seq_window = (
    Window
    .partitionBy("record_xref")
    .orderBy("event_line_no")
)

type_seq_window = (
    Window
    .partitionBy("record_xref", "event_type")
    .orderBy("event_line_no")
)

# ----------------------------
# 4) Final event output
# ----------------------------
final_events = (
    events
    .join(event_date, on=["record_xref", "event_line_no"], how="left")
    .join(event_place, on=["record_xref", "event_line_no"], how="left")
    .join(event_note, on=["record_xref", "event_line_no"], how="left")
    .join(event_subtype, on=["record_xref", "event_line_no"], how="left")
)

final_events = (
    final_events
    .withColumn(
        "event_sequence_in_record",
        F.row_number().over(event_seq_window)
    )
    .withColumn(
        "event_id",
        F.sha2(
            F.concat_ws(
                "|",
                F.col("record_xref"),
                F.col("event_type"),
                F.col("event_sequence_in_record").cast("string")
            ),
            256
        )
    )
    .withColumn(
        "event_ordinal_within_type",
        F.row_number().over(type_seq_window)
    )
)

#display(final_events.orderBy("record_xref", "event_id_line_no"))

target_table = "genealogy.silver_event"

final_events = final_events.withColumn(
    "event_place_variant_id", # this will be populated later when place variants table is generated
    F.lit(None).cast("string")
)

(final_events
  .write
  .format("delta")
  .mode("overwrite")              # replace existing table
  .option("overwriteSchema", "true")
  .partitionBy("event_type")
  .saveAsTable(target_table)
)
