# GNN-based Antenna Array Clustering

This notebook demonstrates Graph Neural Networks for clustering irregular antenna arrays
using unsupervised learning with MinCut optimization.

**Architecture overview:**
1. Graph Construction: Convert antenna positions to k-NN graph
2. GNN Layers: GAT/GCN for learning node embeddings
3. Clustering Head: Soft assignment via softmax
4. Loss: MinCut + Orthogonality (no labels needed)

## 1. Imports and Setup

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, '..')

# Import GNN modules
from gnn import (
    GNNConfig,
    GraphConfig,
    TrainingConfig,
    Trainer,
    train_clustering,
    cluster_sizes,
    compute_clustering_metrics,
    set_seed
)

# Set seed for reproducibility
set_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Generate Sample Antenna Array

Create an irregular 16x16 antenna array with random perturbations.

In [None]:
def generate_irregular_array(grid_size=16, perturbation=0.3, seed=42):
    """
    Generate an irregular antenna array.
    
    Args:
        grid_size: Number of elements per dimension
        perturbation: Standard deviation of position noise
        seed: Random seed
        
    Returns:
        positions: (N, 2) array of antenna positions
    """
    np.random.seed(seed)
    
    # Create base grid
    x = np.linspace(0, grid_size - 1, grid_size)
    y = np.linspace(0, grid_size - 1, grid_size)
    xx, yy = np.meshgrid(x, y)
    
    # Flatten to (N, 2)
    positions = np.stack([xx.flatten(), yy.flatten()], axis=1).astype(np.float32)
    
    # Add random perturbations
    noise = np.random.randn(positions.shape[0], 2) * perturbation
    positions += noise.astype(np.float32)
    
    return positions


# Generate 16x16 irregular array
positions = generate_irregular_array(grid_size=16, perturbation=0.3)
print(f"Generated {positions.shape[0]} antenna elements")
print(f"Position range: x=[{positions[:, 0].min():.2f}, {positions[:, 0].max():.2f}], "
      f"y=[{positions[:, 1].min():.2f}, {positions[:, 1].max():.2f}]")

In [None]:
# Visualize the antenna array
plt.figure(figsize=(8, 8))
plt.scatter(positions[:, 0], positions[:, 1], c='blue', s=20, alpha=0.7)
plt.xlabel('X position')
plt.ylabel('Y position')
plt.title('Irregular 16x16 Antenna Array')
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Configure and Train GNN Clustering

Set up the GNN configuration and train the model.

In [None]:
# Configuration
num_clusters = 4  # K = 4 clusters

gnn_config = GNNConfig(
    in_dim=2,           # x, y positions
    hidden_dim=64,      # Hidden layer dimension
    num_clusters=num_clusters,
    num_layers=3,       # 3 GAT layers
    heads=4,            # 4 attention heads
    dropout=0.1,
    layer_type="gat"    # Use GAT (recommended)
)

graph_config = GraphConfig(
    k_neighbors=8,      # k-NN with k=8
    connection_type="knn",
    add_self_loops=True
)

training_config = TrainingConfig(
    epochs=500,
    lr=1e-3,
    weight_decay=5e-4,
    lambda_ortho=1.0,   # Orthogonality loss weight
    lambda_entropy=0.0, # Entropy regularization (optional)
    verbose=50          # Print every 50 epochs
)

print("GNN Configuration:")
print(f"  - Layers: {gnn_config.num_layers} x {gnn_config.layer_type.upper()}")
print(f"  - Hidden dim: {gnn_config.hidden_dim}")
print(f"  - Attention heads: {gnn_config.heads}")
print(f"  - Output clusters: {gnn_config.num_clusters}")
print(f"\nGraph Configuration:")
print(f"  - k-NN neighbors: {graph_config.k_neighbors}")
print(f"\nTraining Configuration:")
print(f"  - Epochs: {training_config.epochs}")
print(f"  - Learning rate: {training_config.lr}")

In [None]:
# Create trainer and fit
trainer = Trainer(
    num_clusters=num_clusters,
    gnn_config=gnn_config,
    graph_config=graph_config,
    training_config=training_config
)

print("Training GNN...\n")
result = trainer.fit(positions)

print(f"\nTraining complete!")
print(f"Final loss: {result.final_loss:.4f}")

## 4. Analyze Results

In [None]:
# Get cluster assignments
clusters = result.cluster_assignments
sizes = cluster_sizes(clusters)

print("Cluster Distribution:")
for k, size in enumerate(sizes):
    print(f"  Cluster {k}: {size} elements ({100*size/len(clusters):.1f}%)")

In [None]:
# Plot training loss curve
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.plot(result.loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.bar(range(len(sizes)), sizes, color=['C0', 'C1', 'C2', 'C3'][:len(sizes)])
plt.xlabel('Cluster')
plt.ylabel('Number of Elements')
plt.title('Cluster Sizes')
plt.xticks(range(len(sizes)))

plt.tight_layout()
plt.show()

In [None]:
# Visualize clustering result
colors = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']

plt.figure(figsize=(10, 10))

for k in range(num_clusters):
    mask = clusters == k
    plt.scatter(
        positions[mask, 0], 
        positions[mask, 1],
        c=colors[k % len(colors)],
        s=50,
        label=f'Cluster {k} ({sizes[k]} elements)',
        alpha=0.7
    )

plt.xlabel('X position')
plt.ylabel('Y position')
plt.title(f'GNN Clustering Result (K={num_clusters})')
plt.legend(loc='upper right')
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Quick Clustering with Convenience Function

For simple use cases, use `train_clustering()` directly.

In [None]:
# Quick clustering example
clusters_quick = train_clustering(
    positions,
    num_clusters=4,
    k_neighbors=8,
    epochs=300,
    lr=1e-3,
    verbose=100
)

print(f"\nCluster distribution: {np.bincount(clusters_quick)}")

## 6. Compute Clustering Metrics

In [None]:
# Compute clustering quality metrics
metrics = compute_clustering_metrics(clusters, positions)

print("Clustering Metrics:")
print(f"  - Number of clusters: {metrics['num_clusters']}")
print(f"  - Cluster sizes: {metrics['cluster_sizes']}")
print(f"  - Size variance: {metrics['size_variance']:.2f}")
print(f"  - Mean intra-cluster distance: {metrics['mean_intra_distance']:.2f}")

## 7. Soft Assignments Visualization

The GNN outputs soft cluster probabilities. Let's visualize the confidence of assignments.

In [None]:
# Get soft assignments
soft = result.soft_assignments

# Compute confidence (max probability)
confidence = soft.max(axis=1)

plt.figure(figsize=(10, 10))
scatter = plt.scatter(
    positions[:, 0], 
    positions[:, 1],
    c=confidence,
    cmap='RdYlGn',
    s=50,
    vmin=0.5,
    vmax=1.0
)
plt.colorbar(scatter, label='Assignment Confidence')
plt.xlabel('X position')
plt.ylabel('Y position')
plt.title('Cluster Assignment Confidence')
plt.axis('equal')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Mean confidence: {confidence.mean():.3f}")
print(f"Min confidence: {confidence.min():.3f}")
print(f"Elements with >90% confidence: {(confidence > 0.9).sum()} / {len(confidence)}")

## 8. Radiation Pattern Plots

Compute the far-field pattern from the GNN clusters and plot the radiation pattern
using the same lobe-analysis layout used in clustering_comparison.ipynb.


In [None]:
import sys
from matplotlib.gridspec import GridSpec
from scipy.signal import find_peaks

# Add antenna physics utilities to path
sys.path.insert(0, '../optimization/pyvers')

from antenna_physics import (
    LatticeConfig,
    SystemConfig,
    MaskConfig,
    ElementPatternConfig,
    AntennaArray,
)
from gnn import assignments_to_antenna_format


def extract_lobe_metrics(FF_I_dB, azi, ele, azi0, ele0, G_boresight=None):
    """
    Extract lobe performance metrics from far-field pattern.
    """
    # Find boresight indices
    ele_idx = np.argmin(np.abs(ele - ele0))
    azi_idx = np.argmin(np.abs(azi - azi0))

    # Extract cuts
    ele_cut = FF_I_dB[:, azi_idx]  # Elevation cut at azimuth = azi0
    azi_cut = FF_I_dB[ele_idx, :]  # Azimuth cut at elevation = ele0

    # Main lobe gain
    main_lobe_gain = G_boresight if G_boresight else np.max(FF_I_dB)

    # HPBW (Half-Power Beam Width) - find -3dB points
    def find_hpbw(cut, angles):
        max_idx = np.argmax(cut)
        threshold = cut[max_idx] - 3

        # Find left -3dB point
        left_idx = max_idx
        for i in range(max_idx, -1, -1):
            if cut[i] < threshold:
                left_idx = i
                break

        # Find right -3dB point
        right_idx = max_idx
        for i in range(max_idx, len(cut)):
            if cut[i] < threshold:
                right_idx = i
                break

        return angles[right_idx] - angles[left_idx]

    hpbw_ele = find_hpbw(ele_cut, ele)
    hpbw_azi = find_hpbw(azi_cut, azi)

    # Side Lobe Level (relative to main lobe)
    def find_sll_relative(cut, angles):
        max_val = np.max(cut)
        max_idx = np.argmax(cut)

        # Find peaks excluding main lobe region
        peaks, _ = find_peaks(cut)

        # Filter peaks outside main lobe (-3dB region)
        threshold = max_val - 3
        side_peaks = [p for p in peaks if cut[p] < threshold]

        if side_peaks:
            max_side = max(cut[p] for p in side_peaks)
            return max_side  # Already relative (normalized pattern)
        return -30  # Default if no side lobes found

    sll_ele_relative = find_sll_relative(ele_cut, ele)
    sll_azi_relative = find_sll_relative(azi_cut, azi)

    # Count lobes
    peaks_ele, _ = find_peaks(ele_cut, height=-30)
    peaks_azi, _ = find_peaks(azi_cut, height=-30)

    return {
        'main_lobe_gain': main_lobe_gain,
        'hpbw_ele': hpbw_ele,
        'hpbw_azi': hpbw_azi,
        'sll_ele_relative': sll_ele_relative,
        'sll_azi_relative': sll_azi_relative,
        'n_lobes_ele': len(peaks_ele),
        'n_lobes_azi': len(peaks_azi),
        'ele_cut': ele_cut,
        'azi_cut': azi_cut,
    }


def plot_lobe_analysis(FF_I_dB, antenna_array, G_boresight=None,
                       title="Lobe Analysis", save_path=None):
    """
    Plot lobe analysis: elevation/azimuth cuts, 2D pattern, metrics table, polar plots.
    """
    ele0 = antenna_array.system.ele0
    azi0 = antenna_array.system.azi0
    ele = antenna_array.ele
    azi = antenna_array.azi

    metrics = extract_lobe_metrics(FF_I_dB, azi, ele, azi0, ele0, G_boresight)

    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(2, 3, figure=fig, hspace=0.3, wspace=0.3)

    # 1. Elevation Cut with Lobes
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(ele, metrics['ele_cut'], 'b-', linewidth=2, label='Elevation Cut')
    ax1.axhline(y=-3, color='r', linestyle='--', alpha=0.7, label='-3dB (HPBW)')
    ax1.axhline(y=metrics['sll_ele_relative'], color='g', linestyle=':', alpha=0.7,
                label=f"SLL: {metrics['sll_ele_relative']:.1f}dB")
    ax1.set_xlabel('Elevation [deg]')
    ax1.set_ylabel('Normalized Gain [dB]')
    ax1.set_title(f"Elevation Cut (azi={azi0} deg)
HPBW={metrics['hpbw_ele']:.1f} deg")
    ax1.legend(loc='upper right', fontsize=8)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim([-25, 25])

    # 2. Azimuth Cut with Lobes
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(azi, metrics['azi_cut'], 'b-', linewidth=2, label='Azimuth Cut')
    ax2.axhline(y=-3, color='r', linestyle='--', alpha=0.7, label='-3dB (HPBW)')
    ax2.axhline(y=metrics['sll_azi_relative'], color='g', linestyle=':', alpha=0.7,
                label=f"SLL: {metrics['sll_azi_relative']:.1f}dB")
    ax2.set_xlabel('Azimuth [deg]')
    ax2.set_ylabel('Normalized Gain [dB]')
    ax2.set_title(f"Azimuth Cut (ele={ele0} deg)
HPBW={metrics['hpbw_azi']:.1f} deg")
    ax2.legend(loc='upper right', fontsize=8)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim([-25, 25])

    # 3. 2D Pattern (contour)
    ax3 = fig.add_subplot(gs[0, 2])
    levels = np.arange(-40, 5, 3)
    contour = ax3.contourf(antenna_array.AZI, antenna_array.ELE, FF_I_dB,
                           levels=levels, cmap='jet', extend='both')
    plt.colorbar(contour, ax=ax3, label='dB')
    ax3.plot(azi0, ele0, 'w*', markersize=15, markeredgecolor='k')
    ax3.set_xlabel('Azimuth [deg]')
    ax3.set_ylabel('Elevation [deg]')
    ax3.set_title('2D Far-Field Pattern')

    # 4. Metrics Summary Table
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.axis('off')

    table_data = [
        ['Main Lobe Gain', f"{metrics['main_lobe_gain']:.2f} dBi"],
        ['HPBW Elevation', f"{metrics['hpbw_ele']:.1f} deg"],
        ['HPBW Azimuth', f"{metrics['hpbw_azi']:.1f} deg"],
        ['SLL Elevation', f"{metrics['sll_ele_relative']:.1f} dB"],
        ['SLL Azimuth', f"{metrics['sll_azi_relative']:.1f} dB"],
        ['Lobes (Ele)', f"{metrics['n_lobes_ele']}"],
        ['Lobes (Azi)', f"{metrics['n_lobes_azi']}"],
    ]

    table = ax4.table(cellText=table_data, colLabels=['Metric', 'Value'],
                      loc='center', cellLoc='center',
                      colWidths=[0.5, 0.3])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.5)
    ax4.set_title('Performance Metrics', fontsize=12, fontweight='bold')

    # 5. Polar plot elevation
    ax5 = fig.add_subplot(gs[1, 1], projection='polar')
    theta_rad = np.deg2rad(ele)
    r = metrics['ele_cut'] + 40  # Shift to positive
    ax5.plot(theta_rad, r, 'b-', linewidth=2)
    ax5.set_theta_zero_location('N')
    ax5.set_title('Elevation Pattern (Polar)', y=1.1)

    # 6. Polar plot azimuth
    ax6 = fig.add_subplot(gs[1, 2], projection='polar')
    theta_rad = np.deg2rad(azi)
    r = metrics['azi_cut'] + 40  # Shift to positive
    ax6.plot(theta_rad, r, 'b-', linewidth=2)
    ax6.set_theta_zero_location('N')
    ax6.set_title('Azimuth Pattern (Polar)', y=1.1)

    fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")

    plt.show()
    return metrics


In [None]:
# Convert GNN assignments to antenna format (col, row)
clusters_antenna = assignments_to_antenna_format(clusters, grid_shape=(16, 16))

# Antenna array configuration (same as clustering_comparison.ipynb)
lattice = LatticeConfig(Nz=16, Ny=16, dist_z=0.6, dist_y=0.53, lattice_type=1)
system = SystemConfig(freq=29.5e9, azi0=0, ele0=0, dele=0.5, dazi=0.5)
mask = MaskConfig(elem=30, azim=60, SLL_level=20, SLLin=15)
eef = ElementPatternConfig(P=1, Gel=5, load_file=0)

array = AntennaArray(lattice, system, mask, eef)
result_ff = array.evaluate_clustering(clusters_antenna)

plot_lobe_analysis(
    result_ff['FF_I_dB'],
    array,
    G_boresight=result_ff['G_boresight'],
    title=f"GNN Radiation Pattern (K={num_clusters})"
)
