In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
import torch
import pandas as pd
from PopSynthesis.Methods.GNN_activity.model import convert_to_temporal_data, GraphAttentionEmbedding, LinkPredictor
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import IdentityMessage, LastAggregator, LastNeighborLoader
from torch_geometric.loader import TemporalDataLoader

In [23]:
# Load the starting and expected graphs
data = torch.load("data/train_graph.pt")

In [24]:
# Convert PyG HeteroData to TemporalData
train_data = convert_to_temporal_data(data)

In [25]:
to_predict = torch.load("data/to_predict_graph.pt")
test_data = convert_to_temporal_data(to_predict)

In [26]:
# === Model Configuration ===
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

memory_dim = time_dim = embedding_dim = 100

memory = TGNMemory(
    # num_nodes=train_data.num_nodes,
    # raw_msg_dim=train_data.msg.size(-1),
    num_nodes=test_data.num_nodes,
    raw_msg_dim=test_data.msg.size(-1),
    memory_dim=memory_dim,
    time_dim=time_dim,
    # message_module=IdentityMessage(train_data.msg.size(-1), memory_dim, time_dim),
    message_module=IdentityMessage(test_data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    # msg_dim=train_data.msg.size(-1),
    msg_dim=test_data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

# === Optimizer & Loss ===
optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.0001)
criterion_edge = torch.nn.BCEWithLogitsLoss()
criterion_joint = torch.nn.BCEWithLogitsLoss()

In [27]:
# === DataLoader ===
train_loader = TemporalDataLoader(
    train_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)

neighbor_loader = LastNeighborLoader(test_data.num_nodes, size=10, device=device)

# === Training State Variables ===
assoc = torch.empty(train_data.num_nodes, dtype=torch.long, device=device)

In [28]:
def train():
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()  # Fresh memory
    neighbor_loader.reset_state()  # Empty graph at start

    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        # Neighborhood sampling
        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Update memory
        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, train_data.t[e_id].to(device), train_data.msg[e_id].to(device))

        # Positive & Negative samples
        pos_edge_pred, pos_joint_pred = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        neg_edge_pred, neg_joint_pred = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        # Labels
        edge_labels = torch.cat([torch.ones_like(pos_edge_pred), torch.zeros_like(neg_edge_pred)])
        edge_preds = torch.cat([pos_edge_pred, neg_edge_pred])

        # Extract `joint_activity` from graph data
        joint_labels = batch.msg[:, 0].view(-1, 1) # Joint is first
        # Loss Calculation
        loss_edge = criterion_edge(edge_preds, edge_labels)
        loss_joint = criterion_joint(pos_joint_pred, joint_labels)  # Apply only on positive edges

        loss = loss_edge + loss_joint  # Total loss

        # Memory Update
        memory.update_state(batch.src.long(), batch.dst.long(), batch.t.float(), batch.msg.float())
        neighbor_loader.insert(batch.src, batch.dst)

        # Backpropagation
        loss.backward()
        optimizer.step()
        memory.detach()
        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


In [29]:
# === Run Training ===
for epoch in range(1, 101):
    loss = train()
    print(f"Epoch: {epoch:02d}, Loss: {loss:.4f}")

Epoch: 01, Loss: 1.3964
Epoch: 02, Loss: 1.3932
Epoch: 03, Loss: 1.3901
Epoch: 04, Loss: 1.3869
Epoch: 05, Loss: 1.3838
Epoch: 06, Loss: 1.3807
Epoch: 07, Loss: 1.3776
Epoch: 08, Loss: 1.3745
Epoch: 09, Loss: 1.3715
Epoch: 10, Loss: 1.3684
Epoch: 11, Loss: 1.3654
Epoch: 12, Loss: 1.3625
Epoch: 13, Loss: 1.3595
Epoch: 14, Loss: 1.3565
Epoch: 15, Loss: 1.3535
Epoch: 16, Loss: 1.3505
Epoch: 17, Loss: 1.3476
Epoch: 18, Loss: 1.3447
Epoch: 19, Loss: 1.3418
Epoch: 20, Loss: 1.3391
Epoch: 21, Loss: 1.3364
Epoch: 22, Loss: 1.3337
Epoch: 23, Loss: 1.3310
Epoch: 24, Loss: 1.3283
Epoch: 25, Loss: 1.3256
Epoch: 26, Loss: 1.3229
Epoch: 27, Loss: 1.3201
Epoch: 28, Loss: 1.3174
Epoch: 29, Loss: 1.3148
Epoch: 30, Loss: 1.3121
Epoch: 31, Loss: 1.3094
Epoch: 32, Loss: 1.3066
Epoch: 33, Loss: 1.3038
Epoch: 34, Loss: 1.3010
Epoch: 35, Loss: 1.2983
Epoch: 36, Loss: 1.2956
Epoch: 37, Loss: 1.2928
Epoch: 38, Loss: 1.2900
Epoch: 39, Loss: 1.2872
Epoch: 40, Loss: 1.2844
Epoch: 41, Loss: 1.2816
Epoch: 42, Loss:

In [69]:
@torch.no_grad()
def predict_daily_schedule(test_data, memory, gnn, link_pred, neighbor_loader, max_steps=10):
    """
    Simulates daily activity schedules for all persons in the test data.
    
    Args:
        test_data (TemporalData): The PyG TemporalData graph containing all nodes and edges.
        memory (TGNMemory): The trained memory module from TGN.
        gnn (GraphAttentionEmbedding): The trained GNN for message passing.
        link_pred (LinkPredictor): The trained link prediction model.
        neighbor_loader (LastNeighborLoader): Loader for retrieving historical interactions.
        max_steps (int): Maximum number of activity steps to predict.
    
    Returns:
        predicted_schedules (dict): A dictionary mapping each person ID to their predicted activities.
    """
    memory.eval()
    gnn.eval()
    link_pred.eval()

    memory.reset_state()  # Reset memory
    neighbor_loader.reset_state()  # Reset neighbor tracking

    predicted_schedules = {}  # Store results per person

    unique_persons = torch.unique(test_data.src[test_data.src >= 0]).tolist()  # Extract only person nodes

    for person_id in unique_persons:
        person_id = int(person_id)
        schedule = []
        current_time = 0  # Start at t=0 (home activity)

        for step in range(max_steps):
            if step == 0:
                # First step: Only the person's own memory is used
                n_id = torch.tensor([person_id], device=memory.memory.device)
            else:
                # Later steps: Retrieve from neighbor loader
                n_id, edge_index, e_id = neighbor_loader(torch.tensor([person_id], device=memory.memory.device))
            
            # Retrieve updated memory
            z, last_update = memory(n_id)

            # Get updated embeddings
            if step > 0:  # Only do neighbor lookup after first step
                assoc = torch.arange(n_id.size(0), device=z.device)
                z_new = gnn(z, last_update, edge_index, test_data.t[e_id], test_data.msg[e_id])
            else:
                z_new = z  # No edges yet, use initial memory

            # Predict next activity (Link Prediction)
            link_scores_pred, joint_scores_pred = link_pred(z_new[0].repeat(z_new.size(0), 1), z_new)
            predicted_activity = torch.multinomial(torch.softmax(link_scores_pred, dim=0)[0], 1).item()

            # Predict joint activity
            joint_activity = torch.sigmoid(joint_scores_pred[predicted_activity]).item()  # Threshold at 0.5
            joint_activity = 1 if joint_activity > 0.5 else 0
            
            # Predict timestamp using last_update
            next_time = last_update[0] + torch.sigmoid(link_scores_pred.max()).item() * 1440

            # Save prediction step
            schedule.append({
                "person_id": person_id,
                "activity_id": predicted_activity,
                "start_time": current_time,
                "end_time": next_time,
                "joint_activity": int(joint_activity)
            })

            # Update memory with the predicted activity
            memory.update_state(torch.tensor([person_id]), torch.tensor([predicted_activity]), torch.tensor([next_time]), test_data.msg)
            neighbor_loader.insert(torch.tensor([person_id]), torch.tensor([predicted_activity]))

            # Move to the next predicted timestamp
            current_time = next_time

            # Stop if we reach the end of the day (1440 minutes)
            if current_time >= 1440:
                break
        
        # Store the full predicted schedule for the person
        predicted_schedules[person_id] = schedule

    return predicted_schedules


In [70]:
# # === DataLoader ===
# test_loader = TemporalDataLoader(
#     test_data,
#     batch_size=200,
#     neg_sampling_ratio=1.0,
# )
predict_daily_schedule(test_data, memory, gnn, link_pred, neighbor_loader)

{0: [{'person_id': 0,
   'activity_id': 0,
   'start_time': 0,
   'end_time': tensor(718.4775),
   'joint_activity': 1},
  {'person_id': 0,
   'activity_id': 0,
   'start_time': tensor(718.4775),
   'end_time': tensor(1496.0037),
   'joint_activity': 1}],
 1: [{'person_id': 1,
   'activity_id': 0,
   'start_time': 0,
   'end_time': tensor(718.4775),
   'joint_activity': 1},
  {'person_id': 1,
   'activity_id': 0,
   'start_time': tensor(718.4775),
   'end_time': tensor(2233.1782),
   'joint_activity': 1}],
 2: [{'person_id': 2,
   'activity_id': 0,
   'start_time': 0,
   'end_time': tensor(718.4775),
   'joint_activity': 1},
  {'person_id': 2,
   'activity_id': 0,
   'start_time': tensor(718.4775),
   'end_time': tensor(2215.1816),
   'joint_activity': 1}],
 3: [{'person_id': 3,
   'activity_id': 0,
   'start_time': 0,
   'end_time': tensor(718.4775),
   'joint_activity': 1},
  {'person_id': 3,
   'activity_id': 0,
   'start_time': tensor(718.4775),
   'end_time': tensor(2264.0215),
  

In [71]:


# Print the first 5 predicted schedules
# for person_id, schedule in list(predicted_schedules.items())[:5]:
#     print(f"\nPerson {person_id} Schedule:")
#     for activity in schedule:
#         print(f"  Step {activity['step']}: Purpose {activity['predicted_purpose']} at {activity['predicted_time']} min, Joint: {activity['joint_activity']}")
