In [2]:
import os
import uuid
import random
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd


# ================================================================
# CONFIG
# ================================================================

MEDIUM_N_TRACES = 10_000
RANDOM_SEED = 123
DATA_DIR_MEDIUM = "./data_medium"

rng = np.random.default_rng(RANDOM_SEED)
random.seed(RANDOM_SEED)


# ================================================================
# TRANSITION GRAPH (EVENT-LEVEL)
# ================================================================

def build_medium_transition_graph() -> pd.DataFrame:
    """
    Defines allowed transitions between events, including the two
    branching points:
      - RiskCheckCompleted -> {AutoApprovalGranted, ManualReviewRequired}
      - DeliveryExceptionOccurred -> {ReroutePlanned, DelayLogged,
                                      DriverReassigned, OrderCancelled}
    """
    transitions = {
        "CustomerCreated": ["AddressAdded"],
        "AddressAdded": ["DeviceRegistered"],
        "DeviceRegistered": ["MerchantCreated"],
        "MerchantCreated": ["LocationActivated"],
        "LocationActivated": ["OrderCreated"],
        "OrderCreated": ["CartSnapshotSaved"],
        "CartSnapshotSaved": ["ExperimentBucketAssigned"],
        "ExperimentBucketAssigned": ["PromoValidated"],
        "PromoValidated": ["TaxCalculated"],
        "TaxCalculated": ["OrderConfirmed"],
        "OrderConfirmed": ["RiskCheckStarted"],
        "RiskCheckStarted": ["RiskScoreCalculated"],
        "RiskScoreCalculated": ["FraudCheckPerformed"],
        "FraudCheckPerformed": ["RiskCheckCompleted"],
        # Branch 1 (2-way)
        "RiskCheckCompleted": ["AutoApprovalGranted", "ManualReviewRequired"],
        "AutoApprovalGranted": ["KitchenTicketCreated"],
        "ManualReviewRequired": ["ManualReviewCompleted"],
        "ManualReviewCompleted": ["KitchenTicketCreated"],
        "KitchenTicketCreated": ["KitchenCookingStarted"],
        "KitchenCookingStarted": ["KitchenCookingFinished"],
        "KitchenCookingFinished": ["DispatchRequested"],
        "DispatchRequested": ["DriverShiftOnline"],
        "DriverShiftOnline": ["DriverAssigned"],
        "DriverAssigned": ["RoutePlanned"],
        "RoutePlanned": ["SegmentStarted"],
        "SegmentStarted": ["SegmentCompleted"],
        "SegmentCompleted": ["Delivered"],
        "Delivered": ["PaymentAuthorized", "DeliveryExceptionOccurred"],
        # Branch 2 (4-way)
        "DeliveryExceptionOccurred": [
            "ReroutePlanned",
            "DelayLogged",
            "DriverReassigned",
            "OrderCancelled",
        ],
        "ReroutePlanned": ["SegmentStarted"],
        "DelayLogged": ["PaymentAuthorized"],
        "DriverReassigned": ["RoutePlanned"],
        "OrderCancelled": [],
        "PaymentAuthorized": ["PaymentCaptured"],
        "PaymentCaptured": ["PaymentSettled"],
        "PaymentSettled": ["SupportTicketOpened", "ReviewSubmitted"],
        "SupportTicketOpened": ["SupportMessageAdded"],
        "SupportMessageAdded": ["SupportIssueResolved"],
        "SupportIssueResolved": ["ReviewSubmitted"],
        "ReviewSubmitted": ["RewardCredited", "AuditLogWritten"],
        "RewardCredited": ["AuditLogWritten"],
        "AuditLogWritten": [],
    }

    rows = []
    for src, dst_list in transitions.items():
        rows.append({"From": src, "To_List": str(dst_list)})

    return pd.DataFrame(rows, columns=["From", "To_List"])


# ================================================================
# HELPER: TIMESTAMP ADVANCER
# ================================================================

def make_ts_advancer(base_ts: pd.Timestamp):
    """
    Returns a closure next_ts(delta_minutes) that always moves forward.
    """
    current_ts = base_ts

    def next_ts(delta_minutes: int) -> pd.Timestamp:
        nonlocal current_ts
        current_ts = current_ts + pd.Timedelta(minutes=delta_minutes)
        return current_ts

    return next_ts


# ================================================================
# 1. CUSTOMER-RELATED TABLES
# ================================================================

def generate_customers_medium(n_customers: int = 3000) -> Tuple[
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
]:
    customer_ids = np.arange(1, n_customers + 1)

    base_dates = pd.Timestamp("2024-01-01") + pd.to_timedelta(
        rng.integers(0, 90, size=n_customers), unit="D"
    )
    created_ts = base_dates
    updated_ts = created_ts + pd.to_timedelta(
        rng.integers(1, 10, size=n_customers), unit="D"
    )

    customers = pd.DataFrame({
        "customer_id": customer_ids,
        "name": [f"Customer_{i}" for i in customer_ids],
        "region": rng.choice(["NORTH", "SOUTH", "EAST", "WEST"], size=n_customers),
        "credit_rating": rng.integers(1, 6, size=n_customers),
        "outstanding_amount": rng.uniform(0, 500, size=n_customers).round(2),
        "created_ts": created_ts,
        "updated_ts": updated_ts,
    })

    # Addresses
    addr_rows = []
    addr_id = 1
    for idx, row in customers.iterrows():
        cid = row["customer_id"]
        reg = row["region"]
        base = row["created_ts"]
        n_addr = rng.integers(1, 4)
        for _ in range(n_addr):
            addr_rows.append({
                "address_id": addr_id,
                "cust_ref": cid,
                "address_line": f"{addr_id} Main Street",
                "zone": reg,
                "created_ts": base + pd.Timedelta(days=int(rng.integers(0, 30))),
                "updated_ts": base + pd.Timedelta(days=int(rng.integers(5, 40))),
            })
            addr_id += 1
    customer_addresses = pd.DataFrame(addr_rows)

    # Devices
    dev_rows = []
    dev_id = 1
    for idx, row in customers.iterrows():
        cid = row["customer_id"]
        base = row["created_ts"]
        n_dev = rng.integers(1, 3)
        for _ in range(n_dev):
            dev_rows.append({
                "device_id": dev_id,
                "cust_ref": cid,
                "device_type": rng.choice(["IOS", "ANDROID", "WEB"]),
                "created_ts": base + pd.Timedelta(days=int(rng.integers(0, 20))),
                "last_seen_ts": base + pd.Timedelta(days=int(rng.integers(10, 60))),
            })
            dev_id += 1
    customer_devices = pd.DataFrame(dev_rows)

    # Payment methods
    pay_rows = []
    pay_id = 1
    for idx, row in customers.iterrows():
        cid = row["customer_id"]
        base = row["created_ts"]
        n_methods = rng.integers(1, 4)
        for _ in range(n_methods):
            mtype = rng.choice(["CARD", "WALLET", "BANK"])
            wallet_balance = (
                rng.uniform(0, 300) if mtype == "WALLET" else rng.uniform(20, 300)
            )
            pay_rows.append({
                "pay_method_id": pay_id,
                "cust_ref": cid,
                "method_type": mtype,
                "wallet_balance": round(float(wallet_balance), 2),
                "created_ts": base + pd.Timedelta(days=int(rng.integers(0, 10))),
                "updated_ts": base + pd.Timedelta(days=int(rng.integers(5, 30))),
            })
            pay_id += 1
    customer_payment_methods = pd.DataFrame(pay_rows)

    # Verification (KYC)
    ver_rows = []
    for idx, row in customers.iterrows():
        cid = row["customer_id"]
        base = row["created_ts"]
        submitted_ts = base + pd.Timedelta(days=int(rng.integers(0, 10)))
        verified_ts = submitted_ts + pd.Timedelta(days=int(rng.integers(1, 5)))
        ver_rows.append({
            "kyc_id": cid,  # one row per customer
            "cust_ref": cid,
            "status": "VERIFIED",
            "submitted_ts": submitted_ts,
            "verified_ts": verified_ts,
            "updated_ts": verified_ts + pd.Timedelta(days=1),
        })
    customer_verification = pd.DataFrame(ver_rows)

    # Loyalty
    loy_rows = []
    for idx, row in customers.iterrows():
        cid = row["customer_id"]
        base = row["created_ts"]
        loy_rows.append({
            "loyalty_id": cid,
            "cust_ref": cid,
            "tier": rng.choice(["BRONZE", "SILVER", "GOLD", "PLATINUM"]),
            "effective_ts": base + pd.Timedelta(days=int(rng.integers(5, 40))),
            "updated_ts": base + pd.Timedelta(days=int(rng.integers(30, 80))),
        })
    customer_loyalty_status = pd.DataFrame(loy_rows)

    return (
        customers,
        customer_addresses,
        customer_devices,
        customer_payment_methods,
        customer_verification,
        customer_loyalty_status,
    )


# ================================================================
# 2. MERCHANT / MENU TABLES
# ================================================================

def generate_merchants_and_menu_medium(
    n_merchants: int = 300,
) -> Tuple[
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame,
]:
    merch_ids = np.arange(1, n_merchants + 1)
    base_dates = pd.Timestamp("2024-01-01") + pd.to_timedelta(
        rng.integers(0, 90, size=n_merchants), unit="D"
    )

    merchants = pd.DataFrame({
        "merchant_id": merch_ids,
        "name": [f"Merchant_{i}" for i in merch_ids],
        "category": rng.choice(["RESTAURANT", "GROCERY", "LIQUOR"], size=n_merchants),
        "created_ts": base_dates,
        "updated_ts": base_dates + pd.to_timedelta(
            rng.integers(5, 40, size=n_merchants), unit="D"
        ),
    })

    # Locations
    loc_rows = []
    loc_id = 1
    for idx, row in merchants.iterrows():
        mid = row["merchant_id"]
        base = row["created_ts"]
        n_locs = rng.integers(1, 4)
        for _ in range(n_locs):
            loc_rows.append({
                "location_id": loc_id,
                "merchant_ref": mid,
                "address_line": f"MerchantLoc_{loc_id}",
                "region": rng.choice(["NORTH", "SOUTH", "EAST", "WEST"]),
                "created_ts": base + pd.Timedelta(days=int(rng.integers(0, 20))),
                "updated_ts": base + pd.Timedelta(days=int(rng.integers(10, 60))),
            })
            loc_id += 1
    merchant_locations = pd.DataFrame(loc_rows)

    # Business docs
    doc_rows = []
    doc_id = 1
    for idx, row in merchants.iterrows():
        mid = row["merchant_id"]
        base = row["created_ts"]
        submitted = base + pd.Timedelta(days=int(rng.integers(0, 15)))
        verified = submitted + pd.Timedelta(days=int(rng.integers(1, 10)))
        doc_rows.append({
            "doc_id": doc_id,
            "merchant_ref": mid,
            "submitted_ts": submitted,
            "verified_ts": verified,
            "updated_ts": verified + pd.Timedelta(days=1),
        })
        doc_id += 1
    merchant_business_docs = pd.DataFrame(doc_rows)

    # Bank accounts
    bank_rows = []
    bank_id = 1
    for idx, row in merchants.iterrows():
        mid = row["merchant_id"]
        base = row["created_ts"]
        bank_rows.append({
            "bank_id": bank_id,
            "merchant_ref": mid,
            "account_number": f"ACCT_{bank_id:05d}",
            "linked_ts": base + pd.Timedelta(days=int(rng.integers(1, 20))),
            "updated_ts": base + pd.Timedelta(days=int(rng.integers(20, 60))),
        })
        bank_id += 1
    merchant_bank_accounts = pd.DataFrame(bank_rows)

    # Operational status
    op_rows = []
    op_id = 1
    for idx, row in merchants.iterrows():
        mid = row["merchant_id"]
        base = row["created_ts"]
        op_rows.append({
            "status_id": op_id,
            "merchant_ref": mid,
            "status": rng.choice(["ACTIVE", "INACTIVE"], p=[0.9, 0.1]),
            "effective_ts": base + pd.Timedelta(days=int(rng.integers(1, 10))),
            "batch_run_ts": base + pd.Timedelta(days=int(rng.integers(20, 80))),
        })
        op_id += 1
    merchant_operational_status = pd.DataFrame(op_rows)

    # Menu items + options + inventory + updates + inventory reservations
    menu_rows = []
    opt_rows = []
    inv_rows = []
    adj_rows = []
    upd_rows = []
    inv_reserve_rows = []

    item_id = 1
    opt_id = 1
    inv_id = 1
    adj_id = 1
    upd_id = 1
    inv_res_id = 1

    categories = ["FOOD", "DRINK", "DESSERT", "GROCERY"]

    for _, loc_row in merchant_locations.iterrows():
        base = loc_row["created_ts"]
        loc = loc_row["location_id"]
        n_items = rng.integers(5, 15)

        for _ in range(n_items):
            cat = rng.choice(categories)
            price = float(rng.uniform(5, 60))

            menu_rows.append({
                "item_id": item_id,
                "location_ref": loc,
                "name": f"Item_{item_id}",
                "category": cat,
                "base_price": round(price, 2),
                "created_ts": base + pd.Timedelta(days=int(rng.integers(0, 10))),
                "updated_ts": base + pd.Timedelta(days=int(rng.integers(10, 40))),
            })

            # Options
            n_opts = rng.integers(0, 4)
            for _ in range(n_opts):
                opt_rows.append({
                    "option_id": opt_id,
                    "menu_ref": item_id,
                    "option_name": rng.choice(["LARGE", "SPICY", "ADD_ON", "SIZE_UP"]),
                    "extra_price": round(float(rng.uniform(0.5, 8)), 2),
                    "created_ts": base + pd.Timedelta(days=int(rng.integers(0, 10))),
                    "updated_ts": base + pd.Timedelta(days=int(rng.integers(10, 40))),
                })
                opt_id += 1

            # Inventory snapshot
            inv_rows.append({
                "inventory_id": inv_id,
                "menu_ref": item_id,
                "stock_level": int(rng.integers(0, 500)),
                "snapshot_ts": base + pd.Timedelta(days=int(rng.integers(0, 5))),
                "updated_ts": base + pd.Timedelta(days=int(rng.integers(5, 30))),
            })
            inv_id += 1

            # Inventory adjustment
            adj_rows.append({
                "adjustment_id": adj_id,
                "menu_ref": item_id,
                "delta": int(rng.integers(-10, 10)),
                "created_ts": base + pd.Timedelta(days=int(rng.integers(5, 25))),
                "batch_run_ts": base + pd.Timedelta(days=int(rng.integers(25, 60))),
            })
            adj_id += 1

            # Menu updates
            upd_rows.append({
                "update_id": upd_id,
                "menu_ref": item_id,
                "update_type": rng.choice(["PRICE_CHANGE", "DESC_CHANGE", "IMAGE_CHANGE"]),
                "created_ts": base + pd.Timedelta(days=int(rng.integers(2, 30))),
                "applied_ts": base + pd.Timedelta(days=int(rng.integers(5, 40))),
            })
            upd_id += 1

            # Inventory reservations (pseudo-dim)
            inv_reserve_rows.append({
                "reservation_id": inv_res_id,
                "menu_ref": item_id,
                "reserved_qty": int(rng.integers(0, 20)),
                "reserved_ts": base + pd.Timedelta(days=int(rng.integers(1, 20))),
            })
            inv_res_id += 1

            item_id += 1

    menu_items = pd.DataFrame(menu_rows)
    menu_item_options = pd.DataFrame(opt_rows)
    inventory = pd.DataFrame(inv_rows)
    inventory_adjustments = pd.DataFrame(adj_rows)
    merchant_menu_updates = pd.DataFrame(upd_rows)
    inventory_reservations = pd.DataFrame(inv_reserve_rows)

    return (
        merchants,
        merchant_locations,
        merchant_business_docs,
        merchant_bank_accounts,
        merchant_operational_status,
        menu_items,
        menu_item_options,
        inventory_reservations,
        inventory_adjustments,
        merchant_menu_updates,
    )


# ================================================================
# 3. DRIVER / LOGISTICS TABLES
# ================================================================

def generate_drivers_and_logistics_medium(
    n_drivers: int = 1000,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    driver_ids = np.arange(1, n_drivers + 1)
    base_dates = pd.Timestamp("2024-02-01") + pd.to_timedelta(
        rng.integers(0, 90, size=n_drivers), unit="D"
    )

    drivers = pd.DataFrame({
        "driver_id": driver_ids,
        "name": [f"Driver_{i}" for i in driver_ids],
        "rating": rng.uniform(2.5, 5.0, size=n_drivers).round(2),
        "region": rng.choice(["NORTH", "SOUTH", "EAST", "WEST"], size=n_drivers),
        "created_ts": base_dates,
        "updated_ts": base_dates + pd.to_timedelta(
            rng.integers(5, 60, size=n_drivers), unit="D"
        ),
    })

    # Shift status
    shift_rows = []
    shift_id = 1
    for idx, row in drivers.iterrows():
        did = row["driver_id"]
        base = row["created_ts"]
        shift_rows.append({
            "shift_id": shift_id,
            "driver_ref": did,
            "shift_start_ts": base + pd.Timedelta(days=int(rng.integers(0, 10))),
            "shift_end_ts": base + pd.Timedelta(days=int(rng.integers(10, 20))),
        })
        shift_id += 1
    driver_shift_status = pd.DataFrame(shift_rows)

    # Location updates
    loc_rows = []
    loc_id = 1
    for idx, row in drivers.iterrows():
        did = row["driver_id"]
        base = row["created_ts"]
        n_ping = rng.integers(1, 5)
        for _ in range(n_ping):
            loc_rows.append({
                "location_update_id": loc_id,
                "driver_ref": did,
                "lat": rng.uniform(-90, 90),
                "lon": rng.uniform(-180, 180),
                "ping_ts": base + pd.Timedelta(days=int(rng.integers(0, 30))),
            })
            loc_id += 1
    driver_location_updates = pd.DataFrame(loc_rows)

    return drivers, driver_shift_status, driver_location_updates


# ================================================================
# 4. MAIN ORDER + EVENT GENERATION
# ================================================================

def generate_orders_and_events_medium(
    n_traces: int,
    customers: pd.DataFrame,
    customer_addresses: pd.DataFrame,
    customer_payment_methods: pd.DataFrame,
    merchants: pd.DataFrame,
    merchant_locations: pd.DataFrame,
    menu_items: pd.DataFrame,
    drivers: pd.DataFrame,
) -> Tuple[Dict[str, pd.DataFrame], pd.DataFrame]:

    # --- Prepare lookups ---
    addr_by_cust = customer_addresses.groupby("cust_ref")["address_id"].apply(list).to_dict()
    pay_by_cust = customer_payment_methods.groupby("cust_ref")["pay_method_id"].apply(list).to_dict()
    loc_ids = merchant_locations["location_id"].tolist()
    driver_ids = drivers["driver_id"].tolist()

    customers_idx = customers.set_index("customer_id")
    payment_idx = customer_payment_methods.set_index("pay_method_id")
    menu_by_loc = menu_items.groupby("location_ref")

    # --- Collectors for ~50+ tables ---
    orders = []
    order_items = []
    order_discounts = []
    order_status_events = []
    order_tax_calculation = []
    order_promo_validation = []
    order_cart_snapshots = []
    order_experiments = []

    risk_requests = []
    risk_scores = []
    fraud_checks = []
    manual_review_cases = []

    kitchen_tickets = []
    kitchen_events = []

    dispatch_requests = []
    driver_assignments = []
    delivery_status_events = []
    delivery_package_scans = []
    handover_confirmations = []

    delivery_exceptions = []
    reassignment_requests = []
    fallback_delivery_methods = []

    payment_authorizations = []
    payment_captures = []
    payment_settlements = []
    refund_requests = []
    refund_approvals = []
    refund_settlements = []

    support_tickets = []
    support_ticket_messages = []
    support_issue_resolutions = []

    customer_reviews = []
    customer_rewards = []
    system_audit_logs = []
    user_action_logs = []

    route_plans = []
    route_segments = []
    traffic_incidents = []

    # ID counters
    order_id = 1
    order_status_id = 1
    tax_id = 1
    promo_id = 1
    cart_id = 1
    exp_id = 1
    order_item_id = 1
    discount_id = 1

    risk_req_id = 1
    risk_score_id = 1
    fraud_id = 1
    manual_review_id = 1

    kitchen_ticket_id = 1
    kitchen_event_id = 1

    dispatch_req_id = 1
    driver_assign_id = 1
    delivery_evt_id = 1
    pkg_scan_id = 1
    handover_id = 1

    del_ex_id = 1
    reassignment_id = 1
    fallback_id = 1

    auth_id = 1
    capture_id = 1
    settle_id = 1
    refund_req_id = 1
    refund_app_id = 1
    refund_settle_id = 1

    ticket_id = 1
    ticket_msg_id = 1
    issue_res_id = 1

    review_id = 1
    reward_id = 1
    audit_id = 1
    user_action_id = 1

    route_plan_id = 1
    segment_id = 1
    incident_id = 1

    # Event traces
    event_trace_rows = []

    for _ in range(n_traces):
        # --- Choose base entities ---
        cust_id = int(rng.choice(customers["customer_id"].to_numpy()))
        cust_row = customers_idx.loc[cust_id]

        addr_id = int(rng.choice(addr_by_cust[cust_id]))
        pay_id = int(rng.choice(pay_by_cust[cust_id]))
        pay_row = payment_idx.loc[pay_id]

        loc_id = int(rng.choice(loc_ids))
        loc_menu = menu_by_loc.get_group(loc_id)
        menu_row = loc_menu.sample(1).iloc[0]

        driver_id = int(rng.choice(driver_ids))

        # Base timestamp for trace
        base_ts = pd.Timestamp("2024-03-01") + pd.to_timedelta(
            int(rng.integers(0, 60)), unit="D"
        )
        next_ts = make_ts_advancer(base_ts)

        events: List[str] = []
        join_path: List[str] = []
        key_uuid = str(uuid.uuid4())

        # --- CUSTOMER + MERCHANT EVENTS (fixed) ---

        # CustomerCreated
        events.append("CustomerCreated")
        join_path.append("customers.customer_id")
        ts_cust_created = next_ts(1)

        # AddressAdded
        events.append("AddressAdded")
        join_path.append("customer_addresses.address_id")
        ts_addr_added = next_ts(1)

        # DeviceRegistered (logged in user_action_logs)
        events.append("DeviceRegistered")
        join_path.append("user_action_logs.action_id")
        ts_device = next_ts(1)
        user_action_logs.append({
            "action_id": user_action_id,
            "cust_ref": cust_id,
            "action_type": "DEVICE_REGISTERED",
            "created_ts": ts_device,
        })
        user_action_id += 1

        # MerchantCreated (audit log)
        events.append("MerchantCreated")
        join_path.append("system_audit_logs.audit_id")
        ts_merch = next_ts(1)
        system_audit_logs.append({
            "audit_id": audit_id,
            "entity_type": "MERCHANT",
            "entity_id": loc_id,
            "action": "MERCHANT_USED_FOR_ORDER",
            "created_ts": ts_merch,
        })
        audit_id += 1

        # LocationActivated (audit log)
        events.append("LocationActivated")
        join_path.append("system_audit_logs.audit_id")
        ts_loc = next_ts(1)
        system_audit_logs.append({
            "audit_id": audit_id,
            "entity_type": "LOCATION",
            "entity_id": loc_id,
            "action": "LOCATION_ACTIVE",
            "created_ts": ts_loc,
        })
        audit_id += 1

        # --- ORDER CREATION SEGMENT ---

        # OrderCreated
        events.append("OrderCreated")
        join_path.append("orders.order_id")
        ts_order_created = next_ts(1)

        quantity = int(rng.integers(1, 6))
        item_price = float(menu_row["base_price"])
        subtotal = quantity * item_price

        orders.append({
            "order_id": order_id,
            "cust_ref": cust_id,
            "location_ref": loc_id,
            "pay_method_ref": pay_id,
            "primary_address_ref": addr_id,
            "order_amount": subtotal,
            "created_ts": ts_order_created,
            "updated_ts": ts_order_created,
        })

        # CartSnapshotSaved
        events.append("CartSnapshotSaved")
        join_path.append("order_cart_snapshots.snapshot_id")
        ts_cart = next_ts(1)
        order_cart_snapshots.append({
            "snapshot_id": cart_id,
            "order_ref": order_id,
            "created_ts": ts_cart,
        })
        cart_id += 1

        # ExperimentBucketAssigned
        events.append("ExperimentBucketAssigned")
        join_path.append("order_experiments.experiment_id")
        ts_exp = next_ts(1)
        bucket = rng.choice(["A", "B", "C"])
        order_experiments.append({
            "experiment_id": exp_id,
            "order_ref": order_id,
            "bucket": bucket,
            "assigned_ts": ts_exp,
        })
        exp_id += 1

        # PromoValidated
        events.append("PromoValidated")
        join_path.append("order_promo_validation.promo_id")
        ts_promo = next_ts(1)
        discount = float(rng.uniform(0, 10))
        order_promo_validation.append({
            "promo_id": promo_id,
            "order_ref": order_id,
            "valid": True,
            "discount_amount": discount,
            "validated_ts": ts_promo,
        })
        promo_id += 1

        if discount > 0:
            order_discounts.append({
                "discount_id": discount_id,
                "order_ref": order_id,
                "discount_amount": discount,
                "created_ts": ts_promo,
            })
            discount_id += 1

        # TaxCalculated
        events.append("TaxCalculated")
        join_path.append("order_tax_calculation.tax_id")
        ts_tax = next_ts(1)
        tax_amount = round(subtotal * 0.08, 2)
        order_tax_calculation.append({
            "tax_id": tax_id,
            "order_ref": order_id,
            "tax_amount": tax_amount,
            "calculated_ts": ts_tax,
        })
        tax_id += 1

        # OrderConfirmed
        events.append("OrderConfirmed")
        join_path.append("order_status_events.status_event_id")
        ts_conf = next_ts(1)
        order_status_events.append({
            "status_event_id": order_status_id,
            "order_ref": order_id,
            "status": "CONFIRMED",
            "event_ts": ts_conf,
        })
        order_status_id += 1

        # Order items
        order_items.append({
            "order_item_id": order_item_id,
            "order_ref": order_id,
            "menu_ref": int(menu_row["item_id"]),
            "quantity": quantity,
            "unit_price": item_price,
            "created_ts": ts_order_created,
        })
        order_item_id += 1

        # --- RISK SEGMENT (Branch 1) ---

        # RiskCheckStarted
        events.append("RiskCheckStarted")
        join_path.append("risk_requests.risk_request_id")
        ts_risk_start = next_ts(1)
        risk_requests.append({
            "risk_request_id": risk_req_id,
            "order_ref": order_id,
            "created_ts": ts_risk_start,
        })

        # RiskScoreCalculated
        events.append("RiskScoreCalculated")
        join_path.append("risk_scores.risk_score_id")
        ts_risk_score = next_ts(1)

        # Deterministic risk_level based on customer + wallet + amount
        wallet_balance = float(pay_row["wallet_balance"])
        outstanding = float(cust_row["outstanding_amount"])
        credit = int(cust_row["credit_rating"])

        raw_score = (
            (subtotal / 100.0) +
            (outstanding / 200.0) -
            (wallet_balance / 300.0) +
            (6 - credit) * 0.5
        )
        if raw_score < 1.0:
            risk_level = "LOW"
        elif raw_score < 2.0:
            risk_level = "MEDIUM"
        else:
            risk_level = "HIGH"

        risk_scores.append({
            "risk_score_id": risk_score_id,
            "risk_request_ref": risk_req_id,
            "risk_level": risk_level,
            "score_value": raw_score,
            "calculated_ts": ts_risk_score,
        })

        # FraudCheckPerformed
        events.append("FraudCheckPerformed")
        join_path.append("fraud_checks.fraud_check_id")
        ts_fraud = next_ts(1)
        fraud_checks.append({
            "fraud_check_id": fraud_id,
            "risk_request_ref": risk_req_id,
            "rule_engine_version": "v1.0",
            "check_ts": ts_fraud,
        })
        fraud_id += 1

        # RiskCheckCompleted
        events.append("RiskCheckCompleted")
        join_path.append("risk_scores.risk_score_id")
        ts_risk_done = next_ts(1)

        # Branch 1: Auto vs Manual
        if risk_level in ["LOW", "MEDIUM"]:
            # AutoApprovalGranted
            events.append("AutoApprovalGranted")
            join_path.append("risk_scores.risk_score_id")
            ts_auto = next_ts(1)
        else:
            # ManualReviewRequired
            events.append("ManualReviewRequired")
            join_path.append("manual_review_cases.review_id")
            ts_mr_req = next_ts(1)
            manual_review_cases.append({
                "review_id": manual_review_id,
                "risk_score_ref": risk_score_id,
                "created_ts": ts_mr_req,
                "status": "REQUIRED",
            })

            # ManualReviewCompleted (always approve in this clean version)
            events.append("ManualReviewCompleted")
            join_path.append("manual_review_cases.review_id")
            ts_mr_done = next_ts(3)
            manual_review_cases[-1]["completed_ts"] = ts_mr_done
            manual_review_cases[-1]["status"] = "APPROVED"
            manual_review_id += 1

        risk_req_id += 1
        risk_score_id += 1

        # --- KITCHEN / DISPATCH / DELIVERY ---

        # KitchenTicketCreated
        events.append("KitchenTicketCreated")
        join_path.append("kitchen_tickets.ticket_id")
        ts_kt = next_ts(1)
        kitchen_tickets.append({
            "ticket_id": kitchen_ticket_id,
            "order_ref": order_id,
            "created_ts": ts_kt,
        })

        # KitchenCookingStarted
        events.append("KitchenCookingStarted")
        join_path.append("kitchen_events.kitchen_event_id")
        ts_kstart = next_ts(5)
        kitchen_events.append({
            "kitchen_event_id": kitchen_event_id,
            "ticket_ref": kitchen_ticket_id,
            "event_type": "COOKING_STARTED",
            "event_ts": ts_kstart,
        })
        kitchen_event_id += 1

        # KitchenCookingFinished
        events.append("KitchenCookingFinished")
        join_path.append("kitchen_events.kitchen_event_id")
        ts_kdone = next_ts(15)
        kitchen_events.append({
            "kitchen_event_id": kitchen_event_id,
            "ticket_ref": kitchen_ticket_id,
            "event_type": "COOKING_FINISHED",
            "event_ts": ts_kdone,
        })
        kitchen_event_id += 1
        kitchen_ticket_id += 1

        # DispatchRequested
        events.append("DispatchRequested")
        join_path.append("dispatch_requests.dispatch_request_id")
        ts_disp = next_ts(1)
        dispatch_requests.append({
            "dispatch_request_id": dispatch_req_id,
            "order_ref": order_id,
            "created_ts": ts_disp,
        })

        # DriverShiftOnline (simplified log via assignment)
        events.append("DriverShiftOnline")
        join_path.append("driver_assignments.assignment_id")
        ts_shift = next_ts(1)

        # DriverAssigned
        events.append("DriverAssigned")
        join_path.append("driver_assignments.assignment_id")
        ts_assigned = next_ts(1)
        driver_assignments.append({
            "assignment_id": driver_assign_id,
            "dispatch_ref": dispatch_req_id,
            "driver_ref": driver_id,
            "assigned_ts": ts_assigned,
            "updated_ts": ts_assigned,
        })
        dispatch_req_id += 1

        # RoutePlanned (create route_plan)
        events.append("RoutePlanned")
        join_path.append("route_plans.route_plan_id")
        ts_route = next_ts(2)
        route_plans.append({
            "route_plan_id": route_plan_id,
            "order_ref": order_id,
            "driver_ref": driver_id,
            "created_ts": ts_route,
            "updated_ts": ts_route,
        })

        # Delivery status event for route planned
        delivery_status_events.append({
            "delivery_event_id": delivery_evt_id,
            "assignment_ref": driver_assign_id,
            "status": "ROUTE_PLANNED",
            "event_ts": ts_route,
        })
        delivery_evt_id += 1

        # SegmentStarted (segment #1)
        events.append("SegmentStarted")
        join_path.append("delivery_status_events.delivery_event_id")
        ts_seg_start = next_ts(2)
        delivery_status_events.append({
            "delivery_event_id": delivery_evt_id,
            "assignment_ref": driver_assign_id,
            "status": "SEGMENT_STARTED",
            "event_ts": ts_seg_start,
        })
        delivery_evt_id += 1

        # SegmentCompleted (segment #1)
        events.append("SegmentCompleted")
        join_path.append("delivery_status_events.delivery_event_id")
        ts_seg_done = next_ts(5)
        delivery_status_events.append({
            "delivery_event_id": delivery_evt_id,
            "assignment_ref": driver_assign_id,
            "status": "SEGMENT_COMPLETED",
            "event_ts": ts_seg_done,
        })
        delivery_evt_id += 1

        # Register route segment in route_segments table
        route_segments.append({
            "segment_id": segment_id,
            "route_plan_ref": route_plan_id,
            "sequence_no": 1,
            "start_ts": ts_seg_start,
            "end_ts": ts_seg_done,
        })
        segment_id += 1

        # Delivered
        events.append("Delivered")
        join_path.append("handover_confirmations.handover_id")
        ts_delivered = next_ts(5)
        handover_confirmations.append({
            "handover_id": handover_id,
            "assignment_ref": driver_assign_id,
            "confirmed_ts": ts_delivered,
        })
        handover_id += 1

        # Package scan log
        delivery_package_scans.append({
            "package_scan_id": pkg_scan_id,
            "assignment_ref": driver_assign_id,
            "scan_ts": ts_delivered,
        })
        pkg_scan_id += 1

        driver_assign_id += 1
        route_plan_id += 1

        # --- DELIVERY EXCEPTION BRANCH (Branch 2) ---

        driver_rating = float(drivers.loc[drivers["driver_id"] == driver_id, "rating"].iloc[0])

        # Simple route risk score
        route_risk_score = 0.0
        if risk_level == "HIGH":
            route_risk_score += 1.0
        if driver_rating < 3.5:
            route_risk_score += 0.8
        if cust_row["region"] == "NORTH":
            route_risk_score += 0.4

        # Exception occurs if route_risk_score >= 1.2
        if route_risk_score >= 1.2:
            events.append("DeliveryExceptionOccurred")
            join_path.append("delivery_exceptions.exception_id")
            ts_exc = next_ts(1)

            # Determine exception_type deterministically
            if driver_rating < 3.0:
                exception_type = "REASSIGN_DRIVER"
            elif cust_row["region"] == "NORTH":
                exception_type = "REROUTE"
            elif route_risk_score > 1.8:
                exception_type = "CANCEL"
            else:
                exception_type = "DELAY"

            delivery_exceptions.append({
                "exception_id": del_ex_id,
                "assignment_ref": driver_assign_id - 1,
                "exception_type": exception_type,
                "exception_ts": ts_exc,
            })

            # For REROUTE/DELAY, we can log traffic_incidents
            if exception_type in ["REROUTE", "DELAY"]:
                traffic_incidents.append({
                    "incident_id": incident_id,
                    "route_plan_ref": route_plan_id - 1,
                    "incident_ts": ts_exc,
                    "severity": rng.choice(["LOW", "MEDIUM", "HIGH"]),
                })
                incident_id += 1

            # Branch 2: 4 routes
            if exception_type == "REROUTE":
                events.append("ReroutePlanned")
                join_path.append("fallback_delivery_methods.fallback_id")
                ts_reroute = next_ts(2)
                fallback_delivery_methods.append({
                    "fallback_id": fallback_id,
                    "assignment_ref": driver_assign_id - 1,
                    "method": "ALT_ROUTE",
                    "created_ts": ts_reroute,
                })
                fallback_id += 1

                # New segment #2
                events.append("SegmentStarted")
                join_path.append("delivery_status_events.delivery_event_id")
                ts_seg2_start = next_ts(2)
                delivery_status_events.append({
                    "delivery_event_id": delivery_evt_id,
                    "assignment_ref": driver_assign_id - 1,
                    "status": "SEGMENT2_STARTED",
                    "event_ts": ts_seg2_start,
                })
                delivery_evt_id += 1

                events.append("SegmentCompleted")
                join_path.append("delivery_status_events.delivery_event_id")
                ts_seg2_done = next_ts(5)
                delivery_status_events.append({
                    "delivery_event_id": delivery_evt_id,
                    "assignment_ref": driver_assign_id - 1,
                    "status": "SEGMENT2_COMPLETED",
                    "event_ts": ts_seg2_done,
                })
                delivery_evt_id += 1

                # Register route segment #2
                route_segments.append({
                    "segment_id": segment_id,
                    "route_plan_ref": route_plan_id - 1,
                    "sequence_no": 2,
                    "start_ts": ts_seg2_start,
                    "end_ts": ts_seg2_done,
                })
                segment_id += 1

            elif exception_type == "DELAY":
                events.append("DelayLogged")
                join_path.append("delivery_exceptions.exception_id")
                ts_delay = next_ts(5)

            elif exception_type == "REASSIGN_DRIVER":
                events.append("DriverReassigned")
                join_path.append("reassignment_requests.reassignment_id")
                ts_reassign = next_ts(3)
                new_driver = int(rng.choice(driver_ids))
                reassignment_requests.append({
                    "reassignment_id": reassignment_id,
                    "old_assignment_ref": driver_assign_id - 1,
                    "new_driver_ref": new_driver,
                    "created_ts": ts_reassign,
                })
                reassignment_id += 1

                # Continue with new route (secondary route plan)
                events.append("RoutePlanned")
                join_path.append("route_plans.route_plan_id")
                ts_route2 = next_ts(3)
                route_plans.append({
                    "route_plan_id": route_plan_id,
                    "order_ref": order_id,
                    "driver_ref": new_driver,
                    "created_ts": ts_route2,
                    "updated_ts": ts_route2,
                })

                delivery_status_events.append({
                    "delivery_event_id": delivery_evt_id,
                    "assignment_ref": driver_assign_id - 1,
                    "status": "ROUTE_REPLANNED",
                    "event_ts": ts_route2,
                })
                delivery_evt_id += 1

                route_plan_id += 1

            elif exception_type == "CANCEL":
                events.append("OrderCancelled")
                join_path.append("order_status_events.status_event_id")
                ts_cancel = next_ts(1)
                order_status_events.append({
                    "status_event_id": order_status_id,
                    "order_ref": order_id,
                    "status": "CANCELLED_AFTER_DELIVERY_ISSUE",
                    "event_ts": ts_cancel,
                })
                order_status_id += 1

                # END TRACE HERE
                event_trace_rows.append({
                    "Key_Selector": "Order_ID",
                    "Key_ID": key_uuid,
                    "Event_Trace": str(events),
                    "Join_Path": str(join_path),
                })
                del_ex_id += 1
                order_id += 1
                continue

            del_ex_id += 1

        # --- PAYMENT / REFUND / SUPPORT / REVIEW ---

        # PaymentAuthorized
        events.append("PaymentAuthorized")
        join_path.append("payment_authorizations.auth_id")
        ts_auth = next_ts(30)
        total_amount = subtotal + tax_amount - discount
        payment_authorizations.append({
            "auth_id": auth_id,
            "order_ref": order_id,
            "amount": total_amount,
            "created_ts": ts_auth,
        })

        # PaymentCaptured
        events.append("PaymentCaptured")
        join_path.append("payment_captures.capture_id")
        ts_cap = next_ts(60)
        payment_captures.append({
            "capture_id": capture_id,
            "auth_ref": auth_id,
            "captured_amount": total_amount,
            "captured_ts": ts_cap,
        })

        # PaymentSettled
        events.append("PaymentSettled")
        join_path.append("payment_settlements.settlement_id")
        ts_settle = next_ts(24 * 60)  # +1 day
        payment_settlements.append({
            "settlement_id": settle_id,
            "capture_ref": capture_id,
            "settled_amount": total_amount,
            "settled_ts": ts_settle,
        })

        # Occasionally refund
        if total_amount > 40 and rng.random() < 0.2:
            events.append("RefundRequested")
            join_path.append("refund_requests.refund_request_id")
            ts_rreq = next_ts(10)
            refund_requests.append({
                "refund_request_id": refund_req_id,
                "order_ref": order_id,
                "requested_amount": total_amount * 0.5,
                "requested_ts": ts_rreq,
            })

            events.append("RefundApproved")
            join_path.append("refund_approvals.refund_approval_id")
            ts_rappr = next_ts(30)
            refund_approvals.append({
                "refund_approval_id": refund_app_id,
                "refund_request_ref": refund_req_id,
                "approved_ts": ts_rappr,
            })

            events.append("RefundSettled")
            join_path.append("refund_settlements.refund_settlement_id")
            ts_rsettle = next_ts(24 * 60)
            refund_settlements.append({
                "refund_settlement_id": refund_settle_id,
                "refund_approval_ref": refund_app_id,
                "settled_ts": ts_rsettle,
            })

            refund_req_id += 1
            refund_app_id += 1
            refund_settle_id += 1

        auth_id += 1
        capture_id += 1
        settle_id += 1

        # Support path (sometimes)
        if rng.random() < 0.3:
            events.append("SupportTicketOpened")
            join_path.append("support_tickets.ticket_id")
            ts_topn = next_ts(60)
            support_tickets.append({
                "ticket_id": ticket_id,
                "order_ref": order_id,
                "opened_ts": ts_topn,
            })

            events.append("SupportMessageAdded")
            join_path.append("support_ticket_messages.message_id")
            ts_msg = next_ts(10)
            support_ticket_messages.append({
                "message_id": ticket_msg_id,
                "ticket_ref": ticket_id,
                "created_ts": ts_msg,
            })
            ticket_msg_id += 1

            events.append("SupportIssueResolved")
            join_path.append("support_issue_resolutions.resolution_id")
            ts_res = next_ts(50)
            support_issue_resolutions.append({
                "resolution_id": issue_res_id,
                "ticket_ref": ticket_id,
                "resolved_ts": ts_res,
            })
            issue_res_id += 1
            ticket_id += 1

        # ReviewSubmitted
        events.append("ReviewSubmitted")
        join_path.append("customer_reviews.review_id")
        ts_review = next_ts(60)
        rating = int(rng.integers(1, 6))
        customer_reviews.append({
            "review_id": review_id,
            "order_ref": order_id,
            "rating": rating,
            "created_ts": ts_review,
        })
        review_id += 1

        # RewardCredited
        events.append("RewardCredited")
        join_path.append("customer_rewards.reward_id")
        ts_reward = next_ts(5)
        customer_rewards.append({
            "reward_id": reward_id,
            "cust_ref": cust_id,
            "points": max(1, rating) * 10,
            "credited_ts": ts_reward,
        })
        reward_id += 1

        # AuditLogWritten
        events.append("AuditLogWritten")
        join_path.append("system_audit_logs.audit_id")
        ts_audit = next_ts(1)
        system_audit_logs.append({
            "audit_id": audit_id,
            "entity_type": "ORDER",
            "entity_id": order_id,
            "action": "ORDER_COMPLETED",
            "created_ts": ts_audit,
        })
        audit_id += 1

        # --- Final trace row ---
        event_trace_rows.append({
            "Key_Selector": "Order_ID",
            "Key_ID": key_uuid,
            "Event_Trace": str(events),
            "Join_Path": str(join_path),
        })

        order_id += 1

    # Build all tables dict (57 total)
    tables: Dict[str, pd.DataFrame] = {
        "orders": pd.DataFrame(orders),
        "order_items": pd.DataFrame(order_items),
        "order_discounts": pd.DataFrame(order_discounts),
        "order_status_events": pd.DataFrame(order_status_events),
        "order_tax_calculation": pd.DataFrame(order_tax_calculation),
        "order_promo_validation": pd.DataFrame(order_promo_validation),
        "order_cart_snapshots": pd.DataFrame(order_cart_snapshots),
        "order_experiments": pd.DataFrame(order_experiments),
        "risk_requests": pd.DataFrame(risk_requests),
        "risk_scores": pd.DataFrame(risk_scores),
        "fraud_checks": pd.DataFrame(fraud_checks),
        "manual_review_cases": pd.DataFrame(manual_review_cases),
        "kitchen_tickets": pd.DataFrame(kitchen_tickets),
        "kitchen_events": pd.DataFrame(kitchen_events),
        "dispatch_requests": pd.DataFrame(dispatch_requests),
        "driver_assignments": pd.DataFrame(driver_assignments),
        "delivery_status_events": pd.DataFrame(delivery_status_events),
        "delivery_package_scans": pd.DataFrame(delivery_package_scans),
        "handover_confirmations": pd.DataFrame(handover_confirmations),
        "delivery_exceptions": pd.DataFrame(delivery_exceptions),
        "reassignment_requests": pd.DataFrame(reassignment_requests),
        "fallback_delivery_methods": pd.DataFrame(fallback_delivery_methods),
        "payment_authorizations": pd.DataFrame(payment_authorizations),
        "payment_captures": pd.DataFrame(payment_captures),
        "payment_settlements": pd.DataFrame(payment_settlements),
        "refund_requests": pd.DataFrame(refund_requests),
        "refund_approvals": pd.DataFrame(refund_approvals),
        "refund_settlements": pd.DataFrame(refund_settlements),
        "support_tickets": pd.DataFrame(support_tickets),
        "support_ticket_messages": pd.DataFrame(support_ticket_messages),
        "support_issue_resolutions": pd.DataFrame(support_issue_resolutions),
        "customer_reviews": pd.DataFrame(customer_reviews),
        "customer_rewards": pd.DataFrame(customer_rewards),
        "system_audit_logs": pd.DataFrame(system_audit_logs),
        "user_action_logs": pd.DataFrame(user_action_logs),
        "route_plans": pd.DataFrame(route_plans),
        "route_segments": pd.DataFrame(route_segments),
        "traffic_incidents": pd.DataFrame(traffic_incidents),
    }

    event_traces_df = pd.DataFrame(event_trace_rows)

    return tables, event_traces_df


# ================================================================
# 5. FULL MEDIUM DATASET GENERATOR
# ================================================================

def generate_food_delivery_medium(
    n_traces: int = MEDIUM_N_TRACES,
    data_dir: str = DATA_DIR_MEDIUM,
) -> None:
    """
    Generate the medium-scale dataset:
      - ~57 tables in `data_dir`
      - medium_event_traces.csv
      - medium_transition_graph.csv
    """
    os.makedirs(data_dir, exist_ok=True)

    # Customers and related tables
    (
        customers,
        customer_addresses,
        customer_devices,
        customer_payment_methods,
        customer_verification,
        customer_loyalty_status,
    ) = generate_customers_medium()

    # Merchants, locations, menu, inventory, updates
    (
        merchants,
        merchant_locations,
        merchant_business_docs,
        merchant_bank_accounts,
        merchant_operational_status,
        menu_items,
        menu_item_options,
        inventory_reservations,
        inventory_adjustments,
        merchant_menu_updates,
    ) = generate_merchants_and_menu_medium()

    # Drivers
    (
        drivers,
        driver_shift_status,
        driver_location_updates,
    ) = generate_drivers_and_logistics_medium()

    # Orders + main fact tables + event traces
    fact_tables, event_traces = generate_orders_and_events_medium(
        n_traces=n_traces,
        customers=customers,
        customer_addresses=customer_addresses,
        customer_payment_methods=customer_payment_methods,
        merchants=merchants,
        merchant_locations=merchant_locations,
        menu_items=menu_items,
        drivers=drivers,
    )

    # Transition graph
    transition_graph = build_medium_transition_graph()

    # Save event traces & transition graph at root
    event_traces.to_csv("medium_event_traces.csv", index=False)
    transition_graph.to_csv("medium_transition_graph.csv", index=False)

    # Save dimension tables
    customers.to_csv(os.path.join(data_dir, "customers.csv"), index=False)
    customer_addresses.to_csv(os.path.join(data_dir, "customer_addresses.csv"), index=False)
    customer_devices.to_csv(os.path.join(data_dir, "customer_devices.csv"), index=False)
    customer_payment_methods.to_csv(
        os.path.join(data_dir, "customer_payment_methods.csv"), index=False
    )
    customer_verification.to_csv(
        os.path.join(data_dir, "customer_verification.csv"), index=False
    )
    customer_loyalty_status.to_csv(
        os.path.join(data_dir, "customer_loyalty_status.csv"), index=False
    )

    merchants.to_csv(os.path.join(data_dir, "merchants.csv"), index=False)
    merchant_locations.to_csv(os.path.join(data_dir, "merchant_locations.csv"), index=False)
    merchant_business_docs.to_csv(
        os.path.join(data_dir, "merchant_business_docs.csv"), index=False
    )
    merchant_bank_accounts.to_csv(
        os.path.join(data_dir, "merchant_bank_accounts.csv"), index=False
    )
    merchant_operational_status.to_csv(
        os.path.join(data_dir, "merchant_operational_status.csv"), index=False
    )
    menu_items.to_csv(os.path.join(data_dir, "menu_items.csv"), index=False)
    menu_item_options.to_csv(os.path.join(data_dir, "menu_item_options.csv"), index=False)
    inventory_reservations.to_csv(
        os.path.join(data_dir, "inventory_reservations.csv"), index=False
    )
    inventory_adjustments.to_csv(
        os.path.join(data_dir, "inventory_adjustments.csv"), index=False
    )
    merchant_menu_updates.to_csv(
        os.path.join(data_dir, "merchant_menu_updates.csv"), index=False
    )

    drivers.to_csv(os.path.join(data_dir, "drivers.csv"), index=False)
    driver_shift_status.to_csv(
        os.path.join(data_dir, "driver_shift_status.csv"), index=False
    )
    driver_location_updates.to_csv(
        os.path.join(data_dir, "driver_location_updates.csv"), index=False
    )

    # Save all fact/log tables (including route_plans/segments/incidents)
    for name, df in fact_tables.items():
        df.to_csv(os.path.join(data_dir, f"{name}.csv"), index=False)

    print(f"Generated {len(event_traces)} traces and wrote tables to '{data_dir}'.")


# ================================================================
# ENTRY POINT
# ================================================================

if __name__ == "__main__":
    generate_food_delivery_medium(50000, 'data')


Generated 50000 traces and wrote tables to 'data'.


In [4]:
"""
Noise Injector for Medium Dataset (57 tables)
---------------------------------------------

Adds:
  - 30% Schema Drift (Attribute Swap)
  - 15% Timestamp Missingness

Ensures:
  - PK/FK safety
  - No KeyErrors
  - No breakage of join paths
  - No event timestamp corruption
"""

import os
import random
import numpy as np
import pandas as pd

# ============================================================
# CONFIG
# ============================================================

ATTRIBUTE_SWAP_RATIO = 0.30
TIMESTAMP_MISSING_RATIO = 0.15
rng = np.random.default_rng(42)
random.seed(42)

# strings that imply timestamps
TS_KEYWORDS = ["ts", "time", "date"]

# PK / FK detection heuristics
PK_HINTS = ["id"]
FK_HINTS = ["_ref"]


# ============================================================
# LOAD / SAVE TABLES
# ============================================================

def load_tables(data_dir: str) -> dict:
    tables = {}
    for f in os.listdir(data_dir):
        if f.endswith(".csv"):
            name = f.replace(".csv", "")
            df = pd.read_csv(os.path.join(data_dir, f), parse_dates=True)
            tables[name] = df
    return tables


def save_tables(tables: dict, data_dir: str):
    for name, df in tables.items():
        df.to_csv(os.path.join(data_dir, f"{name}.csv"), index=False)


# ============================================================
# IDENTIFY SAFE ATTRIBUTES
# ============================================================

def is_timestamp_col(col: str) -> bool:
    col_low = col.lower()
    return any(tag in col_low for tag in TS_KEYWORDS)


def identify_timestamp_columns(df: pd.DataFrame) -> list:
    ts_cols = []
    for col in df.columns:
        if df[col].dtype == "datetime64[ns]":
            ts_cols.append(col)
        elif is_timestamp_col(col):
            ts_cols.append(col)
    return ts_cols


def identify_pk_fk_cols(df: pd.DataFrame) -> list:
    pkfk = []
    for col in df.columns:
        c = col.lower()
        if any(h in c for h in PK_HINTS):
            pkfk.append(col)
        if any(h in c for h in FK_HINTS):
            pkfk.append(col)
    return list(set(pkfk))


def detect_safe_columns(df: pd.DataFrame) -> list:
    """
    Safe means:
      - not PK/FK
      - not timestamp-like
      - not datetime dtype
      - not boolean
      - not extremely high cardinality (IDs)
    """
    safe = []
    pkfk = identify_pk_fk_cols(df)
    ts_cols = identify_timestamp_columns(df)

    for col in df.columns:
        if col in pkfk or col in ts_cols:
            continue
        if df[col].dtype == "datetime64[ns]":
            continue
        if df[col].dtype == bool:
            continue

        # high-cardinality numeric likely ID
        if df[col].dtype != object and df[col].nunique() > len(df) * 0.90:
            continue

        safe.append(col)

    return safe


# ============================================================
# ATTRIBUTE SWAP
# ============================================================

def apply_attribute_swap(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    safe_cols = detect_safe_columns(df)

    if len(safe_cols) == 0:
        return df

    n_drift = max(1, int(len(safe_cols) * ATTRIBUTE_SWAP_RATIO))
    cols_to_modify = random.sample(safe_cols, n_drift)

    # --- 1. Rename columns ---
    rename_map = {}
    for col in cols_to_modify:
        new_col = f"attr_{random.randint(100,999)}"
        rename_map[col] = new_col

    df.rename(columns=rename_map, inplace=True)

    renamed_cols = list(rename_map.values())

    # --- 2. Value Swap (if >=2)
    if len(renamed_cols) >= 2:
        c1, c2 = random.sample(renamed_cols, 2)
        df[c1], df[c2] = df[c2], df[c1]

    # --- 3. Mild type drift ---
    for col in renamed_cols:
        if df[col].dtype == object:
            df[col] = df[col].astype(str) + "_drift"
        else:
            df[col] = df[col].astype(float) * rng.uniform(0.7, 1.3)

    return df


def apply_attribute_swap_to_all_tables(data_dir: str):
    tables = load_tables(data_dir)
    drifted = {}

    for name, df in tables.items():
        # Protect event traces and transition graph
        if name in ["medium_event_traces", "medium_transition_graph"]:
            drifted[name] = df
            continue

        drifted[name] = apply_attribute_swap(df)

    save_tables(drifted, data_dir)


# ============================================================
# TIMESTAMP MISSINGNESS
# ============================================================

def apply_timestamp_missingness(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    ts_cols = identify_timestamp_columns(df)

    if len(ts_cols) == 0:
        return df

    n = len(df)
    n_missing = int(n * TIMESTAMP_MISSING_RATIO)
    if n_missing == 0:
        return df

    missing_idx = rng.choice(df.index, size=n_missing, replace=False)

    for col in ts_cols:
        df.loc[missing_idx, col] = pd.NaT

    return df


def apply_timestamp_missingness_to_all(data_dir: str):
    tables = load_tables(data_dir)
    noised = {}

    for name, df in tables.items():
        if name in ["medium_event_traces", "medium_transition_graph"]:
            noised[name] = df
            continue

        noised[name] = apply_timestamp_missingness(df)

    save_tables(noised, data_dir)


# ============================================================
# ENTRY POINT
# ============================================================

def inject_noise_medium(data_dir: str = "./data"):
    print("ðŸŒ€ Starting noise injection for MEDIUM dataset...")
    print("â†’ Applying 30% schema drift (safe attribute swap)...")
    apply_attribute_swap_to_all_tables(data_dir)

    print("â†’ Applying 15% timestamp missingness...")
    apply_timestamp_missingness_to_all(data_dir)

    print("âœ” Noise injection complete.")


if __name__ == "__main__":
    inject_noise_medium()


ðŸŒ€ Starting noise injection for MEDIUM dataset...
â†’ Applying 30% schema drift (safe attribute swap)...
â†’ Applying 15% timestamp missingness...


  df.loc[missing_idx, col] = pd.NaT
  df.loc[missing_idx, col] = pd.NaT
  df.loc[missing_idx, col] = pd.NaT
  df.loc[missing_idx, col] = pd.NaT


âœ” Noise injection complete.
