In [1]:
from __future__ import annotations

from pathlib import Path
import tarfile
import xml.etree.ElementTree as ET
from typing import Iterator

# Iterator Funktion
def iter_xml_roots(archives_dir: str | Path, pattern: str = "*.tar.gz"
                   ) -> Iterator[tuple[Path, str, ET.Element]]:
    """
    Iteriert über alle XML-Dateien in allen .tar.gz-Archiven in archives_dir.

    Yields:
        (archive_path, xml_member_name, xml_root_element)
    """
    archives_dir = Path(archives_dir)

    for archive_path in sorted(archives_dir.glob(pattern)):
        try:
            with tarfile.open(archive_path, mode="r:gz") as tar:
                # Iteriert streamend über Members (speichersparender als getmembers())
                for member in tar:
                    if not member.isfile():
                        continue
                    if not member.name.lower().endswith(".xml"):
                        continue

                    extracted = tar.extractfile(member)
                    if extracted is None:
                        continue

                    try:
                        with extracted as f:
                            tree = ET.parse(f)
                            yield archive_path, member.name, tree.getroot()
                    except ET.ParseError as e:
                        print(f"[WARN] XML ParseError in {archive_path}::{member.name}: {e}")

        except (tarfile.ReadError, OSError) as e:
            print(f"[WARN] Konnte Archiv nicht lesen: {archive_path} ({e})")

In [2]:
def extract_s_id_and_pts(root: ET.Element):
    """
    Returns a list of dicts:
      { "id": <s/@id>, "ar_pt": <ar/@pt or None>, "dp_pt": <dp/@pt or None> }
    """
    station_name = root.get("station", None)
    if not station_name:
        return
    rows = []
    for s in root.findall("s"):
        s_id = s.get("id")

        ar = s.find("ar")
        dp = s.find("dp")

        ar_pt = ar.get("pt") if ar is not None else None
        dp_pt = dp.get("pt") if dp is not None else None

        rows.append({"id": s_id, "ar_pt": ar_pt, "dp_pt": dp_pt})
    return station_name, rows

In [3]:
TIMETABLES_PATH = "../timetables"

In [4]:
import psycopg2
conn = psycopg2.connect(host="localhost", dbname="postgres", user="postgres", password="1234")

In [5]:
for path, data_name, xml_root in iter_xml_roots(TIMETABLES_PATH):
    result = extract_s_id_and_pts(xml_root)
    if result is None:
        break

In [6]:
def ensure_stops_table(conn):
    create_sql = """
    CREATE TABLE IF NOT EXISTS stops (
        xml_member_name TEXT NOT NULL,
        stop_id         TEXT NOT NULL,

        eva             BIGINT NOT NULL REFERENCES stationen(eva),
        ar_ts           TIMESTAMPTZ,
        dp_ts           TIMESTAMPTZ,

        PRIMARY KEY (xml_member_name, stop_id),
        CHECK (ar_ts IS NULL OR dp_ts IS NULL OR dp_ts >= ar_ts)
    );

    CREATE INDEX IF NOT EXISTS stops_eva_ar_idx
      ON stops (eva, ar_ts);

    CREATE INDEX IF NOT EXISTS stops_stop_id_idx
      ON stops (stop_id);
    """
    with conn.cursor() as cur:
        cur.execute(create_sql)
    conn.commit()


In [7]:
import re
import difflib
from dataclasses import dataclass
from datetime import datetime
from zoneinfo import ZoneInfo

try:
    # schneller für große Batches
    from psycopg2.extras import execute_values
except Exception:
    execute_values = None


# ----------------------------
# Einstellungen
# ----------------------------
TIMEZONE = "Europe/Berlin"

MATCH_THRESHOLD = 0.85      # <- ab welcher "Genauigkeit" gematcht wird (0..1)
AMBIGUITY_DELTA = 0.02      # wenn best - second_best < delta -> "ambiguous"
BATCH_SIZE = 5000           # Insert-Batches


# ----------------------------
# Parsing von pt: YYMMDDHHMM -> aware datetime
# ----------------------------
_pt_re = re.compile(r"^\d{10}$")  # YYMMDDHHMM

def parse_pt(pt: str | None, tz: ZoneInfo) -> datetime | None:
    if pt is None:
        return None
    pt = pt.strip()
    if not _pt_re.match(pt):
        return None

    try:
        yy = int(pt[0:2])
        year = 2000 + yy
        month = int(pt[2:4])
        day = int(pt[4:6])
        hour = int(pt[6:8])
        minute = int(pt[8:10])
        return datetime(year, month, day, hour, minute, tzinfo=tz)
    except ValueError:
        return None


# ----------------------------
# Normalisierung & Matching
# ----------------------------
_umlaut_map = str.maketrans({"ä": "ae", "ö": "oe", "ü": "ue", "ß": "ss"})
_punct_re = re.compile(r"[^a-z0-9\s]")
_ws_re = re.compile(r"\s+")

# Achtung: name wird in normalize_name() lowercased, daher alles lowercase
_stopwords = {
    "berlin", "s", "u", "s+u", "u+s",
}

def normalize_name(name: str) -> str:
    name = name.lower().translate(_umlaut_map)

    # --- Synonyme/Abkürzungen vereinheitlichen ---
    # Straße (Vollform)
    name = name.replace("straße", "strasse")

    # Betriebsbahnhof
    name = name.replace("betriebsbf", "betriebsbahnhof")

    # Hauptbahnhof -> hbf
    name = name.replace("hauptbahnhof", "hbf")

    # Bahnhof/Bf vereinheitlichen
    name = re.sub(r"\bbahnhof\b", "bf", name)
    name = re.sub(r"\bbhf\b", "bf", name)
    name = re.sub(r"\bbf\.?\b", "bf", name)

    # --- allgemeine Bereinigung ---
    name = name.replace("&", " und ")
    name = _punct_re.sub(" ", name)   # entfernt ., -, etc.

    # Jetzt: "str." / "str" als Abkürzung expandieren (auch wenn angeklebt)
    # Beispiele: "feuerbachstr." -> "feuerbachstr" -> "feuerbachstrasse"
    name = re.sub(r"(?<=\w)str\b", "strasse", name)  # angeklebt am Wortende
    name = re.sub(r"\bstr\b", "strasse", name)       # eigenes Token

    # Komposita auftrennen: "...strasse" -> "... strasse"
    name = re.sub(r"(?<!\s)strasse\b", " strasse", name)

    name = _ws_re.sub(" ", name).strip()
    tokens = [t for t in name.split() if t and t not in _stopwords]
    return " ".join(tokens)



def token_set(s: str) -> set[str]:
    return {t for t in s.split() if len(t) >= 2}

def similarity(a: str, b: str) -> float:
    if not a or not b:
        return 0.0
    if a == b:
        return 1.0

    # 1) substring heuristic (gut für "berlin alexanderplatz" vs "alexanderplatz")
    if a in b or b in a:
        short, long = (a, b) if len(a) <= len(b) else (b, a)
        substring_score = 0.90 + 0.10 * (len(short) / max(1, len(long)))
    else:
        substring_score = 0.0

    # 2) sequence similarity
    seq_score = difflib.SequenceMatcher(None, a, b).ratio()

    # 3) token overlap (F1-ähnlich)
    ta, tb = token_set(a), token_set(b)
    token_score = (2 * len(ta & tb) / (len(ta) + len(tb))) if ta and tb else 0.0

    return max(substring_score, seq_score, token_score)


@dataclass(frozen=True)
class StationRec:
    eva: int
    raw_name: str
    norm_name: str
    tokens: tuple[str, ...]


def load_stations(conn) -> list[StationRec]:
    with conn.cursor() as cur:
        cur.execute("SELECT eva, name FROM stationen;")
        rows = cur.fetchall()

    stations: list[StationRec] = []
    for eva, name in rows:
        if not name:
            continue
        norm = normalize_name(name)
        toks = tuple(sorted(token_set(norm)))
        stations.append(StationRec(int(eva), name, norm, toks))
    return stations


def build_token_index(stations: list[StationRec]) -> dict[str, list[int]]:
    """token -> list of indices into `stations`"""
    idx: dict[str, list[int]] = {}
    for i, st in enumerate(stations):
        for t in st.tokens:
            idx.setdefault(t, []).append(i)
    return idx


def best_station_match(
    query_name: str,
    stations: list[StationRec],
    token_index: dict[str, list[int]],
    threshold: float,
    ambiguity_delta: float,
) -> tuple[int | None, float, str | None, bool]:
    """
    Returns (eva_or_None, best_score, matched_raw_name, is_ambiguous)
    """
    q_norm = normalize_name(query_name)
    if not q_norm:
        return None, 0.0, None, False

    q_tokens = token_set(q_norm)

    # Kandidaten über Token-Index einschränken
    cand_indices: set[int] = set()
    for t in q_tokens:
        cand_indices.update(token_index.get(t, []))

    # Fallback: wenn gar keine Tokens/keine Kandidaten -> global scan
    if not cand_indices:
        cand_indices = set(range(len(stations)))

    best_eva: int | None = None
    best_score: float = 0.0
    best_name: str | None = None
    best_norm: str | None = None

    second_best_score: float = 0.0

    for i in cand_indices:
        st = stations[i]
        sc = similarity(q_norm, st.norm_name)

        if sc > best_score:
            second_best_score = best_score
            best_eva, best_score, best_name, best_norm = st.eva, sc, st.raw_name, st.norm_name
        elif sc == best_score and best_norm is not None:
            # Tie-break: kürzere norm_name bevorzugen (hilft z.B. "hbf" vs "hbf tief")
            if len(st.norm_name) < len(best_norm):
                best_eva, best_score, best_name, best_norm = st.eva, sc, st.raw_name, st.norm_name
        elif sc > second_best_score:
            second_best_score = sc

    if best_eva is None or best_score < threshold:
        return None, best_score, best_name, False

    is_ambiguous = (best_score - second_best_score) < ambiguity_delta
    return best_eva, best_score, best_name, is_ambiguous


# ----------------------------
# DB Insert/Upsert Stops
# ----------------------------
UPSERT_SQL = """
INSERT INTO stops (xml_member_name, stop_id, eva, ar_ts, dp_ts)
VALUES %s
ON CONFLICT (xml_member_name, stop_id) DO UPDATE
SET eva   = EXCLUDED.eva,
    ar_ts = EXCLUDED.ar_ts,
    dp_ts = EXCLUDED.dp_ts;
"""

def flush_batch(conn, batch: list[tuple], use_execute_values: bool = True) -> int:
    if not batch:
        return 0
    with conn.cursor() as cur:
        if use_execute_values and execute_values is not None:
            execute_values(cur, UPSERT_SQL, batch, page_size=2000)
        else:
            # Fallback ohne execute_values (muss auch das Composite-Key-Schema verwenden!)
            upsert_one = """
            INSERT INTO stops (xml_member_name, stop_id, eva, ar_ts, dp_ts)
            VALUES (%s, %s, %s, %s, %s)
            ON CONFLICT (xml_member_name, stop_id) DO UPDATE
            SET eva = EXCLUDED.eva, ar_ts = EXCLUDED.ar_ts, dp_ts = EXCLUDED.dp_ts;
            """
            cur.executemany(upsert_one, batch)
    conn.commit()
    return len(batch)


# ----------------------------
# Main Import
# ----------------------------
def import_stops_from_archives(conn, archives_dir: str, pattern: str = "*.tar.gz"):
    ensure_stops_table(conn)

    tz = ZoneInfo(TIMEZONE)

    stations = load_stations(conn)
    token_index = build_token_index(stations)

    unmatched_station_names: set[str] = set()
    ambiguous_station_names: set[str] = set()

    # Cache, damit gleiche Station nicht immer neu gematcht wird
    match_cache: dict[str, tuple[int | None, bool]] = {}  # name -> (eva_or_None, ambiguous)

    # Optional: dedupe innerhalb eines Batches auf Composite-Key (verhindert CardinalityViolation)
    batch: dict[tuple[str, str], tuple] = {}  # (xml_member_name, stop_id) -> row

    total_upserted = 0
    total_seen_stops = 0

    for archive_path, xml_member_name, root in iter_xml_roots(archives_dir, pattern=pattern):
        res = extract_s_id_and_pts(root)
        if not res:
            continue

        station_name, stop_rows = res
        if not station_name:
            continue

        if station_name in match_cache:
            eva, is_ambiguous = match_cache[station_name]
        else:
            eva, score, matched_name, is_ambiguous = best_station_match(
                station_name, stations, token_index,
                threshold=MATCH_THRESHOLD,
                ambiguity_delta=AMBIGUITY_DELTA,
            )
            match_cache[station_name] = (eva, is_ambiguous)

        if eva is None:
            unmatched_station_names.add(station_name)
            continue
        if is_ambiguous:
            ambiguous_station_names.add(station_name)
            continue

        ap = str(archive_path)

        for r in stop_rows:
            stop_id = r.get("id")
            if not stop_id:
                continue

            ar_ts = parse_pt(r.get("ar_pt"), tz)
            dp_ts = parse_pt(r.get("dp_pt"), tz)

            key = (xml_member_name, stop_id)
            batch[key] = (xml_member_name, stop_id, eva, ar_ts, dp_ts)
            total_seen_stops += 1

            if len(batch) >= BATCH_SIZE:
                total_upserted += flush_batch(conn, list(batch.values()), use_execute_values=True)
                batch.clear()

    # Rest flushen
    total_upserted += flush_batch(conn, list(batch.values()), use_execute_values=True)
    batch.clear()

    return {
        "total_seen_stops": total_seen_stops,
        "total_upserted": total_upserted,
        "unmatched_station_names": unmatched_station_names,
        "ambiguous_station_names": ambiguous_station_names,
        "match_cache_size": len(match_cache),
    }


In [8]:
result = import_stops_from_archives(conn, archives_dir=TIMETABLES_PATH)
print("Seen stops:", result["total_seen_stops"])
print("Upserted stops:", result["total_upserted"])
print("Unmatched stations:", len(result["unmatched_station_names"]))
print("Ambiguous stations:", len(result["ambiguous_station_names"]))

# z.B. anzeigen
print("Beispiele unmatched:", list(sorted(result["unmatched_station_names"]))[:20])
print("Beispiele ambiguous:", list(sorted(result["ambiguous_station_names"]))[:20])

conn.close()

Seen stops: 2104080
Upserted stops: 2104080
Unmatched stations: 0
Ambiguous stations: 0
Beispiele unmatched: []
Beispiele ambiguous: []
