In [11]:
import torch
import pandas as pd
from PopSynthesis.Methods.GNN_activity.model import TravelGNN
from PopSynthesis.Methods.GNN_activity.graph_construct import construct_starting_graph_pyg
from PopSynthesis.Methods.GNN_activity.utils import visualize_pyg_graph_with_zones

In [4]:
# Recreate the model with original hyperparameters
model = TravelGNN(hidden_channels=64)

# Load model parameters
checkpoint = torch.load("data/trained_travel_gnn_model.pt")
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()  # Set to evaluation mode


TravelGNN(
  (conv1): HeteroConv(num_relations=9)
  (edge_classifier): Linear(64, 1, bias=True)
  (duration_regressor): Linear(64, 1, bias=True)
  (joint_classifier): Linear(64, 1, bias=True)
)

In [8]:
def infer(model, data, max_duration=1440):
    model.eval()
    with torch.no_grad():
        embeddings = model(data.x_dict, data.edge_index_dict)
        person_emb = embeddings["person"]
        purpose_emb = embeddings["purpose"]

        # For prediction, create all possible combinations (person × purpose)
        person_indices = torch.arange(person_emb.size(0))
        purpose_indices = torch.arange(purpose_emb.size(0))

        # Create all combinations
        grid_person, grid_purpose = torch.meshgrid(person_indices, purpose_indices, indexing='ij')
        grid_person_flat = grid_person.reshape(-1)
        grid_purpose_flat = grid_purpose.reshape(-1)

        # Embeddings for all combinations
        combined_person_emb = person_emb[grid_person_flat]
        combined_purpose_emb = purpose_emb[grid_purpose_flat]

        # Predict
        edge_probs = torch.sigmoid(model.predict_edges(combined_person_emb, combined_purpose_emb))
        predicted_durations = model.predict_duration(combined_person_emb, combined_purpose_emb) * max_duration
        joint_probs = torch.sigmoid(model.predict_joint(combined_person_emb, combined_purpose_emb))

    # Reshape results into readable format
    predictions = {
        "person_idx": grid_person_flat.numpy(),
        "purpose_idx": grid_purpose_flat.numpy(),
        "edge_probability": edge_probs.numpy(),
        "predicted_duration": predicted_durations.numpy(),
        "joint_activity_probability": joint_probs.numpy()
    }

    return pd.DataFrame(predictions)


In [7]:
to_predict_households = pd.read_csv("data/to_predict_households.csv")
to_predict_persons = pd.read_csv("data/to_predict_people.csv")
zone_data = pd.read_csv("data/zones.csv")
purpose_data = pd.read_csv("data/purposes.csv")
od_matrix_data = pd.read_csv("data/od_matrix.csv")
to_predict_data = construct_starting_graph_pyg(zone_data, purpose_data, to_predict_households, to_predict_persons, od_matrix_data)

In [9]:
predictions_df = infer(model, to_predict_data)

# Filter predictions for high-probability edges, e.g., probability > 0.5
high_confidence_predictions = predictions_df[predictions_df["edge_probability"] > 0.5]

# Example output:
print(high_confidence_predictions.head())


    person_idx  purpose_idx  edge_probability  predicted_duration  \
0            0            0          0.627325         -561.442566   
6            0            6          0.626436         -158.505157   
12           0           12          0.596545          158.145035   
24           0           24          0.524741          750.746033   
30           1            0          0.636125         -608.462646   

    joint_activity_probability  
0                     0.435014  
6                     0.422689  
12                    0.467246  
24                    0.577943  
30                    0.435636  


In [13]:
a = visualize_pyg_graph_with_zones(to_predict_data, to_predict_persons, to_predict_households, purpose_data, zone_data)
a.save_graph("data/to_predict_graph.html")

In [14]:
high_confidence_predictions

Unnamed: 0,person_idx,purpose_idx,edge_probability,predicted_duration,joint_activity_probability
0,0,0,0.627325,-561.442566,0.435014
6,0,6,0.626436,-158.505157,0.422689
12,0,12,0.596545,158.145035,0.467246
24,0,24,0.524741,750.746033,0.577943
30,1,0,0.636125,-608.462646,0.435636
...,...,...,...,...,...
1506,50,6,0.571353,667.999573,0.629695
1512,50,12,0.583164,828.645447,0.642250
1530,51,0,0.554389,494.484467,0.610751
1536,51,6,0.571808,673.939026,0.630025
