In [None]:
import os
import pickle
from datetime import datetime, timedelta
import pandas as pd
from gnn_package import training, preprocessing
from private_uoapi import (
    LightsailWrapper,
    LSAuth,
    LSConfig,
    DateRangeParams,
    convert_to_dataframe,
)

In [None]:
adj_matrix_dense, node_ids, metadata = preprocessing.load_graph_data(
    prefix="25022025_test", return_df=False
)
name_id_map = preprocessing.get_sensor_name_id_map()
# node_names = [name_id_map[str(node_id)] for node_id in node_ids]
adj_matrix_dense.max()

# get sensor data
if os.path.exists("test_data_1yr.pkl"):
    with open("test_data_1yr.pkl", "rb") as f:
        results_containing_data = pickle.load(f)
else:
    config = LSConfig()
    auth = LSAuth(config)
    client = LightsailWrapper(config, auth)

    print(f"Using base URL: {config.base_url}")
    print(f"Using username: {config.username}")
    print(f"Using secret key: {'*' * len(config.secret_key)}")  # Mask the secret key

    sensor_locations = client.get_traffic_sensors()
    sensor_locations = pd.DataFrame(sensor_locations)
    display(sensor_locations.head())

    date_range_params = DateRangeParams(
        start_date=datetime(2024, 2, 18, 0, 0, 0),
        end_date=datetime(
            2024, 2, 18, 0, 0, 0
        )
        max_date_range=timedelta(days=365),
    )

    count_data = await client.get_traffic_data(date_range_params)

    counts_df = convert_to_dataframe(count_data)

    counts_dict = {}
    for location in sensor_locations["location"]:
        df = counts_df[counts_df["location"] == location]
        series = pd.Series(df["value"].values, index=df["dt"])
        location_id = name_id_map[location]
        counts_dict[location_id] = series if not df.empty else None

    len([series for series in counts_dict.values() if series is not None])

    results_containing_data = {
        node_id: data for node_id, data in counts_dict.items() if data is not None
    }



In [None]:
with open("test_data_1yr.pkl", "wb") as f:
    pickle.dump(results_containing_data, f)

In [None]:
# Load and preprocess data
data = training.preprocess_data(
    results_containing_data,
    graph_prefix="25022025_test",
    window_size=24,  # 24 time steps as input
    horizon=6,  # Predict 6 steps ahead (1.5 hours with 15-min data)
    batch_size=32,
)

# Train the model
results = training.train_model(
    data_loaders=data,
    input_dim=1,  # Traffic count is a single value
    hidden_dim=64,  # Size of hidden layers
    num_epochs=50,
    patience=10,  # Early stopping after 10 epochs with no improvement
)

# Save the trained model
from torch import save

save(results["model"].state_dict(), "stgnn_model.pth")