In [1]:
import networkx as nx
import osmnx as ox
import osmnx_func as osmfunc
import overpass
import geojson
import geopandas as gpd
import pandas as pd
from shapely.geometry import Point, LineString
import xml.etree.ElementTree as ET
from shapely import length
import math
from pyproj import Transformer

In [32]:
def download_osm_transit_data(transport, city):
    api = overpass.API()

    # fetch all ways and nodes
    if transport == 'light_rail':
        result = api.get(f"""
                        area["name"="{city}"] -> .a;
                        (
                        rel [type=route][route=light_rail][railway!=platform](area.a);
                        );
                        (._;>>;);
                        out geom;
                        >;
                        """, responseformat="xml")

    if transport == 'subway':
        result = api.get(f"""
                        area["name"="{city}"] -> .a;
                        (
                        rel [type=route][route=subway][railway!=platform](area.a);
                        );
                        (._;>>;);
                        out geom;
                        >;
                        """, responseformat="xml")

    if transport == 'bus':
        result = api.get(f"""
                        area["name"="{city}"]["boundary"="administrative"] -> .a;
                        (
                        rel(area.a)[route=bus](area.a);
                        );
                        out geom;
                        >;
                        """, responseformat="xml")

    if transport == 'tram':
        result = api.get(f"""
                        area["name"="{city}"] -> .a;
                        (
                        rel [type=route][route=tram][railway!=platform](area.a);
                        );
                        (._;>>;);
                        out geom;
                        >;
                        """, responseformat="xml")

    tree = ET.ElementTree(ET.fromstring(result))

    return tree

def get_meta_from_tree(tree, osm_type):
    """
    Get the meta data from nodes and ways
    in the element tree, returns a list
    of dicts with the meta_data
    """
    dicts = []
    for element in tree.findall(osm_type):
        tags = element.findall('tag')
        temp_dict = {}
        temp_dict['id'] = int(element.get('id'))
        temp_dict['osm_type'] = osm_type
        if osm_type == 'node':
            temp_dict['lat'] = float(element.get('lat'))
            temp_dict['lon'] = float(element.get('lon'))
        for tag in tags:
            temp_dict[tag.get('k')] = tag.get('v')
        dicts.append(temp_dict)
    return dicts

def get_nodes(tree):
    """
    Get all the stations of a relation,
    returns a dict with list, with a 
    stations (point, osm_id)
    """
    node_order = {} #key = rel_id, value = stations nodes
    for rel in tree.findall('relation'):
        nodes = []
        relation_id = int(rel.attrib['id'])

        #Get members of relations
        for mem in rel.findall('member'):
            #Get node ids
            if mem.attrib['type'] == 'node':
                lon = float(mem.attrib['lon'])
                lat = float(mem.attrib['lat'])
                nodes.append([(lon,lat), int(mem.attrib['ref'])])
        node_order[relation_id] = nodes
    return node_order

def get_way_order(tree):
    """
    Get all the ways of a relation,
    creates a network, where each node 
    is a point of the ways.
    Returns a dict with a nx.Graph of the
    ways.
    """
    rel_graph_dict = {} 
    for rel in tree.findall('relation'):
        relation_id = int(rel.attrib['id'])
        G = nx.Graph()
        #Get members of relations
        for mem in rel.findall('member'):
            #Check if it is a way
            if mem.attrib['type'] == 'way':
                osm_id = int(mem.attrib['ref'])
                previous_point = None
                #Add edge in the graph
                for point in mem.findall('nd'):
                    lon = float(point.attrib['lon'])
                    lat = float(point.attrib['lat'])
                    if previous_point == None:
                        previous_point = (lon, lat)
                    else:
                        G.add_edge(u_of_edge = previous_point, v_of_edge = (lon, lat), attr={'osm_id':osm_id})
                        #print(f"{previous_point} -> {(lon, lat)}")
                        previous_point = (lon, lat)
        rel_graph_dict[relation_id] = G
    
    return rel_graph_dict

def check_stations(node_order, way_graphs):
    """
    Checks if the stations are in the graph, if 
    not it match them to the point on the line
    closest to the station
    """
    new_node_order = {}
    for rel_id in way_graphs.keys():
        nodes = node_order[rel_id]
        G = way_graphs[rel_id]
        graph_nodes = list(G.nodes())
        nodes_in_graph = []
        for n in nodes:
            if G.has_node(n):
                nodes_in_graph.append(n)       
            else: #Match to point on the ways
                closest = 1_000_000
                close = None
                p = Point(n[0][0], n[0][1])
                for i in graph_nodes:
                    dist = p.distance(Point(i[0], i[1]))
                    if dist < closest:
                        closest = dist
                        close = i
                if close != None:
                    nodes_in_graph.append((close, n[1]))
        new_node_order[rel_id] = nodes_in_graph
    return new_node_order

def get_meta(tree):
    nodes_meta = get_meta_from_tree(tree = tree, osm_type='node')
    way_meta = get_meta_from_tree(tree = tree, osm_type='way')
    meta_data = pd.DataFrame(nodes_meta+way_meta)
    meta_data = meta_data.drop_duplicates('id')
    meta_data.index = meta_data.id
    meta_data = meta_data.to_dict(orient = 'index')
    return meta_data

def clean_meta_dict(d, keep_coord = True):
    """
    Takes dict and removes nan values
    """
    clean_meta = d.copy()
    for key,value in d.items():
        if key == 'lat' and keep_coord:
            del clean_meta[key]
            clean_meta['y'] = value
        elif key == 'lon'and keep_coord:
            del clean_meta[key]
            clean_meta['x'] = value
        elif type(value) != str and math.isnan(value):
            del clean_meta[key]
        else:
            None 
    return clean_meta

def get_osmid_from_shortest_path(G, path):
    ids = []
    for idx in range(1, len(path)-2): #to exclude the stations
        att = G[path[idx]][path[idx+1]]['attr']
        ids.append(att['osm_id'])
    return list(set(ids))

def line_lenght(line):
    # Transformer to convert from WGS84 to EPSG:3857 (meters)
    transformer = Transformer.from_crs("EPSG:4326", "EPSG:3857", always_xy=True)
    projected_line = LineString([transformer.transform(*coord) for coord in line.coords])
    length_in_meters = projected_line.length
    return length_in_meters

def get_components_ends(component):
    source = list(component.nodes())[0]
    dfs_tree = list(nx.dfs_tree(component, source = source).edges())
    paths = []

    while dfs_tree != []:
        path = []
        used_edges = []
        previous = None
        for i in dfs_tree:
            if previous == None or previous == i[0]:
                used_edges.append(i)
                previous = i[1] #Set new previous node
                path.append(i[0]) #Add to path
                path.append(i[1]) #Add to path
        #Remove edges from dfs_tree
        for e in used_edges:
            dfs_tree.remove(e)
        #Add path to paths
        paths.append(path)

    if len(paths) == 2:
        return paths[0][-1], paths[1][-1]
    elif len(paths) == 1:
        return paths[0][0], paths[0][-1]
    else:
        #Find the two longest comonents!!!!!!!!!!
        l = [[i, len(paths[i])] for i in range(len(paths))] #[index, len of path]
        l.sort(key= lambda x: x[1], reverse= True) 
        first = paths[l[0][0]]
        second = paths[l[1][0]]
        return first[0], second[-1]



def create_network(tree, network = nx.MultiDiGraph()):
    node_order = get_nodes(tree)
    way_graphs = get_way_order(tree)
    meta_data = get_meta(tree)

    node_order = check_stations(node_order, way_graphs)
    data = []

    great_graph = network
    #Plot graph
    for rel_id in way_graphs.keys():
        G = way_graphs[rel_id]

        #Check if multiple components and fix if is
        """
        if nx.is_connected(G) != True:
            components = [G.subgraph(c).copy() for c in nx.connected_components(G)]
            components_points = {}
            for i in range(len(components)): 
                start, end = get_components_ends(component = components[i])
                components_points[start] = i
                components_points[end] = i
            
            #Matches
            matches = []
            #Used nodes
            used = []
            #Match the points to the closets one from another component

            for key_one in components_points.keys():
                if key_one not in used:
                    dist = 1_000_000
                    temp_match = None
                    #Point 1
                    p = Point(key_one[0], key_one[1]) 
                    p_com = components_points[key_one]
                    for key_two in components_points.keys():
                        if key_two not in used and components_points[key_two] != p_com:
                            p_two = Point(key_two[0], key_two[1])
                            if p.distance(p_two) < dist:
                                temp_match = key_two
                                dist = p.distance(p_two)
                
                    if temp_match != None:
                        used.append(key_one) #No need to look at these anymore
                        used.append(temp_match) #No need to look at these anymore
                        matches.append([key_one, temp_match])

            G.add_edges_from(matches)

        if nx.is_connected(G) != True:
            print(f'Relation {rel_id} is connected')
        else:
            print(f'Relation {rel_id} it not connected with {nx.number_connected_components(G)}, however these edges were added: \n {matches}')
        """
        nodes = node_order[rel_id]
        for n_idx in range(len(nodes)-1):
            u = nodes[n_idx]
            u_meta = clean_meta_dict(meta_data[u[1]])

            v = nodes[n_idx+1]
            v_meta = clean_meta_dict(meta_data[v[1]])

            great_graph.add_nodes_from([(u[1], u_meta)])
            great_graph.add_nodes_from([(v[1], v_meta)])
            
            data.append([u, None, Point(u[0][0], u[0][1])])
            data.append([v, None, Point(u[0][0], u[0][1])])
            
            try:
                if u[0] != v[0]:
                    shortest_path = nx.shortest_path(G, u[0], v[0])
                    osm_ids = get_osmid_from_shortest_path(G, shortest_path)
                    attr_dict = {}
                    for i in osm_ids:
                        attr_dict[int(i)] = clean_meta_dict(meta_data[i], keep_coord= False)
                        del attr_dict[i]['id']
                    line = LineString(shortest_path)
                    data.append([u, v, line])
                    great_graph.add_edge(u_for_edge = u[1], v_for_edge = v[1], osmid = osm_ids, geometry = line, 
                                                                                        length= line_lenght(line),
                                                                                        attr_dict = attr_dict
                                                                                        )
            except:
                ...
                #print(f'Failed to find a way from {u} to {v} in relation {rel_id}')
                     
    return great_graph, gpd.GeoDataFrame(data, columns = ['from', 'to', 'geometry'])

def get_plublic_transport(city):
    network = nx.MultiDiGraph()

    for tran in ['subway', 'light_rail', 'bus', 'tram']:
        tree = download_osm_transit_data(tran, city)
        if len(tree.findall('relation')) > 0:
            network, dataframe = create_network(tree = tree, network = network)
        print(f"After adding {tran} to the network, we have {network.number_of_nodes()} nodes")
    
    osmfunc.city_to_files(G = network, city = city, osm_type = 'public_transport', nx_type = 'multidigraph')

In [34]:
get_plublic_transport("Copenhagen Municipality")

After adding subway to the network, we have 0 nodes
After adding light_rail to the network, we have 0 nodes
After adding bus to the network, we have 0 nodes
After adding tram to the network, we have 0 nodes


  _to_file_fiona(df, filename, driver, schema, crs, mode, **kwargs)
