In [87]:
import sys
import os

# Set the main path in the root folder of the project.
sys.path.append(os.path.join('..'))

In [88]:
# Settings for autoreloading

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [89]:
from src.utils.seed import set_random_seed

# Set the random seed for deterministic operations.
SEED = 42
set_random_seed(SEED)

In [90]:
import torch

# Set the device for training and querying the model.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'The selected device is: "{DEVICE}"')

The selected device is: "cuda"


# 1 Loading the Data
In this section the data is loaded.

In [91]:
import os

BASE_DATA_DIR = os.path.join('..', 'data', 'metr-la')

In [92]:
import pickle
with open(os.path.join(BASE_DATA_DIR, 'processed', 'scaler.pkl'), 'rb') as f:
    scaler = pickle.load(f)

In [93]:
from src.spatial_temporal_gnn.model import SpatialTemporalGNN
from src.data.data_extraction import get_adjacency_matrix

# Get the adjacency matrix
adj_matrix_structure = get_adjacency_matrix(
    os.path.join(BASE_DATA_DIR, 'raw', 'adj_mx_metr_la.pkl'))

# Get the header of the adjacency matrix, the node indices and the
# matrix itself.
header, node_ids_dict, adj_matrix = adj_matrix_structure

# Get the STGNN and load the checkpoints.
spatial_temporal_gnn = SpatialTemporalGNN(9, 1, 12, 12, adj_matrix, DEVICE, 64)

stgnn_checkpoints_path = os.path.join('..', 'models', 'checkpoints',
                                      'st_gnn_metr_la.pth')

stgnn_checkpoints = torch.load(stgnn_checkpoints_path, map_location=torch.device(DEVICE))
spatial_temporal_gnn.load_state_dict(stgnn_checkpoints['model_state_dict'])

# Set the model in evaluation mode.
spatial_temporal_gnn.eval();

In [94]:
from src.data.data_extraction import get_locations_dataframe

# Get the dataframe containing the latitude and longitude of each sensor.
locations_df = get_locations_dataframe(
    os.path.join(BASE_DATA_DIR, 'raw', 'graph_sensor_locations_metr_la.csv'),
    has_header=True)
locations_df

Unnamed: 0_level_0,sensor_id,latitude,longitude
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,773869,34.15497,-118.31829
1,767541,34.11621,-118.23799
2,767542,34.11641,-118.23819
3,717447,34.07248,-118.26772
4,717446,34.07142,-118.26572
...,...,...,...
202,717592,34.14604,-118.22430
203,717595,34.14163,-118.18290
204,772168,34.16542,-118.47985
205,718141,34.15133,-118.37456


In [95]:
# Get the node positions dictionary.
node_pos_dict = { i: id for id, i in node_ids_dict.items() }

In [96]:
import os
import numpy as np
from src.spatial_temporal_gnn.prediction import predict

# Get the data and the values predicted by the STGNN.
x_train = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_train.npy'))
y_train = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'y_train.npy'))
x_val = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_val.npy'))
y_val = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'y_val.npy'))
x_test = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_test.npy'))
y_test = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'y_test.npy'))

In [97]:
from src.data.dataloaders import get_dataloader

train_dataloader = get_dataloader(x_train, y_train, batch_size=64,
                                  shuffle=True)
val_dataloader = get_dataloader(x_val, y_val, batch_size=64,
                                shuffle=False)
test_dataloader = get_dataloader(x_test, y_test, batch_size=64,
                                 shuffle=False)

# 2 Predictions
In this section the predictions of the STGNN are performed considering different event kinds.

In [98]:
from langgraph.graph import StateGraph, END
import numpy as np
import pandas as pd
import datetime

In [99]:
from src.data.data_extraction import get_node_values_dataframe

def ui_agent(state: dict) -> dict:
    # 1. Collect user input
    user_prompt = 'Suggest best route avoiding congestion'
    start_node = '773869'
    end_node = '717446'
    time_str = '2012-03-01 15:00'
    try:
        query_time = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M")
    except ValueError:
        print("Invalid time format. Using current time.")
        query_time = datetime.datetime.now()
        
    # Ensure sensor_id column is int
    locations_df['sensor_id'] = locations_df['sensor_id'].astype(int)
    # Convert node IDs to int for lookup
    start_node_int = int(start_node)
    end_node_int = int(end_node)

    start_row = locations_df[locations_df['sensor_id'] == start_node_int]
    end_row = locations_df[locations_df['sensor_id'] == end_node_int]
    if start_row.empty or end_row.empty:
        raise ValueError(f"Start node {start_node} or end node {end_node} not found in locations_df.")

    start_coord = start_row[['latitude', 'longitude']].values[0]
    end_coord = end_row[['latitude', 'longitude']].values[0]
    
    # 2. Build initial state
    state = {
        "user_prompt": user_prompt,
        "start_node": start_node,
        "end_node": end_node,
        "query_time": query_time,
        "start_coord": start_coord,
        "end_coord": end_coord
    }

    return state

In [100]:
import google.generativeai as genai
from dotenv import load_dotenv

load_dotenv()

api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
    raise EnvironmentError("GOOGLE_API_KEY not found in environment variables.")

# Configure the Gemini model with your API key
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

# Load the model (Gemini 1.5 Pro or Gemini 1.0 Pro)
model = genai.GenerativeModel("gemini-1.5-pro-002")

In [101]:
def call_gemini(prompt, max_tokens=256, temperature=0.7):
    response = model.generate_content(
        prompt,
        generation_config={
            "temperature": temperature,
            "max_output_tokens": max_tokens,
            "top_p": 0.95
        }
    )
    return response.text.strip()

In [102]:
def forecasting_agent(state: dict) -> dict:
    start, end, query_time = state['start_node'], state['end_node'], state['query_time']
    # Load input data (adapt as needed)
    x_data = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'x_test.npy'))  # shape: [batch, 12, 207, 9]
    y_data = np.load(os.path.join(BASE_DATA_DIR, 'processed', 'y_test.npy'))  # shape: [batch, 12, 207, 1]
    
    node_pos_dict = {key: value for key, value in node_ids_dict.items()}
    
    # Find the node indices
    try:
        start_idx = node_pos_dict[str(start)]
        end_idx = node_pos_dict[str(end)]
    except Exception as e:
        print(f"[WARN] Failed to resolve node index: {e}")
        start_idx = 0
        end_idx = 1
    # Find the closest time index (simplified: just use the first sample)
    x_sample = x_data[0:1]  # shape: [1, 12, 207, 9]
    # Model inference
    x_tensor = torch.tensor(x_sample, dtype=torch.float32).to(DEVICE)
    with torch.no_grad():
        y_pred = spatial_temporal_gnn(x_tensor)
    y_pred_np = y_pred.cpu().numpy()
    # Inverse transform if scaler has it
    if hasattr(scaler, 'inverse_transform'):
        y_pred_np = scaler.inverse_transform(y_pred_np)
    # Extract forecast for start and end node at last timestep
    forecast_start = y_pred_np[0, -1, start_idx, 0]
    forecast_end = y_pred_np[0, -1, end_idx, 0]
    state['forecast'] = {
        "start_node": float(forecast_start),
        "end_node": float(forecast_end),
        "full_pred": y_pred_np[0, -1, :, 0]  # all nodes at last timestep
    }
    return state


In [103]:
def knowledge_agent(state: dict) -> dict:
    start, end, query_time, prompt = state['start_node'], state['end_node'], state['query_time'], state['user_prompt']
    start_coord = state.get('start_coord')
    end_coord = state.get('end_coord')
    llm_prompt = (
        "You are a traffic knowledge assistant. The locations are sensor node IDs from the METR-LA dataset, "
        "each with known latitude and longitude. "
        f"Start node: {start} at ({start_coord[0]:.5f}, {start_coord[1]:.5f})\n"
        f"End node: {end} at ({end_coord[0]:.5f}, {end_coord[1]:.5f})\n"
        f"Time: {query_time.strftime('%Y-%m-%d %H:%M')}\n"
        f"User prompt: {prompt}\n"
        "Summarize relevant traffic patterns, incidents, or historical congestion for these nodes and time."
    )
    knowledge = call_gemini(llm_prompt, max_tokens=200, temperature=0.2)
    state['knowledge'] = knowledge
    return state

In [104]:
def recommendation_agent(state: dict) -> dict:
    node_pos_dict = {key: value for key, value in node_ids_dict.items()}
    idx_to_sensor = {v: k for k, v in node_pos_dict.items()}
    
    start, end = state['start_node'], state['end_node']
    forecast = state['forecast']
    # Use Dijkstra's algorithm on the adjacency matrix, weighted by predicted congestion
    from scipy.sparse.csgraph import dijkstra
    # Use the predicted congestion as edge weights (lower speed = higher congestion)
    pred_congestion = forecast["full_pred"]
    # Build weighted adjacency: higher congestion = higher weight
    adj = adj_matrix.copy()
    # Avoid division by zero
    pred_congestion = np.clip(pred_congestion, 1e-3, None)
    # Invert speed to get congestion weight (lower speed = higher weight)
    weights = 1.0 / pred_congestion
    weighted_adj = adj * weights[np.newaxis, :]
    # Compute shortest path
    try:
        start_idx = node_pos_dict[start]
        end_idx = node_pos_dict[end]
    except Exception:
        print("[WARN] Failed to resolve node index, using defaults.")
        start_idx = 0
        end_idx = 1
    dist_matrix, predecessors = dijkstra(weighted_adj, directed=False, indices=start_idx, return_predecessors=True)
    # Reconstruct path
    path = []
    i = end_idx
    while i != start_idx and i != -9999:
        path.append(i)
        i = predecessors[i]
    path.append(start_idx)
    path = path[::-1]
    
    # Build a readable route string with sensor IDs and coordinates
    route_steps = []
    routes_list = []
    for idx in path:
        sensor_id = idx_to_sensor.get(idx, "Unknown")
        loc_row = locations_df[locations_df['sensor_id'] == int(sensor_id)]
        if not loc_row.empty:
            lat, lon = loc_row.iloc[0][['latitude', 'longitude']]
            route_steps.append(f"Sensor {sensor_id} ({lat:.5f}, {lon:.5f})")
            routes_list.append({
                "sensor_id": sensor_id,
                "latitude": lat,
                "longitude": lon
            })
        else:
            route_steps.append(f"Sensor {sensor_id}")
            routes_list.append({
                "sensor_id": sensor_id,
                "latitude": None,
                "longitude": None
            })

    route_str = " → ".join(route_steps)
    
    state['route'] = route_str
    state['routes_list'] = routes_list
    return state

In [105]:
def explanation_agent(state: dict) -> dict:
    start, end, query_time = state['start_node'], state['end_node'], state['query_time']
    forecast, knowledge, route = state['forecast'], state['knowledge'], state['route']
    llm_prompt = (
        "You are an expert traffic assistant. Given a forecast, knowledge, and a recommended route, "
        "write a clear, accurate, and highly explainable summary for a user. "
        "Explain the traffic situation, why the route was chosen, and cite relevant knowledge.\n"
        f"Forecast: Start node speed: {forecast['start_node']:.2f} km/h, End node speed: {forecast['end_node']:.2f} km/h\n"
        f"Knowledge: {knowledge}\n"
        f"Recommended route: {route}\n"
        f"Start: {start}\nEnd: {end}\nTime: {query_time.strftime('%Y-%m-%d %H:%M')}\n"
        "Write a user-facing explanation."
    )
    explanation = call_gemini(llm_prompt, max_tokens=300, temperature=0.2)
    state['explanation'] = explanation
    return state


In [106]:
import folium

def visualization_agent(state: dict) -> dict:
    sensor_locations = state['routes_list']
    
    avg_lat = sum(loc['latitude'] for loc in sensor_locations) / len(sensor_locations)
    avg_lon = sum(loc['longitude'] for loc in sensor_locations) / len(sensor_locations)
    
    # Create a folium map
    m = folium.Map(location=[avg_lat, avg_lon], zoom_start=13)
    
    for loc in sensor_locations:
        folium.Marker(
            location=[loc['latitude'], loc['longitude']],
            popup=f"Sensor ID: {loc['sensor_id']}",
            icon=folium.Icon(color="blue", icon="info-sign")
        ).add_to(m)
        
    display(m)
    
    return state

In [107]:
def output_agent(state: dict) -> dict:
    print("\n[UIAgent Output] Final Response to User:")
    print(state["explanation"])
    print("\n[UIAgent Output] Recommended Route:")
    print(state["route"])
    return state

In [108]:
def build_graph():
    builder = StateGraph(dict)

    # Add function-based nodes
    builder.add_node("UIAgent", ui_agent)
    builder.add_node("ForecastingAgent", forecasting_agent)
    builder.add_node("KnowledgeAgent", knowledge_agent)
    builder.add_node("RecommendationAgent", recommendation_agent)
    builder.add_node("ExplanationAgent", explanation_agent)
    builder.add_node("VisualizationAgent", visualization_agent)
    builder.add_node("OutputAgent", output_agent)

    # Transitions
    builder.set_entry_point("UIAgent")
    builder.add_edge("UIAgent", "ForecastingAgent")
    builder.add_edge("ForecastingAgent", "KnowledgeAgent")
    builder.add_edge("KnowledgeAgent", "RecommendationAgent")
    builder.add_edge("RecommendationAgent", "VisualizationAgent")
    builder.add_edge("VisualizationAgent", "ExplanationAgent")
    builder.add_edge("ExplanationAgent", "OutputAgent")
    builder.add_edge("OutputAgent", END)

    return builder.compile()

In [109]:
graph = build_graph()
final_state = graph.invoke({})


[UIAgent Output] Final Response to User:
It's 3 PM on Thursday, March 1st, 2012, and you're heading from sensor 773869 to 717446.  Expect heavy traffic due to the afternoon rush hour in Los Angeles.  Major freeways and arterial roads are likely congested at this time.

While I can't access real-time data for that specific date and time, I've calculated a route designed to *minimize* your exposure to the worst of the rush hour congestion:

**Recommended Route:** 773869 → 717570 → 764760 → 769403 → 717468 → 717456 → 717446

This route likely prioritizes less congested side streets and possibly surface roads over the more heavily trafficked freeways.  If I had access to the historical traffic data from the METR-LA dataset, I could have refined this route further by considering:

* **Historical Speed Data:** I would have looked at the average speeds along different road segments at 3 PM on Thursdays in the dataset to identify historically congested areas.
* **Flow Data:**  Knowing the typ