In [None]:
import pathlib
import csv
import networkx as nx
from pyvis import network as net

# NetworkX graphs

- **nx.Graph** - Unidirected graph
- **nx.DiGraph** - Directed graph

In [None]:
STATION = 0
LINE = 1
PERSON = 2

def node_id(type, name):
    # Node ID must be either str or int to work with pyvis
    return f"{type}_{name}"
    

# Create graph
def create_graph():
    G = nx.DiGraph()

    root = pathlib.Path().resolve() / "data"

    # Import stations
    stations = root / "stations.csv"
    all_stations = {}
    with open(stations) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            name = row["name"]
            node = node_id(STATION, name)
            all_stations[row["id"]] = node
            G.add_node(node, label=name, type=STATION)

    # Import lines
    lines = root / "lines_final.csv"
    with open(lines) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            name = row["line_name"]
            line_node = node_id(LINE, name)
            G.add_node(line_node, color=row["color"], label=name, type=LINE)

            station1 = all_stations[row["station1"]]
            station2 = all_stations[row["station2"]]
            G.add_edge(line_node, station1, type="CONSISTS_OF")
            G.add_edge(line_node, station2, type="CONSISTS_OF")
            G.add_edge(station1, station2, type="CONNECTS_TO")
            G.add_edge(station2, station1, type="CONNECTS_TO")

    # Import people
    people = root / "person.csv"
    with open(people) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            name = row["name"]
            node = node_id(PERSON, name)
            G.add_node(node, label=name, type=PERSON)
            station = all_stations[row["station"]]
            G.add_edge(node, station, type="LIVES_NEAR")
            G.add_edge(station, node, type="LIVES_NEAR")

    return G

G = create_graph()
print("Number of nodes:", G.number_of_nodes())
print("Number of edges:", G.number_of_edges())



In [None]:
# Visualization using pyvis
def show(G, directed):
    n = net.Network(cdn_resources="remote", directed=directed)
    n.from_nx(G)
    n.show("sample.html", notebook=False)

In [None]:
# All stations on the Victoria Line
line_id = node_id(LINE, "Victoria Line")
G_line = G.subgraph(G[line_id])
show(G_line, False)

In [None]:
def filter_person_location(G, person):
    nodes_to_show = []
    person_id = node_id(PERSON, person)
    nodes_to_show.append(person_id)
    for neighbor in G.neighbors(person_id):

        # Find stations
        if G.nodes[neighbor]["type"] == STATION:
            nodes_to_show.append(neighbor)

            # Find lines
            for (line, _) in G.in_edges(neighbor):
                if G.nodes[line]["type"] == LINE:
                    nodes_to_show.append(line)
                    
    view = nx.subgraph_view(G, filter_node=lambda n: n in nodes_to_show)
    return view

In [None]:
# Find station and line where Alison lives
G_alison = filter_person_location(G, "Alison")
show(G_alison, False)

In [None]:
# Find station and line where Bob lives
G_alison = filter_person_location(G, "Bob")
show(G_alison, False)

In [None]:
# Find path from Alison's station to Bob's station
alison_id = node_id(PERSON, "Alison")
bob_id = node_id(PERSON, "Bob")
route = nx.shortest_path(G, alison_id, bob_id)
print(route)
G_route = G.subgraph(route)
show(G_route, True)

This example shows why it is important to use DiGraph here. We I used Graph, the shortest path would be **Person -> Station -> Line -> Station -> Person**