## Setup

Run the main notebook cells 1-17 first to create clients and data.

---
## Experiment 1: GNN-Enhanced Federated Learning

In [None]:
# CELL 6.6: COMPLETE GNN-ENHANCED FEDERATED LEARNING WITH VISUALIZATION
# ============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import FancyBboxPatch, Circle, FancyArrowPatch
import seaborn as sns
import networkx as nx
from IPython.display import HTML, display
import pandas as pd
from datetime import datetime
import json
import os

print("=" * 90)
print("CELL 6.6: GNN-ENHANCED FEDERATED LEARNING + VISUALIZATION")
print("=" * 90)

# Check if torch-geometric is installed
try:
    from torch_geometric.nn import GATConv, global_mean_pool
    from torch_geometric.data import Data, Batch
    import gymnasium as gym
    from gymnasium import spaces
    TORCH_GEOMETRIC_AVAILABLE = True
    print("✓ torch-geometric available")
except ImportError:
    TORCH_GEOMETRIC_AVAILABLE = False
    print("⚠ torch-geometric not found. Install with:")
    print("  pip install torch-geometric")
    print("  pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cpu.html")

# ============================================================================
# PART 1: CLIENT GRAPH NETWORK (GAT-based)
# ============================================================================

if TORCH_GEOMETRIC_AVAILABLE:
    
    class ClientGraphNetwork(nn.Module):
        """
        Graph Neural Network for modeling federated client relationships.
        
        Key Innovation:
        - Models clients as nodes in a graph
        - Edges represent data similarity (feature distributions, label overlap)
        - GAT learns to attend to similar/complementary clients
        - Outputs aggregation weights considering client relationships
        """
        
        def __init__(self, n_clients=5, node_features=4, hidden_dim=64, heads=4):
            super(ClientGraphNetwork, self).__init__()
            
            self.n_clients = n_clients
            self.node_features = node_features
            
            # Graph Attention layers
            self.gat1 = GATConv(
                in_channels=node_features,
                out_channels=hidden_dim,
                heads=heads,
                dropout=0.2,
                concat=True
            )
            
            self.gat2 = GATConv(
                in_channels=hidden_dim * heads,
                out_channels=hidden_dim,
                heads=1,
                dropout=0.2,
                concat=False
            )
            
            # MLP for weight prediction
            self.weight_predictor = nn.Sequential(
                nn.Linear(hidden_dim, 32),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(32, 1)
            )
            
            print(f"✓ ClientGraphNetwork initialized:")
            print(f"  - Nodes (clients): {n_clients}")
            print(f"  - Node features: {node_features}")
            print(f"  - Hidden dim: {hidden_dim}")
            print(f"  - Attention heads: {heads}")
        
        def forward(self, node_features, edge_index):
            """Forward pass through GNN."""
            # First GAT layer with multi-head attention
            x = F.elu(self.gat1(node_features, edge_index))
            x = F.dropout(x, p=0.2, training=self.training)
            
            # Second GAT layer
            x = self.gat2(x, edge_index)
            x = F.elu(x)
            
            # Predict aggregation weight for each client
            weights = self.weight_predictor(x).squeeze()
            
            # Softmax to get valid probability distribution
            weights = F.softmax(weights, dim=0)
            
            return weights
    
    print("✓ ClientGraphNetwork defined (requires torch_geometric)")

else:
    print("⚠ torch_geometric not installed - ClientGraphNetwork skipped")
    print("  Install with: pip install torch-geometric")

print("\n" + "="*90)
print("✅ Experiment 1 Complete")
print("="*90)

---
## Experiment 2: ONNX-Based Environment

In [None]:
# CELL 9: EXPORT TRAINED CLIENTS TO ONNX FORMAT
# ============================================================================

import torch
import os

print("=" * 90)
print("EXPORTING TRAINED CLIENTS TO ONNX FORMAT")
print("=" * 90)

# Verify prerequisites
if 'clients' not in globals():
    raise RuntimeError("clients list not found. Run main notebook first.")

# Verify all clients are trained
for i, client in enumerate(clients):
    if not hasattr(client, 'model') or client.model is None:
        raise RuntimeError(f"Client {i} is not trained.")

print(f"✓ All {len(clients)} clients are trained and ready for export")

# Initialize lists
onnx_client_paths = []
client_metadata = []

# Export each client to ONNX
print(f"\n{'─' * 90}")
print(f"Exporting {len(clients)} clients to ONNX format...")
print(f"{'─' * 90}\n")

for i, client in enumerate(clients):
    onnx_path = f"client_{i}_model.onnx"
    
    print(f"[Client {i+1}/{len(clients)}] {onnx_path}")
    
    try:
        # Determine input dimension
        if hasattr(client, 'Xlocal'):
            input_dim = client.Xlocal.shape[1]
        elif hasattr(client, 'model'):
            input_dim = client.model.seq[0].in_features
        else:
            raise ValueError(f"Cannot determine input dimension for client {i}")
        
        # Create dummy input
        dummy_input = torch.randn(1, input_dim, dtype=torch.float32).to(client.device)
        
        # Move model to CPU for ONNX export
        client.model.cpu()
        dummy_input = dummy_input.cpu()
        
        # Set model to eval mode
        client.model.eval()
        
        # Export to ONNX
        torch.onnx.export(
            client.model,
            dummy_input,
            onnx_path,
            export_params=True,
            opset_version=11,
            do_constant_folding=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size'},
                'output': {0: 'batch_size'}
            }
        )
        
        # Move model back to original device
        client.model.to(client.device)
        
        file_size = os.path.getsize(onnx_path) / 1024  # KB
        
        # Store metadata
        onnx_client_paths.append(onnx_path)
        client_metadata.append({
            'client_id': i,
            'n_samples': len(client.ylocal),
            'n_classes': len(np.unique(client.ylocal)),
            'y': client.ylocal,
            'onnx_path': onnx_path,
            'file_size_kb': file_size
        })
        
        print(f"  ✓ Exported ({file_size:.1f} KB)")
        
    except Exception as e:
        print(f"  ✗ Export failed: {e}")
        raise

print(f"\n{'=' * 90}")
print("✅ ONNX Export Complete")
print(f"{'=' * 90}")

---
## Experiment 3: Generic FL-RL Baseline

Alternative baseline implementations

In [None]:
# Generic FL-RL baseline approaches
# (Original cells 33-41 would go here)

print("Experimental baseline approaches...")
print("See original notebook cells 33-41 for implementations")