
# GNNs for Combinatorial Optimization: Satellite Constellation Management

This notebook provides a detailed implementation of Graph Neural Networks (GNNs) for solving combinatorial optimization problems in satellite constellation management. Each code block is paired with LaTeX equations and explanations.



### Representing the Problem as a Graph

The problem is represented as a graph:

$$
G = (V, E)
$$

where:

- $V$ is the set of nodes (e.g., satellites).
- $E$ is the set of edges representing communication links.

Each node $v \in V$ has features $\mathbf{x}_v$ and each edge $(u, v) \in E$ has features $\mathbf{e}_{u,v}$.

### Cost Function and Objective

The cost function for optimization is given as:

$$
c(S) = \sum_{v \in S} w(v) + \sum_{(u, v) \in S} w(u, v)
$$

where $w(v)$ and $w(u, v)$ are node and edge weights, respectively.

The objective is to find:

$$
S^* = \text{argmin}_{S \in \mathcal{F}} c(S)
$$

subject to constraints $\mathcal{F} \subseteq 2^V$.

In [None]:

# Step 1: Representing the Problem as a Graph
import torch
from torch_geometric.data import Data

def prepare_satellite_graph(sat_positions, sat_features, links, link_features):
    '''
    Converts satellite constellation data into a PyG graph.
    '''
    # Edge index: [2, num_edges]
    edge_index = torch.tensor(links, dtype=torch.long).t().contiguous()
    
    # Node features: [num_nodes, feature_dim]
    x = torch.tensor(sat_features, dtype=torch.float)
    
    # Edge features: [num_edges, feature_dim]
    edge_attr = torch.tensor(link_features, dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


### Node Embedding Update

The node embeddings are updated iteratively using a GNN:

$$
\mathbf{h}_v^{(t+1)} = \text{Update}\left(\mathbf{h}_v^{(t)}, \text{Aggregate}\left(\{\mathbf{m}_{u \to v}^{(t)} : u \in \mathcal{N}(v)\}\right)\right)
$$

where:

- **Messages**:

$$
\mathbf{m}_{u \to v}^{(t)} = \phi_{\text{message}}(\mathbf{h}_u^{(t)}, \mathbf{h}_v^{(t)}, \mathbf{e}_{u,v})
$$

- **Aggregation**:

$$
\text{Aggregate}(\{\mathbf{m}_{u \to v}\}) = \text{AGG}(\{\mathbf{m}_{u \to v}\})
$$

- **Update**:

$$
\mathbf{h}_v^{(t+1)} = \phi_{\text{update}}(\mathbf{h}_v^{(t)}, \text{Aggregate})
$$

In [None]:

# Step 2: Graph Neural Network Framework
from torch_geometric.nn import GATConv, EdgeConv
import torch.nn.functional as F

class SatelliteGNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SatelliteGNN, self).__init__()
        self.conv1 = GATConv(input_dim, hidden_dim, heads=4, concat=True)  # Graph Attention
        self.conv2 = GATConv(hidden_dim * 4, hidden_dim, heads=4, concat=False)
        self.edge_conv = EdgeConv(nn=torch.nn.Linear(2 * hidden_dim, hidden_dim))  # Edge-level
        self.out = torch.nn.Linear(hidden_dim, output_dim)  # Final output layer

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        
        # Node embeddings
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        
        # Edge prediction
        edge_out = self.edge_conv(x, edge_index)
        
        # Node-level output (e.g., routing decisions)
        node_out = self.out(x)
        return node_out, edge_out



## Step 3: Training

The objective is to minimize latency and predict link failures.
### Training

To train the GNN, we:

1. Initialize parameters $\theta$.
2. Convert the input data to a graph $G$.
3. Perform $T$ layers of message passing to compute embeddings.
4. Use a decoder to predict $S_{\text{pred}}$.
5. Compute the loss $\mathcal{L}$.
6. Update $\theta$ via backpropagation.


### Loss Function

The loss function for supervised learning is given by:

$$
\mathcal{L}_{\text{sup}} = \text{CrossEntropy}(S_{\text{pred}}, S_{\text{true}})
$$

For reinforcement learning, the loss function is:

$$
\mathcal{L}_{\text{RL}} = - \mathbb{E}_{S \sim \pi_\theta} [R(S)]
$$

The total loss combines both:

$$
\mathcal{L} = \mathcal{L}_{\text{sup/RL}} + \lambda \mathcal{L}_{\text{constraint}}
$$

#### Routing Optimization:

$$
\mathcal{L}_{\text{routing}} = \text{CrossEntropy}(S_{\text{pred}}, S_{\text{true}})
$$

#### Link Quality Prediction:

$$
\mathcal{L}_{\mathrm{link}} = \mathrm{MSE}(\text{predictedlinks}, \mathrm{truelinks})
$$

#### Combined Loss:

$$
\mathcal{L} = \alpha \mathcal{L}_{\text{routing}} + \beta \mathcal{L}_{\text{link}}
$$


In [None]:

# Step 3: Training
def combined_loss(node_out, edge_out, true_routes, true_links, alpha=0.5, beta=0.5):
    '''
    Compute combined loss for satellite constellation management.
    '''
    routing_loss = F.cross_entropy(node_out, true_routes)
    link_loss = F.mse_loss(edge_out, true_links)
    return alpha * routing_loss + beta * link_loss



## Step 4: Inference

After training, use the GNN to predict routing paths and link qualities.
Given a new graph $G$, we:

1. Compute embeddings $\mathbf{h}_v^{(T)}$ using the trained GNN.
2. Decode $S_{\text{pred}}$ from the embeddings.
3. Evaluate $S_{\text{pred}}$ for feasibility and quality.

### Routing Path Prediction:

$$
S_{\text{pred}} = \text{argmax}(\text{node\_out})
$$

### Link Quality Prediction:

$$
\text{edge\_out} \text{ represents predicted link qualities.}
$$


In [None]:

def predict_routing(data, model, device):
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        node_out, _ = model(data)
    # Predicted routes
    return node_out.argmax(dim=1) 

def predict_link_quality(data, model, device):
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        _, edge_out = model(data)
    # Predicted link qualities
    return edge_out  


### 1. Discounted Returns
The discounted return at timestep \( t \) is computed as:
$$
G_t = R_t + \gamma G_{t+1}
$$
Where:
- \( R_t \) is the reward at timestep \( t \),
- \( \gamma \) is the discount factor (\( 0 \leq \gamma \leq 1 \)),
- \( G_t \) is the cumulative discounted reward starting from timestep \( t \).


### 2. Policy Loss
The objective of the policy is to maximize the expected return:
$$
\mathcal{L}_{\text{policy}} = - \mathbb{E}_{\tau \sim \pi_\theta} \left[ G_t \cdot \log \pi_\theta(a_t | s_t) \right]
$$
Where:
- \( \pi_\theta(a_t | s_t) \) is the probability of taking action \( a_t \) in state \( s_t \),
- \( G_t \) is the discounted return starting from timestep \( t \),
- \( \tau \) is a trajectory sampled under policy \( \pi_\theta \).


### 3. Value Loss
The value network is trained to minimize the difference between predicted values and actual returns:
$$
\mathcal{L}_{\text{value}} = \frac{1}{2} \sum_t \left( V_\phi(s_t) - G_t \right)^2
$$
Where:
- \( V_\phi(s_t) \) is the value function parameterized by \( \phi \),
- \( G_t \) is the discounted return.


### 4. Total Loss
The total loss combines the policy and value losses, with an optional entropy term to encourage exploration:
$$
\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{policy}} + \lambda \mathcal{L}_{\text{value}} - \beta \mathcal{L}_{\text{entropy}}
$$
Where:
- \( \mathcal{L}_{\text{entropy}} = - \sum_a \pi_\theta(a | s) \log \pi_\theta(a | s) \) encourages exploration.


### 5. GNN Message Passing
Each node embedding \( \mathbf{h}_v^{(t+1)} \) is updated using:
$$
\mathbf{h}_v^{(t+1)} = \phi_{\text{update}} \left( \mathbf{h}_v^{(t)}, \text{Aggregate} \left( \{ \phi_{\text{message}}(\mathbf{h}_u^{(t)}, \mathbf{h}_v^{(t)}, \mathbf{e}_{u,v}) : u \in \mathcal{N}(v) \} \right) \right)
$$
Where:
- \( \phi_{\text{message}} \) is the message function,
- \( \text{Aggregate} \) is a permutation-invariant aggregation function (e.g., sum, mean, max),
- \( \phi_{\text{update}} \) is the update function (e.g., an MLP),
- \( \mathbf{e}_{u,v} \) are edge features,
- \( \mathcal{N}(v) \) is the set of neighbors of node \( v \).
