In [None]:
from pathlib import Path, PurePosixPath
from datetime import datetime
from zoneinfo import ZoneInfo
import tarfile
import xml.etree.ElementTree as ET

try:
    from psycopg2.extras import execute_values
except Exception:
    execute_values = None


# ----------------------------
# 1) Schema-Änderung (einmalig)
# ----------------------------
def ensure_actual_columns(conn):
    # Spalten idempotent anlegen
    with conn.cursor() as cur:
        cur.execute("""
            ALTER TABLE public.stops
                ADD COLUMN IF NOT EXISTS actual_arrival   timestamptz,
                ADD COLUMN IF NOT EXISTS actual_departure timestamptz,
                ADD COLUMN IF NOT EXISTS cancelled_arrival   timestamptz,
                ADD COLUMN IF NOT EXISTS cancelled_departure timestamptz,
                ADD COLUMN IF NOT EXISTS arrival_cs   text,
                ADD COLUMN IF NOT EXISTS departure_cs text;
        """)
        # Constraint idempotent hinzufügen (ohne IF NOT EXISTS, da nicht überall unterstützt)
        cur.execute("""
            DO $$
            BEGIN
                ALTER TABLE public.stops
                    ADD CONSTRAINT stops_cs_check
                    CHECK (
                        (arrival_cs   IS NULL OR arrival_cs   IN ('a','p','c')) AND
                        (departure_cs IS NULL OR departure_cs IN ('a','p','c'))
                    );
            EXCEPTION
                WHEN duplicate_object THEN
                    NULL;
            END $$;
        """)
    conn.commit()


# ----------------------------
# 2) Helpers: ct/clt -> datetime + cs + snapshot_ts aus Member-Pfad
# ----------------------------
BERLIN_TZ = ZoneInfo("Europe/Berlin")

def parse_db_ct(ct: str | None):
    """Format: YYMMDDHHMM -> aware datetime (Europe/Berlin)"""
    if not ct:
        return None
    ct = ct.strip()
    if len(ct) != 10 or not ct.isdigit():
        return None
    dt = datetime.strptime(ct, "%y%m%d%H%M")
    return dt.replace(tzinfo=BERLIN_TZ)

def parse_cs(x: str | None) -> str | None:
    if not x:
        return None
    x = x.strip().lower()
    return x if x in {"a", "p", "c"} else None

def snapshot_ts_from_member_name(xml_member_name: str) -> datetime | None:
    """
    Erwartet Member-Pfade wie: "2510011345/...._change.xml"
    -> snapshot_ts = 2025-10-01 13:45 Europe/Berlin
    """
    p = PurePosixPath(xml_member_name)
    if not p.parts:
        return None
    head = p.parts[0]
    if len(head) == 10 and head.isdigit():
        dt = datetime.strptime(head, "%y%m%d%H%M")
        return dt.replace(tzinfo=BERLIN_TZ)
    return None


# ----------------------------
# 3) Batch-Update via TEMP staging table + TEMP latest_snapshot
# ----------------------------
def init_stage_tables(conn):
    with conn.cursor() as cur:
        cur.execute("""
            CREATE TEMP TABLE IF NOT EXISTS _stops_change_stage (
                stop_id             text PRIMARY KEY,
                snapshot_ts         timestamptz NULL,
                actual_arrival      timestamptz NULL,
                actual_departure    timestamptz NULL,
                cancelled_arrival   timestamptz NULL,
                cancelled_departure timestamptz NULL,
                arrival_cs          text NULL,
                departure_cs        text NULL
            );
        """)
        # bleibt über alle Batches erhalten (pro Connection)
        cur.execute("""
            CREATE TEMP TABLE IF NOT EXISTS _stops_latest_snapshot (
                stop_id     text PRIMARY KEY,
                snapshot_ts timestamptz NOT NULL
            );
        """)
    conn.commit()


def apply_batch(conn, batch_rows):
    """
    batch_rows:
      [(stop_id, snapshot_ts,
        actual_arrival, actual_departure,
        cancelled_arrival, cancelled_departure,
        arrival_cs, departure_cs), ...]
    """
    if not batch_rows:
        return 0

    init_stage_tables(conn)

    with conn.cursor() as cur:
        cur.execute("TRUNCATE TABLE _stops_change_stage;")

        if execute_values is not None:
            execute_values(
                cur,
                """
                INSERT INTO _stops_change_stage (
                    stop_id, snapshot_ts,
                    actual_arrival, actual_departure,
                    cancelled_arrival, cancelled_departure,
                    arrival_cs, departure_cs
                )
                VALUES %s
                """,
                batch_rows,
                page_size=10_000
            )
        else:
            cur.executemany(
                """
                INSERT INTO _stops_change_stage (
                    stop_id, snapshot_ts,
                    actual_arrival, actual_departure,
                    cancelled_arrival, cancelled_departure,
                    arrival_cs, departure_cs
                )
                VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                """,
                batch_rows
            )

        # 1) latest snapshot pro stop_id updaten (global über alle Batches)
        cur.execute("""
            INSERT INTO _stops_latest_snapshot (stop_id, snapshot_ts)
            SELECT
                stop_id,
                COALESCE(snapshot_ts, '-infinity'::timestamptz)
            FROM _stops_change_stage
            ON CONFLICT (stop_id) DO UPDATE
            SET snapshot_ts = GREATEST(_stops_latest_snapshot.snapshot_ts, EXCLUDED.snapshot_ts);
        """)

        # 2) Nur anwenden, wenn dieser Stage-Datensatz zum aktuell neuesten Snapshot gehört
        cur.execute("""
            UPDATE public.stops s
            SET
                actual_arrival      = COALESCE(st.actual_arrival,      s.actual_arrival),
                actual_departure    = COALESCE(st.actual_departure,    s.actual_departure),
                cancelled_arrival   = COALESCE(st.cancelled_arrival,   s.cancelled_arrival),
                cancelled_departure = COALESCE(st.cancelled_departure, s.cancelled_departure),
                arrival_cs          = COALESCE(st.arrival_cs,          s.arrival_cs),
                departure_cs        = COALESCE(st.departure_cs,        s.departure_cs)
            FROM _stops_change_stage st
            JOIN _stops_latest_snapshot ls
              ON ls.stop_id = st.stop_id
            WHERE s.stop_id = st.stop_id
              AND COALESCE(st.snapshot_ts, '-infinity'::timestamptz) = ls.snapshot_ts;
        """)
        updated = cur.rowcount

    conn.commit()
    return updated


# ----------------------------
# 4) Main: Changes einlesen & updaten
# ----------------------------
def process_change_archives(conn, archives_dir: str | Path, pattern: str = "*.tar.gz", batch_size: int = 50_000):
    ensure_actual_columns(conn)
    init_stage_tables(conn)

    # Dedupe im Batch: max(snapshot_ts) wins; bei Gleichstand: non-null merge (inkl. cs)
    batch: dict[str, tuple] = {}

    def snap_key(dt: datetime | None) -> datetime:
        return dt if dt is not None else datetime.min.replace(tzinfo=BERLIN_TZ)

    n_change_files = 0
    n_s_nodes = 0
    n_updates = 0

    for archive_path, xml_member_name, root in iter_xml_roots(archives_dir, pattern=pattern):
        if not (xml_member_name.endswith("_change.xml") or xml_member_name.endswith("change.xml")):
            continue

        n_change_files += 1
        snapshot_ts = snapshot_ts_from_member_name(xml_member_name)

        for s in root.findall("./s"):
            stop_id = s.get("id")
            if not stop_id:
                continue

            ar = s.find("ar")
            dp = s.find("dp")

            ar_ct  = ar.get("ct")  if ar is not None else None
            dp_ct  = dp.get("ct")  if dp is not None else None
            ar_clt = ar.get("clt") if ar is not None else None
            dp_clt = dp.get("clt") if dp is not None else None

            ar_cs = parse_cs(ar.get("cs") if ar is not None else None)
            dp_cs = parse_cs(dp.get("cs") if dp is not None else None)

            actual_arrival      = parse_db_ct(ar_ct)
            actual_departure    = parse_db_ct(dp_ct)
            cancelled_arrival   = parse_db_ct(ar_clt)
            cancelled_departure = parse_db_ct(dp_clt)

            # Wenn wirklich gar nichts da ist, skip (cs zählt als "Daten vorhanden")
            if (
                actual_arrival is None and actual_departure is None
                and cancelled_arrival is None and cancelled_departure is None
                and ar_cs is None and dp_cs is None
            ):
                continue

            new_row = (
                stop_id, snapshot_ts,
                actual_arrival, actual_departure,
                cancelled_arrival, cancelled_departure,
                ar_cs, dp_cs
            )

            old = batch.get(stop_id)
            if old is None:
                batch[stop_id] = new_row
            else:
                # FIX: korrekt 8 Werte entpacken
                _, old_snap, old_aa, old_ad, old_ca, old_cd, old_ar_cs, old_dp_cs = old

                new_snap_key = snap_key(snapshot_ts)
                old_snap_key = snap_key(old_snap)

                if new_snap_key > old_snap_key:
                    batch[stop_id] = new_row
                elif new_snap_key == old_snap_key:
                    # gleicher Snapshot -> “mehr Info wins” (non-null merge) inkl. cs
                    batch[stop_id] = (
                        stop_id, old_snap,
                        old_aa if old_aa is not None else actual_arrival,
                        old_ad if old_ad is not None else actual_departure,
                        old_ca if old_ca is not None else cancelled_arrival,
                        old_cd if old_cd is not None else cancelled_departure,
                        old_ar_cs if old_ar_cs is not None else ar_cs,
                        old_dp_cs if old_dp_cs is not None else dp_cs,
                    )
                # else: älterer Snapshot -> ignorieren

            n_s_nodes += 1

            if len(batch) >= batch_size:
                n_updates += apply_batch(conn, list(batch.values()))
                batch.clear()

    if batch:
        n_updates += apply_batch(conn, list(batch.values()))
        batch.clear()

    print("Done.")
    print(f"  change files processed: {n_change_files}")
    print(f"  stop updates staged:    {n_s_nodes}")
    print(f"  rows updated (SQL):     {n_updates}")


# ----------------------------
# 5) AUSFÜHREN
# ----------------------------
changes_dir = Path("../timetable_changes")
process_change_archives(conn, changes_dir, pattern="*.tar.gz", batch_size=50_000)

Done.
  change files processed: 540668
  stop updates staged:    1346664
  rows updated (SQL):     953744


Wichtig: ``conn.close()``, ich hatte mal einen Deadlock Error, der ggf. dadurch entstanden ist, dass mehrere Connections aus verschiedenen Notebooks gleichzeitig vorhanden waren

In [4]:
conn.close()