****
# Helpers for key normalizer
****

In [1]:
import re
from typing import Dict, List

_CAMEL_RE = re.compile(r'(?<!^)(?=[A-Z])')
_SPECIALS = {
    "latDeg": "lat_deg",
    "lonDeg": "lon_deg",
    "trackNumber": "track_number",
    "flightId": "flight_id",
    "latLon": "lat_lon",
    "positionUpdated": "position_updated",
    "measurementId": "measurement_id",
    "transferType": "transfer_type", 
    "clearanceType": "clearance_type", 
    "flightId1": "flight_id1",
    "flightId2": "flight_id2",
    "modeUpdated": "mode_updated",
    "lengthSeconds":"length_seconds",
    "actionName": "action_name",
    "markType": "mark_type",
    "markVariant": "mark_variant",
    "markScope": "mark_scope",
    "markSet": "mark_set",
    "flightId": "flight_id",
    "trackNumber": "track_number",
}

def _to_snake(s): 
    if not isinstance(s,str): return s
    return _SPECIALS.get(s, _CAMEL_RE.sub('_', s).lower())

def _normalize_keys(obj):
    if isinstance(obj, Dict):
        return {_to_snake(k): _normalize_keys(v) for k,v in obj.items()}
    if isinstance(obj, List):
        return [_normalize_keys(x) for x in obj]
    return obj

def _flatten_mp(prefix, mp, *, log_conflicts=False):
    """MeasurementPoint for distance measurement-> flat dict under <prefix>_* with backfill from flight_id.track_number."""
    out = {}
    if not isinstance(mp, Dict):
        return out

    fi = mp.get("flight_id") or {}

    # --- track_number with backfill from nested flight_id.track_number ---
    tn_top = mp.get("track_number")
    tn_fi  = fi.get("track_number")
    tn = tn_top if tn_top is not None else tn_fi
    out[f"{prefix}_track_number"] = tn

    if log_conflicts and (tn_top is not None and tn_fi is not None and tn_top != tn_fi):
        out[f"{prefix}_track_number_conflict"] = True  # boolean flag

    # --- kind + lat/lon or flight_id expansion ---
    if isinstance(mp.get("lat_lon"), Dict):
        out[f"{prefix}_kind"] = "lat_lon"
        out[f"{prefix}_lat_deg"] = mp["lat_lon"].get("lat_deg")
        out[f"{prefix}_lon_deg"] = mp["lat_lon"].get("lon_deg")
    elif isinstance(fi, Dict) and len(fi):
        out[f"{prefix}_kind"] = "flight_id"
        for k, v in fi.items():
            out[f"{prefix}_flight_{k}"] = v
    else:
        out[f"{prefix}_kind"] = "unknown"

    return out

def _get(d, *names):
    """Safe getter that tries multiple key spellings (snake & camel)."""
    if not isinstance(d, Dict):
        return None
    for k in names:
        if k in d:
            return d[k]
    return None

_COORD_RE = re.compile(r"^(\d{2})(\d{2})([NS])(\d{3})(\d{2})([EW])$")  # e.g. 5700N01305E, 5624N01341E

def _norm_action_name(s):
    return None if not isinstance(s, str) else s.strip().upper().replace(" ", "_").replace("-", "_")

def _parse_compact_coord(s):
    """
    Parse strings like '5700N01305E' into (lat_deg, lon_deg).
    Returns (lat, lon) as floats or (None, None) if not matching.
    """
    if not isinstance(s, str):
        return None, None
    m = _COORD_RE.match(s.strip())
    if not m:
        return None, None
    lat_deg = int(m.group(1))
    lat_min = int(m.group(2))
    lat_hem = m.group(3)
    lon_deg = int(m.group(4))
    lon_min = int(m.group(5))
    lon_hem = m.group(6)

    lat = lat_deg + lat_min / 60.0
    lon = lon_deg + lon_min / 60.0
    if lat_hem == "S": lat = -lat
    if lon_hem == "W": lon = -lon
    return lat, lon

In [29]:
def rows_mouse_position(asd_dict):
    prefix = "mouse_position"
    if "mouse_position" not in asd_dict: return None
    p = asd_dict["mouse_position"]
    return {"event_name":"mouse_position",f"{prefix}_x":p.get("x"),f"{prefix}_y":p.get("y")}

def rows_track_screen_position(asd_dict):
    prefix = "track_screen_position"
    if "track_screen_position" not in asd_dict: return None
    p = asd_dict["track_screen_position"]
    return {"event_name":"track_screen_position",f"{prefix}_x":p.get("x"),f"{prefix}_y":p.get("y"),
            f"{prefix}_track_number":p.get("track_number"),
            f"{prefix}_visible":p.get("visible"), **{f"{prefix}_flight_{k}":v for k,v in (p.get("flight_id") or {}).items()}}

def rows_track_label_position(asd_dict):
    prefix = "track_label_position"
    if "track_label_position" not in asd_dict: return None
    p = asd_dict["track_label_position"]
    out = {"event_name":"track_label_position",f"{prefix}_x":p.get("x"),f"{prefix}_y":p.get("y"),
           f"{prefix}_width":p.get("width"),f"{prefix}_height":p.get("height"),
           f"{prefix}_visible":p.get("visible"),f"{prefix}_hovered":p.get("hovered"),
           f"{prefix}_selected":p.get("selected"),f"{prefix}_on_pip":p.get("on_pip"),
           f"{prefix}_track_number":p.get("track_number"),
           **{f"{prefix}_flight_{k}":v for k,v in (p.get("flight_id") or {}).items()}}
    return out

def rows_speed_vector(asd_dict):
    prefix = "speed_vector"
    sv = asd_dict.get("speed_vector")
    if not isinstance(sv, Dict): return None
    # oneof update -> one of: mode_updated, visibility, length
    if "mode_updated" in sv:
        return {"event_name":"speed_vector", f"{prefix}_variant":"mode_updated",
                f"{prefix}_mode_name": sv["mode_updated"].get("mode")}
    if "visibility" in sv:
        v = sv["visibility"]
        tn = v.get("track_number") or (v.get("flight_id") or {}).get("track_number")
        vis = v.get("visible")
        return {"event_name":"speed_vector",f"{prefix}_variant":"visibility",f"{prefix}_track_number":tn,
                f"{prefix}_visible":vis,
                f"{prefix}_visibility_event_type": ("set_true" if vis is True else "set_false" if vis is False else "touched"),
                **{f"{prefix}_flight_{k}":v for k,v in (v.get("flight_id") or {}).items()}}
    if "length" in sv:
        return {"event_name":"speed_vector",f"{prefix}_variant":"length",
                f"{prefix}_length_seconds": sv["length"].get("length_seconds")}
    return None

def rows_popup(asd_dict):
    prefix = "popup"
    p = asd_dict.get("popup")
    if not isinstance(p, Dict): return None
    tn = p.get("track_number") or (p.get("flight_id") or {}).get("track_number")
    base = {"event_name":"popup",f"{prefix}_name":p.get("name"),f"{prefix}_opened":p.get("opened"),
            f"{prefix}_track_number":tn}
    base.update({f"{prefix}_flight_{k}":v for k,v in (p.get("flight_id") or {}).items()})
    return base

def rows_transfer(asd_dict):
    prefix = "transfer"
    t = asd_dict.get("transfer")
    if not isinstance(t, Dict): return None
    tn = t.get("track_number") or (t.get("flight_id") or {}).get("track_number")
    base = {"event_name":"transfer","transfer_type_name":t.get("transfer_type"),
            f"{prefix}_track_number":tn}
    base.update({f"{prefix}_flight_{k}":v for k,v in (t.get("flight_id") or {}).items()})
    return base

def rows_clearance(asd_dict):
    prefix = "clearance"
    c = asd_dict.get("clearance")
    if not isinstance(c, Dict): return None
    tn = c.get("track_number") or (c.get("flight_id") or {}).get("track_number")
    base = {"event_name":"clearance","clearance_type":c.get("clearance_type"),
            "clearance": c.get("clearance"), f"{prefix}_track_number": tn}
    base.update({f"{prefix}_flight_{k}":v for k,v in (c.get("flight_id") or {}).items()})
    return base

def rows_distance_measurement(asd_dict):
    """
    Return a flat dict for one DistanceMeasurement event or None.
    Expected columns (aligns with SCHEMA_DISTANCE columns_out):
      change, measurement_id,
      first_track_number, first_kind, first_lat_deg, first_lon_deg,
      second_track_number, second_kind, second_lat_deg, second_lon_deg,
      start_x, start_y, end_x, end_y
    """
    # Top-level: distance_measurement
    prefix = "distance_measurement"
    dm = _get(asd_dict, "distance_measurement", "distanceMeasurement")
    if not isinstance(dm, Dict):
        return None

    row = {"event_name": "distance_measurement"}

    added = _get(dm, "added", "added")
    if isinstance(added, Dict):
        row[f"{prefix}_change"] = "added"
        row[f"{prefix}_measurement_id"] = _get(added, "measurement_id", "measurementId")
        # first / second MeasurementPoint (may be by lat_lon or flight_id)
        first  = _get(added, "first",  "first")
        second = _get(added, "second", "second")
        if first:
            row.update(_flatten_mp(f"{prefix}_first", first))
        if second:
            row.update(_flatten_mp(f"{prefix}_second", second))
        return row

    pos = _get(dm, "position_updated", "positionUpdated")
    if isinstance(pos, Dict):
        row[f"{prefix}_change"] = "position_updated"
        row[f"{prefix}_measurement_id"] = _get(pos, "measurement_id", "measurementId")
        start = _get(pos, "start", "start") or {}
        end   = _get(pos, "end",   "end")   or {}
        row[f"{prefix}_start_x"] = _get(start, "x", "x")
        row[f"{prefix}_start_y"] = _get(start, "y", "y")
        row[f"{prefix}_end_x"]   = _get(end,   "x", "x")
        row[f"{prefix}_end_y"]   = _get(end,   "y", "y")
        return row

    removed = _get(dm, "removed", "removed")
    if isinstance(removed, Dict):
        row[f"{prefix}_change"] = "removed"
        row[f"{prefix}_measurement_id"] = _get(removed, "measurement_id", "measurementId")
        return row

    return None
    
def rows_sep_tool(asd_dict):
    """
    Flatten one Separation Tool event into a single row dict.

    Produces:
      sep_type, change,
      opened_track_number,
      connected_track_number_1, connected_track_number_2,
      closed,
    plus any expanded opened_flight_* / connected_flight1_* / connected_flight2_* fields.
    """
    # top-level key 
    prefix = "sep_tool"
    st = _get(asd_dict, "sep_tool", "sepTool")
    if not isinstance(st, Dict):
        return None

    row = {"event_name": "sep_tool"}

    # common type
    row[f"{prefix}_type"] = _get(st, "type", "type")

    # variant: opened
    opened = _get(st, "opened", "opened")
    if isinstance(opened, Dict):
        row[f"{prefix}_change"] = "opened"
        fi = _get(opened, "flight_id", "flightId") or {}
        row[f"{prefix}_opened_track_number"] = _get(fi, "track_number", "trackNumber")
        # expand nested flight id
        if isinstance(fi, Dict):
            for k, v in fi.items():
                row[f"{prefix}_opened_flight_{_to_snake(k)}"] = v
        return row

    # variant: connected
    connected = _get(st, "connected", "connected")
    if isinstance(connected, Dict):
        row[f"{prefix}_change"] = "connected"
        fi1 = _get(connected, "flight_id1", "flightId1") or {}
        fi2 = _get(connected, "flight_id2", "flightId2") or {}
        row[f"{prefix}_connected_track_number_1"] = _get(fi1, "track_number", "trackNumber")
        row[f"{prefix}_connected_track_number_2"] = _get(fi2, "track_number", "trackNumber")
        # expand nested
        if isinstance(fi1, Dict):
            for k, v in fi1.items():
                row[f"{prefix}_connected_flight1_{_to_snake(k)}"] = v
        if isinstance(fi2, Dict):
            for k, v in fi2.items():
                row[f"{prefix}_connected_flight2_{_to_snake(k)}"] = v
        return row

    # variant: closed (boolean or dict)
    if "closed" in st:
        row[f"{prefix}_change"] = "closed"
        row[f"{prefix}_closed"] = bool(_get(st, "closed", "closed"))
        return row

    # unknown shape
    row[f"{prefix}_change"] = None
    return row

def rows_route_interaction(asd_dict):
    """
    Flatten one route_interaction event into a row dict compatible with SCHEMA_ROUTE.
    Produces:
      event_name='route_interaction',
      action_type_raw, action_type_name,
      value, value_kind, value_lat_deg, value_lon_deg,
      track_number, and any flight_* fields (e.g., flight_uuid, flight_track_number).
    """
    # top-level key
    prefix = "route_interaction"
    ri = _get(asd_dict, "route_interaction", "routeInteraction")
    if not isinstance(ri, dict):
        return None

    row = {"event_name": "route_interaction"}

    # action type
    action_raw = _get(ri, "action_type", "actionType")
    row[f"{prefix}_action_type_raw"]  = action_raw
    row[f"{prefix}_action_type_name"] = _norm_action_name(action_raw)

    # value + optional compact coord parse
    val = _get(ri, "value", "value")
    row[f"{prefix}_value"] = val
    lat, lon = _parse_compact_coord(val)
    row[f"{prefix}_value_lat_deg"] = lat
    row[f"{prefix}_value_lon_deg"] = lon
    row[f"{prefix}_value_kind"] = ("coord" if lat is not None else ("fix" if isinstance(val, str) else None))

    # flight_id + backfill track_number
    fi = _get(ri, "flight_id", "flightId") or {}
    tn = _get(ri, "track_number", "trackNumber") or _get(fi, "track_number", "trackNumber")
    row[f"{prefix}_track_number"] = tn

    # expand flight_id.* -> flight_*
    if isinstance(fi, dict):
        for k, v in fi.items():
            row[f"{prefix}_flight_{_to_snake(k)}"] = v

    return row

def rows_keyboard_shortcut(asd_dict):
    """
    Flatten one KeyboardShortcut event into a row.

    Produces:
        event_name = 'keyboard_shortcut'
        action_name
        action_name_norm
    (exactly matching SCHEMA_KEYBOARD)
    """
    # Key may be keyboard_shortcut or keyboardShortcut depending on source
    prefix = "keyboard"
    ks = _get(asd_dict, "keyboard_shortcut", "keyboardShortcut")
    if not isinstance(ks, dict):
        return None

    # Extract action name
    name = _get(ks, "action_name", "actionName")

    return {
        "event_name": "keyboard_shortcut",
        f"{prefix}_action_name": name,
        f"{prefix}_action_name_norm": _norm_action_name(name),
    }

def rows_mark(asd_dict):
    """
    Flatten one track_mark / mark event into a row.

    Produces fields compatible with SCHEMA_MARK:
      mark_type_raw, mark_variant_raw, mark_scope_raw,
      mark_type_name, mark_variant_name, mark_scope_name,
      mark_set, mark_action,
      track_number,
      plus any flight_* fields (e.g. flight_track_number, flight_uuid).
    """
    # Top-level message may be named "track_mark" or "trackMark"
    prefix = "track_mark"
    m = _get(asd_dict, "track_mark", "trackMark", "mark")
    if not isinstance(m, dict):
        return None

    out = {"event_name": "track_mark"}

    # raw values
    out["mark_type_raw"]    = _get(m, "mark_type", "markType")
    out["mark_variant_raw"] = _get(m, "mark_variant", "markVariant")
    out["mark_scope_raw"]   = _get(m, "mark_scope", "markScope")
    out["mark_set"]         = _get(m, "mark_set", "markSet")

    # normalized (for grouping)
    out["mark_type_name"]    = _norm_action_name(out["mark_type_raw"])
    out["mark_variant_name"] = _norm_action_name(out["mark_variant_raw"])
    out["mark_scope_name"]   = _norm_action_name(out["mark_scope_raw"])

    # derive an action label
    if out["mark_set"] is True:
        out["mark_action"] = "SET"
    elif out["mark_set"] is False:
        out["mark_action"] = "UNSET"
    else:
        out["mark_action"] = "TOUCH"

    # track number: top-level, else from flight_id.track_number
    tn = _get(m, "track_number", "trackNumber")
    fi = _get(m, "flight_id", "flightId") or {}
    if tn is None and isinstance(fi, dict):
        tn = _get(fi, "track_number", "trackNumber")
    out[f"{prefix}_track_number"] = tn

    # expand flight_id.* â†’ flight_*
    if isinstance(fi, dict):
        for k, v in fi.items():
            out[f"{prefix}_flight_{_to_snake(k)}"] = v

    return out

EXTRACTORS = [
    rows_mouse_position,
    rows_track_screen_position,
    rows_track_label_position,
    rows_speed_vector,
    rows_popup,
    rows_transfer,
    rows_clearance,
    rows_distance_measurement,
    rows_sep_tool,
    rows_route_interaction,
    rows_keyboard_shortcut,
    rows_mark,
]


****
# Pull raw ASD events as dicts from SQLite
****

In [30]:
import sqlite3
from pathlib import Path
from aware_protos.aware.proto import messages_pb2
from google.protobuf.json_format import MessageToDict

def iter_asd_events(db_path: Path, start_ms: int, end_ms: int, batch=50_000):
    
    uri = f"file:{db_path}?mode=ro&immutable=1"
    con = sqlite3.connect(uri, uri=True)
    con.text_factory = bytes
    con.execute("PRAGMA query_only=ON")
    con.execute("PRAGMA mmap_size=268435456")
    con.execute("PRAGMA temp_store=MEMORY")

    sql = ('SELECT epoch_ms, payload FROM "events" '
           'WHERE epoch_ms BETWEEN ? AND ? ORDER BY epoch_ms')
    cur = con.execute(sql, (start_ms, end_ms))
    try:
        while True:
            rows = cur.fetchmany(batch)
            if not rows:
                break
            for ms, blob in rows:
                ev = messages_pb2.Event()
                ev.ParseFromString(blob)
                if ev.WhichOneof("payload") != "asd_event":
                    continue
                # Convert to dict (camelCase), then normalize to snake_case
                d = MessageToDict(ev.asd_event, preserving_proto_field_name=True)
                yield int(ms), _normalize_keys(d)
    finally:
        con.close()

In [31]:
# import sys
# from utils.build_raw_inputs import find_scenarios, find_et_tsv, find_sim_db, build_et_frame

# root = "/store/kruu/eye_tracking/training_data"
# scenarios = find_scenarios(Path(root))

# i = 0

# for _, _, scen_dir in scenarios:
#     et = find_et_tsv(scen_dir)
#     db = find_sim_db(scen_dir)
#     if not et or not db:
#         print(f"[skip] Missing ET or DB in: {scen_dir}", file=sys.stderr)
#         continue
    
#     try:
#         df_et = build_et_frame(et)
#         if df_et.empty:
#             print(f"[warn] ET slice empty in {et}", file=sys.stderr)
#             continue
        
#         iter_asd = iter_asd_events(db, int(df_et["epoch_ms"].min()), int(df_et["epoch_ms"].max()))
#         if i == 10:
#             break
#         i+=1
#     except Exception as e:
#         print(f"[error] {scen_dir}: {e}", file=sys.stderr)
#         continue



# count_total = 0
# count_dm = 0
# examples = []

# for ms, asd in iter_asd:
#     count_total += 1

#     r_dm = rows_keyboard_shortcut(asd)
#     if r_dm:
#         count_dm += 1
#         r_dm["epoch_ms"] = ms
#         examples.append(r_dm)
#         # Stop early after a few to inspect structure
#         if len(examples) < 5:
#             print("\n--- Example event ---")
#             print(r_dm)

# print(f"\nDecoded {count_dm} events of {count_total} total ASD events")

****
# Merging
****

In [None]:
import pandas as pd
from zoneinfo import ZoneInfo
TZ = ZoneInfo("Europe/Zagreb")

def load_selected_asd(db_path: Path, start_ms: int, end_ms: int) -> pd.DataFrame:
    rows = []
    for ms, asd in iter_asd_events(db_path, start_ms, end_ms):
        for f in EXTRACTORS:
            r = f(asd)
            if r:
                r["epoch_ms"] = ms
                rows.append(r)
    if not rows:
        return pd.DataFrame()
    df = pd.DataFrame(rows).sort_values("epoch_ms").reset_index(drop=True)
    # tidy types for common fields
    for c in ("x","y","width","height","track_number","length_seconds"):
        if c in df.columns: df[c] = pd.to_numeric(df[c], errors="coerce").astype("Int32")
    for c in ("visible","opened"):
        if c in df.columns: df[c] = df[c].map(lambda v: None if pd.isna(v) else bool(v)).astype("boolean")
    return df

# Because we are doing some separate feature engineering on ET and Mouse data, we actually don't need to merge them. We only need to have them
# in the same time window, and compute the features separately. 

# def merge_et_mouse_asd(df_et: pd.DataFrame, df_asd: pd.DataFrame, out_parquet: Path, tol_ms=8):
#     dfe = df_et.astype({"epoch_ms":"int64"}).sort_values("epoch_ms")
#     # Merging mouse data to the nearest ET observation based on many to one:
#     # For each ET observation we take the closest mouse observation within tol_ms
#     # One mouse observation might be associated to several ET, and other to none.
#     # Other ASD events are saved separately
#     dfm = df_asd.astype({"epoch_ms":"int64"}).query("event_name=='mouse_position'")[["epoch_ms","mouse_position_x","mouse_position_y"]]
#     merged = pd.merge_asof(dfe, dfm, on="epoch_ms", direction="nearest", tolerance=tol_ms)

#     ts_utc = pd.to_datetime(merged["epoch_ms"], unit="ms", utc=True)
#     merged["ts_cet"] = ts_utc.dt.tz_convert(TZ)

#     # out_parquet.parent.mkdir(parents=True, exist_ok=True)
#     # merged.to_parquet(out_parquet, index=False)
#     return merged


In [43]:
import sys
from utils.build_raw_inputs import find_scenarios, find_et_tsv, find_sim_db, build_et_frame

root = "/store/kruu/eye_tracking/training_data"
scenarios = find_scenarios(Path(root))

et = find_et_tsv(scenarios[0][2])
db = find_sim_db(scenarios[0][2])
df_et = build_et_frame(et)

df_asd = load_selected_asd(db, int(df_et["epoch_ms"].min()), int(df_et["epoch_ms"].max()))

  df = pd.read_csv(tsv_path, sep="\t")


In [None]:
# df_mouse = (
#     df_asd.query("event_name == 'mouse_position'")
#           .rename(columns={
#               "mouse_position_x": "Mouse position X",
#               "mouse_position_y": "Mouse position Y",
#           })
#           [["epoch_ms", "Mouse position X", "Mouse position Y"]]
# )
# merged = merge_and_write(df_et, df_mouse, out_parquet, tol_ms=8)