This notebook contains code on collecting public transport data from OSM for the generated grids and then calculating the number of public transport options for each 100m x 100m grid

In [None]:
import geopandas as gpd
import os
import glob
import logging
import osmium
import shapely.geometry
from shapely.geometry import box
import subprocess
import gc

In [None]:
# set path for grid files
# you might have grids generated here from the 'generate_grids.ipynb'
grid_path = 'data/*.parquet'

# create a list of all parquet grid files from the specified directory
grids_list = [parquet for parquet in glob.glob(grid_path)]
print(grids_list)

# # configure logging (recommended if you monitor processing over a lot of files)
# log_path = 'logs/public_transport.log'

# # ensure log directory exists
# log_dir = os.path.dirname(log_path)
# if not os.path.exists(log_dir):
#     os.makedirs(log_dir)
    
# logging.basicConfig(filename=log_path, level=logging.INFO,
#                     format='%(asctime)s:%(levelname)s:%(message)s', force=True)

In [None]:
def is_public_transport_node(tags):
    """
    Determine if a node has public transport-related tags.
    """
    transport_tags = {
        ("highway", "bus_stop"),
        ("public_transport", "station"),
        ("public_transport", "platform"),
        ("railway", "station"),
        ("railway", "tram_stop"),
        ("railway", "halt"),
        ("railway", "subway_entrance"),
        ("amenity", "bus_station"),
    }
    return any(tags.get(k) == v for k, v in transport_tags)

def extract_transport_nodes_from_pbf(pbf_path):
    """
    Extract public transportation nodes from an OSM PBF file.

    Args:
    - pbf_path (str): Path to the .osm.pbf file.

    Returns:
    - GeoDataFrame of transport nodes.
    """
    transport_nodes = []

    class Handler(osmium.SimpleHandler):
        def node(self, n):
            if n.location.valid() and is_public_transport_node(n.tags):
                transport_nodes.append({
                    "id": n.id,
                    "name": n.tags.get("name"),
                    "geometry": shapely.geometry.Point(n.location.lon, n.location.lat),
                    "type": next((f"{k}={v}" for k, v in n.tags if (k, v) in {
                        ("highway", "bus_stop"),
                        ("public_transport", "station"),
                        ("public_transport", "platform"),
                        ("railway", "station"),
                        ("railway", "tram_stop"),
                        ("railway", "halt"),
                        ("railway", "subway_entrance"),
                        ("amenity", "bus_station"),
                    }), None)
                })

    handler = Handler()
    handler.apply_file(pbf_path, locations=True)

    if not transport_nodes:
        return gpd.GeoDataFrame(columns=["id", "name", "type", "geometry"], geometry="geometry", crs="EPSG:4326")

    return gpd.GeoDataFrame(transport_nodes, crs="EPSG:4326")

In [None]:
def process_grid(grid_path):
    """
    Processes each grid file to query and save public transport data.
    
    Args:
    - grid_path (str): Path to the grid file.
    """
    try:
        grid_number = grid_path.split('_')[-1].split('.')[0]
        output_dir = os.path.join(os.path.dirname(grid_path), 'public_transport_data')
        output_file = f'public_transport_points_{grid_number}.parquet'
        os.makedirs(output_dir, exist_ok=True)

        if os.path.exists(os.path.join(output_dir, output_file)):
            # logging.info(f'Skipping {output_file} as it already exists')
            print(f'Skipping {output_file} as it already exists')
            return

        grid_gdf = gpd.read_parquet(grid_path)
        if 'index' not in grid_gdf.columns:
            grid_gdf.reset_index(inplace=True)

        grid_gdf_4326 = grid_gdf.to_crs('epsg:4326')
        bounds = grid_gdf_4326.total_bounds
        buffered_box = box(*bounds).buffer(0.01)
        # logging.info(f'Started processing grid {grid_number}')
        print(f'Started processing grid {grid_number}')

        # clip .pbf file
        clipped_pbf_path = f'data/{grid_number}.pbf'
        os.makedirs(os.path.dirname(clipped_pbf_path), exist_ok=True)

        bbox_str = ','.join(map(str, buffered_box.bounds))
        subprocess.run([
            'osmium', 'extract',
            '-b', bbox_str,
            '/Volumes/ssd1/osm_europe/europe-latest.osm.pbf',
            '-o', clipped_pbf_path,
            '--overwrite'
        ], check=True)
        print(f'clipped PBF for {grid_number}')

        # process with pyosmium
        data_gdf = extract_transport_nodes_from_pbf(clipped_pbf_path)
        
        if data_gdf.empty:
            # logging.info(f'No public transport data found for grid {grid_path}')
            print(f'No public transport data found for grid {grid_path}')
            return

        data_gdf.to_crs(grid_gdf.crs, inplace=True)
        data_gdf.to_parquet(os.path.join(output_dir, output_file))
        # logging.info(f'Saved public transport data to {output_file}')
        print(f'Saved public transport data to {output_file}')

        joined = gpd.sjoin(grid_gdf, data_gdf, how="left", predicate='intersects')
        node_counts_per_grid = joined.groupby('index')['index_right'].nunique().reset_index(name='pub_trans_count')
        grid_gdf_f = grid_gdf.merge(node_counts_per_grid, on='index', how='left').fillna({'pub_trans_count': 0})
        grid_gdf_f.to_parquet(grid_path)
        # logging.info(f'Successfully processed grid {grid_path}')
        print(f'Successfully processed grid {grid_path}')

        # clean up temp PBF file
        if os.path.exists(clipped_pbf_path):
            os.remove(clipped_pbf_path)
            
        del data_gdf, grid_gdf, grid_gdf_f
        gc.collect()
    except Exception as e:
        # logging.error(f'Error processing grid {grid_path}: {e}')
        print(f'Error processing grid {grid_path}: {e}')

In [None]:
# sequential
for elem in grids_list:
    process_grid(elem)

# # parallel processsing if you want to process a lot of files
# num_processes = 5

# with Pool(processes=num_processes) as pool:
#     pool.map(process_grid, grids_list)