In [0]:
%pip install osmnx==1.7.1 networkx==3.2.1 geopy==2.4.1 shapely==2.0.3 databricks-sdk[openai]>=0.35.0

In [0]:
import json

try:
    sim_cfg_json = dbutils.widgets.get("sim_cfg_json")
except:
    print("no widget")
    sim_cfg_json = ''
    
SIM_CFG = json.loads(sim_cfg_json) if sim_cfg_json != '' else json.load(open('./config.json'))

In [0]:
# ╔═══════════════════════════════════════════════════════════════════════╗
# ║  Ghost-Kitchen Event Simulator 2.1  →  Delta table                    ║
# ╚═══════════════════════════════════════════════════════════════════════╝
#
#  All tunables live in SIM_CFG  – supply as a JSON widget or edit below
# ------------------------------------------------------------------------

import asyncio, datetime as dt, json, math, pickle, random, time, uuid
from pathlib import Path
from typing import Dict, List, Tuple

import nest_asyncio, networkx as nx, numpy as np, osmnx as ox, pandas as pd
from pyspark.sql import SparkSession, functions as F

from datetime import datetime
import json

nest_asyncio.apply()
spark = SparkSession.builder.getOrCreate()

# ─────────────────────────────────────────────────────────────────────────
# 0.  CONFIGURATION                                                      │
# ─────────────────────────────────────────────────────────────────────────
CFG = SIM_CFG; get = lambda k: CFG[k]
RAND = random.Random(get("random_seed")); np.random.seed(get("random_seed"))

CATALOG, SCHEMA, VOLUME = get("catalog"), get("schema"), get("volume")
START_TS, END_TS       = map(lambda t: dt.datetime.strptime(t,"%Y-%m-%d %H:%M:%S"),
                             (get("start_ts"), get("end_ts")))
SPEED_UP               = get("speed_up")
GK_ADDRESS, GK_R_MI    = get("gk_location"), get("radius_mi")
LOC_NAME               = get("location_name")
GK_DRIVER_MPH          = get("driver_mph")
NOISE, SVC             = get("noise_pct")/100, get("svc")
BATCH_ROWS, BATCH_SEC  = get("batch_rows"), get("batch_seconds")
PING_SEC               = get("ping_sec")
VOLUME_PATH = f"/Volumes/{CATALOG}/{SCHEMA}/{VOLUME}/"
VOLUME_DIR  = Path(VOLUME_PATH).expanduser()   

# ─────────────────────────────────────────────────────────────────────────
# 1.  ROAD GRAPH + NODES                                     │
# ─────────────────────────────────────────────────────────────────────────

# ─── NEW imports (all come with osmnx’s deps) ────────────────────────────
import geopandas as gpd
from shapely.geometry import Point

# ─── Bulk-address cache (lazy-loaded once per run) ───────────────────────
_BUILDINGS = None  # GeoDataFrame cache

def load_buildings() -> gpd.GeoDataFrame:
    """
    Grab all building polygons that already have addr:housenumber + addr:street
    tags inside the simulator’s radius.  No external HTTP calls after this.
    """
    global _BUILDINGS
    if _BUILDINGS is not None:           # reuse if we already fetched
        return _BUILDINGS

    center_pt = ox.geocoder.geocode(GK_ADDRESS)
    tags = {"addr:housenumber": True, "addr:street": True}

    # One Overpass query → GeoDataFrame of footprints with address tags
    bldgs = ox.geometries_from_point(center_pt,
                                     dist=GK_R_MI * 1609.34,
                                     tags=tags)

    # Keep only rows that really have both tags
    bldgs = bldgs.dropna(subset=["addr:housenumber", "addr:street"])
    # footprints are Polygons; centroids are fine for nearest-search
    bldgs = bldgs.to_crs("EPSG:4326")  # ensure same CRS as nodes
    _BUILDINGS = bldgs[["addr:housenumber", "addr:street", "geometry"]]
    return _BUILDINGS

# ─── Replacement for load_nodes() ───────────────────────────────────────
def load_nodes(g: nx.MultiDiGraph) -> pd.DataFrame:
    """
    Return DataFrame with real lat/ lon and best available street address.
    Works offline after the initial Overpass pull inside load_buildings().
    """
    # ---- 1. Nodes → GeoDataFrame ---------------------------------------
    node_rows = []
    for nid, data in g.nodes(data=True):
        lat, lon = data["y"], data["x"]
        node_rows.append({"node_id": nid,
                          "lat": lat,
                          "lon": lon,
                          "geometry": Point(lon, lat)})
    gdf_nodes = gpd.GeoDataFrame(node_rows, crs="EPSG:4326")

    # ---- 2. Spatial join: nearest addressed footprint ------------------
    bldgs = load_buildings()                       # cached after first call
    joined = gpd.sjoin_nearest(
        gdf_nodes, bldgs,
        how="left",
        distance_col="addr_dist_m"                 # keep for curiosity
    )

    # ---- 3. Build final rows with fallback for empties -----------------
    rows = []
    for _, r in joined.iterrows():
        if pd.notna(r["addr:housenumber"]) and pd.notna(r["addr:street"]):
            label = f"{r['addr:housenumber']} {r['addr:street']}"
        else:
            label = f"{RAND.randint(1, 9999)} Main St"
        rows.append(dict(node_id=int(r["node_id"]),
                         lat=r["lat"],
                         lon=r["lon"],
                         addr=label))
    return pd.DataFrame(rows)

def load_graph() -> nx.MultiDiGraph:
    ox.settings.log_console = False
    g = ox.graph_from_point(ox.geocoder.geocode(GK_ADDRESS),
                            dist=GK_R_MI*1609.34, network_type="drive")
    return g

GRAPH = load_graph(); NODES = load_nodes(GRAPH)
GK_LAT, GK_LON = ox.geocoder.geocode(GK_ADDRESS)
GK_NODE = ox.distance.nearest_nodes(GRAPH, GK_LON, GK_LAT)
comp_map = {n:cid for cid,comp in enumerate(nx.connected_components(GRAPH.to_undirected())) for n in comp}
NODES = NODES[NODES.node_id.map(comp_map)==comp_map[GK_NODE]].reset_index(drop=True)
def rand_customer():
    r = NODES.sample(1, random_state=RAND.randrange(2**32)).iloc[0]  # ✔
    return int(r.node_id), r.lat, r.lon, r.addr

# ─────────────────────────────────────────────────────────────────────────
# 2.  MENU + BRAND MOMENTUM                                              │
# ─────────────────────────────────────────────────────────────────────────
ITEMS_DF = spark.read.table("caspers.simulator.items").toPandas()
ITEMS_BY_BRAND = {bid: grp.to_dict("records") for bid, grp in ITEMS_DF.groupby("brand_id")}
BRANDS = list(ITEMS_BY_BRAND); RAND.shuffle(BRANDS)

bm = get("brand_momentum"); cuts = np.cumsum([bm["improving"], bm["flat"]])*len(BRANDS)
IMPR, FLAT, DECL = BRANDS[:int(cuts[0])], BRANDS[int(cuts[0]):int(cuts[1])], BRANDS[int(cuts[1]):]
rates = get("momentum_rates")

def brand_weight(day, total, b):
    if b in IMPR:  f = (1+rates["growth"])**(day/30)
    elif b in DECL: f = (1-rates["decline"])**(day/30)
    else:          f = 1.0
    return f

def rand_basket(day, total, p_single=0.7, max_brands=4, items_rng=(1,4)):
    w = np.array([brand_weight(day,total,b) for b in BRANDS]); w = w/w.sum()
    chosen = [RAND.choices(BRANDS, weights=w, k=1)[0]] if RAND.random()<p_single \
             else RAND.choices(BRANDS, weights=w, k=RAND.randint(2,max_brands))
    items=[]
    for b in chosen:
        for itm in RAND.sample(ITEMS_BY_BRAND[b], RAND.randint(*items_rng)):
            rec = itm.copy(); rec["qty"] = RAND.randint(1,3); items.append(rec)
    return items

# ─────────────────────────────────────────────────────────────────────────
# 3.  DEMAND SHAPE                                                       │
# ─────────────────────────────────────────────────────────────────────────
def minute_weights():
    w = np.ones(1440)
    for s,e,m in [("11:00","13:30",3), ("17:00","20:00",3.5)]:
        s_dt,e_dt = [dt.datetime.strptime(x,"%H:%M") for x in (s,e)]
        s_m,e_m = s_dt.hour*60+s_dt.minute, e_dt.hour*60+e_dt.minute; span=e_m-s_m
        for mi in range(s_m,e_m):
            x = (mi-s_m)/span; w[mi] += (m-1)*(math.sin(math.pi*x)**2)
    return w
MIN_W = minute_weights()

def orders_today(d, total, date):
    base = CFG["orders_day_1"]+(CFG["orders_last"]-CFG["orders_day_1"])*(d/total)
    wd = {"mon":1,"tue":1.05,"wed":1.08,"thu":1.10,"fri":1.25,"sat":1.35,"sun":1.15}[date.strftime("%a").lower()]
    return base*wd*RAND.uniform(1-NOISE,1+NOISE)

# ─────────────────────────────────────────────────────────────────────────
# 4.  ROUTING                                                            │
# ─────────────────────────────────────────────────────────────────────────
def shortest_route(lat,lon):
    cust = ox.distance.nearest_nodes(GRAPH, lon, lat)
    try:
        path = nx.shortest_path(GRAPH, GK_NODE, cust, weight="length"); g = GRAPH
    except nx.NetworkXNoPath:
        g = GRAPH.to_undirected(); path = nx.shortest_path(g, GK_NODE, cust, weight="length")
    coords = [(g.nodes[n]["y"], g.nodes[n]["x"]) for n in path]
    dist = sum(min(d["length"] for d in g[u][v].values()) for u,v in zip(path[:-1],path[1:]))
    return coords, dist

# ─────────────────────────────────────────────────────────────────────────
# 5.  JSON WRITER (flat directory, no date partitions)
# ─────────────────────────────────────────────────────────────────────────
def write_event_json(row: Dict):
    """
    Save each event in a single, flat directory:
        {json_dir}/<ts>-<event_id>.json
    """
    ts = dt.datetime.strptime(row["ts"], "%Y-%m-%d %H:%M:%S.%f")
    fname = f"{ts:%Y%m%d-%H%M%S.%f}-{row['event_id']}.json"
    (VOLUME_DIR / fname).write_text(json.dumps(row))

# ─────────────────────────────────────────────────────────────────────────
# 5.  DELTA WRITER + DATA-QUALITY                                        │
# ─────────────────────────────────────────────────────────────────────────
EVENT_Q = asyncio.PriorityQueue(); CNT = 0
GK_ID = uuid.uuid4().hex; DQ = get("dq")

def maybe_corrupt(ev, payload):
    dq = DQ.get(ev, {}); return {k:(None if RAND.random()<dq.get(k,0) else v) for k,v in payload.items()}

def enqueue(ts, ev, oid, seq, payload):
    global CNT; CNT += 1
    EVENT_Q.put_nowait((ts.timestamp(), CNT, {
        "event_id": uuid.uuid4().hex,
        "event_type": ev,
        "ts": ts.strftime("%Y-%m-%d %H:%M:%S.%f"),
        "gk_id": GK_ID,
        "location": LOC_NAME,
        "order_id": oid,
        "sequence": seq,
        "body": json.dumps(maybe_corrupt(ev, payload))
    }))

def flush(rows):
    if not rows: return
    ts = dt.datetime.now()                               
    fname = ts.strftime("%Y%m%d-%H%M%S.%f") + ".json"
    (VOLUME_DIR / fname).write_text(json.dumps(rows))

async def consumer():
    buf,last=[],time.time()
    while True:
        _,_,row = await EVENT_Q.get(); buf.append(row)
        if len(buf)>=BATCH_ROWS or (time.time()-last)>=BATCH_SEC:
            flush(buf); buf.clear(); last=time.time()

# ─────────────────────────────────────────────────────────────────────────
# 6.  ORDER LIFECYCLE (real-time)                                        │
# ─────────────────────────────────────────────────────────────────────────
MICRO = lambda n: dt.timedelta(microseconds=n)
def gauss(mu_sigma): return max(0.1, RAND.gauss(*mu_sigma))

driver_cfg = get("driver_arrival")
def driver_arrival_time(created_at, t_ready, t_pick):
    """Return a dt between created_at and t_pick with prob mass after ready."""
    if RAND.random() < driver_cfg["after_ready_pct"]:
        base, span = t_ready, (t_pick - t_ready)
    else:
        base, span = created_at, (t_ready - created_at)
    frac = np.random.beta(driver_cfg["alpha"], driver_cfg["beta"])
    t_arr = base + span * frac
    if t_arr >= t_pick:  # safety: keep strict order
        t_arr = t_pick - MICRO(1)
    return t_arr

async def play_order(created_at, day, total):
    oid=uuid.uuid4().hex; seq=0
    _,lat,lon,addr = rand_customer(); items = rand_basket(day,total)
    pts,dist = shortest_route(lat,lon); drive_min = dist/1609.34/GK_DRIVER_MPH*60

    t_cs  = created_at + dt.timedelta(minutes=gauss(SVC["cs"]))
    t_sf  = t_cs + dt.timedelta(minutes=gauss(SVC["sf"]))
    t_fr  = t_sf + dt.timedelta(minutes=gauss(SVC["fr"]))
    t_ready = t_fr
    t_pick  = t_ready + dt.timedelta(minutes=gauss(SVC["rp"]))
    t_drop  = t_pick  + dt.timedelta(minutes=drive_min)
    t_arr   = driver_arrival_time(created_at, t_ready, t_pick)

    enqueue(created_at+MICRO(seq),"order_created",oid,seq,
            dict(customer_lat=lat,customer_lon=lon,customer_addr=addr,items=items)); seq+=1
    enqueue(t_cs+MICRO(seq),"gk_started",oid,seq,{}); seq+=1
    enqueue(t_sf+MICRO(seq),"gk_finished",oid,seq,{}); seq+=1
    enqueue(t_ready+MICRO(seq),"gk_ready",oid,seq,{}); seq+=1
    enqueue(t_arr+MICRO(seq),"driver_arrived",oid,seq,{}); seq+=1
    enqueue(t_pick+MICRO(seq),"driver_picked_up",oid,seq,
            dict(route_points=pts, eta_mins=round(drive_min,1))); seq+=1

    hops=max(1,int(drive_min*60//PING_SEC))
    for h in range(1,hops):
        p=h/hops; lat_i,lon_i=pts[int(p*(len(pts)-1))]
        enqueue(t_pick+dt.timedelta(seconds=h*PING_SEC)+MICRO(seq),"driver_ping",oid,seq,
                dict(progress_pct=round(p*100,1),loc_lat=lat_i,loc_lon=lon_i)); seq+=1

    enqueue(t_drop+MICRO(seq),"delivered",oid,seq,
            dict(delivered_lat=lat,delivered_lon=lon))

# ─────────────────────────────────────────────────────────────────────────
# 7.  BACK-FILL (single write)                                           │
# ─────────────────────────────────────────────────────────────────────────
def build_rows(ts, lat, lon, addr, items, pts, dist, svc, day, total):
    rows=[]; seq=0; oid=uuid.uuid4().hex
    drive=dist/1609.34/GK_DRIVER_MPH*60
    t_cs=ts+dt.timedelta(minutes=gauss(svc["cs"]))
    t_sf=t_cs+dt.timedelta(minutes=gauss(svc["sf"]))
    t_fr=t_sf+dt.timedelta(minutes=gauss(svc["fr"]))
    t_ready=t_fr
    t_pick=t_ready+dt.timedelta(minutes=gauss(svc["rp"]))
    t_drop=t_pick+dt.timedelta(minutes=drive)
    t_arr=driver_arrival_time(ts, t_ready, t_pick)

    def add(t,ev,p):
        nonlocal seq
        rows.append({
            "event_id":uuid.uuid4().hex,"event_type":ev,
            "ts":t.strftime("%Y-%m-%d %H:%M:%S.%f"),"gk_id":GK_ID,
            "location": LOC_NAME,
            "order_id":oid,"sequence":seq,
            "body":json.dumps(maybe_corrupt(ev,p))
        }); seq+=1

    add(ts,"order_created",dict(customer_lat=lat,customer_lon=lon,customer_addr=addr,items=items))
    add(t_cs,"gk_started",{}); add(t_sf,"gk_finished",{}); add(t_ready,"gk_ready",{})
    add(t_arr,"driver_arrived",{}); add(t_pick,"driver_picked_up",dict(route_points=pts,eta_mins=round(drive,1)))
    hops=max(1,int(drive*60//PING_SEC))
    for h in range(1,hops):
        p=h/hops; lat_i,lon_i=pts[int(p*(len(pts)-1))]
        add(t_pick+dt.timedelta(seconds=h*PING_SEC),"driver_ping",
            dict(progress_pct=round(p*100,1),loc_lat=lat_i,loc_lon=lon_i))
    add(t_drop,"delivered",dict(delivered_lat=lat,delivered_lon=lon))
    return rows

def gen_backfill(start, now, total_days):
    back=[]
    for d in range(total_days+1):
        date = start.date()+dt.timedelta(days=d)
        mean = orders_today(d,total_days,date)
        lam  = mean/MIN_W.sum()*MIN_W; midnight=dt.datetime.combine(date,dt.time.min)
        for m,v in enumerate(lam):
            for _ in range(np.random.poisson(v)):
                ts=midnight+dt.timedelta(minutes=m,seconds=RAND.randint(0,59))
                if ts>=now: continue
                _,lat,lon,addr = rand_customer(); items=rand_basket(d,total_days)
                pts,dist=shortest_route(lat,lon)
                back.extend(build_rows(ts,lat,lon,addr,items,pts,dist,SVC,d,total_days))
    return back

# ─────────────────────────────────────────────────────────────────────────
# 8.  MAIN SCHEDULER                                                     │
# ─────────────────────────────────────────────────────────────────────────
async def schedule():
    now = dt.datetime.utcnow()
    total_days = (END_TS.date()-START_TS.date()).days

    flush(gen_backfill(START_TS, now, total_days))   # back-fill
    asyncio.create_task(consumer())                  # live consumer

    async def later(ts, d):
        await asyncio.sleep(max(0,(ts-now).total_seconds()/SPEED_UP))
        await play_order(ts, d, total_days)

    futures=[]
    for d in range(total_days+1):
        date = START_TS.date()+dt.timedelta(days=d)
        mean = orders_today(d,total_days,date)
        lam  = mean/MIN_W.sum()*MIN_W; midnight = dt.datetime.combine(date,dt.time.min)
        for m,v in enumerate(lam):
            for _ in range(np.random.poisson(v)):
                ts=midnight+dt.timedelta(minutes=m,seconds=RAND.randint(0,59))
                if ts>=now:
                    futures.append(asyncio.create_task(later(ts,d)))
    if futures: await asyncio.gather(*futures)

In [0]:
print(f"👻  GK-sim → {VOLUME_PATH}  (×{SPEED_UP})")
print(f"Running with config: {CFG}")
await schedule()