In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Batch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder

In [2]:
### Create new mock data ###

# Household Data
households_data = {
    "household_id": [1, 2, 3],
    "income_level": ["low", "medium", "high"],
    "num_rooms": [2, 3, 5],
    "household_size": [3, 4, 2]
}
households_df = pd.DataFrame(households_data)

# Person Data
persons_data = {
    "person_id": [1, 2, 3, 4, 5, 6, 7, 8, 9],
    "household_id": [1, 1, 1, 2, 2, 2, 2, 3, 3],  # Mapping to households
    "age": [35, 33, 5, 40, 38, 15, 10, 50, 48],
    "gender": ["M", "F", "M", "M", "F", "M", "F", "M", "F"],
    "role": ["Parent", "Parent", "Child", "Parent", "Parent", "Child", "Child", "Parent", "Parent"]
}
persons_df = pd.DataFrame(persons_data)

print("Households Data:")
print(households_df)

print("\nPersons Data:")
print(persons_df)
    

Households Data:
   household_id income_level  num_rooms  household_size
0             1          low          2               3
1             2       medium          3               4
2             3         high          5               2

Persons Data:
   person_id  household_id  age gender    role
0          1             1   35      M  Parent
1          2             1   33      F  Parent
2          3             1    5      M   Child
3          4             2   40      M  Parent
4          5             2   38      F  Parent
5          6             2   15      M   Child
6          7             2   10      F   Child
7          8             3   50      M  Parent
8          9             3   48      F  Parent


In [3]:


# ==== 1. Initialize Encoders ====
age_scaler = StandardScaler()
room_scaler = MinMaxScaler()
income_encoder = OneHotEncoder(sparse_output=False)
gender_encoder = OneHotEncoder(sparse_output=False)
# role_encoder = OneHotEncoder(sparse_output=False)

# ==== 2. Transform Household Data ====
households_df["num_rooms"] = room_scaler.fit_transform(households_df[["num_rooms"]])
income_encoded = income_encoder.fit_transform(households_df[["income_level"]])

# Convert encoded income to DataFrame
income_columns = income_encoder.get_feature_names_out(["income_level"])
households_encoded_df = households_df.drop(columns=["income_level"])
households_encoded_df[income_columns] = income_encoded

# ==== 3. Transform Person Data ====
persons_df["age"] = age_scaler.fit_transform(persons_df[["age"]])
gender_encoded = gender_encoder.fit_transform(persons_df[["gender"]])
# role_encoded = role_encoder.fit_transform(persons_df[["role"]])

# Convert encoded gender & role to DataFrame
gender_columns = gender_encoder.get_feature_names_out(["gender"])
# role_columns = role_encoder.get_feature_names_out(["role"])
persons_encoded_df = persons_df.drop(columns=["gender", "role"])
persons_encoded_df[gender_columns] = gender_encoded
# persons_encoded_df[role_columns] = role_encoded
persons_encoded_df["role"] = persons_df["role"]

# ==== 4. Display Encoded Data ====
print("\nTransformed Households Data:")
print(households_encoded_df)

print("\nTransformed Persons Data:")
print(persons_encoded_df)



Transformed Households Data:
   household_id  num_rooms  household_size  income_level_high  \
0             1   0.000000               3                0.0   
1             2   0.333333               4                0.0   
2             3   1.000000               2                1.0   

   income_level_low  income_level_medium  
0               1.0                  0.0  
1               0.0                  1.0  
2               0.0                  0.0  

Transformed Persons Data:
   person_id  household_id       age  gender_F  gender_M    role
0          1             1  0.293366       0.0       1.0  Parent
1          2             1  0.164571       1.0       0.0  Parent
2          3             1 -1.638559       0.0       1.0   Child
3          4             2  0.615354       0.0       1.0  Parent
4          5             2  0.486559       1.0       0.0  Parent
5          6             2 -0.994584       0.0       1.0   Child
6          7             2 -1.316572       1.0       0.

In [4]:
# ==== 2. Create Graphs for Each Household ====
household_graphs = []
relationship_types = ["parent-child", "spouse", "sibling"]
relationship_encoder = OneHotEncoder(sparse_output=False)
relationship_encoded = relationship_encoder.fit_transform([[rel] for rel in relationship_types])
relationship_mapping = {rel: torch.tensor(relationship_encoded[i], dtype=torch.float) for i, rel in enumerate(relationship_types)}

for household_id in households_encoded_df["household_id"].unique():
    # Get household attributes
    household_attr = households_encoded_df[households_encoded_df["household_id"] == household_id].drop(columns=["household_id", "household_size"]).values
    household_attr = torch.tensor(household_attr, dtype=torch.float)

    # Get people in the household
    household_people = persons_encoded_df[persons_encoded_df["household_id"] == household_id]
    node_features = torch.tensor(household_people.drop(columns=["household_id", "person_id", "role"]).values, dtype=torch.float)

    # Assign unique node indices
    person_ids = household_people["person_id"].tolist()
    person_index_map = {pid: i for i, pid in enumerate(person_ids)}

    mapping_id_role = dict(zip(household_people["person_id"], household_people["role"]))
    mapping_indx_role = {person_index_map[key]: val for key, val in mapping_id_role.items()}

    # Define edges based on relationships (Mock Example)
    edge_index_list = []
    edge_attr_list = []
    for i in range(len(person_ids)):  # Simple pairwise connection (replace with real relationships)
        for j in range(len(person_ids)):
          if i != j:
            edge_index_list.append([person_index_map[person_ids[i]], person_index_map[person_ids[j]]])
            if mapping_indx_role[i] == "Parent" and mapping_indx_role[j] == "Child":
              edge_attr_list.append(relationship_mapping["parent-child"])
            elif mapping_indx_role[i] == "Child" and mapping_indx_role[j] == "Parent":
              edge_attr_list.append(relationship_mapping["parent-child"])
            elif mapping_indx_role[i] == "Parent" and mapping_indx_role[j] == "Parent":
              edge_attr_list.append(relationship_mapping["spouse"])
            elif mapping_indx_role[i] == "Child" and mapping_indx_role[j] == "Child":
              edge_attr_list.append(relationship_mapping["sibling"])
            else:
              raise ValueError("something is wrong")

    # Convert edge lists to tensors
    edge_index = torch.tensor(edge_index_list, dtype=torch.long).T if edge_index_list else torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.stack(edge_attr_list) if edge_attr_list else torch.empty((0, len(relationship_types)), dtype=torch.float)

    # Broadcast global attributes to all nodes
    global_attr = household_attr.repeat(node_features.size(0), 1)

    # Create PyG Data object
    graph_data = Data(x=torch.cat([node_features, global_attr], dim=1), edge_index=edge_index, edge_attr=edge_attr)
    household_graphs.append(graph_data)

# ==== 3. Batch All Graphs Together ====
batched_graph = Batch.from_data_list(household_graphs)

# ==== 4. Print Graph Information ====
print(batched_graph)


DataBatch(x=[9, 7], edge_index=[2, 20], edge_attr=[20, 3], batch=[9], ptr=[4])


In [24]:
batched_graph.x

tensor([[ 0.2934,  0.0000,  1.0000,  0.0000,  0.0000,  1.0000,  0.0000],
        [ 0.1646,  1.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000],
        [-1.6386,  0.0000,  1.0000,  0.0000,  0.0000,  1.0000,  0.0000],
        [ 0.6154,  0.0000,  1.0000,  0.3333,  0.0000,  0.0000,  1.0000],
        [ 0.4866,  1.0000,  0.0000,  0.3333,  0.0000,  0.0000,  1.0000],
        [-0.9946,  0.0000,  1.0000,  0.3333,  0.0000,  0.0000,  1.0000],
        [-1.3166,  1.0000,  0.0000,  0.3333,  0.0000,  0.0000,  1.0000],
        [ 1.2593,  0.0000,  1.0000,  1.0000,  1.0000,  0.0000,  0.0000],
        [ 1.1305,  1.0000,  0.0000,  1.0000,  1.0000,  0.0000,  0.0000]])

In [23]:
household_attr_dim = 4  # Number of household attributes (last 4 columns)
node_attr_dim = batched_graph.x.shape[1] - household_attr_dim  # Remaining columns are node attributes
print(household_attr_dim, node_attr_dim)

4 3


In [32]:
torch.zeros(1,1)

tensor([[0.]])

In [None]:
def split_pp_att_hh_att(batched_graph, household_attr_dim, node_attr_dim):
    # Extract node attributes (excluding household attributes)
    node_attr = batched_graph.x[:, :node_attr_dim]  # [num_nodes, node_attr_dim]
    # node_attr = torch.cat((node_attr.T, torch.zeros(1,node_attr.shape[0]))).T

    # Extract household attributes
    household_attr = batched_graph.x[:, -household_attr_dim:]  # [num_nodes, household_attr_dim]
    return node_attr, household_attr
split_pp_att_hh_att(batched_graph, household_attr_dim, node_attr_dim)

(tensor([[ 0.2934,  0.0000,  1.0000],
         [ 0.1646,  1.0000,  0.0000],
         [-1.6386,  0.0000,  1.0000],
         [ 0.6154,  0.0000,  1.0000],
         [ 0.4866,  1.0000,  0.0000],
         [-0.9946,  0.0000,  1.0000],
         [-1.3166,  1.0000,  0.0000],
         [ 1.2593,  0.0000,  1.0000],
         [ 1.1305,  1.0000,  0.0000]]),
 tensor([[0.0000, 0.0000, 1.0000, 0.0000],
         [0.0000, 0.0000, 1.0000, 0.0000],
         [0.0000, 0.0000, 1.0000, 0.0000],
         [0.3333, 0.0000, 0.0000, 1.0000],
         [0.3333, 0.0000, 0.0000, 1.0000],
         [0.3333, 0.0000, 0.0000, 1.0000],
         [0.3333, 0.0000, 0.0000, 1.0000],
         [1.0000, 1.0000, 0.0000, 0.0000],
         [1.0000, 1.0000, 0.0000, 0.0000]]))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [87]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GATConv
from torch_geometric.data import Data, Batch
import random

# ==== 1. DEFINE UPDATED GAT MODEL ====
class GATHouseholdModel(nn.Module):
    def __init__(self, given_dim, target_dim, hidden_dim, num_heads, edge_dim):
        super().__init__()
        self.gat1 = GATConv(target_dim, hidden_dim, heads=num_heads, concat=True)
        self.gat2 = GATConv(hidden_dim * num_heads, given_dim, heads=1, concat=False)
        
        # Edge attribute predictor (only takes node embeddings)
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * given_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, edge_dim)
        )

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index).relu()
        x = self.gat2(x, edge_index)
        
        # Edge prediction using node embeddings
        row, col = edge_index
        edge_input = torch.cat([x[row], x[col]], dim=1)
        edge_pred = self.edge_mlp(edge_input)
        
        return x, edge_pred

# ==== 2. SPLIT DATA INTO TRAINING & CRITIC ====
def split_data(graphs, critic_ratio=0.2):
    random.shuffle(graphs)
    split_idx = int(len(graphs) * (1 - critic_ratio))
    train_graphs = graphs[:split_idx]
    critic_graphs = graphs[split_idx:]
    return train_graphs, critic_graphs


In [89]:

# Example: Assume we have a list of household graphs
train_graphs, critic_graphs = split_data(batched_graph.to_data_list())

# ==== 3. TRAINING LOOP ====
model = GATHouseholdModel(given_dim=3, hidden_dim=8, target_dim=4, num_heads=4, edge_dim=3).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

batched_train = Batch.from_data_list(train_graphs).to(device)
batched_critic = Batch.from_data_list(critic_graphs).to(device)

train_pp_atts, train_hh_atts  = split_pp_att_hh_att(batched_train, household_attr_dim, node_attr_dim)
critic_pp_atts, critic_hh_atts = split_pp_att_hh_att(batched_critic, household_attr_dim, node_attr_dim)

for epoch in range(10000):
    model.train()
    optimizer.zero_grad()
    node_pred, edge_pred = model(train_hh_atts, batched_train.edge_index)
    loss = loss_fn(node_pred, train_pp_atts) + loss_fn(edge_pred, batched_train.edge_attr)
    loss.backward()
    optimizer.step()
    
    if epoch % 50 == 0:
        model.eval()
        with torch.no_grad():
            node_pred_critic, edge_pred_critic = model(critic_hh_atts, batched_critic.edge_index)
            critic_loss = loss_fn(node_pred_critic, critic_pp_atts) + loss_fn(edge_pred_critic, batched_critic.edge_attr)
        print(f"Epoch {epoch}: Train Loss = {loss.item():.4f}, Critic Loss = {critic_loss.item():.4f}")


Epoch 0: Train Loss = 1.2130, Critic Loss = 1.2338
Epoch 50: Train Loss = 0.8326, Critic Loss = 1.0260
Epoch 100: Train Loss = 0.5679, Critic Loss = 0.8624
Epoch 150: Train Loss = 0.4969, Critic Loss = 0.8045
Epoch 200: Train Loss = 0.4795, Critic Loss = 0.8078
Epoch 250: Train Loss = 0.4758, Critic Loss = 0.8155
Epoch 300: Train Loss = 0.4752, Critic Loss = 0.8189
Epoch 350: Train Loss = 0.4751, Critic Loss = 0.8196
Epoch 400: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 450: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 500: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 550: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 600: Train Loss = 0.4751, Critic Loss = 0.8200
Epoch 650: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 700: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 750: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 800: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 850: Train Loss = 0.4751, Critic Loss = 0.8198
Epoch 900: Train Loss = 0.4751, Critic Loss = 0.8

In [90]:
first_graph = batched_graph.to_data_list()[0]  # Get the first graph

# Extract number of nodes in the first graph
num_nodes_first_graph = first_graph.x.shape[0]

# Extract only household attributes and broadcast them to all nodes
household_attr = first_graph.x[0, -4:].unsqueeze(0).repeat(num_nodes_first_graph, 1)

# Create a new graph with only the structure and household attributes
graph_for_test = Data(
    x=household_attr,  # Use only household attributes, replicated for all nodes
    edge_index=first_graph.edge_index  # Keep edges
)

print(graph_for_test)


Data(x=[3, 4], edge_index=[2, 6])


In [91]:
# ==== 4. INFERENCE ====
model.eval()
with torch.no_grad():
    node_pred, edge_pred = model(graph_for_test.x, graph_for_test.edge_index)
print("Node Predictions:\n", node_pred)
print("Edge Predictions:\n", edge_pred)

Node Predictions:
 tensor([[0.2573, 0.4494, 0.2307],
        [0.2573, 0.4494, 0.2307],
        [0.2573, 0.4494, 0.2307]])
Edge Predictions:
 tensor([[0.2737, 0.0573, 0.4204],
        [0.2737, 0.0573, 0.4204],
        [0.2737, 0.0573, 0.4204],
        [0.2737, 0.0573, 0.4204],
        [0.2737, 0.0573, 0.4204],
        [0.2737, 0.0573, 0.4204]])
