# waykit Grid Index Visualization

This notebook visualizes how the square-grid spatial index works as a
pre-filter for finding POIs near a GPX track.

Run with:
```
uv run --with jupyter --with folium jupyter lab notebooks/grid_visualization.ipynb
```

In [None]:
import folium
import gpxpy
from pathlib import Path

from waykit.grid_index import SquareGridIndex, project_local_m, cell_id_from_point, encode_cell_id
from waykit.cached_provider import load_index, _ORIGIN_LAT, _ORIGIN_LON, _CELL_SIZE_M
from waykit.geo import extract_gpx_points, haversine_m
from waykit.models import Feature

## Configuration

In [None]:
DISTANCE_M = 500.0      # proximity threshold for "nearby" POIs
PREFILTER_RADIUS_M = 5000.0  # radius to show POIs on the map

TESTDATA = Path("../testdata")
gpx_files = sorted(TESTDATA.glob("*.gpx"))
print(f"Found {len(gpx_files)} GPX files:")
for f in gpx_files:
    print(f"  {f.name}")

## Load GPX tracks and POI index

In [None]:
# Collect all GPX points (lon, lat) from all tracks
all_gpx_points = []
for gpx_file in gpx_files:
    with open(gpx_file, "r") as f:
        gpx = gpxpy.parse(f)
    all_gpx_points.extend(extract_gpx_points(gpx))

print(f"Total GPX points: {len(all_gpx_points)}")

# Load the cached POI index
index = load_index()
print(f"POIs in index: {len(index)}, grid cells: {index.buckets()}")

## Classify POIs

For each GPX point we:
1. Query the grid index with `PREFILTER_RADIUS_M` to get all POIs in the wider area
2. Among those, identify which ones the grid pre-filter returns for `DISTANCE_M`
3. Among the pre-filtered ones, check exact haversine distance to find true matches

In [None]:
from math import floor

# Collect all POIs within the wider area (for display)
all_nearby = {}  # id -> Feature
for glon, glat in all_gpx_points:
    for feat in index.candidates_near(glat, glon, radius_m=PREFILTER_RADIUS_M):
        all_nearby[feat.id] = feat

# Collect grid-prefiltered POIs (candidates at DISTANCE_M)
prefiltered = {}  # id -> Feature
for glon, glat in all_gpx_points:
    for feat in index.candidates_near(glat, glon, radius_m=DISTANCE_M):
        prefiltered[feat.id] = feat

# Exact haversine match
matched = {}
for fid, feat in prefiltered.items():
    flon, flat = feat.geometry.coordinates
    for glon, glat in all_gpx_points:
        if haversine_m(flon, flat, glon, glat) <= DISTANCE_M:
            matched[fid] = feat
            break

print(f"POIs within {PREFILTER_RADIUS_M/1000:.0f} km (displayed): {len(all_nearby)}")
print(f"Grid pre-filtered at {DISTANCE_M:.0f} m: {len(prefiltered)}")
print(f"Exact match at {DISTANCE_M:.0f} m: {len(matched)}")

## Compute active grid cells

These are the cells the index searches when querying at `DISTANCE_M`.

In [None]:
from waykit.grid_index import EARTH_RADIUS_M
from math import radians, cos

def cell_to_latlon_bounds(cx, cy):
    """Convert grid cell (cx, cy) back to (lat_min, lon_min, lat_max, lon_max)."""
    # Cell covers [cx * cell_size, (cx+1) * cell_size] in x (meters east)
    # and [cy * cell_size, (cy+1) * cell_size] in y (meters north)
    x_min = cx * _CELL_SIZE_M
    x_max = (cx + 1) * _CELL_SIZE_M
    y_min = cy * _CELL_SIZE_M
    y_max = (cy + 1) * _CELL_SIZE_M

    # Reverse the equirectangular projection
    lat0_r = radians(_ORIGIN_LAT)
    lon0_r = radians(_ORIGIN_LON)
    cos_lat0 = cos(lat0_r)

    from math import degrees
    lat_min = degrees(y_min / EARTH_RADIUS_M + lat0_r)
    lat_max = degrees(y_max / EARTH_RADIUS_M + lat0_r)
    lon_min = degrees(x_min / (EARTH_RADIUS_M * cos_lat0) + lon0_r)
    lon_max = degrees(x_max / (EARTH_RADIUS_M * cos_lat0) + lon0_r)

    return lat_min, lon_min, lat_max, lon_max


# Cells that directly contain a GPX track point
route_cells = set()
for glon, glat in all_gpx_points:
    pt = project_local_m(glat, glon, _ORIGIN_LAT, _ORIGIN_LON)
    cx, cy = cell_id_from_point(pt, _CELL_SIZE_M)
    route_cells.add((cx, cy))

# All cells searched for DISTANCE_M (route cells + surrounding ring)
active_cells = set()
r = int((DISTANCE_M + _CELL_SIZE_M - 1) // _CELL_SIZE_M)

for glon, glat in all_gpx_points:
    pt = project_local_m(glat, glon, _ORIGIN_LAT, _ORIGIN_LON)
    cx, cy = cell_id_from_point(pt, _CELL_SIZE_M)
    for dx in range(-r, r + 1):
        for dy in range(-r, r + 1):
            active_cells.add((cx + dx, cy + dy))

# Search-only cells = active minus route
search_cells = active_cells - route_cells

print(f"Search radius: {r} cells in each direction")
print(f"Route cells (on track): {len(route_cells)}")
print(f"Search cells (surrounding): {len(search_cells)}")
print(f"Total active cells: {len(active_cells)}")

## Build the map

In [None]:
# Center map on the GPX tracks
lats = [lat for lon, lat in all_gpx_points]
lons = [lon for lon, lat in all_gpx_points]
center_lat = (min(lats) + max(lats)) / 2
center_lon = (min(lons) + max(lons)) / 2

m = folium.Map(location=[center_lat, center_lon], zoom_start=13, tiles="OpenStreetMap")

# --- Layer: Search grid cells (lighter blue) ---
search_layer = folium.FeatureGroup(name="Grid cells (search radius)", show=True)
for cx, cy in search_cells:
    lat_min, lon_min, lat_max, lon_max = cell_to_latlon_bounds(cx, cy)
    folium.Rectangle(
        bounds=[[lat_min, lon_min], [lat_max, lon_max]],
        color="#4A90D9",
        fill=True,
        fill_color="#4A90D9",
        fill_opacity=0.12,
        weight=0.5,
    ).add_to(search_layer)
search_layer.add_to(m)

# --- Layer: Route grid cells (darker blue) ---
route_layer = folium.FeatureGroup(name="Grid cells (on route)", show=True)
for cx, cy in route_cells:
    lat_min, lon_min, lat_max, lon_max = cell_to_latlon_bounds(cx, cy)
    folium.Rectangle(
        bounds=[[lat_min, lon_min], [lat_max, lon_max]],
        color="#1A3A6B",
        fill=True,
        fill_color="#1A3A6B",
        fill_opacity=0.30,
        weight=1,
    ).add_to(route_layer)
route_layer.add_to(m)

# --- Layer: GPX tracks ---
track_layer = folium.FeatureGroup(name="GPX tracks", show=True)
for gpx_file in gpx_files:
    with open(gpx_file, "r") as f:
        gpx = gpxpy.parse(f)
    for track in gpx.tracks:
        for seg in track.segments:
            coords = [(p.latitude, p.longitude) for p in seg.points]
            folium.PolyLine(
                coords, color="red", weight=3, opacity=0.8,
                tooltip=gpx_file.stem,
            ).add_to(track_layer)
track_layer.add_to(m)

# --- Layer: POIs ---
poi_layer = folium.FeatureGroup(name="POIs", show=True)
for fid, feat in all_nearby.items():
    flon, flat = feat.geometry.coordinates
    props = feat.properties
    ele_str = f", {props.ele_m:.0f} m" if props.ele_m else ""

    if fid in matched:
        color = "green"
        status = "matched"
    elif fid in prefiltered:
        color = "blue"
        status = "pre-filtered"
    else:
        color = "gray"
        status = "out of range"

    folium.CircleMarker(
        location=[flat, flon],
        radius=6,
        color=color,
        fill=True,
        fill_color=color,
        fill_opacity=0.8,
        tooltip=f"{props.name}{ele_str} [{status}]",
    ).add_to(poi_layer)
poi_layer.add_to(m)

# Layer control
folium.LayerControl().add_to(m)

m

## Legend

| Color | Meaning |
|-------|---------|
| **Red line** | GPX track |
| **Dark blue rectangles** | Grid cells directly on the GPX route |
| **Light blue rectangles** | Surrounding grid cells within the search radius |
| **Green dots** | POIs within the distance threshold (exact haversine match) |
| **Blue dots** | POIs returned by the grid pre-filter but outside exact distance |
| **Grey dots** | POIs in the wider area but outside the pre-filter |