In [None]:
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import GATConv
from torch_geometric.data import Data

# Define a simple graph with 5 nodes
# Edges between rooms (node indices)
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3, 3, 4],
    [1, 0, 2, 1, 3, 2, 4, 3]
], dtype=torch.long)

# Room features (e.g., room types, sizes, etc.)
x = torch.randn(5, 16)  # 5 rooms, each with 16 features

# Create the PyG data object
data = Data(x=x, edge_index=edge_index)


In [None]:
class GraphGenerator(nn.Module):
    def __init__(self, in_channels, out_channels, heads=4):
        super(GraphGenerator, self).__init__()
        # GAT layer for local attention
        self.gat = GATConv(in_channels, out_channels, heads=heads, concat=True)
        self.fc = nn.Linear(out_channels * heads, 1)  # Output: next node prediction
    
    def forward(self, data):
        # Apply the GAT layer
        x, edge_index = data.x, data.edge_index
        x = self.gat(x, edge_index)
        
        # Simple pooling: just take the first node's output for simplicity
        x = x.mean(dim=0)
        
        # Output prediction (next node's features)
        return self.fc(x)


In [None]:
# Initialize model
model = GraphGenerator(in_channels=16, out_channels=32, heads=4)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Dummy training loop for graph generation
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass through the model
    output = model(data)
    
    # Simple loss function (e.g., regression loss to predict the next node's feature)
    loss = F.mse_loss(output, torch.randn(1))  # Replace with actual target node features
    
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch}, Loss: {loss.item()}")
