In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from torch_geometric.utils import remove_self_loops
import networkx as nx
import matplotlib.pyplot as plt
from dsc2024 import datasets


# Function to generate graph from DataFrame
def generate_multigraph(df: pd.DataFrame):
    arcs = list(zip(df.origem, df.destino))

    # Initialize a multigraph
    G = nx.MultiDiGraph()

    # Add each arc as a multiedge with edge features
    for i, (origem, destino) in enumerate(arcs):
        G.add_edge(origem, destino, key=i, espera=df.loc[i, 'espera'], flightid=df.loc[i, 'flightid'])

    return G

# Function to prepare the PyTorch Geometric data structure from NetworkX multigraph
def prepare_data_from_multigraph(G: nx.MultiDiGraph, df: pd.DataFrame, edge_features: list):
    nodes = list(G.nodes())
    node_map = {node: i for i, node in enumerate(nodes)}

    # Node features (can be modified as needed)
    node_features = torch.randn((len(nodes), 3))  # Random node features for this example

    # Edge index (source, destination) and edge attributes
    edge_index = []
    edge_attr = []

    for u, v, key, data in G.edges(keys=True, data=True):
        edge_index.append([node_map[u], node_map[v]])
        edge_attr.append([data[feat] for feat in edge_features])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    # Target for edge classification (espera)
    y = torch.tensor(df['espera'].values, dtype=torch.float).unsqueeze(1)

    return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=y)

# Define the GNN model using GAT
class FlightGNNWithGAT(torch.nn.Module):
    def __init__(self, node_in_channels, edge_in_channels, hidden_channels, out_channels, heads=4):
        super(FlightGNNWithGAT, self).__init__()

        # Node embedding using GATConv with multiple attention heads
        self.node_embedding = GATConv(node_in_channels, hidden_channels, heads=heads, concat=False)

        # Edge attention using GATConv
        self.edge_attention = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False, edge_dim=edge_in_channels)

        # Final edge regression (concatenating source, destination, and edge embeddings)
        self.regressor = nn.Linear(2 * hidden_channels + edge_in_channels, out_channels)

    def forward(self, x, edge_index, edge_attr):
        # Step 1: Node embedding using GATConv
        x = self.node_embedding(x, edge_index)

        # Step 2: Edge embedding using GAT
        edge_embeds = self.edge_attention(x, edge_index, edge_attr)

        # Step 3: Get source and destination node embeddings
        x_src = x[edge_index[0]]
        x_dst = x[edge_index[1]]

        # Step 4: Concatenate source, destination node embeddings, and edge features
        out = torch.cat([x_src, x_dst, edge_attr], dim=-1)

        # Step 5: Edge feature regression to predict `espera`
        return self.regressor(out)

# Training loop
def train(model, data, optimizer, criterion, epochs=100):
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item()}")

# Test function (MSE)
def test(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index, data.edge_attr)
        mse = nn.MSELoss()(out, data.y)
        print(f"Test MSE: {mse.item()}")
        return mse.item()

# Main function
def main():
    # Load your dataset
    df = datasets.get_train_dataset(raw_data=True)

    # Generate multigraph
    G = generate_multigraph(df)

    # Prepare PyTorch Geometric data from multigraph
    edge_features = ['espera']  # Specify the edge features to use
    data = prepare_data_from_multigraph(G, df, edge_features)

    # Instantiate the model
    model = FlightGNNWithGAT(node_in_channels=3, edge_in_channels=len(edge_features), hidden_channels=16, out_channels=1)

    # Loss and optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Train the model
    train(model, data, optimizer, criterion, epochs=100)

    # Test the model
    test(model, data)

    # Visualize the graph
    plt.figure(figsize=(8, 8))
    pos = nx.spring_layout(G)
    nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray')
    plt.show()

if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'torch_geometric'