# GNN-based Antenna Array Clustering

This notebook implements Graph Neural Networks for clustering irregular antenna arrays
using unsupervised learning with MinCut optimization.

**Architecture overview:**
1. Graph Construction: Convert antenna positions to k-NN graph
2. GNN Layers: GAT/GCN for learning node embeddings
3. Clustering Head: Soft assignment via softmax
4. Loss: MinCut + Orthogonality (no labels needed)

## 1. Configuration

In [None]:
"""
Configuration dataclasses for GNN-based antenna clustering.
"""

from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple, Union, List, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


@dataclass
class GNNConfig:
    """
    GNN architecture configuration.

    Attributes:
        in_dim: Input feature dimension (2 for x,y positions)
        hidden_dim: Hidden layer dimension
        num_clusters: K, number of output clusters
        num_layers: Number of GNN layers
        heads: Number of attention heads (GAT only)
        dropout: Dropout probability for regularization
        layer_type: Type of GNN layer ('gat' or 'gcn')
        use_edge_features: Whether to incorporate edge features (distance, coupling)
    """
    in_dim: int = 2
    hidden_dim: int = 64
    num_clusters: int = 4
    num_layers: int = 3
    heads: int = 4
    dropout: float = 0.1
    layer_type: Literal["gat", "gcn"] = "gat"
    use_edge_features: bool = False
    edge_dim: int = 1  # Dimension of edge features if used


@dataclass
class GraphConfig:
    """
    Graph construction configuration.

    Attributes:
        k_neighbors: Number of neighbors for k-NN graph
        connection_type: Strategy for edge creation ('knn', 'radius', 'coupling')
        radius: Connection radius for radius-based graphs (in normalized units)
        coupling_threshold: Threshold for mutual coupling-based edges
        add_self_loops: Whether to add self-loops to the adjacency matrix
    """
    k_neighbors: int = 8
    connection_type: Literal["knn", "radius", "coupling"] = "knn"
    radius: float = 0.5
    coupling_threshold: float = 0.1
    add_self_loops: bool = True


@dataclass
class TrainingConfig:
    """
    Training hyperparameters.

    Attributes:
        epochs: Number of training iterations
        lr: Learning rate for Adam optimizer
        weight_decay: L2 regularization coefficient
        lambda_ortho: Weight for orthogonality loss
        lambda_entropy: Weight for entropy regularization
        device: Compute device ('cuda' or 'cpu')
        verbose: Print training progress every N epochs (0 = silent)
    """
    epochs: int = 500
    lr: float = 1e-3
    weight_decay: float = 5e-4
    lambda_ortho: float = 1.0
    lambda_entropy: float = 0.0
    device: str = "auto"  # 'auto', 'cuda', or 'cpu'
    verbose: int = 50  # Print every N epochs


@dataclass
class ClusteringConfig:
    """
    Complete configuration combining all sub-configs.
    """
    gnn: GNNConfig = field(default_factory=GNNConfig)
    graph: GraphConfig = field(default_factory=GraphConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)

## 2. Utilities

In [None]:
"""
Utility functions for GNN-based antenna clustering.
"""


def normalize_positions(
    positions: torch.Tensor,
    method: str = "standard"
) -> torch.Tensor:
    """
    Normalize antenna positions for stable training.

    Args:
        positions: (N, 2) tensor of (x, y) coordinates
        method: Normalization method
            - "standard": Zero mean, unit variance (z-score)
            - "minmax": Scale to [0, 1] range

    Returns:
        Normalized positions (N, 2)
    """
    if method == "standard":
        mean = positions.mean(dim=0, keepdim=True)
        std = positions.std(dim=0, keepdim=True)
        return (positions - mean) / (std + 1e-8)

    elif method == "minmax":
        min_val = positions.min(dim=0, keepdim=True).values
        max_val = positions.max(dim=0, keepdim=True).values
        return (positions - min_val) / (max_val - min_val + 1e-8)

    else:
        raise ValueError(f"Unknown normalization method: {method}")


def get_hard_assignments(z: torch.Tensor) -> torch.Tensor:
    """
    Convert soft cluster probabilities to hard assignments.

    c_i = argmax_k z_ik

    Args:
        z: (N, K) soft assignment matrix

    Returns:
        c: (N,) hard cluster labels in {0, ..., K-1}
    """
    return z.argmax(dim=-1)


def get_device(device_str: str = "auto") -> torch.device:
    """
    Get torch device from string specification.

    Args:
        device_str: "auto", "cuda", or "cpu"

    Returns:
        torch.device object
    """
    if device_str == "auto":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(device_str)


def cluster_sizes(assignments: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
    """
    Count elements in each cluster.

    Args:
        assignments: (N,) cluster labels

    Returns:
        sizes: (K,) count per cluster
    """
    if isinstance(assignments, torch.Tensor):
        assignments = assignments.cpu().numpy()
    return np.bincount(assignments)


def cluster_to_list(
    assignments: Union[torch.Tensor, np.ndarray],
    num_clusters: Optional[int] = None
) -> List[np.ndarray]:
    """
    Convert flat assignments to list of index arrays per cluster.

    Useful for interfacing with antenna_physics.py which expects
    List[np.ndarray] format for clusters.

    Args:
        assignments: (N,) cluster labels
        num_clusters: K (inferred from data if None)

    Returns:
        List of K arrays, each containing indices of elements in that cluster
    """
    if isinstance(assignments, torch.Tensor):
        assignments = assignments.cpu().numpy()

    if num_clusters is None:
        num_clusters = assignments.max() + 1

    clusters = []
    for k in range(num_clusters):
        indices = np.where(assignments == k)[0]
        clusters.append(indices)

    return clusters


def assignments_to_antenna_format(
    assignments: Union[torch.Tensor, np.ndarray],
    grid_shape: tuple = (16, 16)
) -> List[np.ndarray]:
    """
    Convert flat cluster assignments to antenna array format.

    Compatible with AntennaArray.index_to_position_cluster() method.

    Args:
        assignments: (N,) cluster labels for flattened array
        grid_shape: (Nz, Ny) shape of the antenna grid

    Returns:
        List of K arrays with shape (L_k, 2) containing [col, row] indices
    """
    if isinstance(assignments, torch.Tensor):
        assignments = assignments.cpu().numpy()

    Nz, Ny = grid_shape
    num_clusters = assignments.max() + 1

    clusters = []
    for k in range(num_clusters):
        flat_indices = np.where(assignments == k)[0]

        # Convert flat index to 2D grid indices
        # Assuming row-major (C) ordering: flat_idx = row * Ny + col
        rows = flat_indices // Ny
        cols = flat_indices % Ny

        # Format as [col, row] to match antenna_physics convention
        cluster_coords = np.stack([cols, rows], axis=1)
        clusters.append(cluster_coords)

    return clusters


def compute_clustering_metrics(
    assignments: np.ndarray,
    positions: np.ndarray
) -> dict:
    """
    Compute basic clustering quality metrics.

    Args:
        assignments: (N,) cluster labels
        positions: (N, 2) antenna positions

    Returns:
        Dictionary with metrics:
            - num_clusters: Actual number of non-empty clusters
            - cluster_sizes: Elements per cluster
            - size_variance: Variance in cluster sizes
            - mean_intra_distance: Average within-cluster distance
    """
    # TODO: Implement clustering metrics
    raise NotImplementedError


def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

## 3. Graph Construction

In [None]:
"""
Graph construction utilities for antenna array clustering.

Converts antenna positions into graph structures (nodes, edges, adjacency matrix).
Supports multiple edge creation strategies: k-NN, radius-based, mutual coupling.
"""


class GraphBuilder:
    """
    Builds graph representations from antenna positions.

    The antenna array is represented as:
        - Nodes: Individual antenna elements
        - Edges: Connections based on proximity or coupling
        - Node features: Positions (and optionally other attributes)
        - Edge features: Distance, mutual coupling magnitude/phase
    """

    def __init__(self, config: Optional[GraphConfig] = None):
        self.config = config or GraphConfig()

    def build_knn_edges(
        self,
        positions: torch.Tensor,
        k: Optional[int] = None
    ) -> torch.Tensor:
        """
        Build edge index using k-nearest neighbors.

        Args:
            positions: (N, 2) tensor of antenna positions
            k: Number of neighbors (uses config default if None)

        Returns:
            edge_index: (2, E) tensor of edge indices in COO format
        """
        # TODO: Implement k-NN edge construction
        # Use torch_geometric.nn.knn_graph or scipy.spatial.cKDTree
        raise NotImplementedError

    def build_radius_edges(
        self,
        positions: torch.Tensor,
        radius: Optional[float] = None
    ) -> torch.Tensor:
        """
        Build edge index using radius-based connectivity (epsilon-ball).

        Connects all nodes within distance `radius` of each other.

        Args:
            positions: (N, 2) tensor of antenna positions
            radius: Connection radius (uses config default if None)

        Returns:
            edge_index: (2, E) tensor of edge indices
        """
        # TODO: Implement radius-based edge construction
        raise NotImplementedError

    def build_coupling_edges(
        self,
        coupling_matrix: torch.Tensor,
        threshold: Optional[float] = None
    ) -> torch.Tensor:
        """
        Build edge index based on mutual coupling magnitude.

        Connects antennas with coupling above threshold.

        Args:
            coupling_matrix: (N, N) complex mutual coupling matrix M
            threshold: Minimum |M_ij| to create edge

        Returns:
            edge_index: (2, E) tensor of edge indices
        """
        # TODO: Implement coupling-based edge construction
        raise NotImplementedError

    def compute_adjacency_matrix(
        self,
        edge_index: torch.Tensor,
        num_nodes: int,
        symmetric: bool = True
    ) -> torch.Tensor:
        """
        Convert edge index to dense adjacency matrix A.

        Args:
            edge_index: (2, E) COO format edges
            num_nodes: N, number of nodes
            symmetric: Symmetrize the adjacency matrix

        Returns:
            adj: (N, N) adjacency matrix
        """
        # TODO: Implement adjacency matrix construction
        raise NotImplementedError

    def compute_degree_matrix(self, adj: torch.Tensor) -> torch.Tensor:
        """
        Compute degree matrix D from adjacency matrix.

        D_ii = sum_j A_ij (number of neighbors)

        Args:
            adj: (N, N) adjacency matrix

        Returns:
            deg: (N, N) diagonal degree matrix
        """
        # TODO: Implement degree matrix computation
        raise NotImplementedError

    def compute_edge_features(
        self,
        positions: torch.Tensor,
        edge_index: torch.Tensor,
        coupling_matrix: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute physics-informed edge features.

        Features per edge (i, j):
            - Euclidean distance d_ij
            - Mutual coupling magnitude |M_ij| (if provided)
            - Mutual coupling phase angle(M_ij) (if provided)

        Args:
            positions: (N, 2) node positions
            edge_index: (2, E) edges
            coupling_matrix: (N, N) optional mutual coupling matrix

        Returns:
            edge_attr: (E, F) edge feature matrix
        """
        # TODO: Implement edge feature computation
        raise NotImplementedError

    def build_graph(
        self,
        positions: Union[torch.Tensor, np.ndarray],
        coupling_matrix: Optional[Union[torch.Tensor, np.ndarray]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Build complete graph representation from antenna positions.

        This is the main entry point for graph construction.

        Args:
            positions: (N, 2) antenna positions
            coupling_matrix: (N, N) optional mutual coupling matrix

        Returns:
            Tuple of:
                - edge_index: (2, E) edge indices
                - adj: (N, N) adjacency matrix
                - deg: (N, N) degree matrix
                - edge_attr: (E, F) edge features (or None)
        """
        # TODO: Implement complete graph building pipeline
        raise NotImplementedError


def normalized_adjacency(adj: torch.Tensor, deg: torch.Tensor) -> torch.Tensor:
    """
    Compute normalized adjacency: D^{-1/2} A D^{-1/2}

    Used in GCN to prevent exploding/vanishing gradients during message passing.

    Args:
        adj: (N, N) adjacency matrix (with or without self-loops)
        deg: (N, N) degree matrix

    Returns:
        norm_adj: (N, N) normalized adjacency matrix
    """
    # TODO: Implement symmetric normalization
    raise NotImplementedError

## 4. GNN Layers

In [None]:
"""
GNN layer implementations for antenna clustering.

Implements message passing layers:
    - GCN (Graph Convolutional Network): Simple aggregation with learned weights
    - GAT (Graph Attention Network): Attention-weighted aggregation
"""


class GCNLayer(nn.Module):
    """
    Graph Convolutional Network layer.

    Implements: H^{l+1} = sigma(D^{-1/2} A D^{-1/2} H^{l} W^{l})

    Each node aggregates normalized neighbor features, applies linear transform,
    then non-linearity.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True
    ):
        """
        Args:
            in_features: Input dimension per node
            out_features: Output dimension per node
            bias: Include learnable bias term
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Learnable weight matrix W
        self.weight = nn.Parameter(torch.empty(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter("bias", None)

        self._reset_parameters()

    def _reset_parameters(self):
        """Initialize weights using Xavier/Glorot initialization."""
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(
        self,
        x: torch.Tensor,
        adj_norm: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: (N, in_features) node features
            adj_norm: (N, N) normalized adjacency matrix

        Returns:
            out: (N, out_features) updated node features
        """
        # TODO: Implement GCN forward pass
        # 1. Linear transform: XW
        # 2. Neighborhood aggregation: A_norm @ (XW)
        # 3. Add bias
        raise NotImplementedError


class GATLayer(nn.Module):
    """
    Graph Attention Network layer.

    Learns attention weights to focus on relevant neighbors during aggregation.
    Supports multi-head attention for richer representations.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        heads: int = 1,
        concat: bool = True,
        dropout: float = 0.0,
        negative_slope: float = 0.2
    ):
        """
        Args:
            in_features: Input dimension per node
            out_features: Output dimension per head
            heads: Number of attention heads
            concat: Concatenate heads (True) or average (False)
            dropout: Dropout on attention weights
            negative_slope: LeakyReLU negative slope
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.heads = heads
        self.concat = concat
        self.dropout = dropout
        self.negative_slope = negative_slope

        # Linear transform for node features: W
        self.W = nn.Parameter(torch.empty(heads, in_features, out_features))

        # Attention mechanism parameters: a = [a_l || a_r]
        self.a_l = nn.Parameter(torch.empty(heads, out_features, 1))
        self.a_r = nn.Parameter(torch.empty(heads, out_features, 1))

        self._reset_parameters()

    def _reset_parameters(self):
        """Initialize using Xavier/Glorot."""
        nn.init.xavier_uniform_(self.W)
        nn.init.xavier_uniform_(self.a_l)
        nn.init.xavier_uniform_(self.a_r)

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass with attention mechanism.

        Args:
            x: (N, in_features) node features
            edge_index: (2, E) edge indices in COO format
            edge_attr: (E, edge_dim) optional edge features

        Returns:
            out: (N, heads * out_features) if concat else (N, out_features)
        """
        # TODO: Implement GAT forward pass
        # 1. Linear transform: Wh_i for all nodes
        # 2. Compute attention coefficients: e_ij = LeakyReLU(a^T [Wh_i || Wh_j])
        # 3. Normalize with softmax over neighbors: alpha_ij
        # 4. Aggregate: h'_i = sum_j alpha_ij * Wh_j
        # 5. Concatenate or average heads
        raise NotImplementedError

    def _compute_attention(
        self,
        h_l: torch.Tensor,
        h_r: torch.Tensor,
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute attention coefficients for each edge.

        Args:
            h_l: (N, heads, out_features) source node representations
            h_r: (N, heads, out_features) target node representations
            edge_index: (2, E) edges

        Returns:
            alpha: (E, heads) attention weights
        """
        # TODO: Implement attention computation
        raise NotImplementedError


class EdgeConvLayer(nn.Module):
    """
    Edge-conditioned convolution layer.

    Incorporates edge features (distance, coupling) into message passing.
    Useful for physics-informed learning.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        edge_features: int
    ):
        """
        Args:
            in_features: Node feature dimension
            out_features: Output dimension
            edge_features: Edge feature dimension
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.edge_features = edge_features

        # MLP for combining node and edge features
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_features + edge_features, out_features),
            nn.ReLU(),
            nn.Linear(out_features, out_features)
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        edge_attr: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass with edge features.

        Args:
            x: (N, in_features) node features
            edge_index: (2, E) edges
            edge_attr: (E, edge_features) edge features

        Returns:
            out: (N, out_features) updated features
        """
        # TODO: Implement edge-conditioned message passing
        # h'_i = sum_j MLP([h_i || h_j || e_ij])
        raise NotImplementedError

## 5. Loss Functions

In [None]:
"""
Loss functions for unsupervised GNN clustering.

No labels are needed - clustering quality is measured by graph structure:
    - MinCut: Minimize edges between clusters (maximize within-cluster connectivity)
    - Orthogonality: Ensure balanced, non-overlapping clusters
    - Entropy: Encourage confident (non-uniform) assignments
"""


def mincut_loss(
    z: torch.Tensor,
    adj: torch.Tensor,
    deg: torch.Tensor,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Normalized MinCut loss.

    L_cut = -Tr(Z^T A Z) / Tr(Z^T D Z)

    Minimizing this loss maximizes within-cluster edges (good clustering
    keeps connected nodes together).

    Args:
        z: (N, K) soft cluster assignment matrix
        adj: (N, N) adjacency matrix
        deg: (N, N) degree matrix (diagonal)
        eps: Small constant for numerical stability

    Returns:
        loss: Scalar MinCut loss (negative, to be minimized)
    """
    # TODO: Implement MinCut loss
    # Z^T A Z measures within-cluster edges
    # Z^T D Z normalizes by cluster sizes
    raise NotImplementedError


def orthogonality_loss(
    z: torch.Tensor,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Orthogonality regularization loss.

    L_ortho = || Z^T Z / N  -  I_K / K ||_F^2

    Prevents trivial solution where all nodes go to one cluster.
    Encourages balanced cluster sizes.

    Args:
        z: (N, K) soft cluster assignment matrix
        eps: Small constant for stability

    Returns:
        loss: Scalar orthogonality loss
    """
    # TODO: Implement orthogonality loss
    # Z^T Z should be close to (N/K) * I_K for balanced clusters
    raise NotImplementedError


def entropy_loss(
    z: torch.Tensor,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Entropy regularization loss.

    L_entropy = -(1/N) * sum_i sum_k z_ik * log(z_ik)

    Low entropy means confident predictions (z_ik close to 0 or 1).
    Can be used to encourage sharper cluster assignments.

    Args:
        z: (N, K) soft cluster assignment matrix
        eps: Small constant to avoid log(0)

    Returns:
        loss: Scalar mean entropy (negative for confident assignments)
    """
    # TODO: Implement entropy loss
    # Shannon entropy averaged over all nodes
    raise NotImplementedError


def cluster_size_loss(
    z: torch.Tensor,
    target_size: Optional[int] = None
) -> torch.Tensor:
    """
    Cluster size regularization.

    Penalizes deviation from uniform cluster sizes.
    Optional: can target specific cluster sizes.

    Args:
        z: (N, K) soft cluster assignment matrix
        target_size: Target elements per cluster (default: N/K)

    Returns:
        loss: Scalar size variance loss
    """
    # TODO: Implement cluster size regularization
    raise NotImplementedError


def total_loss(
    z: torch.Tensor,
    adj: torch.Tensor,
    deg: torch.Tensor,
    lambda_ortho: float = 1.0,
    lambda_entropy: float = 0.0,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Combined loss function.

    L_total = L_cut + lambda_ortho * L_ortho + lambda_entropy * L_entropy

    Args:
        z: (N, K) soft cluster assignments
        adj: (N, N) adjacency matrix
        deg: (N, N) degree matrix
        lambda_ortho: Weight for orthogonality term
        lambda_entropy: Weight for entropy term (optional)
        eps: Numerical stability constant

    Returns:
        loss: Total scalar loss
    """
    loss = mincut_loss(z, adj, deg, eps)
    loss = loss + lambda_ortho * orthogonality_loss(z, eps)

    if lambda_entropy > 0:
        loss = loss + lambda_entropy * entropy_loss(z, eps)

    return loss


class ClusteringLoss(nn.Module):
    """
    Module wrapper for clustering loss computation.

    Convenient for use in training loops with configurable weights.
    """

    def __init__(
        self,
        lambda_ortho: float = 1.0,
        lambda_entropy: float = 0.0
    ):
        """
        Args:
            lambda_ortho: Weight for orthogonality loss
            lambda_entropy: Weight for entropy loss
        """
        super().__init__()
        self.lambda_ortho = lambda_ortho
        self.lambda_entropy = lambda_entropy

    def forward(
        self,
        z: torch.Tensor,
        adj: torch.Tensor,
        deg: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute total loss.

        Args:
            z: Soft cluster assignments
            adj: Adjacency matrix
            deg: Degree matrix

        Returns:
            Total loss scalar
        """
        return total_loss(
            z, adj, deg,
            self.lambda_ortho,
            self.lambda_entropy
        )

## 6. Model

In [None]:
"""
Main GNN model for antenna array clustering.

Architecture:
    Input(N, 2) -> GAT/GCN layers -> Embeddings(N, d) -> Linear+Softmax -> Z(N, K)

The model outputs soft cluster assignments Z where Z_ik = P(node i in cluster k).
"""


class AntennaClusteringGNN(nn.Module):
    """
    GNN model for unsupervised clustering of irregular antenna arrays.

    Pipeline:
        1. Stack of GAT/GCN layers to learn node embeddings
        2. Final linear layer to map embeddings to K cluster logits
        3. Softmax to get soft cluster probabilities
    """

    def __init__(self, config: Optional[GNNConfig] = None):
        """
        Args:
            config: GNN architecture configuration
        """
        super().__init__()
        self.config = config or GNNConfig()

        self._build_layers()

    def _build_layers(self):
        """Construct GNN layers based on configuration."""
        cfg = self.config

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        # Input dimension for first layer
        in_dim = cfg.in_dim

        # Build GNN layers
        for i in range(cfg.num_layers):
            # Last layer: single head, no concat
            is_last = (i == cfg.num_layers - 1)

            if cfg.layer_type == "gat":
                heads = 1 if is_last else cfg.heads
                out_dim = cfg.hidden_dim
                layer = GATLayer(
                    in_features=in_dim,
                    out_features=out_dim,
                    heads=heads,
                    concat=not is_last,
                    dropout=cfg.dropout
                )
                # Update input dim for next layer
                in_dim = out_dim if is_last else out_dim * heads
            else:  # gcn
                out_dim = cfg.hidden_dim
                layer = GCNLayer(
                    in_features=in_dim,
                    out_features=out_dim
                )
                in_dim = out_dim

            self.layers.append(layer)

            # Optional: layer normalization
            if not is_last:
                self.norms.append(nn.LayerNorm(in_dim))

        # Final embedding dimension
        self.embed_dim = in_dim

        # Output layer: embeddings -> cluster logits
        self.classifier = nn.Linear(self.embed_dim, cfg.num_clusters)

        # Dropout
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        adj_norm: Optional[torch.Tensor] = None,
        edge_attr: Optional[torch.Tensor] = None,
        return_embeddings: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Forward pass.

        Args:
            x: (N, in_dim) input node features (positions)
            edge_index: (2, E) edges in COO format
            adj_norm: (N, N) normalized adjacency (for GCN)
            edge_attr: (E, edge_dim) edge features (optional)
            return_embeddings: Also return intermediate embeddings

        Returns:
            z: (N, K) soft cluster assignment probabilities
            h: (N, embed_dim) node embeddings (if return_embeddings=True)
        """
        h = x

        # Pass through GNN layers
        for i, layer in enumerate(self.layers):
            if isinstance(layer, GATLayer):
                h = layer(h, edge_index, edge_attr)
            else:  # GCN
                h = layer(h, adj_norm)

            # Activation + dropout (except last layer)
            if i < len(self.layers) - 1:
                h = F.elu(h)
                h = self.dropout(h)
                if i < len(self.norms):
                    h = self.norms[i](h)

        # Final activation
        h = F.elu(h)

        # Cluster probabilities via softmax
        logits = self.classifier(h)
        z = F.softmax(logits, dim=-1)

        if return_embeddings:
            return z, h
        return z, None

    def get_hard_assignments(self, z: torch.Tensor) -> torch.Tensor:
        """
        Convert soft probabilities to hard cluster assignments.

        c_i = argmax_k z_ik

        Args:
            z: (N, K) soft assignment probabilities

        Returns:
            c: (N,) hard cluster labels in {0, ..., K-1}
        """
        return z.argmax(dim=-1)

    @property
    def num_parameters(self) -> int:
        """Total number of trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


class AntennaClusteringGNNWithEdgeFeatures(AntennaClusteringGNN):
    """
    Extended model that incorporates physics-informed edge features.

    Edge features can include:
        - Euclidean distance between antennas
        - Mutual coupling magnitude
        - Mutual coupling phase
    """

    def __init__(self, config: Optional[GNNConfig] = None):
        # Ensure edge features are enabled
        if config is None:
            config = GNNConfig(use_edge_features=True)
        else:
            config.use_edge_features = True

        super().__init__(config)

    def _build_layers(self):
        """Build layers with edge feature support."""
        cfg = self.config

        # Edge feature encoder
        self.edge_encoder = nn.Sequential(
            nn.Linear(cfg.edge_dim, cfg.hidden_dim),
            nn.ReLU(),
            nn.Linear(cfg.hidden_dim, cfg.hidden_dim)
        )

        # Rest of the architecture
        super()._build_layers()

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
        adj_norm: Optional[torch.Tensor] = None,
        edge_attr: Optional[torch.Tensor] = None,
        return_embeddings: bool = False
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Forward pass with edge feature processing."""
        # Encode edge features
        if edge_attr is not None:
            edge_attr = self.edge_encoder(edge_attr)

        # Standard forward pass
        return super().forward(
            x, edge_index, adj_norm, edge_attr, return_embeddings
        )

## 7. Training

In [None]:
"""
Training utilities for GNN-based antenna clustering.

Provides a Trainer class that handles:
    - Model initialization
    - Graph construction from positions
    - Training loop with loss computation
    - Cluster extraction
"""


@dataclass
class TrainingResult:
    """Container for training results."""
    cluster_assignments: np.ndarray  # (N,) hard cluster labels
    soft_assignments: np.ndarray     # (N, K) cluster probabilities
    loss_history: List[float]        # Loss per epoch
    final_loss: float


class Trainer:
    """
    Trainer for GNN-based antenna array clustering.

    Usage:
        trainer = Trainer(num_clusters=4)
        result = trainer.fit(positions)
        clusters = result.cluster_assignments
    """

    def __init__(
        self,
        num_clusters: int = 4,
        gnn_config: Optional[GNNConfig] = None,
        graph_config: Optional[GraphConfig] = None,
        training_config: Optional[TrainingConfig] = None
    ):
        """
        Args:
            num_clusters: Number of clusters K
            gnn_config: GNN architecture config
            graph_config: Graph construction config
            training_config: Training hyperparameters
        """
        self.gnn_config = gnn_config or GNNConfig(num_clusters=num_clusters)
        self.gnn_config.num_clusters = num_clusters

        self.graph_config = graph_config or GraphConfig()
        self.training_config = training_config or TrainingConfig()

        self.model: Optional[AntennaClusteringGNN] = None
        self.graph_builder = GraphBuilder(self.graph_config)

        # Cached graph data
        self._edge_index: Optional[torch.Tensor] = None
        self._adj: Optional[torch.Tensor] = None
        self._deg: Optional[torch.Tensor] = None

    def fit(
        self,
        positions: Union[np.ndarray, torch.Tensor],
        coupling_matrix: Optional[Union[np.ndarray, torch.Tensor]] = None
    ) -> TrainingResult:
        """
        Train the GNN model on antenna positions.

        Args:
            positions: (N, 2) array of antenna (x, y) positions
            coupling_matrix: (N, N) optional mutual coupling matrix

        Returns:
            TrainingResult with cluster assignments and training info
        """
        # Setup device
        device = get_device(self.training_config.device)

        # Prepare data
        positions = self._prepare_positions(positions, device)
        n_nodes = positions.shape[0]

        # Build graph
        self._build_graph(positions, coupling_matrix, device)

        # Initialize model
        self.model = AntennaClusteringGNN(self.gnn_config).to(device)

        # Setup optimizer and loss
        optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.training_config.lr,
            weight_decay=self.training_config.weight_decay
        )
        criterion = ClusteringLoss(
            lambda_ortho=self.training_config.lambda_ortho,
            lambda_entropy=self.training_config.lambda_entropy
        )

        # Training loop
        loss_history = self._train_loop(
            positions, optimizer, criterion, device
        )

        # Extract results
        self.model.eval()
        with torch.no_grad():
            z, _ = self.model(positions, self._edge_index, self._adj_norm)
            clusters = z.argmax(dim=-1).cpu().numpy()
            soft_assignments = z.cpu().numpy()

        return TrainingResult(
            cluster_assignments=clusters,
            soft_assignments=soft_assignments,
            loss_history=loss_history,
            final_loss=loss_history[-1] if loss_history else float('inf')
        )

    def _prepare_positions(
        self,
        positions: Union[np.ndarray, torch.Tensor],
        device: torch.device
    ) -> torch.Tensor:
        """Convert and normalize positions."""
        if isinstance(positions, np.ndarray):
            positions = torch.from_numpy(positions).float()

        positions = normalize_positions(positions)
        return positions.to(device)

    def _build_graph(
        self,
        positions: torch.Tensor,
        coupling_matrix: Optional[Union[np.ndarray, torch.Tensor]],
        device: torch.device
    ):
        """Construct graph from positions."""
        # TODO: Implement using GraphBuilder
        # For now, create placeholder tensors
        n = positions.shape[0]

        # Placeholder edge_index (will be built by GraphBuilder)
        self._edge_index = torch.zeros((2, 0), dtype=torch.long, device=device)

        # Placeholder adjacency and degree matrices
        self._adj = torch.zeros((n, n), device=device)
        self._deg = torch.zeros((n, n), device=device)
        self._adj_norm = torch.zeros((n, n), device=device)

    def _train_loop(
        self,
        positions: torch.Tensor,
        optimizer: torch.optim.Optimizer,
        criterion: ClusteringLoss,
        device: torch.device
    ) -> List[float]:
        """Execute training loop."""
        self.model.train()
        loss_history = []

        for epoch in range(self.training_config.epochs):
            optimizer.zero_grad()

            # Forward pass
            z, _ = self.model(positions, self._edge_index, self._adj_norm)

            # Compute loss
            loss = criterion(z, self._adj, self._deg)

            # Backward pass
            loss.backward()
            optimizer.step()

            loss_val = loss.item()
            loss_history.append(loss_val)

            # Logging
            if self.training_config.verbose > 0:
                if (epoch + 1) % self.training_config.verbose == 0:
                    print(f"Epoch {epoch + 1}/{self.training_config.epochs}: "
                          f"Loss = {loss_val:.4f}")

        return loss_history

    def predict(
        self,
        positions: Union[np.ndarray, torch.Tensor]
    ) -> np.ndarray:
        """
        Get cluster assignments for new positions (after training).

        Args:
            positions: (N, 2) antenna positions

        Returns:
            clusters: (N,) cluster labels
        """
        if self.model is None:
            raise RuntimeError("Model not trained. Call fit() first.")

        device = next(self.model.parameters()).device
        positions = self._prepare_positions(positions, device)

        self.model.eval()
        with torch.no_grad():
            z, _ = self.model(positions, self._edge_index, self._adj_norm)
            return z.argmax(dim=-1).cpu().numpy()


def train_clustering(
    positions: Union[np.ndarray, torch.Tensor],
    num_clusters: int = 4,
    k_neighbors: int = 8,
    epochs: int = 500,
    lr: float = 1e-3,
    verbose: int = 50
) -> np.ndarray:
    """
    Convenience function for quick clustering.

    Args:
        positions: (N, 2) antenna positions
        num_clusters: Number of clusters K
        k_neighbors: Neighbors for graph construction
        epochs: Training iterations
        lr: Learning rate
        verbose: Print progress every N epochs

    Returns:
        clusters: (N,) array of cluster labels
    """
    trainer = Trainer(
        num_clusters=num_clusters,
        graph_config=GraphConfig(k_neighbors=k_neighbors),
        training_config=TrainingConfig(epochs=epochs, lr=lr, verbose=verbose)
    )
    result = trainer.fit(positions)
    return result.cluster_assignments

## 8. Example Usage

In [None]:
# Example: Create dummy antenna positions and run clustering

# Generate a 4x4 grid of antenna positions
# positions = np.array([[i, j] for i in range(4) for j in range(4)], dtype=np.float32)
# print(f"Antenna positions shape: {positions.shape}")

# Run clustering (will fail until TODOs are implemented)
# clusters = train_clustering(positions, num_clusters=4, epochs=100, verbose=20)