# Data Exploration Notebook
Goal: analyze the processed data in all_trajs.pkl to gain understanding of the nature of this data, extract insights, and relate it to the original, raw data.

#### Imports & Pandas display options

In [2]:
from __future__ import annotations
import os, re, json, pickle, random, math
from typing import Dict, List, Iterable, Tuple, Any

import numpy as np
import pandas as pd

# Nice display
pd.set_option("display.max_rows", 50)
pd.set_option("display.max_columns", 0)
pd.set_option("display.width", 120)

RNG = random.Random(42)


#### Input file paths

In [3]:
# Processed pickle file containing driver trajectory data
PKL_PATH = "all_trajs.pkl"

#### Helpers to load data

In [4]:
def load_all_trajs(pkl_path: str = PKL_PATH) -> Dict[Any, List[List[List[float]]]]:
    """
    Try loading pickle file from PKL_PATH
    Returns: dict[driver_id] -> list[trajectory] -> list[state_action_vector length 126]
    """
    if pkl_path and os.path.exists(pkl_path):
        with open(pkl_path, "rb") as f:
            data = pickle.load(f)
        return data

    raise FileNotFoundError("The pickle file could not be loaded.")

data = load_all_trajs(PKL_PATH)
len(data), list(list(data.keys())[:10])

# for key, value in data.items():
#     print(f'd{key} len = {len(value)}')

(50, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

#### Basic shape checks

In [6]:
def describe_top_level(d: Dict[Any, Any], preview_n: int = 10):
    assert isinstance(d, dict), f"Top-level should be dict, got {type(d)}"
    keys = list(d.keys())
    print(f"Top-level type: {type(d).__name__}")
    print(f"Number of drivers (top-level keys): {len(d):,}")
    print(f"Sample keys (drivers): {keys[:preview_n]}")

    # Quick preview of per-key value types and lengths
    print("\nValue preview (# of trajectories for each driver):")
    for k in keys[:preview_n]:
        v = d[k]
        vtype = type(v).__name__
        vlen = len(v) if hasattr(v, "__len__") else None
        print(f"  - Driver {k!r}: # of trajectories={vlen}")

describe_top_level(data)


Top-level type: dict
Number of drivers (top-level keys): 50
Sample keys (drivers): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Value preview (# of trajectories for each driver):
  - Driver 0: # of trajectories=1013
  - Driver 1: # of trajectories=797
  - Driver 2: # of trajectories=998
  - Driver 3: # of trajectories=667
  - Driver 4: # of trajectories=940
  - Driver 5: # of trajectories=1350
  - Driver 6: # of trajectories=899
  - Driver 7: # of trajectories=1038
  - Driver 8: # of trajectories=954
  - Driver 9: # of trajectories=616


#### Utility iterators (for memory friendly sampling)

In [7]:
def iter_driver_ids(d: Dict[Any, Any]) -> Iterable[Any]:
    for driver_id in d.keys():
        yield driver_id

def iter_trajectories(d: Dict[Any, List[List[List[float]]]],
                      driver_id: Any) -> Iterable[List[List[float]]]:
    for traj in d[driver_id]:
        yield traj

def iter_steps(traj: List[List[float]]) -> Iterable[List[float]]:
    for vec in traj:
        yield vec

def sample_steps(d: Dict[Any, Any], max_steps: int = 20000) -> List[Tuple[Any, int, int, List[float]]]:
    """
    Randomly sample up to max_steps vectors across drivers/trajectories.
    Returns list of tuples: (driver_id, traj_idx, step_idx, vector)
    """
    sampled = []
    driver_ids = list(iter_driver_ids(d))
    RNG.shuffle(driver_ids)

    for driver_id in driver_ids:
        trajs = d[driver_id]
        if not isinstance(trajs, list): 
            continue
        # Shuffle a small subset of trajectories per driver to diversify
        idxs = list(range(len(trajs)))
        RNG.shuffle(idxs)
        for ti in idxs[: min(20, len(idxs))]:
            traj = trajs[ti]
            if not isinstance(traj, list): 
                continue
            for si, vec in enumerate(traj[: min(200, len(traj))]):
                sampled.append((driver_id, ti, si, vec))
                if len(sampled) >= max_steps:
                    return sampled
    return sampled


#### Validate nested structure & vector lengths on a sample

In [8]:
sample = sample_steps(data, max_steps=10000)  # adjust as needed
print(f"Sampled steps: {len(sample):,}")

lengths = []
bad_types = 0
non_numeric_positions = 0

for (driver_id, ti, si, vec) in sample:
    if not isinstance(vec, (list, tuple, np.ndarray)):
        bad_types += 1
        continue
    lengths.append(len(vec))
    # Optional quick numeric check (stop early for speed):
    for x in vec:
        if not isinstance(x, (int, float, np.integer, np.floating)):
            non_numeric_positions += 1
            break

print(f"Distinct vector lengths in sample: {sorted(set(lengths))}")
print(f"Count len==126: {sum(l==126 for l in lengths):,} / {len(lengths):,}")
print(f"Non-numeric vector entries encountered (early-stopped per vec): {non_numeric_positions}")
print(f"Unexpected vector container types: {bad_types}")


Sampled steps: 10,000
Distinct vector lengths in sample: [126]
Count len==126: 10,000 / 10,000
Non-numeric vector entries encountered (early-stopped per vec): 0
Unexpected vector container types: 0


#### Action Analysis (last element) & basic stats

In [9]:
def extract_actions(sampled: List[Tuple[Any, int, int, List[float]]]) -> np.ndarray:
    acts = []
    for (_, _, _, vec) in sampled:
        if isinstance(vec, (list, tuple, np.ndarray)) and len(vec) >= 1:
            acts.append(vec[-1])
    return np.array(acts, dtype=float)

actions = extract_actions(sample)
unique_actions = np.unique(actions)
print(f"Unique action values (sampled): {unique_actions[:50]}")
print(f"Action value min/max: {actions.min() if actions.size else None} / {actions.max() if actions.size else None}")

# If you expect 9 actions (e.g., 0..8), check that assumption:
expected_num_actions = 9
print(f"Observed number of distinct actions in sample: {len(unique_actions)} (expected ~{expected_num_actions})")


Unique action values (sampled): [ 0.  1.  2.  3.  4.  5.  6.  7.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.]
Action value min/max: 0.0 / 18.0
Observed number of distinct actions in sample: 18 (expected ~9)


#### Quick per-driver & per-trajectory sanity stats (sampled drivers only)

In [10]:
def quick_stats(d: Dict[Any, Any], max_drivers: int = 20):
    rows = []
    driver_ids = list(iter_driver_ids(d))
    RNG.shuffle(driver_ids)
    for driver_id in driver_ids[:max_drivers]:
        trajs = d[driver_id]
        n_traj = len(trajs) if isinstance(trajs, list) else np.nan
        # trajectory lengths:
        lengths = [len(t) for t in trajs if isinstance(t, list)]
        row = dict(driver_id=driver_id,
                   n_traj=n_traj,
                   traj_len_min=min(lengths) if lengths else np.nan,
                   traj_len_p50=float(np.median(lengths)) if lengths else np.nan,
                   traj_len_p90=float(np.percentile(lengths, 90)) if lengths else np.nan,
                   traj_len_max=max(lengths) if lengths else np.nan)
        rows.append(row)
    return pd.DataFrame(rows)

df_quick = quick_stats(data, max_drivers=20)
df_quick


Unnamed: 0,driver_id,n_traj,traj_len_min,traj_len_p50,traj_len_p90,traj_len_max
0,41,914,1,15.0,38.0,45
1,31,815,1,15.0,34.0,45
2,40,953,2,14.0,32.0,45
3,36,987,1,13.0,28.0,45
4,19,1156,1,8.0,19.0,45
5,24,462,1,14.0,34.0,45
6,49,972,1,14.0,30.0,45
7,10,1181,1,7.0,13.0,39
8,20,960,1,14.0,28.0,45
9,38,425,1,22.0,45.0,45


#### Missing / NaN check on a small flattened slice

In [11]:
def flatten_sample(sampled: List[Tuple[Any, int, int, List[float]]], max_rows: int = 5000) -> pd.DataFrame:
    """
    Turn a subset of steps into a DataFrame with columns:
    [driver_id, traj_idx, step_idx] + f0..f124 + action
    """
    rows = []
    for (driver_id, ti, si, vec) in sampled[:max_rows]:
        if not isinstance(vec, (list, tuple, np.ndarray)):
            continue
        row = {
            "driver_id": driver_id,
            "traj_idx": ti,
            "step_idx": si,
        }
        L = len(vec)
        if L >= 2:
            # state features: 0..124, action: 125  (if length==126)
            for j in range(min(125, L-1)):
                row[f"f{j}"] = vec[j]
            row["action"] = vec[-1]
        rows.append(row)
    return pd.DataFrame(rows)

df_flat = flatten_sample(sample, max_rows=5000)
print(df_flat.shape)
df_flat.head()


(5000, 129)


Unnamed: 0,driver_id,traj_idx,step_idx,f0,f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24,f25,f26,f27,f28,f29,f30,f31,f32,f33,f34,f35,f36,...,f86,f87,f88,f89,f90,f91,f92,f93,f94,f95,f96,f97,f98,f99,f100,f101,f102,f103,f104,f105,f106,f107,f108,f109,f110,f111,f112,f113,f114,f115,f116,f117,f118,f119,f120,f121,f122,f123,f124,action
0,25,92,0,22,30,247,1,5,11,19,11,24,26,12,17,22,16,11,26,16,9,13,22,10,23,12,12,13,45.0,22.0,7.0,2.0,4.0,14.0,25.0,18.0,13.0,2.0,23.0,23.0,...,0.004318,0.005518,0.004669,0.004729,0.003824,0.00487,0.00379,0.004129,0.00412,0.005222,0.004655,0.005105,0.002013,0.003442,38.070856,41.244412,34.529167,9.745833,7.142857,12.528395,9.57801,7.209028,11.692857,9.093889,11.393171,8.807504,33.495698,13.330108,23.429692,12.596266,24.226293,28.69918,10.777171,14.437312,16.304264,10.477764,14.59697,26.693467,19.374369,4
1,25,92,1,22,29,247,1,4,12,20,12,23,25,13,18,21,17,12,25,15,10,14,21,11,22,13,13,14,5.0,45.0,22.0,7.0,2.0,1.0,14.0,25.0,18.0,13.0,5.0,23.0,...,0.005303,0.004318,0.005518,0.004669,0.004709,0.003824,0.00487,0.00379,0.004129,0.003141,0.005222,0.004655,0.005105,0.002013,6.766667,38.070856,41.244412,34.529167,9.745833,15.875423,12.528395,9.57801,7.209028,11.692857,11.501501,11.393171,8.807504,33.495698,13.330108,13.366852,12.596266,24.226293,28.69918,10.777171,20.661331,16.304264,10.477764,14.59697,26.693467,9
2,25,174,0,26,30,25,4,9,15,23,15,26,30,16,21,26,20,15,30,20,13,17,26,12,27,16,16,17,2.0,32.0,27.0,9.0,4.0,31.0,3.0,2.0,8.0,5.0,83.0,48.0,...,0.006567,0.002945,0.006198,0.000165,0.001771,0.003032,0.007467,0.018903,0.002492,0.017672,0.012179,0.003485,0.0114,0.000199,19.868954,12.787893,81.66985,1.0,51.973333,44.406403,19.502037,26.822222,25.275707,29.411905,39.894055,5.72037,45.516667,7.841667,98.410952,47.653333,50.416631,5.175,0.433333,60.0,2.0,12.491667,5.933333,0.0,46.07,18
3,25,174,1,26,30,26,4,9,15,23,15,26,30,16,21,26,20,15,30,20,13,17,26,12,27,16,16,17,10.0,24.0,17.0,20.0,2.0,23.0,4.0,2.0,4.0,21.0,74.0,37.0,...,0.004807,0.003681,0.006205,0.000177,0.002984,0.005458,0.008681,0.00968,0.003674,0.021995,0.010579,0.003925,0.011874,0.000153,49.464467,26.955871,116.479076,4.825,46.664423,35.540608,8.117033,0.0,43.014622,10.675397,49.00427,23.760833,39.2375,4.866667,58.082418,18.355556,18.498889,10.894444,0.0,25.091111,0.0,10.622222,4.416667,0.0,104.279365,9
4,25,406,0,26,32,237,5,11,13,21,15,28,32,16,19,28,18,13,32,22,13,15,28,10,29,16,14,17,8.0,7.0,10.0,7.0,0.0,3.0,13.0,7.0,1.0,0.0,9.0,0.0,...,0.004756,0.000537,0.007761,0.011197,0.004594,0.00731,0.006475,0.010013,0.010117,0.007092,0.008096,0.001506,0.002518,0.020104,18.079081,26.189357,10.263462,34.20463,8.321979,8.811111,13.019976,17.378947,4.45,1.1,12.411111,5.939167,145.915873,19.533333,0.916667,7.176508,31.416667,33.463889,0.366667,5.491606,6.869444,3.4,21.431667,4.956032,0.0,6


#### NaN/inf audit & quick summary of vector lengths

In [12]:
nan_counts = df_flat.isna().sum().sort_values(ascending=False)
has_nan = int(nan_counts.sum()) > 0
print(f"Any NaNs in flattened sample? {has_nan}")
if has_nan:
    print(nan_counts[nan_counts > 0].head(20))

# Vector length distribution in the sampled slice
vc_lengths = pd.Series([len(v) for (_, _, _, v) in sample]).value_counts().sort_index()
vc_lengths


Any NaNs in flattened sample? False


126    10000
Name: count, dtype: int64

#### (Optional) assert the core invariants for early failure if something is off

In [13]:
# 1) Data dict non-empty
assert isinstance(data, dict) and len(data) > 0, "Top-level data must be a non-empty dict"

# 2) Every value at top level should be a list (trajectories container)
bad = [k for k,v in data.items() if not isinstance(v, list)]
assert not bad, f"Some drivers have non-list trajectory containers: {bad[:5]}"

# 3) At least a large majority of sampled vectors must be length 126
ok_ratio = sum(lenv==126 for lenv in [len(v) for (_,_,_,v) in sample]) / max(1, len(sample))
assert ok_ratio > 0.8, f"Too many non-126 vectors in sample (ok_ratio={ok_ratio:.2%})"

# 4) Actions should be finite numbers
assert np.isfinite(actions).all(), "Actions contain NaN/inf"
print("✅ Basic structural validations passed (on sample).")


✅ Basic structural validations passed (on sample).


In [15]:
def compute_dimension_ranges(d: Dict[Any, List[List[List[float]]]]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute min and max values for each dimension across all vectors.
    Returns: (min_values, max_values) as numpy arrays of length 126
    """
    # Initialize with infinity values
    min_vals = np.full(126, np.inf)
    max_vals = np.full(126, -np.inf)
    
    total_vectors = 0
    
    for driver_id in iter_driver_ids(d):
        trajs = d[driver_id]
        if not isinstance(trajs, list):
            continue
            
        for traj in trajs:
            if not isinstance(traj, list):
                continue
                
            for vec in traj:
                if not isinstance(vec, (list, tuple, np.ndarray)) or len(vec) != 126:
                    continue
                
                vec_array = np.array(vec, dtype=float)
                min_vals = np.minimum(min_vals, vec_array)
                max_vals = np.maximum(max_vals, vec_array)
                total_vectors += 1
    
    print(f"Processed {total_vectors:,} vectors total")
    return min_vals, max_vals

# Compute the ranges
print("Computing min/max for all 126 dimensions across entire dataset...")
min_values, max_values = compute_dimension_ranges(data)

# Print min/max for each dimension
print("\n" + "="*60)
print("MIN and MAX values for all 126 dimensions:")
print("="*60)

for i in range(126):
    dim_type = "state" if i < 125 else "action"
    print(f"Index {i:3d} ({dim_type:6s}): min={min_values[i]:12.6f}  max={max_values[i]:12.6f}  range={max_values[i]-min_values[i]:12.6f}")

# Create a summary DataFrame
range_df = pd.DataFrame({
    'dimension': range(126),
    'min': min_values,
    'max': max_values,
    'range': max_values - min_values
})

# Mark the last dimension as action
range_df['type'] = ['state'] * 125 + ['action']

print("\n" + "="*60)
print("Summary Statistics:")
print("="*60)

print("\n--- State features (dimensions 0-124) ---")
print(range_df.iloc[:125].describe())

print("\n--- Action (dimension 125) ---")
print(range_df.iloc[125:])

Computing min/max for all 126 dimensions across entire dataset...
Processed 666,729 vectors total

MIN and MAX values for all 126 dimensions:
Index   0 (state ): min=    3.000000  max=   48.000000  range=   45.000000
Index   1 (state ): min=    1.000000  max=   81.000000  range=   80.000000
Index   2 (state ): min=    1.000000  max=  288.000000  range=  287.000000
Index   3 (state ): min=    1.000000  max=    6.000000  range=    5.000000
Index   4 (state ): min=    0.000000  max=   82.000000  range=   82.000000
Index   5 (state ): min=    0.000000  max=   74.000000  range=   74.000000
Index   6 (state ): min=    0.000000  max=   82.000000  range=   82.000000
Index   7 (state ): min=    0.000000  max=   86.000000  range=   86.000000
Index   8 (state ): min=    0.000000  max=   99.000000  range=   99.000000
Index   9 (state ): min=    0.000000  max=  103.000000  range=  103.000000
Index  10 (state ): min=    0.000000  max=   87.000000  range=   87.000000
Index  11 (state ): min=    0.000