In [1]:
!pip install graphdatascience



In [2]:
from graphdatascience.session import GdsSessions, AuraAPICredentials, DbmsConnectionInfo, AlgorithmCategory
from datetime import timedelta
import pandas as pd
import os
from google.colab import userdata

## Load in Data
Load in data from my github that covers how the NYC subway connects together.

In [3]:
lines = pd.read_csv("https://raw.githubusercontent.com/corydonbaylor/aura-graph-analytics/refs/heads/main/mta_subways/data/lines.csv")
stations = pd.read_csv("https://raw.githubusercontent.com/corydonbaylor/aura-graph-analytics/refs/heads/main/mta_subways/data/nodes.csv")

## Set Up Sessions
Set up sessions with credentials and then spin up a *session*

In [4]:
CLIENT_ID = userdata.get("CLIENT_ID")
CLIENT_SECRET = userdata.get("CLIENT_SECRET")
TENANT_ID = userdata.get("TENANT_ID")


# Neo4j Database Connection Info
SUPPLIER_URI = userdata.get("SUPPLIER_URI")
NEO4J_USER = userdata.get("NEO4J_USER")
SUPPLIER_PASSWORD = userdata.get("SUPPLIER_PASSWORD")

In [5]:
sessions = GdsSessions(api_credentials=AuraAPICredentials(CLIENT_ID, CLIENT_SECRET, TENANT_ID))

name = "my-new-session"
memory = sessions.estimate(
    node_count=475,
    relationship_count=800,
    algorithm_categories=[AlgorithmCategory.CENTRALITY, AlgorithmCategory.NODE_EMBEDDING],
)

db_connection_info = DbmsConnectionInfo(SUPPLIER_URI, NEO4J_USER, SUPPLIER_PASSWORD)


In [6]:
# Create or retrieve a session
gds = sessions.get_or_create(
    session_name=name,
    memory=memory,
    db_connection=db_connection_info, # this is checking for a bolt server currently
    ttl=timedelta(hours=5),
)

# Creating a Projection
You can create a projection directly from python dataframes. We have two dataframes-- one that represents stations and one that represents lines.

In [7]:
stations

Unnamed: 0.1,Unnamed: 0,station_name,id
0,0,Van Cortlandt Park-242 - Bx,0
1,1,238 St - Bx,1
2,2,231 St - Bx,2
3,3,Marble Hill-225 St - M,3
4,4,215 St - M,4
...,...,...,...
416,416,World Trade Center - M,416
417,417,Broad St - M,417
418,418,Canarsie-Rockaway Pkwy - Bk,418
419,419,Middle Village-Metropolitan Av - Q,419


In [8]:
lines

Unnamed: 0,sourceNodeId,targetNodeId,relationshipType
0,0,1,GOES_TO
1,1,2,GOES_TO
2,2,3,GOES_TO
3,3,4,GOES_TO
4,4,5,GOES_TO
...,...,...,...
694,408,336,GOES_TO
695,336,32,GOES_TO
696,32,34,GOES_TO
697,34,61,GOES_TO


Currently, Graph Analytics only accepts directed graphs. So we need to explicitly create the relationships going in the other direction.

In [9]:
lines2 = lines.rename(
    columns={
        'sourceNodeId' : 'targetNodeId',
        'targetNodeId' : 'sourceNodeId'
    }
)

lines = pd.concat([lines, lines2], ignore_index=True)
lines

Unnamed: 0,sourceNodeId,targetNodeId,relationshipType
0,0,1,GOES_TO
1,1,2,GOES_TO
2,2,3,GOES_TO
3,3,4,GOES_TO
4,4,5,GOES_TO
...,...,...,...
1393,336,408,GOES_TO
1394,32,336,GOES_TO
1395,34,32,GOES_TO
1396,61,34,GOES_TO


We do need to do some mild clean up to make sure that everything has the right names.

For the dataframe representing nodes:
- The first column should be called `nodeId`
- There can be no characters so we will have to drop the station names

For the dataframe representing relationships:
- We need to have columns called `sourceNodeId` and `targetNodeId`
- As well as what we want to call that relationship in a column called `relationshipType`

In [10]:
stations = stations.rename(columns={'id': 'nodeId'})
nodes = stations[['nodeId']]
nodes

Unnamed: 0,nodeId
0,0
1,1
2,2
3,3
4,4
...,...
416,416
417,417
418,418
419,419


## Graph Construct
Using `graph.construct`, we can easily create a projection.

In [11]:
G = gds.graph.construct("subways2", nodes, lines)

Uploading Nodes:   0%|          | 0/421 [00:00<?, ?Records/s]

Uploading Relationships:   0%|          | 0/1398 [00:00<?, ?Records/s]

And then we can run a path-finding algorithm to check things out.

## Returning Results
We will use Dijkstra shortest path to see how we can move through the system efficiently.

We can create a simple wrapper function below, so that we can use the names of stations rather than their `nodeIds`:

In [12]:
station_crosswalk = dict(zip(stations['station_name'], stations['nodeId']))

In [13]:
# Function to get the node IDs from station names and run Dijkstra
def get_shortest_path(source_station, target_station, G):
    # Map the station names to node IDs
    source_node_id = station_crosswalk.get(source_station)
    target_node_id = station_crosswalk.get(target_station)

    result = gds.shortestPath.dijkstra.stream(
          G,
          sourceNode=source_node_id,
          targetNode=target_node_id
      )
    node_ids = result['nodeIds'][0]
    id_to_station = {v: k for k, v in station_crosswalk.items()}
    ordered_subset = {id_to_station[i]: i for i in node_ids if i in id_to_station}
    return ordered_subset

Let's see how to get from Grand Army Plaza in Brooklyn to Times Square:

In [14]:
# Example usage
# Assuming 'G' is your graph
source_station = "Grand Army Plaza - Bk"
target_station = "Times Sq-42 St - M"

# Call the function
path_df = get_shortest_path(source_station, target_station, G)
path_df

{'Grand Army Plaza - Bk': 69,
 'Bergen St - Bk': 68,
 'Atlantic Av-Barclays Ctr - Bk': 67,
 'Canal St - M': 32,
 '14 St-Union Sq - M': 104,
 '34 St-Herald Sq - M': 230,
 'Times Sq-42 St - M': 24}

But what if one of those stations closed? What would be the quickest path there? Let's see what would happen if Canal St was closed:

In [15]:
def exclude_node(nodes_df, lines_df, node_to_exclude):
    closed = nodes_df[nodes_df['nodeId'] != node_to_exclude]
    closed_lines = lines_df[
        (lines_df['sourceNodeId'] != node_to_exclude) &
        (lines_df['targetNodeId'] != node_to_exclude)
    ]
    return closed, closed_lines

closed_nodes, closed_lines = exclude_node(nodes, lines, 230)

We then need to create a new projection without canal street:

In [17]:
G = gds.graph.construct("subways6", closed_nodes, closed_lines)

Uploading Nodes:   0%|          | 0/420 [00:00<?, ?Records/s]

Uploading Relationships:   0%|          | 0/1366 [00:00<?, ?Records/s]

In [18]:
# Example usage
# Assuming 'G' is your graph
source_station = "Grand Army Plaza - Bk"
target_station = "Times Sq-42 St - M"

# Call the function
path_df = get_shortest_path(source_station, target_station, G)
path_df

{'Grand Army Plaza - Bk': 69,
 'Bergen St - Bk': 68,
 'Atlantic Av-Barclays Ctr - Bk': 67,
 'Canal St - M': 32,
 'Chambers St - M': 34,
 '14 St - M': 29,
 '34 St-Penn Station - M': 25,
 'Times Sq-42 St - M': 24}

In [19]:
sessions.delete(session_name="my-new-session")

True