In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv, Linear, GATConv
from torch_geometric.data import HeteroData
import pandas as pd

In [2]:
# Load the starting and expected graphs
starting_graph_pyg = torch.load("data/start_graph.pt")
expected_graph_pyg = torch.load("data/expected_graph.pt")
travel_diaries_df = pd.read_csv("data/sample_travel_diaries.csv")

In [3]:
class TravelGNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = HeteroConv({
            ('person', 'belongs_to', 'household'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('household', 'located_in', 'zone'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('zone', 'has_purpose', 'purpose'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('person', 'performs', 'purpose'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('person', 'parent', 'person'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('person', 'child', 'person'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('person', 'spouse', 'person'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('person', 'housemate', 'person'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            ('person', 'sibling', 'person'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
        }, aggr='sum')

        # Output heads
        self.edge_classifier = Linear(hidden_channels, 1)
        self.duration_regressor = Linear(hidden_channels, 1)
        self.joint_classifier = Linear(hidden_channels, 1)

    def forward(self, x_dict, edge_index_dict):
        embeddings = self.conv1(x_dict, edge_index_dict)
        embeddings = {k: F.relu(v) for k, v in embeddings.items()}
        return embeddings

    def predict_edges(self, person_emb, purpose_emb):
        combined = person_emb * purpose_emb
        return self.edge_classifier(combined).view(-1)

    def predict_duration(self, person_emb, purpose_emb):
        combined = person_emb * purpose_emb
        return self.duration_regressor(combined).view(-1)

    def predict_joint(self, person_emb, purpose_emb):
        combined = person_emb * purpose_emb
        return self.joint_classifier(combined).view(-1)


In [4]:
model = TravelGNN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

criterion_edge = torch.nn.BCEWithLogitsLoss()
criterion_duration = torch.nn.MSELoss()
criterion_joint = torch.nn.BCEWithLogitsLoss()

data = expected_graph_pyg.to('cpu')

In [5]:
def train_model(model, data, epochs=50, lr=0.005, max_duration=1440):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    criterion_edge = torch.nn.BCEWithLogitsLoss()
    criterion_duration = torch.nn.MSELoss()
    criterion_joint = torch.nn.BCEWithLogitsLoss()

    data = data.to('cpu')

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        embeddings = model(data.x_dict, data.edge_index_dict)
        person_emb = embeddings["person"]
        purpose_emb = embeddings["purpose"]

        # Positive edges
        edge_index = data["person", "performs", "purpose"].edge_index
        pos_person_emb = person_emb[edge_index[0]]
        pos_purpose_emb = purpose_emb[edge_index[1]]

        # Negative Sampling
        num_neg = edge_index.size(1)
        neg_person_idx = torch.randint(0, person_emb.size(0), (num_neg,))
        neg_purpose_idx = torch.randint(0, purpose_emb.size(0), (num_neg,))
        neg_person_emb = person_emb[neg_person_idx]
        neg_purpose_emb = purpose_emb[neg_purpose_idx]

        # Predictions
        pos_edge_preds = model.predict_edges(pos_person_emb, pos_purpose_emb)
        neg_edge_preds = model.predict_edges(neg_person_emb, neg_purpose_emb)

        duration_preds = model.predict_duration(pos_person_emb, pos_purpose_emb)
        joint_preds = model.predict_joint(pos_person_emb, pos_purpose_emb)

        # Labels (normalized)
        duration_labels = data["person", "performs", "purpose"].duration / max_duration
        edge_labels = torch.cat([torch.ones_like(pos_edge_preds), torch.zeros_like(neg_edge_preds)])
        edge_preds = torch.cat([pos_edge_preds, neg_edge_preds])

        joint_labels = data["person", "performs", "purpose"].joint_activity.float()

        # Losses
        loss_edge = criterion_edge(edge_preds, edge_labels)
        loss_duration = criterion_duration(duration_preds, duration_labels)
        loss_joint = criterion_joint(joint_preds, joint_labels)

        loss = loss_edge + loss_duration + loss_joint
        loss.backward()
        optimizer.step()

        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}: Total Loss={loss.item():.4f}, "
                  f"Edge={loss_edge.item():.4f}, Duration={loss_duration.item():.4f}, Joint={loss_joint.item():.4f}")

    return model


In [6]:
model = train_model(model, data)

Epoch 0: Total Loss=61.0375, Edge=2.8907, Duration=54.2714, Joint=3.8754
Epoch 5: Total Loss=11.4663, Edge=1.4612, Duration=8.8814, Joint=1.1238
Epoch 10: Total Loss=4.3246, Edge=0.9068, Duration=2.7847, Joint=0.6332
Epoch 15: Total Loss=2.1321, Edge=0.9215, Duration=0.4432, Joint=0.7674
Epoch 20: Total Loss=2.3424, Edge=0.8148, Duration=0.8653, Joint=0.6623
Epoch 25: Total Loss=1.8060, Edge=0.7393, Duration=0.4966, Joint=0.5702
Epoch 30: Total Loss=1.4051, Edge=0.6748, Duration=0.1732, Joint=0.5572
Epoch 35: Total Loss=1.2106, Edge=0.5797, Duration=0.0841, Joint=0.5468
Epoch 40: Total Loss=1.2107, Edge=0.6177, Duration=0.0626, Joint=0.5304
Epoch 45: Total Loss=1.1282, Edge=0.5694, Duration=0.0410, Joint=0.5178
Epoch 49: Total Loss=1.1191, Edge=0.5774, Duration=0.0377, Joint=0.5040


In [7]:
def infer(model, data, max_duration=1440): # just to test
    model.eval()
    with torch.no_grad():
        embeddings = model(data.x_dict, data.edge_index_dict)
        person_emb = embeddings["person"]
        purpose_emb = embeddings["purpose"]

        edge_index = data["person", "performs", "purpose"].edge_index
        pos_person_emb = person_emb[edge_index[0]]
        pos_purpose_emb = purpose_emb[edge_index[1]]

        edge_probs = torch.sigmoid(model.predict_edges(pos_person_emb, pos_purpose_emb))
        predicted_durations = model.predict_duration(pos_person_emb, pos_purpose_emb) * max_duration
        joint_probs = torch.sigmoid(model.predict_joint(pos_person_emb, pos_purpose_emb))

    return edge_probs, predicted_durations, joint_probs


In [8]:
# Run inference
edge_probs, durations, joint_probs = infer(model, data)

print("Edge Probabilities (first 5):", edge_probs[:5])
print("Predicted Durations (min, first 5):", durations[:5])
print("Joint Activity Probabilities (first 5):", joint_probs[:5])

Edge Probabilities (first 5): tensor([0.5202, 0.5638, 0.5115, 0.5202, 0.7503])
Predicted Durations (min, first 5): tensor([293.7080, 561.0504, 255.6238, 293.7080, -51.8185])
Joint Activity Probabilities (first 5): tensor([0.7231, 0.6913, 0.7146, 0.7231, 0.8327])


In [9]:
# save model
torch.save({
    "model_state_dict": model.state_dict(),
    "hidden_channels": 64,  # Any hyperparameters you may have
}, "data/trained_travel_gnn_model.pt")