In [None]:
%matplotlib widget

In [None]:
import asyncio
import itertools
import os

import dotenv
import geopandas as gpd
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import shapely
import shapely.plotting
import tqdm
import tqdm.asyncio
from shapely.geometry import LineString, Point

import tfl.api
import tfl.exceptions
import tfl.models

In [None]:
%matplotlib widget

In [None]:
dotenv.load_dotenv()

In [None]:
tf_client = tfl.api.Tfl(app_key=os.environ["FLATHUNT__TFL_API_KEY"])

In [None]:
stations_facilities = await tf_client.get_stations_facilities()

In [None]:
lines = await tf_client.get_lines_by_mode(
    [
        tfl.models.ModeId.TUBE,
        tfl.models.ModeId.OVERGROUND,
        tfl.models.ModeId.DLR,
        tfl.models.ModeId.ELIZABETH_LINE,
        # tfl.models.ModeId.CABLE_CAR,
        # tfl.models.ModeId.NATIONAL_RAIL,
        # tfl.models.ModeId.TRAM,
    ]
)

In [None]:
line_id_stop_points: dict[str, list[tfl.models.StopPointDetail]] = {}
for line in tqdm.tqdm(lines):
    line_id_stop_points[line.id] = await tf_client.get_stop_points_by_line(line.id)

In [None]:
import httpx

line_id_stop_point_timetables: dict[
    str, dict[str, dict[tfl.api.Direction, tfl.models.TimetableResponse]]
] = {}


async def do_work(line_id_stop_point_timetables, line_id, stop_points):
    for stop_point in stop_points:
        try:
            result = await tf_client.get_timetable(line_id, stop_point.naptan_id, None)
            if result.disambiguation is None:
                line_id_stop_point_timetables.setdefault(line_id, {}).setdefault(
                    stop_point.naptan_id, {}
                )[tfl.api.Direction(result.direction)] = result
            else:
                for direction in tfl.api.Direction:
                    result = await tf_client.get_timetable(
                        line_id, stop_point.naptan_id, direction
                    )
                    line_id_stop_point_timetables.setdefault(line_id, {}).setdefault(
                        stop_point.naptan_id, {}
                    )[direction] = result
        except tfl.exceptions.TflApiError as e:
            # 404: Stop not found, 400: Direction not found (common for National Rail)
            if e.http_status_code in (400, 404):
                continue
            raise
        except httpx.HTTPStatusError as e:
            # Catch raw HTTP errors not wrapped by TflApiError (e.g., 400 for tram/national rail)
            if e.response.status_code == 400:
                continue
            raise


async for future in tqdm.asyncio.tqdm(
    asyncio.as_completed(
        [
            do_work(line_id_stop_point_timetables, line_id, stop_points)
            for line_id, stop_points in line_id_stop_points.items()
        ]
    ),
    total=len(line_id_stop_points),
):
    await future

In [None]:
# Build a complete lookup of travel times between all station pairs using station_intervals
# Structure: line_id -> from_station_id -> to_station_id -> duration (minutes)

all_station_durations: dict[str, dict[str, dict[str, float]]] = {}

for line_id, stop_timetables in tqdm.tqdm(line_id_stop_point_timetables.items()):
    all_station_durations[line_id] = {}

    for naptan_id, direction_timetables in stop_timetables.items():
        for direction, timetable_response in direction_timetables.items():
            if timetable_response is None or timetable_response.timetable is None:
                continue
            if not timetable_response.timetable.routes:
                continue

            # Get the departure station ID from the timetable
            from_station_id = timetable_response.timetable.departure_stop_id

            for route in timetable_response.timetable.routes:
                if not route.station_intervals:
                    continue

                # Use the first station_interval (they're typically similar)
                station_interval = route.station_intervals[0]

                if from_station_id not in all_station_durations[line_id]:
                    all_station_durations[line_id][from_station_id] = {}

                for interval in station_interval.intervals:
                    to_station_id = interval.stop_id
                    duration = interval.time_to_arrival

                    # Keep the value (or update if we find a different one - they should match)
                    if (
                        to_station_id
                        not in all_station_durations[line_id][from_station_id]
                    ):
                        all_station_durations[line_id][from_station_id][
                            to_station_id
                        ] = duration

print("Summary of station duration data:")
for line_id, from_stations in all_station_durations.items():
    total_pairs = sum(len(to_stations) for to_stations in from_stations.values())
    print(
        f"  {line_id}: {len(from_stations)} departure stations, {total_pairs} total pairs"
    )

In [None]:
roads_gdf = gpd.read_file(
    # /Users/cemlyn/Downloads/greater-london-251126-free/gis_osm_buildings_a_free_1.cpg
    "/Users/cemlyn/Downloads/greater-london-251126-free/gis_osm_roads_free_1.shp"
)

In [None]:
roads_gdf = roads_gdf.to_crs("EPSG:27700")

In [None]:
def project_to_meters(lon: float, lat: float):
    point_wgs84 = gpd.GeoSeries([Point(lon, lat)], crs="EPSG:4326")
    point_osgb36 = point_wgs84.to_crs("EPSG:27700")
    return point_osgb36.x.item(), point_osgb36.y.item()

In [None]:
def euclidean(x1, y1, x2, y2):
    return np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)

In [None]:
graph = nx.Graph()
for _, road in tqdm.tqdm(roads_gdf.iterrows(), total=len(roads_gdf)):
    for i, ((x1, y1), (x2, y2)) in enumerate(itertools.pairwise(road.geometry.coords)):
        if (x1, y1) not in graph:
            graph.add_node((x1, y1), x=x1, y=y1)
        if (x2, y2) not in graph:
            graph.add_node((x2, y2), x=x2, y=y2)
        if not graph.has_edge((x1, y1), (x2, y2)):
            graph.add_edge(
                (x1, y1),
                (x2, y2),
                length=(euclidean(x1, y1, x2, y2)).item(),
                geometry=LineString([(x1, y1), (x2, y2)]),
            )  # in meters

In [None]:
transport_graph = nx.Graph()
missing_pairs = []

for line_id in line_id_stop_points.keys():
    for stop_point in line_id_stop_points[line_id]:
        x, y = project_to_meters(stop_point.lon, stop_point.lat)
        if (x, y) not in transport_graph:
            transport_graph.add_node(
                (x, y),
                x=x,
                y=y,
                station_name=stop_point.common_name,
            )

    line_durations = all_station_durations.get(line_id, {})

    for stop_point, other_stop_point in itertools.combinations(
        line_id_stop_points[line_id], 2
    ):
        # Use naptan_id to match the keys in all_station_durations (from departure_stop_id)
        stop_id = stop_point.naptan_id
        other_id = other_stop_point.naptan_id

        x1, y1 = project_to_meters(stop_point.lon, stop_point.lat)
        x2, y2 = project_to_meters(other_stop_point.lon, other_stop_point.lat)

        # Try both directions since station_intervals only go one way
        time = None
        if stop_id in line_durations and other_id in line_durations[stop_id]:
            time = line_durations[stop_id][other_id]
        elif other_id in line_durations and stop_id in line_durations[other_id]:
            time = line_durations[other_id][stop_id]

        if time is None:
            missing_pairs.append((line_id, stop_id, other_id))
            continue

        time += 5  # add 5 minutes for boarding/alighting

        transport_graph.add_edge(
            (x1, y1),
            (x2, y2),
            time=time,
            geometry=LineString(
                [
                    (x1, y1),
                    (x2, y2),
                ]
            ),
        )

print(f"Missing pairs: {len(missing_pairs)}")
if missing_pairs:
    # Show a sample of missing pairs by line
    from collections import Counter

    line_counts = Counter(line_id for line_id, _, _ in missing_pairs)
    print("Missing pairs by line:")
    for line_id, count in line_counts.most_common():
        print(f"  {line_id}: {count}")

In [None]:
whole_graph = nx.compose_all([graph, transport_graph])

In [None]:
def find_nearest_node(x, y):
    """Find the nearest node to a given (x, y) coordinate."""
    distances = euclidean(x, y, points[:, 0], points[:, 1])
    return distances.argmin(axis=0).item()


non_transport_nodes = list(graph.nodes)
points = np.array([(data["x"], data["y"]) for _, data in graph.nodes(data=True)])

for transport_node_key in tqdm.tqdm(transport_graph.nodes):
    x = transport_graph.nodes[transport_node_key]["x"]
    y = transport_graph.nodes[transport_node_key]["y"]
    closest = find_nearest_node(x, y)
    non_transport_key = non_transport_nodes[closest]
    whole_graph.add_edge(
        transport_node_key,
        non_transport_key,
        length=(
            euclidean(
                x,
                y,
                graph.nodes[non_transport_key]["x"],
                graph.nodes[non_transport_key]["y"],
            ).item()
        ),
        geometry=LineString(
            [
                (x, y),
                (
                    graph.nodes[non_transport_key]["x"],
                    graph.nodes[non_transport_key]["y"],
                ),
            ]
        ),
    )

In [None]:
meters_per_minute = 60
for a, b, data in whole_graph.edges(data=True):
    if (
        "station_name" in whole_graph.nodes[a]
        and "station_name" in whole_graph.nodes[b]
    ):
        if "time" not in data:
            raise ValueError
    else:
        data["time"] = data["length"] / meters_per_minute

In [None]:
def isochrones(G, node, trip_time: float):
    subgraph = nx.ego_graph(G, node, radius=trip_time, distance="time")

    remove_edges = set()
    for n_fr, n_to in subgraph.edges():
        if (
            "station_name" in subgraph.nodes[n_fr]
            and "station_name" in subgraph.nodes[n_to]
        ):
            remove_edges.add((n_fr, n_to))

    for n_fr, n_to in remove_edges:
        subgraph.remove_edge(n_fr, n_to)

    subgraphs_nodes = nx.connected_components(subgraph)

    return [nx.subgraph(graph, nodes) for nodes in subgraphs_nodes]


def make_poly(G, edge_buff: float, node_buff: float):
    node_points = [Point((data["x"], data["y"])) for node, data in G.nodes(data=True)]
    nodes_gdf = gpd.GeoDataFrame({"id": list(G.nodes)}, geometry=node_points)
    nodes_gdf = nodes_gdf.set_index("id")
    edge_lines = []
    for n_fr, n_to in G.edges():
        if "station_name" in G.nodes[n_fr] and "station_name" in G.nodes[n_to]:
            continue
        edge_lookup = G.get_edge_data(n_fr, n_to)["geometry"]
        edge_lines.append(edge_lookup)
    n = nodes_gdf.buffer(node_buff).geometry
    e = gpd.GeoSeries(edge_lines).buffer(edge_buff).geometry
    all_gs = list(n) + list(e)
    new_iso = gpd.GeoSeries(all_gs).union_all()
    return new_iso

In [None]:
query_iso_subgraphs = []
queries = [
    (-0.10813726002192411, 51.51804484802881),
    (-0.0207016567503272, 51.503329567778614),
]
for query in queries:
    x, y = project_to_meters(query[0], query[1])
    closest_node_index = find_nearest_node(x, y)
    locked_query = non_transport_nodes[closest_node_index]
    query_iso_subgraphs.append(isochrones(whole_graph, locked_query, 25))

In [None]:
import concurrent.futures

NODE_BUFFER = 0
EDGE_BUFFER = 25

all_polys = []
for subgraphs in query_iso_subgraphs:
    subgraph_polys = []
    with tqdm.tqdm(total=len(subgraphs)) as pbar:
        with concurrent.futures.ThreadPoolExecutor() as executor:
            for poly in executor.map(
                lambda sg: make_poly(sg, NODE_BUFFER, EDGE_BUFFER), subgraphs
            ):
                subgraph_polys.append(poly)
                pbar.update(1)
    all_polys.append(subgraph_polys)

In [None]:
pairs = []
a_subgraphs, b_subgraphs = query_iso_subgraphs
a_polys, b_polys = all_polys
for a_subgraph, a_poly in tqdm.tqdm(
    zip(a_subgraphs, a_polys, strict=True), total=len(a_subgraphs)
):
    for b_subgraph, b_poly in zip(b_subgraphs, b_polys, strict=True):
        a_boundary = a_poly.boundary
        b_boundary = b_poly.boundary
        if (
            a_boundary is not None
            and b_boundary is not None
            and a_boundary.intersects(b_boundary)
        ):
            pairs.append((a_subgraph, b_subgraph))

In [None]:
import concurrent.futures
import itertools

compatible_intersections = []
with tqdm.tqdm(total=len(pairs)) as pbar:
    with concurrent.futures.ThreadPoolExecutor() as executor:
        for intersection in executor.map(
            nx.intersection, [a for a, b in pairs], [b for a, b in pairs]
        ):
            if intersection.number_of_nodes() > 0:
                intersection_subgraphs = list(nx.connected_components(intersection))
                compatible_intersections.extend(
                    [
                        nx.subgraph(intersection, nodes)
                        for nodes in intersection_subgraphs
                    ]
                )
            pbar.update(1)

In [None]:
compatible_intersections = [g.copy() for g in compatible_intersections]
for intersection in compatible_intersections:
    for node_id, node_attributes in intersection.nodes.items():
        # TODO: Why is this happening??
        node_attributes.update(whole_graph.nodes[node_id])
        # Add back the edges
        for neighbor, edge_attributes in whole_graph[node_id].items():
            if neighbor in intersection.nodes:
                intersection.add_edge(node_id, neighbor, **edge_attributes)

In [None]:
all_polys = []
all_graphs = query_iso_subgraphs + [compatible_intersections]
with tqdm.tqdm(total=sum(map(len, all_graphs))) as pbar:
    with concurrent.futures.ThreadPoolExecutor() as executor:
        maps = [
            executor.map(lambda sg: make_poly(sg, EDGE_BUFFER, NODE_BUFFER), subgraphs)
            for subgraphs in all_graphs
        ]
        for _map in maps:
            subgraph_polys = []
            for subgraph in _map:
                subgraph_polys.append(subgraph)
                pbar.update(1)
            all_polys.append(subgraph_polys)

In [None]:
from matplotlib.patches import PathPatch
from matplotlib.path import Path

patches = []
for polys, color, zorder in tqdm.tqdm(
    zip(
        # [[poly.exterior for poly in ps if not poly.is_empty] for ps in all_polys],
        [[poly for poly in ps if not poly.is_empty] for ps in all_polys],
        ["blue", "red", "cyan"],
        [0, 0, 1],
        strict=True,
    ),
    total=len(all_polys),
):
    _poly = shapely.union_all(polys)
    if isinstance(_poly, shapely.MultiPolygon):
        patch = shapely.plotting.patch_from_polygon(
            _poly,
            facecolor=color,
            edgecolor=color,
            linewidth=0.1,
            alpha=0.5 if zorder != 1 else 1.0,
            zorder=zorder,
        )
    elif isinstance(_poly, shapely.MultiLineString):
        path = Path.make_compound_path(
            *[Path(np.asarray(mline.coords)[:, :2]) for mline in _poly.geoms]
        )
        patch = PathPatch(
            path,
            facecolor=color,
            edgecolor=color,
            linewidth=0.1,
            alpha=0.5 if zorder != 1 else 1.0,
            zorder=zorder,
        )
    elif isinstance(_poly, shapely.LineString):
        path = Path(np.asarray(_poly.coords)[:, :2])
        patch = PathPatch(
            path,
            facecolor=color,
            edgecolor=color,
            linewidth=0.1,
            alpha=0.5 if zorder != 1 else 1.0,
            zorder=zorder,
        )
    else:
        raise ValueError(f"Unexpected geometry type: {_poly.geom_type}")
    patches.append(patch)
figure = plt.figure(dpi=300)
ax = figure.gca()
roads_gdf.geometry.plot(ax=ax, color="black", linewidth=0.1)
for p in patches:
    ax.add_patch(p)
ax.autoscale_view()
plt.show()

In [None]:
intersections = all_polys[-1]
check_coords = []
for poly in intersections:
    if poly.is_empty:
        continue
    x, y = poly.centroid.x, poly.centroid.y
    lon, lat = (
        gpd.GeoSeries([Point(x, y)], crs="EPSG:27700")
        .to_crs("EPSG:4326")
        .geometry[0]
        .coords[0]
    )
    check_coords.append((lon, lat))

In [None]:
queries

In [None]:
min_times = {}
for lon, lat in tqdm.tqdm(check_coords):
    for i, (query_lon, query_lat) in enumerate(queries):
        journey_results = await tf_client.get_journey_results(
            from_location=(lat, lon),
            to_location=(query_lat, query_lon),
            arrival_datetime=None,
            modes=[
                tfl.models.ModeId.TUBE,
                tfl.models.ModeId.OVERGROUND,
                tfl.models.ModeId.DLR,
                tfl.models.ModeId.ELIZABETH_LINE,
                tfl.models.ModeId.WALKING,
            ],
            use_multi_modal_call=False,
        )
        if isinstance(journey_results, tfl.models.DisambiguationResult):
            print(f"  Query {(lon, lat)} {i + 1}: Disambiguation result, skipping")
            continue
        min_time = min(journey.duration for journey in journey_results.journeys)
        min_times.setdefault((lon, lat), {})[(query_lon, query_lat)] = min_time

In [140]:
journey_results.journeys[0].legs

[Leg(type='Tfl.Api.Presentation.Entities.JourneyPlanner.Leg, Tfl.Api.Presentation.Entities', duration=9, instruction=Instruction(type='Tfl.Api.Presentation.Entities.Instruction, Tfl.Api.Presentation.Entities', summary='Walk to Green Park Station', detailed='Walk to Green Park Station', steps=[InstructionStep(type='Tfl.Api.Presentation.Entities.InstructionStep, Tfl.Api.Presentation.Entities', description=' for 19 metres', turn_direction='STRAIGHT', street_name='', distance=19, cumulative_distance=19, sky_direction=56, sky_direction_description='NorthEast', cumulative_travel_time=15, latitude=51.506890630757, longitude=-0.14276047247499998, path_attribute=PathAttribute(type='Tfl.Api.Presentation.Entities.PathAttribute, Tfl.Api.Presentation.Entities'), description_heading='Continue along ', track_type='None', travel_time=15), InstructionStep(type='Tfl.Api.Presentation.Entities.InstructionStep, Tfl.Api.Presentation.Entities', description='on to Stratton Street, continue for 8 metres', turn

In [None]:
import itertools

list(itertools.chain.from_iterable((times.values() for times in min_times.values())))

In [None]:
all_poly_lines = [list(iso_poly.exterior.coords) for iso_poly in all_polys[-1]]

In [None]:
if False:
    transport_gdf = gpd.read_file(
        "/Users/cemlyn/Downloads/greater-london-251126-free/gis_osm_transport_free_1.shp"
    )
    transport_gdf = transport_gdf[
        transport_gdf["fclass"].isin(["railway_station", "tram_stop"])
    ]