In [3]:
import networkx as nx
import pandas as pd
import numpy as np
import random

In [4]:
def create_dummy_topology_data(num_rings=10, nodes_per_ring=10, logical_rings_per_physical=3):
    """
    Create dummy topology data for testing
    
    Args:
        num_rings: Number of physical rings to create
        nodes_per_ring: Number of nodes per ring
        logical_rings_per_physical: Number of logical rings per physical ring
        
    Returns:
        DataFrame with topology data
    """
    data = []
    
    for pr_idx in range(num_rings):
        physical_ring = f"RING_{pr_idx}"
        block_name = f"BLOCK_{pr_idx}"  # One block per physical ring
        
        # Each physical ring has multiple logical rings from the same block
        for lr_idx in range(logical_rings_per_physical):
            logical_ring = f"LR_{pr_idx}_{lr_idx}"
            
            # Create a linear path for each logical ring (not a complete ring)
            for i in range(nodes_per_ring - 1):  # Connect nodes in a path, not a ring
                node_a = f"NODE_{pr_idx}_{lr_idx}_{i}"
                node_b = f"NODE_{pr_idx}_{lr_idx}_{i+1}"
                
                data.append({
                    'aendname': node_a,
                    'bendname': node_b,
                    'aendip': f"10.{pr_idx}.{lr_idx}.{i}",
                    'bendip': f"10.{pr_idx}.{lr_idx}.{i+1}",
                    'aendifIndex': i,
                    'bendifIndex': i+1,
                    'block_name': block_name,
                    'physicalringname': physical_ring,
                    'lrname': logical_ring
                })
            
            # Connect the first and last nodes to the block
            # First node connects to block
            data.append({
                'aendname': f"NODE_{pr_idx}_{lr_idx}_0",
                'bendname': block_name,
                'aendip': f"10.{pr_idx}.{lr_idx}.0",
                'bendip': f"10.{pr_idx}.99.99",  # Special IP for block
                'aendifIndex': 100 + lr_idx,
                'bendifIndex': 100 + lr_idx,
                'block_name': block_name,
                'physicalringname': physical_ring,
                'lrname': logical_ring
            })
            
            # Last node connects to block
            data.append({
                'aendname': f"NODE_{pr_idx}_{lr_idx}_{nodes_per_ring-1}",
                'bendname': block_name,
                'aendip': f"10.{pr_idx}.{lr_idx}.{nodes_per_ring-1}",
                'bendip': f"10.{pr_idx}.99.99",  # Special IP for block
                'aendifIndex': 200 + lr_idx,
                'bendifIndex': 200 + lr_idx,
                'block_name': block_name,
                'physicalringname': physical_ring,
                'lrname': logical_ring
            })
    
    # Add connections between blocks from different physical rings
    # for pr_idx in range(num_rings):
    #     if pr_idx < num_rings - 1:  # Connect to next physical ring
    #         # Connect this block to the next ring's block
    #         block_a = f"BLOCK_{pr_idx}"
    #         block_b = f"BLOCK_{pr_idx+1}"
            
    #         data.append({
    #             'aendname': block_a,
    #             'bendname': block_b,
    #             'aendip': f"10.{pr_idx}.99.99",
    #             'bendip': f"10.{pr_idx+1}.99.99",
    #             'aendifIndex': 300 + pr_idx,
    #             'bendifIndex': 300 + pr_idx + 1,
    #             'block_name': block_a,
    #             'physicalringname': f"RING_{pr_idx}",
    #             'lrname': "INTER_BLOCK"  # Inter-block connection
    #         })
    
    return pd.DataFrame(data)

In [5]:
topo_data = create_dummy_topology_data()


In [5]:


def build_network_graph(topology_df):
    """Build a graph from topology_data_logical table"""
    
    G = nx.Graph()
    
    # Add nodes and edges with attributes
    for _, row in topology_df.iterrows():
        aend = row['aendname'].upper()
        bend = row['bendname'].upper()
        
        # Add node attributes
        G.add_node(aend, ip=row['aendip'], 
                   physicalringname=row['physicalringname'], 
                   lrname=row['lrname'],
                   block_name=row['block_name'])
        
        G.add_node(bend, ip=row['bendip'], 
                   physicalringname=row['physicalringname'], 
                   lrname=row['lrname'],
                   block_name=row['block_name'])
        
        # Add edge with attributes
        G.add_edge(aend, bend, 
                  physicalringname=row['physicalringname'], 
                  lrname=row['lrname'],
                  aendifIndex=row['aendifIndex'],
                  bendifIndex=row['bendifIndex'])
    
    return G


In [6]:
G =build_network_graph(topo_data)

In [7]:

from pyvis.network import Network
import matplotlib.colors as mcolors

def visualize_network(G, filename='network_graph.html'):
    """
    Visualize the network graph using PyVis
    
    Args:
        G: NetworkX graph
        filename: Output HTML file name
    """
    # Create a PyVis network
    net = Network(notebook=True, height="750px", width="100%")
    
    # Generate a color map for different physical rings and logical rings
    physical_rings = set(nx.get_node_attributes(G, 'physicalringname').values())
    logical_rings = set(nx.get_node_attributes(G, 'lrname').values())
    
    # Create color maps
    pr_colors = {pr: f"#{hash(pr) % 0xffffff:06x}" for pr in physical_rings}
    lr_colors = {lr: f"#{hash(lr) % 0xffffff:06x}" for lr in logical_rings}
    
    # Add nodes with attributes
    for node in G.nodes():
        pr = G.nodes[node]['physicalringname']
        lr = G.nodes[node]['lrname']
        is_block = node == G.nodes[node]['block_name']
        
        # Set node properties
        node_title = f"Node: {node}<br>IP: {G.nodes[node]['ip']}<br>PR: {pr}<br>LR: {lr}"
        node_color = pr_colors[pr]  # Color by physical ring
        node_shape = 'diamond' if is_block else 'dot'
        node_size = 25 if is_block else 15
        
        net.add_node(node, title=node_title, color=node_color, shape=node_shape, size=node_size, label=node)
    
    # Add edges with attributes
    for u, v in G.edges():
        pr = G.edges[u, v]['physicalringname']
        lr = G.edges[u, v]['lrname']
        
        # Set edge properties
        edge_title = f"PR: {pr}<br>LR: {lr}"
        edge_color = lr_colors[lr]  # Color by logical ring
        
        net.add_edge(u, v, title=edge_title, color=edge_color)
    
    # Set physics layout
    net.barnes_hut(spring_length=200)
    
    # Save and show the graph
    net.show(filename)
    
    return net

In [8]:
net = visualize_network(G, 'network_topology.html')

network_topology.html


In [9]:
def simulate_node_failure(G, node_to_fail):
    """
    Simulate a node failure and identify isolated nodes
    
    Returns:
        - isolated_nodes: list of nodes that become isolated
    """
    # Get the node's ring information
    if node_to_fail not in G.nodes():
        return []
    
    node_pr = G.nodes[node_to_fail]['physicalringname']
    node_lr = G.nodes[node_to_fail]['lrname']
    block_name = G.nodes[node_to_fail]['block_name']
    
    # Get all connected nodes in the same ring before failure
    before_graph = G.copy()
    block_component = set(nx.node_connected_component(before_graph, block_name))
    connected_before = {n for n in block_component 
                        if n in before_graph.nodes() and 
                        before_graph.nodes[n].get('physicalringname') == node_pr and 
                        before_graph.nodes[n].get('lrname') == node_lr}
    
    # Remove the failed node
    after_graph = G.copy()
    after_graph.remove_node(node_to_fail)
    
    # Get connected nodes after failure
    if block_name in after_graph.nodes():
        block_component_after = set(nx.node_connected_component(after_graph, block_name))
        connected_after = {n for n in block_component_after 
                          if n in after_graph.nodes() and 
                          after_graph.nodes[n].get('physicalringname') == node_pr and 
                          after_graph.nodes[n].get('lrname') == node_lr}
    else:
        connected_after = set()
    
    # Find isolated nodes (nodes connected before but not after)
    isolated_nodes = connected_before - connected_after - {node_to_fail}
    
    return list(isolated_nodes)


In [10]:

def simulate_edge_failure(G, edge_to_fail):
    """
    Simulate an edge failure and identify isolated nodes
    
    Returns:
        - isolated_nodes: list of nodes that become isolated
    """
    u, v = edge_to_fail
    if not G.has_edge(u, v):
        return []
    
    # Get edge information
    edge_pr = G.edges[u, v]['physicalringname']
    edge_lr = G.edges[u, v]['lrname']
    block_name = G.nodes[u]['block_name']  # Assuming same block for connected nodes
    
    # Get connected nodes before failure
    before_graph = G.copy()
    block_component = set(nx.node_connected_component(before_graph, block_name))
    connected_before = {n for n in block_component 
                        if n in before_graph.nodes() and 
                        before_graph.nodes[n].get('physicalringname') == edge_pr and 
                        before_graph.nodes[n].get('lrname') == edge_lr}
    
    # Remove the edge
    after_graph = G.copy()
    after_graph.remove_edge(u, v)
    
    # Get connected nodes after failure
    block_component_after = set(nx.node_connected_component(after_graph, block_name))
    connected_after = {n for n in block_component_after 
                       if n in after_graph.nodes() and 
                       after_graph.nodes[n].get('physicalringname') == edge_pr and 
                       after_graph.nodes[n].get('lrname') == edge_lr}
    
    # Find isolated nodes
    isolated_nodes = connected_before - connected_after
    
    return list(isolated_nodes)

In [11]:
# Print all edges in the graph
for edge in G.edges():
    print(edge)


('NODE_0_0_0', 'NODE_0_0_1')
('NODE_0_0_0', 'BLOCK_0')
('NODE_0_0_1', 'NODE_0_0_2')
('NODE_0_0_2', 'NODE_0_0_3')
('NODE_0_0_3', 'NODE_0_0_4')
('NODE_0_0_4', 'NODE_0_0_5')
('NODE_0_0_5', 'NODE_0_0_6')
('NODE_0_0_6', 'NODE_0_0_7')
('NODE_0_0_7', 'NODE_0_0_8')
('NODE_0_0_8', 'NODE_0_0_9')
('NODE_0_0_9', 'BLOCK_0')
('BLOCK_0', 'NODE_0_1_0')
('BLOCK_0', 'NODE_0_1_9')
('BLOCK_0', 'NODE_0_2_0')
('BLOCK_0', 'NODE_0_2_9')
('NODE_0_1_0', 'NODE_0_1_1')
('NODE_0_1_1', 'NODE_0_1_2')
('NODE_0_1_2', 'NODE_0_1_3')
('NODE_0_1_3', 'NODE_0_1_4')
('NODE_0_1_4', 'NODE_0_1_5')
('NODE_0_1_5', 'NODE_0_1_6')
('NODE_0_1_6', 'NODE_0_1_7')
('NODE_0_1_7', 'NODE_0_1_8')
('NODE_0_1_8', 'NODE_0_1_9')
('NODE_0_2_0', 'NODE_0_2_1')
('NODE_0_2_1', 'NODE_0_2_2')
('NODE_0_2_2', 'NODE_0_2_3')
('NODE_0_2_3', 'NODE_0_2_4')
('NODE_0_2_4', 'NODE_0_2_5')
('NODE_0_2_5', 'NODE_0_2_6')
('NODE_0_2_6', 'NODE_0_2_7')
('NODE_0_2_7', 'NODE_0_2_8')
('NODE_0_2_8', 'NODE_0_2_9')
('NODE_1_0_0', 'NODE_1_0_1')
('NODE_1_0_0', 'BLOCK_1')
('NODE

In [12]:
isolated_nodes = simulate_node_failure(G,'BLOCK_1')
isolated_nodes

['NODE_1_2_6',
 'NODE_1_2_2',
 'NODE_1_2_4',
 'NODE_1_2_9',
 'NODE_1_2_7',
 'NODE_1_2_1',
 'NODE_1_2_3',
 'NODE_1_2_0',
 'NODE_1_2_5',
 'NODE_1_2_8']

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
import torch_geometric.transforms as T
import numpy as np

In [14]:
def simulate_multiple_failures(G, elements_to_fail, failure_type='node'):
    """
    Simulate multiple node or edge failures and identify isolated nodes
    
    Args:
        G: NetworkX graph
        elements_to_fail: List of nodes or edges to fail
        failure_type: 'node' or 'edge'
    
    Returns:
        List of isolated nodes
    """
    # Create a copy of the graph 
    after_graph = G.copy()
    
    # Remove all failed elements
    if failure_type == 'node':
        for node in elements_to_fail:
            if node in after_graph:
                after_graph.remove_node(node)
    else:  # edge
        for edge in elements_to_fail:
            if after_graph.has_edge(*edge):
                after_graph.remove_edge(*edge)
    
    # Find isolated nodes (nodes that can't reach their block)
    isolated_nodes = []
    
    # Group nodes by physical ring and logical ring
    nodes_by_ring = {}
    for node in G.nodes():
        if 'BLOCK' in node:
            continue
            
        pr = G.nodes[node]['physicalringname']
        lr = G.nodes[node]['lrname']
        block = G.nodes[node]['block_name']
        
        key = (pr, lr, block)
        if key not in nodes_by_ring:
            nodes_by_ring[key] = []
        nodes_by_ring[key].append(node)
    
    # Check connectivity for each group
    for (pr, lr, block), nodes in nodes_by_ring.items():
        if block not in after_graph:
            # If block is gone, all nodes in this group are isolated
            isolated_nodes.extend(nodes)
            continue
            
        for node in nodes:
            if node in elements_to_fail or node not in after_graph:
                continue  # Skip failed nodes
                
            # Check if node can reach its block
            try:
                path = nx.shortest_path(after_graph, node, block)
            except nx.NetworkXNoPath:
                isolated_nodes.append(node)
    
    return isolated_nodes

In [15]:
isolated_nodes = simulate_multiple_failures(G,['BLOCK_1'],'node')
isolated_nodes


['NODE_1_0_0',
 'NODE_1_0_1',
 'NODE_1_0_2',
 'NODE_1_0_3',
 'NODE_1_0_4',
 'NODE_1_0_5',
 'NODE_1_0_6',
 'NODE_1_0_7',
 'NODE_1_0_8',
 'NODE_1_0_9',
 'NODE_1_1_0',
 'NODE_1_1_1',
 'NODE_1_1_2',
 'NODE_1_1_3',
 'NODE_1_1_4',
 'NODE_1_1_5',
 'NODE_1_1_6',
 'NODE_1_1_7',
 'NODE_1_1_8',
 'NODE_1_1_9',
 'NODE_1_2_0',
 'NODE_1_2_1',
 'NODE_1_2_2',
 'NODE_1_2_3',
 'NODE_1_2_4',
 'NODE_1_2_5',
 'NODE_1_2_6',
 'NODE_1_2_7',
 'NODE_1_2_8',
 'NODE_1_2_9']

In [16]:
def create_isolation_prediction_dataset(topology_df, num_simulations=1000, max_failures=3, balance_ratio=2):
    """
    Create an optimized dataset for predicting node isolations after failures
    
    Args:
        topology_df: DataFrame with network topology information
        num_simulations: Total number of examples to generate
        max_failures: Maximum number of simultaneous failures to simulate
        balance_ratio: Maximum ratio of non-isolation to isolation examples
        
    Returns:
        data_list: List of PyG Data objects
        node_list: List of all node names
        node_to_idx: Dictionary mapping node names to indices
    """
    print("Building network graph...")
    G = build_network_graph(topology_df)
    
    # Create node mapping
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Identify regular nodes (non-block nodes)
    regular_nodes = [n for n in node_list if "BLOCK" not in n]
    
    print("Computing network vulnerability metrics...")
    # Pre-compute network properties for isolation prediction
    articulation_points = set(nx.articulation_points(G))
    bridges = list(nx.bridges(G))
    
    # Group nodes by physical ring and logical ring
    nodes_by_ring = {}
    for node in regular_nodes:
        if node not in G.nodes():
            continue
            
        pr = G.nodes[node]['physicalringname']
        lr = G.nodes[node]['lrname']
        block = G.nodes[node]['block_name']
        key = (pr, lr, block)
        
        if key not in nodes_by_ring:
            nodes_by_ring[key] = []
        nodes_by_ring[key].append(node)
    
    # Helper function to calculate path redundancy
    def get_path_redundancy(G, node, block):
        """Return number of edge-disjoint paths from node to block"""
        try:
            return len(list(nx.edge_disjoint_paths(G, node, block)))
        except:
            return 0
    
    # Helper function to calculate vulnerability score
    def calculate_vulnerability_score(G, node, block, articulation_points, bridges):
        """Calculate a node's vulnerability to isolation"""
        vulnerability = 0.0
        
        # Factor 1: Few paths to block
        paths = get_path_redundancy(G, node, block)
        if paths == 0:
            vulnerability += 1.0  # Already isolated
        elif paths == 1:
            vulnerability += 0.6  # Critical - only one path
        elif paths == 2:
            vulnerability += 0.3  # Vulnerable - two paths
        
        # Factor 2: Node is articulation point
        if node in articulation_points:
            vulnerability += 0.2
        
        # Factor 3: Connected to bridges
        bridge_connections = 0
        for u, v in bridges:
            if node == u or node == v:
                bridge_connections += 1
        vulnerability += min(0.3, bridge_connections * 0.1)
        
        # Factor 4: Path to block goes through articulation points
        try:
            path = nx.shortest_path(G, node, block)
            critical_on_path = sum(1 for n in path[1:-1] if n in articulation_points)
            vulnerability += min(0.3, critical_on_path * 0.1)
        except:
            vulnerability += 0.2  # No path exists
        
        return min(1.0, vulnerability)  # Cap at 1.0
    
    print("Creating node features...")
    # Create isolation-focused node features
    node_features = []
    for node in node_list:
        features = []
        
        # Basic features
        pr = G.nodes[node]['physicalringname']
        pr_hash = hash(pr) % 10
        features.append(pr_hash/10.0)
        
        lr = G.nodes[node]['lrname']
        lr_hash = hash(lr) % 10
        features.append(lr_hash/10.0)
        
        is_block = 1.0 if "BLOCK" in node else 0.0
        features.append(is_block)
        
        # Skip advanced features for block nodes
        if is_block:
            # Add placeholder values for block nodes to match feature dimensions
            features.extend([0.0] * 7)  # 7 extra features for non-block nodes
        else:
            # Get block for this node
            block = G.nodes[node]['block_name']
            
            # 1. Path redundancy - key indicator of isolation risk
            path_redundancy = get_path_redundancy(G, node, block) / 5.0
            features.append(path_redundancy)
            
            # 2. Critical point features
            is_articulation = 1.0 if node in articulation_points else 0.0
            has_bridge = any((node == u or node == v) for u, v in bridges)
            features.extend([is_articulation, float(has_bridge)])
            
            # 3. Path to block features
            try:
                path_length = nx.shortest_path_length(G, node, block) / 10.0
                path = nx.shortest_path(G, node, block)
                critical_on_path = len(set(path[1:-1]).intersection(articulation_points)) / 5.0
            except:
                path_length = 1.0
                critical_on_path = 1.0
            features.extend([path_length, critical_on_path])
            
            # 4. Vulnerability score (combined metric)
            vulnerability = calculate_vulnerability_score(G, node, block, articulation_points, bridges)
            features.append(vulnerability)
            
            # 5. Network centrality
            degree = G.degree(node) / len(G)
            features.append(degree)
        
        node_features.append(features)
    
    # Convert to tensor
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge for undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Track dataset statistics
    examples_with_isolations = []
    examples_without_isolations = []
    isolation_counts = []
    failure_counts = []
    
    print("Generating dataset examples...")
    # PHASE 1: Generate examples that are likely to have isolations (50% of dataset)
    target_isolation_examples = num_simulations // 2
    phase1_attempts = 0
    
    while len(examples_with_isolations) < target_isolation_examples and phase1_attempts < num_simulations * 2:
        phase1_attempts += 1
        
        # Choose failure type (node or edge)
        failure_type = random.choice(['node', 'edge'])
        
        if failure_type == 'node':
            # STRATEGY: Target critical nodes or nodes in the same ring
            strategy = random.choice(['critical', 'same_ring', 'random'])
            
            if strategy == 'critical':
                # Target articulation points and their neighbors
                potential_targets = list(articulation_points.intersection(set(regular_nodes)))
                if len(potential_targets) == 0:
                    potential_targets = regular_nodes
                
                # Select 1-3 critical nodes
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                nodes_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Target multiple nodes in the same ring
                if not nodes_by_ring:
                    # Fallback to random
                    nodes_to_fail = random.sample(regular_nodes, random.randint(1, min(max_failures, len(regular_nodes))))
                else:
                    # Pick a random ring
                    ring_key = random.choice(list(nodes_by_ring.keys()))
                    ring_nodes = nodes_by_ring[ring_key]
                    
                    # Select 1-3 nodes from this ring
                    num_failures = random.randint(1, min(max_failures, len(ring_nodes)))
                    nodes_to_fail = random.sample(ring_nodes, num_failures)
            
            else:  # random strategy
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            for node in nodes_to_fail:
                node_idx = node_to_idx[node]
                invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                valid_edge_mask[invalid_edges] = False
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([0]),
                num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
            )
            
        else:  # edge failure
            # STRATEGY: Target bridge edges or edges in the same ring
            strategy = random.choice(['bridges', 'same_ring', 'random'])
            
            if strategy == 'bridges':
                # Target bridge edges
                potential_targets = [(u, v) for u, v in bridges if "BLOCK" not in u and "BLOCK" not in v]
                if len(potential_targets) == 0:
                    # Fallback to random edges
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    potential_targets = all_edges
                
                # Select 1-3 bridge edges
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                edges_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Target multiple edges in the same ring
                ring_edges = {}
                for u, v in G.edges():
                    if "BLOCK" in u or "BLOCK" in v:
                        continue
                        
                    try:
                        pr_u = G.nodes[u]['physicalringname']
                        lr_u = G.nodes[u]['lrname']
                        pr_v = G.nodes[v]['physicalringname']
                        lr_v = G.nodes[v]['lrname']
                        
                        if pr_u == pr_v and lr_u == lr_v:
                            key = (pr_u, lr_u)
                            if key not in ring_edges:
                                ring_edges[key] = []
                            ring_edges[key].append((u, v))
                    except:
                        continue
                
                if not ring_edges:
                    # Fallback to random
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    num_failures = random.randint(1, min(max_failures, len(all_edges)))
                    edges_to_fail = random.sample(all_edges, num_failures)
                else:
                    # Pick a random ring
                    ring_key = random.choice(list(ring_edges.keys()))
                    ring_edge_list = ring_edges[ring_key]
                    
                    # Select 1-3 edges from this ring
                    num_failures = random.randint(1, min(max_failures, len(ring_edge_list)))
                    edges_to_fail = random.sample(ring_edge_list, num_failures)
            
            else:  # random strategy
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            for u, v in edges_to_fail:
                u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                
                for i in range(edge_index.size(1)):
                    e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                    if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                        valid_edge_mask[i] = False
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([1]),
                num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
            )
        
        # Track isolation information
        isolation_count = y.sum().item()
        isolation_counts.append(isolation_count)
        failure_counts.append(len(nodes_to_fail if failure_type == 'node' else edges_to_fail))
        
        # Add to appropriate list
        if isolation_count > 0:
            examples_with_isolations.append(data)
        else:
            examples_without_isolations.append(data)
            
        # Report progress
        if phase1_attempts % 100 == 0:
            print(f"Phase 1: Generated {len(examples_with_isolations)} examples with isolations after {phase1_attempts} attempts")
    
    # PHASE 2: Generate remaining examples randomly to complete the dataset
    remaining = num_simulations - len(examples_with_isolations) - len(examples_without_isolations)
    
    if remaining > 0:
        print(f"Generating {remaining} additional examples for Phase 2...")
        
        for _ in range(remaining):
            # Choose failure type (node or edge)
            failure_type = random.choice(['node', 'edge'])
            
            if failure_type == 'node':
                # Choose random nodes
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                for node in nodes_to_fail:
                    node_idx = node_to_idx[node]
                    invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                    valid_edge_mask[invalid_edges] = False
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([0]),
                    num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
                )
                
            else:  # edge failure
                # Choose random edges
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                for u, v in edges_to_fail:
                    u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                    
                    for i in range(edge_index.size(1)):
                        e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                        if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                            valid_edge_mask[i] = False
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([1]),
                    num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
                )
            
            # Track isolation information
            isolation_count = y.sum().item()
            isolation_counts.append(isolation_count)
            failure_counts.append(len(nodes_to_fail if failure_type == 'node' else edges_to_fail))
            
            # Add to appropriate list
            if isolation_count > 0:
                examples_with_isolations.append(data)
            else:
                examples_without_isolations.append(data)
    
    # Balance the dataset using the specified ratio
    print(f"Balancing dataset with isolation:non-isolation ratio of 1:{balance_ratio}")
    if len(examples_without_isolations) > len(examples_with_isolations) * balance_ratio:
        # Downsample negative examples
        target_non_isolations = len(examples_with_isolations) * balance_ratio
        examples_without_isolations = random.sample(examples_without_isolations, int(target_non_isolations))
    
    # Combine and shuffle the final dataset
    final_data_list = examples_with_isolations + examples_without_isolations
    random.shuffle(final_data_list)
    
    # Print dataset statistics
    total_examples = len(final_data_list)
    total_isolations = len(examples_with_isolations)
    total_non_isolations = len(examples_without_isolations)
    
    print("\n=== Dataset Statistics ===")
    print(f"Total examples: {total_examples}")
    print(f"Examples with isolations: {total_isolations} ({total_isolations/total_examples*100:.1f}%)")
    print(f"Examples without isolations: {total_non_isolations} ({total_non_isolations/total_examples*100:.1f}%)")
    print(f"Average isolations per positive example: {sum(c for c in isolation_counts if c > 0)/max(1, total_isolations):.2f}")
    print(f"Average failures per example: {sum(failure_counts)/len(failure_counts):.2f}")
    
    # Detailed breakdown by failure type
    node_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 0)
    edge_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 1)
    
    print("\nIsolation examples by failure type:")
    print(f"  Node failures: {node_isolations} ({node_isolations/max(1, total_isolations)*100:.1f}%)")
    print(f"  Edge failures: {edge_isolations} ({edge_isolations/max(1, total_isolations)*100:.1f}%)")
    
    # Breakdown by number of failures
    failure_counts_dist = {}
    for data in final_data_list:
        num_fails = data.num_failures.item()
        has_isolations = data.y.sum().item() > 0
        
        key = f"{num_fails}_{'iso' if has_isolations else 'non'}"
        if key not in failure_counts_dist:
            failure_counts_dist[key] = 0
        failure_counts_dist[key] += 1
    
    print("\nExamples by number of failures:")
    for k in sorted(failure_counts_dist.keys()):
        num, iso_status = k.split('_')
        status = "with isolations" if iso_status == 'iso' else "without isolations"
        print(f"  {num} failures {status}: {failure_counts_dist[k]}")
    
    return final_data_list, node_list, node_to_idx

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, GCNConv, GraphConv, global_mean_pool, global_max_pool

class ImprovedGNN(torch.nn.Module):
    def __init__(self, num_node_features, edge_dim=None, hidden_channels=64):
        super(ImprovedGNN, self).__init__()
        # Increase model capacity with more sophisticated layers
        
        # Initial feature extraction with GCN
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        
        # Attention layer to focus on important connections
        self.gat1 = GATv2Conv(hidden_channels, hidden_channels, heads=4, dropout=0.2)
        self.bn2 = nn.BatchNorm1d(hidden_channels * 4)
        
        # Additional graph convolution for message passing
        self.conv2 = GraphConv(hidden_channels * 4, hidden_channels * 2)
        self.bn3 = nn.BatchNorm1d(hidden_channels * 2)
        
        # Final GCN layer
        self.conv3 = GCNConv(hidden_channels * 2, hidden_channels)
        self.bn4 = nn.BatchNorm1d(hidden_channels)
        
        # Output layer
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_channels // 2, 1)
        )
        
    def forward(self, data):
        # Handle either individual inputs or a PyG Data object
        if hasattr(data, 'x') and hasattr(data, 'edge_index'):
            x, edge_index = data.x, data.edge_index
        else:
            # Assume direct inputs
            x, edge_index = data
        
        # First convolution block
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        
        # Attention block - learn which connections matter most
        x = self.gat1(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        
        # Second convolution block
        x = self.conv2(x, edge_index)
        x = self.bn3(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.25, training=self.training)
        
        # Final convolution
        x = self.conv3(x, edge_index)
        x = self.bn4(x)
        x = F.relu(x)
        
        # Classification head
        x = self.classifier(x)
        
        return x

In [35]:
def create_isolation_prediction_dataset(topology_df, num_simulations=5000, max_failures=5, min_isolation_ratio=0.3):
    """
    Create a high-quality dataset for training a GNN to predict node isolations after failures.
    
    Args:
        topology_df: DataFrame with network topology information
        num_simulations: Total number of examples to generate
        max_failures: Maximum number of simultaneous failures to simulate
        min_isolation_ratio: Minimum ratio of examples that must contain isolations
        
    Returns:
        data_list: List of PyG Data objects
        node_list: List of all node names
        node_to_idx: Dictionary mapping node names to indices
    """
    print("Building network graph...")
    G = build_network_graph(topology_df)
    
    # Create node mapping
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Identify regular nodes (non-block nodes)
    regular_nodes = [n for n in node_list if "BLOCK" not in n]
    
    print("Computing network vulnerability metrics...")
    # Pre-compute network properties for better feature engineering
    articulation_points = set(nx.articulation_points(G))
    bridges = list(nx.bridges(G))
    
    # Pre-compute edge betweenness - identifying critical edges
    edge_betweenness = nx.edge_betweenness_centrality(G)
    
    # Group nodes by physical ring and logical ring
    nodes_by_ring = {}
    for node in regular_nodes:
        if node not in G.nodes():
            continue
            
        pr = G.nodes[node].get('physicalringname', 'unknown')
        lr = G.nodes[node].get('lrname', 'unknown')
        block = G.nodes[node].get('block_name', 'unknown')
        key = (pr, lr, block)
        
        if key not in nodes_by_ring:
            nodes_by_ring[key] = []
        nodes_by_ring[key].append(node)
    
    print("Calculating node vulnerability scores...")
    # Calculate vulnerability score for each node
    vulnerability_scores = {}
    for node in regular_nodes:
        if "BLOCK" in node:
            vulnerability_scores[node] = 0.0
            continue
            
        try:
            block = G.nodes[node].get('block_name', None)
            if not block:
                vulnerability_scores[node] = 0.5
                continue
                
            # Factor 1: Path redundancy to block
            try:
                paths = len(list(nx.edge_disjoint_paths(G, node, block)))
                if paths == 0:
                    path_score = 1.0  # Already isolated
                elif paths == 1: 
                    path_score = 0.8  # Critical - only one path
                elif paths == 2:
                    path_score = 0.5  # Vulnerable - two paths
                else:
                    path_score = 0.1  # Low risk
            except:
                path_score = 1.0  # Unable to calculate paths
                
            # Factor 2: Is it an articulation point?
            is_articulation = 1.0 if node in articulation_points else 0.0
            
            # Factor 3: Connected to bridges
            bridge_connections = sum(1 for u, v in bridges if node in (u, v))
            bridge_score = min(1.0, bridge_connections * 0.2)
            
            # Factor 4: Betweenness centrality - how important is it for connecting other nodes
            betweenness = nx.betweenness_centrality(G, k=min(100, len(G))).get(node, 0)
            
            # Factor 5: Closeness to block
            try:
                closeness = 1.0 / max(1.0, nx.shortest_path_length(G, node, block))
            except:
                closeness = 0.0
                
            # Combined score (weighted average)
            vulnerability = (
                0.4 * path_score + 
                0.2 * is_articulation + 
                0.2 * bridge_score + 
                0.1 * betweenness +
                0.1 * closeness
            )
            
            vulnerability_scores[node] = min(1.0, vulnerability)
        except Exception as e:
            print(f"Error calculating vulnerability for {node}: {e}")
            vulnerability_scores[node] = 0.5
    
    print("Creating node features...")
    # Create more informative node features
    node_features = []
    for node in node_list:
        features = []
        
        # Basic features
        pr = G.nodes[node].get('physicalringname', 'unknown')
        pr_hash = int(hashlib.md5(pr.encode()).hexdigest(), 16) % 10
        features.append(pr_hash/10.0)
        
        lr = G.nodes[node].get('lrname', 'unknown')
        lr_hash = int(hashlib.md5(lr.encode()).hexdigest(), 16) % 10
        features.append(lr_hash/10.0)
        
        is_block = 1.0 if "BLOCK" in node else 0.0
        features.append(is_block)
        
        # Skip advanced features for block nodes
        if is_block:
            # Add placeholder values for block nodes
            features.extend([0.0] * 8)  # 8 extra features 
        else:
            # Topological importance features
            degree = G.degree(node) / max(1, max(dict(G.degree()).values()))
            features.append(degree)
            
            clustering = nx.clustering(G, node)
            features.append(clustering)
            
            is_articulation = 1.0 if node in articulation_points else 0.0
            features.append(is_articulation)
            
            # Edge connectivity features
            bridges_connected = min(1.0, sum(1 for u, v in bridges if node in (u, v)) / 5.0)
            features.append(bridges_connected)
            
            # Block connectivity
            block = G.nodes[node].get('block_name', None)
            if block:
                try:
                    path_length = nx.shortest_path_length(G, node, block) / 10.0
                    disjoint_paths = len(list(nx.edge_disjoint_paths(G, node, block))) / 5.0
                except:
                    path_length = 1.0
                    disjoint_paths = 0.0
            else:
                path_length = 1.0
                disjoint_paths = 0.0
            features.append(path_length)
            features.append(disjoint_paths)
            
            # Centrality metrics
            betweenness = nx.betweenness_centrality(G, k=min(100, len(G))).get(node, 0)
            features.append(betweenness)
            
            # Pre-computed vulnerability score
            vulnerability = vulnerability_scores.get(node, 0.5)
            features.append(vulnerability)
        
        node_features.append(features)
    
    # Convert to tensor
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Create edge features
    edge_features = []
    for u, v in G.edges():
        # Compute various edge features
        # 1. Edge betweenness
        betweenness = edge_betweenness.get((u, v), edge_betweenness.get((v, u), 0))
        
        # 2. Same physical ring
        same_pr = int(G.nodes[u].get('physicalringname') == G.nodes[v].get('physicalringname'))
        
        # 3. Same logical ring
        same_lr = int(G.nodes[u].get('lrname') == G.nodes[v].get('lrname'))
        
        # 4. Is this a bridge?
        is_bridge = 1.0 if (u, v) in bridges or (v, u) in bridges else 0.0
        
        # Duplicate for reverse edge
        edge_features.append([betweenness, same_pr, same_lr, is_bridge])
        edge_features.append([betweenness, same_pr, same_lr, is_bridge])
    
    edge_attr = torch.tensor(edge_features, dtype=torch.float)
    
    # Track dataset statistics
    examples_with_isolations = []
    examples_without_isolations = []
    
    # Generate dataset with strategic failure selection
    print("Generating dataset examples...")
    
    # PHASE 1: Generate examples targeting vulnerabilities
    phase1_target = int(num_simulations * 0.7)  # 70% from targeted strategies
    
    while len(examples_with_isolations) + len(examples_without_isolations) < phase1_target:
        # Choose failure type (node or edge with 70/30 split)
        failure_type = 'node' if random.random() < 0.7 else 'edge'
        
        # Strategies weighted by effectiveness
        if failure_type == 'node':
            # Node failure strategies
            strategies = [
                ('high_vulnerability', 0.4),  # Target vulnerable nodes
                ('articulation', 0.3),        # Target articulation points
                ('same_ring', 0.2),           # Target nodes in same ring
                ('random', 0.1)               # Random targets
            ]
            
            strategy = random.choices([s[0] for s in strategies], 
                                     [s[1] for s in strategies])[0]
            
            if strategy == 'high_vulnerability':
                # Target high vulnerability nodes
                sorted_nodes = sorted([(n, vulnerability_scores.get(n, 0)) 
                                      for n in regular_nodes if "BLOCK" not in n],
                                     key=lambda x: x[1], reverse=True)
                
                # Take from top 20% vulnerable nodes
                top_vulnerable = sorted_nodes[:max(1, int(len(sorted_nodes) * 0.2))]
                potential_targets = [n[0] for n in top_vulnerable]
                
                if not potential_targets:
                    potential_targets = regular_nodes
                
                # Select 1-N failures
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                nodes_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'articulation':
                # Target articulation points
                potential_targets = list(articulation_points & set(regular_nodes))
                if len(potential_targets) < 2:
                    potential_targets = regular_nodes
                    
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                nodes_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Select nodes in the same ring 
                if not nodes_by_ring:
                    # Fallback to random
                    nodes_to_fail = random.sample(regular_nodes, 
                                                 random.randint(1, min(max_failures, len(regular_nodes))))
                else:
                    # Pick a random ring group
                    rings_with_enough_nodes = [key for key, nodes in nodes_by_ring.items() 
                                              if len(nodes) >= 2]
                    
                    if not rings_with_enough_nodes:
                        # Fallback to random  
                        nodes_to_fail = random.sample(regular_nodes,
                                                     random.randint(1, min(max_failures, len(regular_nodes))))
                    else:
                        ring_key = random.choice(rings_with_enough_nodes)
                        ring_nodes = nodes_by_ring[ring_key]
                        
                        # Select multiple nodes from this ring
                        num_failures = random.randint(1, min(max_failures, len(ring_nodes)))
                        nodes_to_fail = random.sample(ring_nodes, num_failures)
            
            else:  # random strategy
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            for node in nodes_to_fail:
                node_idx = node_to_idx[node]
                invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                valid_edge_mask[invalid_edges] = False
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([0]),
                failed_elements=[node_to_idx[n] for n in nodes_to_fail],
                num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
            )
            
        else:  # edge failure
            # Edge failure strategies
            strategies = [
                ('bridges', 0.4),       # Target bridge edges 
                ('high_betweenness', 0.3),  # Target high betweenness edges
                ('same_ring', 0.2),     # Target edges in the same ring
                ('random', 0.1)         # Random edges
            ]
            
            strategy = random.choices([s[0] for s in strategies], 
                                     [s[1] for s in strategies])[0]
            
            if strategy == 'bridges':
                # Target bridge edges
                potential_targets = [(u, v) for u, v in bridges if "BLOCK" not in u and "BLOCK" not in v]
                if len(potential_targets) == 0:
                    # Fallback to random edges
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    potential_targets = all_edges
                
                # Select multiple bridge edges
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                edges_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'high_betweenness':
                # Target high betweenness edges
                sorted_edges = sorted([(u, v) for u, v in edge_betweenness.keys() 
                                     if "BLOCK" not in u and "BLOCK" not in v],
                                    key=lambda e: edge_betweenness.get(e, 0), 
                                    reverse=True)
                
                # Take top 20% high betweenness edges
                potential_targets = sorted_edges[:max(1, int(len(sorted_edges) * 0.2))]
                
                if not potential_targets:
                    # Fallback to random
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    potential_targets = all_edges
                    
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                edges_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Target multiple edges in the same ring
                ring_edges = {}
                for u, v in G.edges():
                    if "BLOCK" in u or "BLOCK" in v:
                        continue
                        
                    try:
                        pr_u = G.nodes[u].get('physicalringname')
                        lr_u = G.nodes[u].get('lrname')
                        pr_v = G.nodes[v].get('physicalringname')
                        lr_v = G.nodes[v].get('lrname')
                        
                        if pr_u == pr_v and lr_u == lr_v:
                            key = (pr_u, lr_u)
                            if key not in ring_edges:
                                ring_edges[key] = []
                            ring_edges[key].append((u, v))
                    except:
                        continue
                
                if not ring_edges:
                    # Fallback to random
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    num_failures = random.randint(1, min(max_failures, len(all_edges)))
                    edges_to_fail = random.sample(all_edges, num_failures)
                else:
                    # Pick a random ring
                    ring_key = random.choice(list(ring_edges.keys()))
                    ring_edge_list = ring_edges[ring_key]
                    
                    # Select multiple edges from this ring
                    num_failures = random.randint(1, min(max_failures, len(ring_edge_list)))
                    edges_to_fail = random.sample(ring_edge_list, num_failures)
            
            else:  # random strategy
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            failed_edge_indices = []
            
            for u, v in edges_to_fail:
                u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                
                for i in range(edge_index.size(1)):
                    e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                    if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                        valid_edge_mask[i] = False
                        failed_edge_indices.append(i)
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([1]),
                failed_elements=failed_edge_indices,
                num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
            )
        
        # Add to appropriate list
        if y.sum().item() > 0:
            examples_with_isolations.append(data)
        else:
            examples_without_isolations.append(data)
            
        # Report progress
        total_examples = len(examples_with_isolations) + len(examples_without_isolations)
        if total_examples % 100 == 0:
            isolation_ratio = len(examples_with_isolations) / max(1, total_examples)
            print(f"Generated {total_examples} examples, {len(examples_with_isolations)} with isolations ({isolation_ratio:.2f})")
    
    # PHASE 2: Generate remaining examples with random approach
    phase2_target = num_simulations - len(examples_with_isolations) - len(examples_without_isolations)
    
    if phase2_target > 0:
        print(f"Generating {phase2_target} additional examples for Phase 2...")
        
        for _ in range(phase2_target):
            # Choose failure type (node or edge)
            failure_type = random.choice(['node', 'edge'])
            
            if failure_type == 'node':
                # Choose random nodes
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                for node in nodes_to_fail:
                    node_idx = node_to_idx[node]
                    invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                    valid_edge_mask[invalid_edges] = False
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([0]),
                    failed_elements=[node_to_idx[n] for n in nodes_to_fail],
                    num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
                )
                
            else:  # edge failure
                # Choose random edges
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                failed_edge_indices = []
                
                for u, v in edges_to_fail:
                    u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                    
                    for i in range(edge_index.size(1)):
                        e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                        if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                            valid_edge_mask[i] = False
                            failed_edge_indices.append(i)
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([1]),
                    failed_elements=failed_edge_indices,
                    num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
                )
            
            # Add to appropriate list
            if y.sum().item() > 0:
                examples_with_isolations.append(data)
            else:
                examples_without_isolations.append(data)
    
    # Balance the dataset based on min_isolation_ratio
    isolation_count = len(examples_with_isolations)
    total_count = isolation_count + len(examples_without_isolations)
    current_ratio = isolation_count / total_count
    
    print(f"Current isolation ratio: {current_ratio:.2f}, target: {min_isolation_ratio:.2f}")
    
    if current_ratio < min_isolation_ratio and len(examples_without_isolations) > 0:
        # Need to remove some non-isolation examples
        target_non_isolation = int(isolation_count * (1 - min_isolation_ratio) / min_isolation_ratio)
        if target_non_isolation < len(examples_without_isolations):
            print(f"Downsampling non-isolation examples from {len(examples_without_isolations)} to {target_non_isolation}")
            examples_without_isolations = random.sample(examples_without_isolations, target_non_isolation)
    
    # Combine and shuffle the final dataset
    final_data_list = examples_with_isolations + examples_without_isolations
    random.shuffle(final_data_list)
    
    # Print dataset statistics
    total_examples = len(final_data_list)
    total_isolations = len(examples_with_isolations)
    total_non_isolations = len(examples_without_isolations)
    
    print("\n=== Dataset Statistics ===")
    print(f"Total examples: {total_examples}")
    print(f"Examples with isolations: {total_isolations} ({total_isolations/total_examples*100:.1f}%)")
    print(f"Examples without isolations: {total_non_isolations} ({total_non_isolations/total_examples*100:.1f}%)")
    
    # Calculate average isolations per positive example
    avg_isolations = sum(data.y.sum().item() for data in examples_with_isolations) / max(1, total_isolations)
    print(f"Average isolated nodes per positive example: {avg_isolations:.2f}")
    
    # Detailed breakdown by failure type
    node_failures = sum(1 for data in final_data_list if data.failure_type.item() == 0)
    edge_failures = sum(1 for data in final_data_list if data.failure_type.item() == 1)
    
    node_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 0)
    edge_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 1)
    
    print("\nExamples by failure type:")
    print(f"  Node failures: {node_failures} ({node_failures/total_examples*100:.1f}%)")
    print(f"  Edge failures: {edge_failures} ({edge_failures/total_examples*100:.1f}%)")
    
    print("\nIsolation examples by failure type:")
    print(f"  Node failures with isolations: {node_isolations} ({node_isolations/total_isolations*100:.1f}%)")
    print(f"  Edge failures with isolations: {edge_isolations} ({edge_isolations/total_isolations*100:.1f}%)")
    
    # Breakdown by number of failures
    failure_counts = {}
    for data in final_data_list:
        num_fails = data.num_failures.item()
        has_isolations = data.y.sum().item() > 0
        
        key = f"{num_fails}_{'iso' if has_isolations else 'non'}"
        if key not in failure_counts:
            failure_counts[key] = 0
        failure_counts[key] += 1
    
    print("\nExamples by number of failures:")
    for k in sorted(failure_counts.keys()):
        num, iso_status = k.split('_')
        status = "with isolations" if iso_status == 'iso' else "without isolations"
        print(f"  {num} failures {status}: {failure_counts[k]}")
    
    return final_data_list, node_list, node_to_idx

In [36]:
import hashlib


def train_isolation_model(data_list, epochs=50):
    # Split data
    total = len(data_list)
    train_size = int(0.7 * total)
    val_size = int(0.15 * total)
    test_size = total - train_size - val_size
    
    # Shuffle and split
    random.shuffle(data_list)
    train_data = data_list[:train_size]
    val_data = data_list[train_size:train_size+val_size]
    test_data = data_list[train_size+val_size:]
    
    # Count positives in training set
    positive_count = sum(data.y.sum().item() > 0 for data in train_data)
    positive_rate = positive_count / len(train_data)
    print(f"Training set: {len(train_data)} examples, {positive_count} with isolations ({positive_rate:.2%})")
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
    
    # Initialize model with edge features if available
    edge_dim = train_data[0].edge_attr.size(1) if hasattr(train_data[0], 'edge_attr') else None
    model = ImprovedGNN(num_node_features=train_data[0].x.size(1), edge_dim=edge_dim)

    
    # Calculate positive weight based on class imbalance
    # This is crucial! Use a large value (50.0) to force the model to predict positives
    pos_weight = torch.tensor(40.0)  # Very high weight for rare isolation cases
    
    # You could also calculate from your data:
    # pos_samples = sum(d.y.sum().item() for d in train_data)
    # neg_samples = sum((~d.y.bool()).sum().item() for d in train_data)
    # pos_weight = torch.tensor(neg_samples / max(1, pos_samples))
    
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    # Use Adam with weight decay and a lower learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=10, verbose=True
    )
    
    # Training loop
    best_f1 = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            
            # Flatten predictions and targets
            pred = out.squeeze()
            target = batch.y
            
            loss = criterion(pred, target)
            loss.backward()
            
            # Clip gradients to prevent explosion
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs
        
        train_loss = total_loss / len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch)
                pred = out.squeeze()
                target = batch.y
                
                loss = criterion(pred, target)
                val_loss += loss.item() * batch.num_graphs
                
                # Use a MUCH LOWER threshold (0.1 or 0.2) to detect isolations
                # This counteracts the model's tendency to predict all negatives
                binary_pred = (torch.sigmoid(pred) > 0.2).float()
                
                all_preds.append(binary_pred)
                all_targets.append(target)
        
        val_loss /= len(val_loader.dataset)
        
        # Concatenate predictions and targets
        all_preds = torch.cat(all_preds, dim=0)
        all_targets = torch.cat(all_targets, dim=0)
        
        # Calculate metrics - but handle the case of no positive predictions
        tp = ((all_preds == 1) & (all_targets == 1)).sum().item()
        fp = ((all_preds == 1) & (all_targets == 0)).sum().item()
        fn = ((all_preds == 0) & (all_targets == 1)).sum().item()
        tn = ((all_preds == 0) & (all_targets == 0)).sum().item()
        
        precision = tp / max(tp + fp, 1)
        recall = tp / max(tp + fn, 1)
        f1 = 2 * precision * recall / max(precision + recall, 1e-8)
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        
        # Update learning rate based on validation loss
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}")
        
        # Early stopping based on F1 score
        if f1 > best_f1:
            best_f1 = f1
            patience_counter = 0
            torch.save(model.state_dict(), "best_isolation_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}. Best F1: {best_f1:.4f}")
                model.load_state_dict(torch.load("best_isolation_model.pt"))
                break
    
    return model

In [37]:
def create_test_example(G, node_list, node_to_idx, failure_type, elements_to_fail):
    """
    Create a PyG Data object for testing the model with specific failures.
    
    Args:
        G: NetworkX graph
        node_list: List of all nodes in the graph
        node_to_idx: Mapping from node names to indices
        failure_type: 'node' or 'edge'
        elements_to_fail: List of nodes or edges to fail
        
    Returns:
        PyG Data object ready for model inference
    """
    # Create node features - must match the exact features used in training
    node_features = []
    for node in node_list:
        features = []
        
        # Basic features
        pr = G.nodes[node].get('physicalringname', 'unknown')
        pr_hash = int(hashlib.md5(pr.encode()).hexdigest(), 16) % 10
        features.append(pr_hash/10.0)
        
        lr = G.nodes[node].get('lrname', 'unknown')
        lr_hash = int(hashlib.md5(lr.encode()).hexdigest(), 16) % 10
        features.append(lr_hash/10.0)
        
        is_block = 1.0 if "BLOCK" in node else 0.0
        features.append(is_block)
        
        # Add the same advanced features used in your training dataset
        if is_block:
            # Add placeholder values for block nodes
            features.extend([0.0] * 8)  # Must match the number used in training
        else:
            # Calculate same advanced features
            degree = G.degree(node) / max(1, max(dict(G.degree()).values()))
            features.append(degree)
            
            clustering = nx.clustering(G, node)
            features.append(clustering)
            
            # Precompute these if needed for efficiency
            articulation_points = set(nx.articulation_points(G))
            bridges = list(nx.bridges(G))
            
            is_articulation = 1.0 if node in articulation_points else 0.0
            features.append(is_articulation)
            
            bridges_connected = min(1.0, sum(1 for u, v in bridges if node in (u, v)) / 5.0)
            features.append(bridges_connected)
            
            # Add the rest of your advanced features here
            # For example:
            block = G.nodes[node].get('block_name', None)
            if block:
                try:
                    path_length = nx.shortest_path_length(G, node, block) / 10.0
                    disjoint_paths = len(list(nx.edge_disjoint_paths(G, node, block))) / 5.0
                    features.append(path_length)
                    features.append(disjoint_paths)
                except:
                    features.append(1.0)  # Max path length
                    features.append(0.0)  # No paths
            else:
                features.append(1.0)
                features.append(0.0)
                
            # Isolation risk features
            single_path_vulnerability = 1.0 if disjoint_paths <= 1 else 0.0
            features.append(single_path_vulnerability)
            
            # Add other isolation risk features
            neighbors = list(G.neighbors(node))
            connects_to_articulation = 1.0 if any(n in articulation_points for n in neighbors) else 0.0
            features.append(connects_to_articulation)
            
            isolation_risk = max(single_path_vulnerability, is_articulation, bridges_connected)
            features.append(isolation_risk)
            
        node_features.append(features)
    
    # Convert to tensor
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edge_list = []
    for u, v in G.edges():
        u_idx, v_idx = node_to_idx[u], node_to_idx[v]
        edge_list.append([u_idx, v_idx])
        edge_list.append([v_idx, u_idx])  # Add both directions
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    
    # Create edge features if used in your model
    edge_attr = None
    if hasattr(G.edges[list(G.edges())[0]], 'features'):
        edge_attr = []
        for i, (u, v) in enumerate(edge_list):
            u_name, v_name = node_list[u], node_list[v]
            if G.has_edge(u_name, v_name):
                edge_attr.append(G.edges[u_name, v_name].get('features', [0.0]))
            else:
                edge_attr.append([0.0])
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Create edge mask for failure simulation
    valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
    
    if failure_type == 'node':
        for node in elements_to_fail:
            node_idx = node_to_idx[node]
            invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
            valid_edge_mask[invalid_edges] = False
    else:  # edge failure
        for u, v in elements_to_fail:
            u_idx, v_idx = node_to_idx[u], node_to_idx[v]
            for i in range(edge_index.size(1)):
                e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                    valid_edge_mask[i] = False
    
    # Create Data object
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        valid_edge_mask=valid_edge_mask,
        failure_type=torch.tensor([0 if failure_type == 'node' else 1]),
        failed_elements=[node_to_idx[n] if failure_type == 'node' 
                        else [node_to_idx[u], node_to_idx[v]] for n in elements_to_fail]
    )
    
    return data

In [38]:
import random
import torch
import copy
from torch_geometric.loader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv

# 1. Create better dataset with proper balance
data_list, node_list, node_to_idx = create_isolation_prediction_dataset(
    topo_data,
    num_simulations=5000,
    max_failures=20,
    min_isolation_ratio=0.5 # Force 50% of examples to have isolations
)


Building network graph...
Computing network vulnerability metrics...
Calculating node vulnerability scores...
Creating node features...
Generating dataset examples...
Generated 100 examples, 65 with isolations (0.65)


KeyboardInterrupt: 

In [39]:

# 2. Train with the improved loss function
model = train_isolation_model(data_list, epochs=100)

# 3. Test with a known isolation case (lower threshold!)
def test_model_with_known_case():
    # Find a node that causes isolations when it fails
    for node in G.nodes():
        if "BLOCK" not in node:
            isolated = simulate_multiple_failures(G, [node], 'node')
            if len(isolated) > 0:
                print(f"Testing with node {node} which should cause {len(isolated)} isolations")
                
                # Create test data
                test_data = create_test_example(G, node_list, node_to_idx, 'node', [node])
                
                # Predict
                model.eval()
                with torch.no_grad():
                    pred = model(test_data)
                    # Use lower threshold here!
                    binary_pred = (torch.sigmoid(pred.squeeze()) > 0.2).float()
                
                # Check results
                predicted_nodes = [node_list[i] for i in range(len(node_list)) if binary_pred[i] > 0.5]
                print(f"Actual isolated: {isolated}")
                print(f"Predicted isolated: {predicted_nodes}")
                return

Training set: 3500 examples, 2201 with isolations (62.89%)




Epoch 0: Train Loss: 1.0576, Val Loss: 1.0294, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 1: Train Loss: 1.0311, Val Loss: 1.0282, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 2: Train Loss: 1.0296, Val Loss: 1.0267, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 3: Train Loss: 1.0289, Val Loss: 1.0259, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 4: Train Loss: 1.0282, Val Loss: 1.0258, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 5: Train Loss: 1.0284, Val Loss: 1.0256, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 6: Train Loss: 1.0276, Val Loss: 1.0258, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 7: Train Loss: 1.0279, Val Loss: 1.0261, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
Epoch 8: Train Loss: 1.0274, Val Loss: 1.0255, Precision: 0.0236, Recall: 1.0000, F1: 0.0461, Accuracy: 0.2441
E

In [49]:
def predict_isolated_nodes(model, G, node_list, node_to_idx, failure_type, elements_to_fail, threshold=0.2):
    """
    Predict which nodes will become isolated after failures using a trained GNN model.
    
    Args:
        model: Trained GNN model
        G: NetworkX graph
        node_list: List of all nodes in the graph
        node_to_idx: Mapping from node names to indices
        failure_type: 'node' or 'edge'
        elements_to_fail: Node(s) or edge(s) to fail
        threshold: Probability threshold for considering a node isolated (lower is more sensitive)
    
    Returns:
        List of node names predicted to be isolated
    """
    # Set model to evaluation mode
    model.eval()
    
    # Pre-compute network properties for better feature engineering
    articulation_points = set(nx.articulation_points(G))
    bridges = list(nx.bridges(G))
    
    # Create node features
    node_features = []
    for node in node_list:
        # Feature 1-2: Physical and logical ring
        pr = G.nodes[node].get('physicalringname', 'unknown')
        pr_hash = hash(pr) % 10
        
        lr = G.nodes[node].get('lrname', 'unknown')
        lr_hash = hash(lr) % 10
        
        # Feature 3: Is it a block node?
        is_block = 1.0 if "BLOCK" in node else 0.0
        
        # Feature 4-5: Network properties
        degree = G.degree(node) / max(1, len(G))
        clustering = nx.clustering(G, node)
        
        # Feature 6: Is articulation point?
        is_articulation = 1.0 if node in articulation_points else 0.0
        
        # Feature 7: Connected to bridges
        bridges_connected = min(1.0, sum(1 for u, v in bridges if node in (u, v)) / 5.0)
        
        # Features 8-9: Path to block
        block = G.nodes[node].get('block_name', None)
        if block and not is_block:
            try:
                path_length = nx.shortest_path_length(G, node, block) / 10.0
                disjoint_paths = len(list(nx.edge_disjoint_paths(G, node, block))) / 5.0
            except:
                path_length = 1.0
                disjoint_paths = 0.0
        else:
            path_length = 0.0 if is_block else 1.0
            disjoint_paths = 1.0 if is_block else 0.0
        
        # Feature 10: Single path vulnerability
        single_path_vulnerability = 1.0 if not is_block and disjoint_paths <= 0.2 else 0.0
        
        # Feature 11: Isolation risk
        isolation_risk = max(single_path_vulnerability, is_articulation, bridges_connected) if not is_block else 0.0
        
        # Combine all features
        node_features.append([
            pr_hash/10.0, 
            lr_hash/10.0, 
            is_block, 
            degree, 
            clustering, 
            is_articulation, 
            bridges_connected, 
            path_length, 
            disjoint_paths, 
            single_path_vulnerability, 
            isolation_risk
        ])
    
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge for undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Create edge attributes - FIXED to have exactly 4 features per edge
    edge_attr = []
    for i in range(edge_index.size(1)):
        u_idx, v_idx = edge_index[0, i].item(), edge_index[1, i].item()
        u_node, v_node = node_list[u_idx], node_list[v_idx]
        
        # Edge features
        # 1. Same physical ring
        same_pr = 1.0 if G.nodes[u_node].get('physicalringname') == G.nodes[v_node].get('physicalringname') else 0.0
        
        # 2. Same logical ring
        same_lr = 1.0 if G.nodes[u_node].get('lrname') == G.nodes[v_node].get('lrname') else 0.0
        
        # 3. Is this a bridge?
        is_bridge = 1.0 if (u_node, v_node) in bridges or (v_node, u_node) in bridges else 0.0
        
        # 4. Block connection
        block_connection = 1.0 if ("BLOCK" in u_node) != ("BLOCK" in v_node) else 0.0
        
        # IMPORTANT: Use exactly 4 features for edge_attr
        edge_attr.append([same_pr, same_lr, is_bridge, block_connection])
    
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Convert elements_to_fail to list if it's a single element
    if not isinstance(elements_to_fail, list):
        elements_to_fail = [elements_to_fail]
    
    # Create edge mask for failure simulation
    valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
    
    if failure_type == 'node':
        # Mark all edges connected to failed nodes as invalid
        for node in elements_to_fail:
            node_idx = node_to_idx[node]
            invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
            valid_edge_mask[invalid_edges] = False
    else:  # Edge failure
        for failed_edge in elements_to_fail:
            u_idx, v_idx = node_to_idx[failed_edge[0]], node_to_idx[failed_edge[1]]
            
            for i in range(edge_index.size(1)):
                e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                    valid_edge_mask[i] = False
    
    # Create PyG Data object for the model
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        valid_edge_mask=valid_edge_mask,
        failure_type=torch.tensor([0 if failure_type == 'node' else 1]),
        num_failures=torch.tensor([len(elements_to_fail)], dtype=torch.long)
    )
    
    # Get model predictions
    with torch.no_grad():
        out = model(data)
        pred_probs = torch.sigmoid(out.squeeze())
        
        # Handle different output shapes (batch vs single)
        if len(pred_probs.shape) > 0:
            pred_probs = pred_probs.numpy()
        else:
            pred_probs = pred_probs.item()
    
    # Find nodes predicted to be isolated
    threshold = 0.7
    predicted_isolated_indices = np.where(pred_probs > threshold)[0]
    predicted_isolated_nodes = [node_list[idx] for idx in predicted_isolated_indices]
    
    return predicted_isolated_nodes

In [26]:
elements_to_fail = ['NODE_2_2_0','NODE_2_2_4']
actual_isolated_nodes = simulate_multiple_failures(G, elements_to_fail, failure_type='node')
actual_isolated_nodes


['NODE_2_2_1', 'NODE_2_2_2', 'NODE_2_2_3']

In [51]:
predicted_isolated_nodes = predict_isolated_nodes(model, G, node_list, node_to_idx, 'node', elements_to_fail)
predicted_isolated_nodes



[]

In [105]:
def evaluate_prediction(G, model, node_list, node_to_idx, failure_type, elements_to_fail):
    """
    Evaluate how well the model predicts isolated nodes compared to actual simulation
    
    Args:
        G: NetworkX graph
        model: Trained GNN model
        node_list: List of all nodes
        node_to_idx: Mapping from node names to indices
        failure_type: 'node' or 'edge'
        elements_to_fail: List of nodes or edges to fail
        
    Returns:
        Dict with evaluation metrics
    """
    # Get actual isolated nodes via simulation
    if failure_type == 'node':
        actual_isolated = simulate_multiple_failures(G, elements_to_fail, 'node')
    else:
        actual_isolated = simulate_multiple_failures(G, elements_to_fail, 'edge')
    
    # Get predicted isolated nodes from model
    predicted_isolated = predict_isolated_nodes(model, G, node_list, node_to_idx, 
                                               failure_type, elements_to_fail)
    
    # Calculate evaluation metrics
    true_positives = len(set(actual_isolated) & set(predicted_isolated))
    false_positives = len(set(predicted_isolated) - set(actual_isolated))
    false_negatives = len(set(actual_isolated) - set(predicted_isolated))
    
    # Handle division by zero
    if true_positives + false_positives > 0:
        precision = true_positives / (true_positives + false_positives)
    else:
        precision = 1.0 if len(actual_isolated) == 0 else 0.0
        
    if true_positives + false_negatives > 0:
        recall = true_positives / (true_positives + false_negatives)
    else:
        recall = 1.0 if len(predicted_isolated) == 0 else 0.0
    
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0.0
    
    accuracy = len(set(actual_isolated) & set(predicted_isolated)) / max(1, len(set(actual_isolated) | set(predicted_isolated)))
    
    # Print results
    if failure_type == 'node':
        print(f"Node failure: {', '.join(elements_to_fail)}")
    else:
        print(f"Edge failure: {', '.join([f'({e[0]}-{e[1]})' for e in elements_to_fail])}")
    
    print(f"Actual isolated nodes: {', '.join(actual_isolated) if actual_isolated else 'None'}")
    print(f"Predicted isolated nodes: {', '.join(predicted_isolated) if predicted_isolated else 'None'}")
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}, Accuracy: {accuracy:.2f}")
    
    return {
        'actual_isolated': actual_isolated,
        'predicted_isolated': predicted_isolated,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'accuracy': accuracy
    }

In [106]:
import random


In [107]:
# Test single node failures
print("\n=== TESTING SINGLE NODE FAILURES ===")
for node in random.sample(node_list, 3):  # Test 3 random nodes
    if "BLOCK" not in node:  # Skip block nodes
        evaluate_prediction(G, model, node_list, node_to_idx, 'node', [node])

# Test double node failures (more likely to cause isolations)
print("\n=== TESTING DOUBLE NODE FAILURES ===")
# Find nodes in the same logical ring
for pr_idx in range(2):
    for lr_idx in range(2):
        ring_nodes = [n for n in G.nodes() 
                     if 'physicalringname' in G.nodes[n] and 
                        G.nodes[n]['physicalringname'] == f"RING_{pr_idx}" and
                        'lrname' in G.nodes[n] and
                        G.nodes[n]['lrname'] == f"LR_{pr_idx}_{lr_idx}" and
                        "BLOCK" not in n]
        
        if len(ring_nodes) >= 3:
            # Pick two adjacent nodes in the ring
            nodes_to_fail = [ring_nodes[1], ring_nodes[2]]
            evaluate_prediction(G, model, node_list, node_to_idx, 'node', nodes_to_fail)
            break

# Test single edge failures
print("\n=== TESTING SINGLE EDGE FAILURES ===")
for edge in random.sample(list(G.edges()), 3):
    evaluate_prediction(G, model, node_list, node_to_idx, 'edge', [edge])

# Test double edge failures
print("\n=== TESTING DOUBLE EDGE FAILURES ===")
# Find edges in the same logical ring
edge_pairs = []
for pr_idx in range(2):
    for lr_idx in range(2):
        ring_edges = [(u, v) for u, v in G.edges() 
                      if 'physicalringname' in G.nodes[u] and 
                         G.nodes[u]['physicalringname'] == f"RING_{pr_idx}" and
                         'lrname' in G.nodes[u] and 
                         G.nodes[u]['lrname'] == f"LR_{pr_idx}_{lr_idx}"]
        
        if len(ring_edges) >= 2:
            edges_to_fail = [ring_edges[0], ring_edges[1]]
            evaluate_prediction(G, model, node_list, node_to_idx, 'edge', edges_to_fail)
            break


=== TESTING SINGLE NODE FAILURES ===
Node failure: NODE_0_0_4
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Node failure: NODE_2_2_3
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Node failure: NODE_1_2_1
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00

=== TESTING DOUBLE NODE FAILURES ===
Node failure: NODE_0_0_1, NODE_0_0_2
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Node failure: NODE_1_0_1, NODE_1_0_2
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00

=== TESTING SINGLE EDGE FAILURES ===
Edge failure: (NODE_0_1_3-NODE_0_1_4)
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Edge fail

In [92]:
model.train()
for epoch in range(100):
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out.squeeze(), batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Validation
    if epoch % 5 == 0:
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch)
                val_loss += criterion(out.squeeze(), batch.y).item()
        
        print(f"Epoch {epoch}: Train Loss: {total_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")
        model.train()

KeyError: 'failed_edge_mask'

In [5]:
def create_gnn_dataset(topology_df, num_simulations=100):
    """
    Create a dataset for training a GNN to predict isolated nodes
    
    Returns:
        - List of PyTorch Geometric Data objects
    """
    # Build the graph
    G = build_network_graph(topology_df)
    
    # Node mapping (for creating numerical indices)
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Create node features
    node_features = []
    for node in node_list:
        # Feature 1: One-hot encoded physical ring
        pr = G.nodes[node]['physicalringname']
        pr_hash = hash(pr) % 10  # Simple encoding
        
        # Feature 2: One-hot encoded logical ring
        lr = G.nodes[node]['lrname']
        lr_hash = hash(lr) % 10  # Simple encoding
        
        # Feature 3: Is it a block node?
        is_block = 1.0 if node == G.nodes[node]['block_name'] else 0.0
        
        # Feature 4-5: Degree centrality and clustering coefficient
        degree = G.degree(node) / len(G)
        clustering = nx.clustering(G, node)
        
        node_features.append([pr_hash/10.0, lr_hash/10.0, is_block, degree, clustering])
    
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index for PyG
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge for undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Create dataset entries
    data_list = []
    
    # Simulate node failures
    for _ in range(num_simulations // 2):  # Half node failures, half edge failures
        # Randomly select a node to fail
        node_to_fail = np.random.choice(node_list)
        
        # Find isolated nodes
        isolated_nodes = simulate_node_failure(G, node_to_fail)
        
        # Create target tensor (1 for isolated nodes, 0 for others)
        y = torch.zeros(len(node_list), dtype=torch.float)
        for node in isolated_nodes:
            y[node_to_idx[node]] = 1.0
        
        # Create node failure mask
        node_mask = torch.zeros(len(node_list), dtype=torch.bool)
        node_mask[node_to_idx[node_to_fail]] = True
        
        # Create PyG Data object
        data = Data(
            x=node_features,
            edge_index=edge_index,
            y=y,
            failed_node_mask=node_mask,
            failed_edge_mask=None
        )
        
        data_list.append(data)
    
    # Simulate edge failures
    edge_list = list(G.edges())
    for _ in range(num_simulations // 2):
        # Randomly select an edge to fail
        edge_to_fail = edge_list[np.random.randint(0, len(edge_list))]
        
        # Find isolated nodes
        isolated_nodes = simulate_edge_failure(G, edge_to_fail)
        
        # Create target tensor
        y = torch.zeros(len(node_list), dtype=torch.float)
        for node in isolated_nodes:
            y[node_to_idx[node]] = 1.0
        
        # Create edge failure mask
        edge_mask = torch.zeros(edge_index.size(1), dtype=torch.bool)
        u_idx, v_idx = node_to_idx[edge_to_fail[0]], node_to_idx[edge_to_fail[1]]
        
        for i in range(edge_index.size(1)):
            e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
            if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                edge_mask[i] = True
        
        # Create PyG Data object
        data = Data(
            x=node_features,
            edge_index=edge_index,
            y=y,
            failed_node_mask=None,
            failed_edge_mask=edge_mask
        )
        
        data_list.append(data)
    
    return data_list


In [None]:
import mysql.connector
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

def load_topology_data(config):
    """Load topology data from MySQL database"""
    connection = mysql.connector.connect(**config)
    cursor = connection.cursor(dictionary=True)
    
    query = """
        SELECT 
           aendname, 
           bendname, 
           aendip, 
           bendip, 
           aendifIndex,
           bendifIndex,
           block_name, 
           physicalringname, 
           lrname 
        FROM topology_data_logical
    """
    cursor.execute(query)
    rows = cursor.fetchall()
    connection.close()
    
    return pd.DataFrame(rows)


In [40]:
def create_position_based_dataset(topology_df, num_simulations=5000):
    """
    Create a dataset focused only on sequential position, logical ring ID and physical ring ID
    
    Args:
        topology_df: DataFrame with topology data
        num_simulations: Number of simulations to generate
        
    Returns:
        data_list: List of PyG Data objects for training
        node_list: List of all node names in order
        node_to_idx: Dictionary mapping node names to indices
    """
    print("Building network graph...")
    G = build_network_graph(topology_df)
    
    # Node mapping
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Extract physical ring, logical ring, and position information from node names
    node_features = []
    ring_groups = {}  # Group nodes by ring
    
    for node in node_list:
        # Parse node information from name (NODE_PR_LR_POS) or get from attributes
        if "NODE_" in node:
            parts = node.split("_")
            if len(parts) >= 4:  # Has the pattern NODE_PR_LR_POS
                pr_id = int(parts[1])
                lr_id = int(parts[2])
                pos = int(parts[3])
                is_block = 0.0
            else:
                # Unknown format, use defaults
                pr_id = 0
                lr_id = 0
                pos = 0
                is_block = 0.0
        elif "BLOCK_" in node:
            # Block nodes
            parts = node.split("_")
            pr_id = int(parts[1]) if len(parts) > 1 else 0
            lr_id = 0  # Blocks aren't part of a specific logical ring
            pos = -1  # Special position for blocks
            is_block = 1.0
        else:
            # Default values if pattern doesn't match
            pr_id = 0
            lr_id = 0
            pos = 0
            is_block = 0.0
        
        # Normalize features
        pr_norm = pr_id / 10.0  # Assuming fewer than 10 physical rings
        lr_norm = lr_id / 10.0  # Assuming fewer than 10 logical rings per physical ring
        pos_norm = pos / 10.0   # Assuming fewer than 10 positions per logical ring
        
        # Create feature vector [physical_ring_id, logical_ring_id, position, is_block]
        features = [pr_norm, lr_norm, pos_norm, is_block]
        node_features.append(features)
        
        # Group nodes by physical ring and logical ring for isolation logic
        if not is_block:
            key = (pr_id, lr_id)
            if key not in ring_groups:
                ring_groups[key] = []
            ring_groups[key].append((node, pos))
    
    # Convert to tensor
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge for undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Sort nodes within each ring group by position for sequential isolation logic
    for key in ring_groups:
        ring_groups[key].sort(key=lambda x: x[1])  # Sort by position
    
    print("Generating dataset examples...")
    data_list = []
    examples_with_isolations = 0
    
    for i in range(num_simulations):
        # Choose a random ring that has enough nodes
        valid_rings = [key for key, nodes in ring_groups.items() if len(nodes) >= 4]
        if not valid_rings:
            continue
            
        ring_key = random.choice(valid_rings)
        ring_nodes = ring_groups[ring_key]
        
        # Choose two nodes to fail with positions that have nodes between them
        positions = sorted([node[1] for node in ring_nodes])
        if len(positions) < 3:
            continue
            
        # Choose two positions with at least one node between them
        pos1_idx = random.randint(0, len(positions) - 3)
        pos2_idx = random.randint(pos1_idx + 2, len(positions) - 1)
        
        pos1 = positions[pos1_idx]
        pos2 = positions[pos2_idx]
        
        # Find the corresponding nodes
        node1 = next(node for node, pos in ring_nodes if pos == pos1)
        node2 = next(node for node, pos in ring_nodes if pos == pos2)
        
        # Find nodes that should be isolated (positions between pos1 and pos2)
        isolated_nodes = [
            node for node, pos in ring_nodes 
            if pos > pos1 and pos < pos2  # Position is between the two failed nodes
        ]
        
        # Create target tensor
        y = torch.zeros(len(node_list), dtype=torch.float)
        for node in isolated_nodes:
            y[node_to_idx[node]] = 1.0
        
        # Create node failure mask
        node_mask = torch.zeros(len(node_list), dtype=torch.bool)
        node_mask[node_to_idx[node1]] = True
        node_mask[node_to_idx[node2]] = True
        
        # Create PyG Data object
        data = Data(
            x=node_features,
            edge_index=edge_index,
            y=y,
            failed_nodes=torch.tensor([node_to_idx[node1], node_to_idx[node2]]),
            failure_positions=torch.tensor([pos1, pos2]),
            ring_key=torch.tensor(list(ring_key))
        )
        
        data_list.append(data)
        
        if len(isolated_nodes) > 0:
            examples_with_isolations += 1
            
        # Print progress
        if (i + 1) % 500 == 0:
            print(f"Generated {i+1}/{num_simulations} examples, {examples_with_isolations} with isolations")
    
    print(f"Training set: {len(data_list)} examples, {examples_with_isolations} with isolations ({examples_with_isolations/len(data_list)*100:.2f}%)")
    
    return data_list, node_list, node_to_idx

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader

class PositionalIsolationModel(nn.Module):
    def __init__(self, num_features=4):
        super(PositionalIsolationModel, self).__init__()
        # Simple architecture with position awareness
        self.conv1 = GCNConv(num_features, 32)
        
        # Position-aware layers
        self.position_encoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )
        
        # Final prediction layer
        self.predictor = nn.Linear(32, 1)
    
    def forward(self, x, edge_index, failed_nodes=None):
        # Basic feature extraction
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        
        # Position encoding
        x = self.position_encoder(x)
        
        # Final prediction
        x = self.predictor(x)
        
        return torch.sigmoid(x)


In [None]:
def train_position_model(data_list, epochs=50):
    # Split data
    train_size = int(0.8 * len(data_list))
    train_data = data_list[:train_size]
    val_data = data_list[train_size:]
    
    # Create loaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    
    # Create model
    model = PositionalIsolationModel()
    
    # Use weighted loss since most nodes aren't isolated
    pos_weight = torch.tensor([10.0])
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.failed_nodes)
            loss = criterion(out.squeeze(), batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch.x, batch.edge_index, batch.failed_nodes)
                pred = (out.squeeze() > 0.5).float()
                val_preds.append(pred)
                val_targets.append(batch.y)
        
        val_preds = torch.cat(val_preds)
        val_targets = torch.cat(val_targets)
        
        # Calculate metrics
        precision = ((val_preds == 1) & (val_targets == 1)).sum().float() / (val_preds.sum() + 1e-6)
        recall = ((val_preds == 1) & (val_targets == 1)).sum().float() / (val_targets.sum() + 1e-6)
        f1 = 2 * precision * recall / (precision + recall + 1e-6)
        accuracy = (val_preds == val_targets).float().mean()
        
        print(f"Epoch {epoch+1}/{epochs}: Loss: {total_loss/len(train_loader):.4f}, "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Acc: {accuracy:.4f}")
    
    return model

In [46]:
data_list,_,_ = create_position_based_dataset(topo_data, num_simulations=5000)

Building network graph...
Generating dataset examples...
Generated 500/5000 examples, 500 with isolations
Generated 1000/5000 examples, 1000 with isolations
Generated 1500/5000 examples, 1500 with isolations
Generated 2000/5000 examples, 2000 with isolations
Generated 2500/5000 examples, 2500 with isolations
Generated 3000/5000 examples, 3000 with isolations
Generated 3500/5000 examples, 3500 with isolations
Generated 4000/5000 examples, 4000 with isolations
Generated 4500/5000 examples, 4500 with isolations
Generated 5000/5000 examples, 5000 with isolations
Training set: 5000 examples, 5000 with isolations (100.00%)


In [47]:
train_position_model(data_list, epochs=100)

Epoch 1/100: Loss: 0.8259, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 2/100: Loss: 0.6949, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 3/100: Loss: 0.6933, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 4/100: Loss: 0.6932, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 5/100: Loss: 0.6932, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 6/100: Loss: 0.6932, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 7/100: Loss: 0.6932, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 8/100: Loss: 0.6932, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 9/100: Loss: 0.6932, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Epoch 10/100: Loss: 0.6932, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9913
Early stopping at epoch 10


FileNotFoundError: [Errno 2] No such file or directory: 'best_position_model.pt'

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import networkx as nx
import numpy as np
import random

class DirectPositionalModel(nn.Module):
    def __init__(self, hidden_size=64):
        super(DirectPositionalModel, self).__init__()
        # Simpler, more direct architecture
        self.fc1 = nn.Linear(5, hidden_size)  # 5 features explained below
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        # No graph structure, just direct positional reasoning
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def create_direct_dataset(topology_df, num_examples=1000):
    """Create dataset with explicit positional features"""
    # Extract nodes with position information
    nodes_by_ring = {}
    
    for _, row in topology_df.iterrows():
        # Extract ring IDs and positions from node names
        for node_col in ['aendname', 'bendname']:
            node = row[node_col]
            if not isinstance(node, str) or 'NODE_' not in node:
                continue
                
            parts = node.split('_')
            if len(parts) >= 4:
                try:
                    pr_id = int(parts[1])
                    lr_id = int(parts[2]) 
                    pos = int(parts[3])
                    
                    key = (pr_id, lr_id)
                    if key not in nodes_by_ring:
                        nodes_by_ring[key] = []
                    
                    nodes_by_ring[key].append((node, pos))
                except ValueError:
                    continue
    
    # Create dataset
    X = []
    Y = []
    
    for _ in range(num_examples):
        # Pick a ring with enough nodes
        valid_rings = [ring for ring, nodes in nodes_by_ring.items() if len(nodes) >= 4]
        if not valid_rings:
            continue
            
        ring = random.choice(valid_rings)
        nodes = nodes_by_ring[ring]
        
        # Sort nodes by position
        nodes.sort(key=lambda x: x[1])
        
        # Pick two positions to fail with at least one node between them
        idx1 = random.randint(0, len(nodes) - 3)
        idx2 = random.randint(idx1 + 2, len(nodes) - 1)
        
        fail_node1, pos1 = nodes[idx1]
        fail_node2, pos2 = nodes[idx2]
        
        # Create features for all nodes in this ring
        for node, pos in nodes:
            # Key features:
            # 1. Is this node between the two failed positions?
            is_between = 1.0 if pos1 < pos < pos2 else 0.0
            
            # 2. Normalized position within the ring
            norm_pos = pos / 100.0  # Scale position
            
            # 3. Normalized distance from failed nodes
            dist_to_fail1 = abs(pos - pos1) / 100.0
            dist_to_fail2 = abs(pos - pos2) / 100.0
            
            # 4. Is this one of the failed nodes?
            is_failed = 1.0 if node in [fail_node1, fail_node2] else 0.0
            
            # Feature vector
            features = [norm_pos, dist_to_fail1, dist_to_fail2, is_failed, is_between]
            X.append(features)
            
            # Target: 1 if node should be isolated
            Y.append(1.0 if is_between else 0.0)
    
    return torch.tensor(X, dtype=torch.float), torch.tensor(Y, dtype=torch.float).unsqueeze(1)

def train_direct_model(X, Y, test_split=0.2):
    """Train model with direct positional features"""
    # Split data
    indices = torch.randperm(len(X))
    test_size = int(len(X) * test_split)
    train_indices = indices[:-test_size]
    test_indices = indices[-test_size:]
    
    X_train, Y_train = X[train_indices], Y[train_indices]
    X_test, Y_test = X[test_indices], Y[test_indices]
    
    # Create model
    model = DirectPositionalModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Balance loss with high weight for positive class
    pos_weight = (len(Y_train) - Y_train.sum()) / Y_train.sum() if Y_train.sum() > 0 else 10.0
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
    
    # Train
    batch_size = 64
    epochs = 30
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        # Training batches
        for i in range(0, len(X_train), batch_size):
            batch_X = X_train[i:i+batch_size]
            batch_Y = Y_train[i:i+batch_size]
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_Y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Evaluation
        model.eval()
        with torch.no_grad():
            outputs = model(X_test)
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            
            # Metrics
            correct = (predictions == Y_test).sum().item()
            accuracy = correct / len(Y_test)
            
            tp = ((predictions == 1) & (Y_test == 1)).sum().item()
            fp = ((predictions == 1) & (Y_test == 0)).sum().item()
            fn = ((predictions == 0) & (Y_test == 1)).sum().item()
            
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        print(f"Epoch {epoch+1}/{epochs}: Loss: {total_loss/len(X_train):.4f}, "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Acc: {accuracy:.4f}")
    
    return model

In [9]:
# Load your topology data


# Create the simplified dataset
X, Y = create_direct_dataset(topo_data, num_examples=5000)

# See the class distribution
print(f"Dataset created with {len(X)} examples")
print(f"Isolated nodes: {Y.sum().item()} ({Y.sum().item()/len(Y)*100:.2f}%)")

# Train the model
model = train_direct_model(X, Y)

Dataset created with 100000 examples
Isolated nodes: 21334.0 (21.33%)
Epoch 1/30: Loss: 0.0010, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 2/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 3/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 4/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 5/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 6/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 7/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 8/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 9/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 10/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 1.0000
Epoch 11/30: Loss: 0.0000, Precision: 1.0000, Recall: 1.0000, F1: 1.0000, Acc: 

In [None]:
def predict_ring_isolations(model, topology_df, failed_node1, failed_node2):
    """
    Predict isolations for all nodes in a ring when two specific nodes fail
    
    Args:
        model: Trained DirectPositionalModel
        topology_df: DataFrame with topology data
        failed_node1, failed_node2: The two nodes that will fail
    
    Returns:
        Dictionary of {node_name: isolation_probability}
    """
    # Get node position info
    def get_node_info(node):
        if not isinstance(node, str) or 'NODE_' not in node:
            return None
        parts = node.split('_')
        if len(parts) < 4:
            return None
        try:
            return (int(parts[1]), int(parts[2]), int(parts[3]))
        except ValueError:
            return None
    
    # Get failed nodes' info
    node1_info = get_node_info(failed_node1)
    node2_info = get_node_info(failed_node2)
    
    if not node1_info or not node2_info:
        print("Invalid node names")
        return {}
    
    # Get ring ID and positions
    pr_id, lr_id = node1_info[0], node1_info[1]
    pos1, pos2 = node1_info[2], node2_info[2]
    
    # Ensure pos1 < pos2
    if pos1 > pos2:
        pos1, pos2 = pos2, pos1
        
    # Collect unique nodes in the ring
    ring_nodes = set()
    for _, row in topology_df.iterrows():
        for col in ['aendname', 'bendname']:
            node = row[col]
            info = get_node_info(node)
            if info and info[0] == pr_id and info[1] == lr_id:
                ring_nodes.add(node)
    
    # Generate features for all nodes
    features = []
    node_names = []
    
    for node in ring_nodes:
        info = get_node_info(node)
        if not info:
            continue
            
        pos = info[2]
        
        # Generate feature vector
        norm_pos = pos / 100.0
        dist_to_fail1 = abs(pos - pos1) / 100.0
        dist_to_fail2 = abs(pos - pos2) / 100.0
        is_failed = 1.0 if node in [failed_node1, failed_node2] else 0.0
        is_between = 1.0 if pos1 < pos < pos2 else 0.0
        
        features.append([norm_pos, dist_to_fail1, dist_to_fail2, is_failed, is_between])
        node_names.append(node)
    
    # Make predictions for all nodes
    model.eval()
    with torch.no_grad():
        X = torch.tensor(features, dtype=torch.float)
        outputs = model(X)
        probs = torch.sigmoid(outputs).squeeze().tolist()
    
    # Return results as dictionary
    results = {}
    for i, node in enumerate(node_names):
        # Skip failed nodes themselves
        if node in [failed_node1, failed_node2]:
            continue
        results[node] = probs[i]
    
    return results

In [16]:
predictions = predict_ring_isolations(model, topo_data, "NODE_1_1_6", "NODE_1_1_9")

print("Predicted isolations:")
for node, prob in sorted(predictions.items(), key=lambda x: x[1], reverse=True):
    if prob > 0.5:
        print(f"  {node}: {prob:.4f} - ISOLATED")
    else:
        print(f"  {node}: {prob:.4f}")

# Verify against ground truth


Predicted isolations:
  NODE_1_1_8: 1.0000 - ISOLATED
  NODE_1_1_7: 1.0000 - ISOLATED
  NODE_1_1_5: 0.0000
  NODE_1_1_4: 0.0000
  NODE_1_1_3: 0.0000
  NODE_1_1_2: 0.0000
  NODE_1_1_1: 0.0000
  NODE_1_1_0: 0.0000


In [17]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, DataLoader
import networkx as nx
import pandas as pd
import numpy as np
import random

# Build network graph from topology data
def build_network_graph(topology_df):
    """Convert topology DataFrame to NetworkX graph with node attributes"""
    G = nx.Graph()
    
    # Add edges from topology data
    for _, row in topology_df.iterrows():
        aend = row['aendname']
        bend = row['bendname']
        
        # Skip if missing node names
        if pd.isna(aend) or pd.isna(bend):
            continue
            
        # Extract attributes from node names (NODE_PR_LR_POS)
        aend_attrs = {}
        bend_attrs = {}
        
        if isinstance(aend, str) and 'NODE_' in aend:
            parts = aend.split('_')
            if len(parts) >= 4:
                aend_attrs['pr_id'] = int(parts[1])
                aend_attrs['lr_id'] = int(parts[2])
                aend_attrs['position'] = int(parts[3])
                aend_attrs['is_block'] = False
                
        if isinstance(bend, str) and 'NODE_' in bend:
            parts = bend.split('_')
            if len(parts) >= 4:
                bend_attrs['pr_id'] = int(parts[1])
                bend_attrs['lr_id'] = int(parts[2])
                bend_attrs['position'] = int(parts[3])
                bend_attrs['is_block'] = False
        
        # Add nodes with attributes
        G.add_node(aend, **(aend_attrs or {}))
        G.add_node(bend, **(bend_attrs or {}))
        
        # Add edge with any additional attributes you want
        G.add_edge(aend, bend)
    
    return G

# GNN model for isolation prediction
class IsolationGNN(torch.nn.Module):
    def __init__(self, num_features, hidden_dim=64):
        super(IsolationGNN, self).__init__()
        # First layer: standard graph convolution
        self.conv1 = GCNConv(num_features, hidden_dim)
        
        # Second layer: attention for understanding important connections
        self.conv2 = GATConv(hidden_dim, hidden_dim, heads=2, dropout=0.2)
        
        # Third layer: combine and refine
        self.conv3 = GCNConv(hidden_dim * 2, hidden_dim)
        
        # Final prediction layer
        self.classifier = torch.nn.Linear(hidden_dim, 1)
    
    def forward(self, x, edge_index, failed_nodes=None):
        # First convolution
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        
        # Attention layer
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # Final convolution
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        
        # Prediction
        x = self.classifier(x)
        return torch.sigmoid(x)

# Create dataset for GNN training
def create_isolation_dataset(G, num_examples=5000):
    """Generate training examples by simulating node failures"""
    # Create node mappings
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Create node features
    num_nodes = len(node_list)
    node_features = torch.zeros((num_nodes, 6), dtype=torch.float)
    
    # Prepare node features
    for i, node in enumerate(node_list):
        attrs = G.nodes[node]
        # PR ID
        node_features[i, 0] = attrs.get('pr_id', 0) / 10.0
        # LR ID
        node_features[i, 1] = attrs.get('lr_id', 0) / 10.0
        # Position
        node_features[i, 2] = attrs.get('position', 0) / 100.0
        # Is block node
        node_features[i, 3] = 1.0 if attrs.get('is_block', False) else 0.0
        # Node degree (normalized)
        node_features[i, 4] = G.degree(node) / 5.0
        # Status (0=normal, 1=failed)
        node_features[i, 5] = 0.0  # Will be set to 1.0 for failed nodes
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Group nodes by ring for failure simulation
    ring_groups = {}
    for node in G.nodes():
        attrs = G.nodes[node]
        if 'pr_id' in attrs and 'lr_id' in attrs:
            key = (attrs['pr_id'], attrs['lr_id'])
            if key not in ring_groups:
                ring_groups[key] = []
            ring_groups[key].append(node)
    
    # Sort nodes by position within each ring
    for key in ring_groups:
        ring_groups[key].sort(key=lambda n: G.nodes[n].get('position', 0))
    
    # Generate examples
    data_list = []
    examples_with_isolations = 0
    
    for _ in range(num_examples):
        # Choose a ring with enough nodes
        valid_rings = [k for k, v in ring_groups.items() if len(v) >= 4]
        if not valid_rings:
            continue
            
        ring_key = random.choice(valid_rings)
        ring_nodes = ring_groups[ring_key]
        
        # Choose two positions with nodes between them
        idx1 = random.randint(0, len(ring_nodes) - 3)
        idx2 = random.randint(idx1 + 2, len(ring_nodes) - 1)
        
        node1 = ring_nodes[idx1]
        node2 = ring_nodes[idx2]
        
        # Get positions
        pos1 = G.nodes[node1].get('position', 0)
        pos2 = G.nodes[node2].get('position', 0)
        
        # Find isolated nodes (positions between pos1 and pos2)
        isolated_nodes = []
        for node in ring_nodes:
            if node == node1 or node == node2:
                continue
                
            pos = G.nodes[node].get('position', 0)
            if pos1 < pos < pos2:
                isolated_nodes.append(node)
        
        # Create feature matrix for this example (copy the base features)
        x = node_features.clone()
        
        # Mark failed nodes
        x[node_to_idx[node1], 5] = 1.0
        x[node_to_idx[node2], 5] = 1.0
        
        # Create target tensor
        y = torch.zeros(num_nodes, dtype=torch.float)
        for node in isolated_nodes:
            y[node_to_idx[node]] = 1.0
        
        # Create PyG Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            y=y,
            failed_nodes=torch.tensor([node_to_idx[node1], node_to_idx[node2]], dtype=torch.long)
        )
        
        data_list.append(data)
        
        if isolated_nodes:
            examples_with_isolations += 1
    
    print(f"Dataset: {len(data_list)} examples, {examples_with_isolations} with isolations ({examples_with_isolations/len(data_list)*100:.2f}%)")
    return data_list, node_list, node_to_idx

# Training function
def train_isolation_gnn(data_list, epochs=50, lr=0.001):
    """Train GNN model for isolation prediction"""
    # Split data into training and validation
    random.shuffle(data_list)
    split = int(0.8 * len(data_list))
    train_data = data_list[:split]
    val_data = data_list[split:]
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    
    # Create model
    num_features = train_data[0].x.shape[1]
    model = IsolationGNN(num_features)
    
    # Optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    criterion = torch.nn.BCELoss()
    
    # Training loop
    best_f1 = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = criterion(out.squeeze(), batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # Validation
        model.eval()
        val_preds = []
        val_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch.x, batch.edge_index)
                pred = (out.squeeze() > 0.5).float()
                val_preds.append(pred)
                val_targets.append(batch.y)
        
        val_preds = torch.cat(val_preds)
        val_targets = torch.cat(val_targets)
        
        # Calculate metrics
        tp = ((val_preds == 1) & (val_targets == 1)).sum().item()
        fp = ((val_preds == 1) & (val_targets == 0)).sum().item()
        fn = ((val_preds == 0) & (val_targets == 1)).sum().item()
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        accuracy = (val_preds == val_targets).float().mean().item()
        
        print(f"Epoch {epoch+1}/{epochs}: Loss: {total_loss/len(train_loader):.4f}, "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Acc: {accuracy:.4f}")
        
        # Early stopping
        if f1 > best_f1:
            best_f1 = f1
            patience_counter = 0
            torch.save(model.state_dict(), 'best_isolation_gnn.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered. Best F1: {best_f1:.4f}")
                break
    
    # Load best model
    model.load_state_dict(torch.load('best_isolation_gnn.pt'))
    return model

In [18]:
G = build_network_graph(topo_data)
data_list,_,_ = create_isolation_dataset(G, num_examples=5000)
model = train_isolation_gnn(data_list, epochs=100)

Dataset: 5000 examples, 5000 with isolations (100.00%)




Epoch 1/100: Loss: 0.1496, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9914
Epoch 2/100: Loss: 0.0305, Precision: 1.0000, Recall: 0.0518, F1: 0.0985, Acc: 0.9919
Epoch 3/100: Loss: 0.0232, Precision: 0.4657, Recall: 0.1351, F1: 0.2095, Acc: 0.9912
Epoch 4/100: Loss: 0.0226, Precision: 1.0000, Recall: 0.1351, F1: 0.2381, Acc: 0.9926
Epoch 5/100: Loss: 0.0224, Precision: 1.0000, Recall: 0.0946, F1: 0.1728, Acc: 0.9922
Epoch 6/100: Loss: 0.0220, Precision: 1.0000, Recall: 0.0248, F1: 0.0484, Acc: 0.9916
Epoch 7/100: Loss: 0.0218, Precision: 1.0000, Recall: 0.0019, F1: 0.0037, Acc: 0.9914
Epoch 8/100: Loss: 0.0216, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9914
Epoch 9/100: Loss: 0.0215, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9914
Epoch 10/100: Loss: 0.0213, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9914
Epoch 11/100: Loss: 0.0212, Precision: 0.0000, Recall: 0.0000, F1: 0.0000, Acc: 0.9914
Epoch 12/100: Loss: 0.0210, Precision: 0.0000, Recal

In [24]:
def ring_aware_loss(predictions, targets, node_features, failed_pr_id, failed_lr_id, penalty_weight=10.0):
    """
    Custom loss that penalizes predicting isolation for nodes in different rings.
    Handles batched data properly.
    """
    # Fix shape mismatch by squeezing predictions to match targets
    if predictions.dim() > targets.dim():
        predictions_flat = predictions.squeeze()
    else:
        predictions_flat = predictions
    
    # Use binary_cross_entropy_with_logits for numerical stability
    base_loss = F.binary_cross_entropy_with_logits(predictions_flat, targets)
    
    # Get batch size and nodes per batch
    batch_size = failed_pr_id.size(0)
    nodes_per_batch = node_features.size(0) // batch_size
    
    # Reshape node features for batch processing
    batched_features = node_features.view(batch_size, nodes_per_batch, -1)
    
    # Extract ring IDs from batched features
    pr_ids = batched_features[:, :, 1] * 10.0  # Multiply by 10.0 to reverse scaling
    lr_ids = batched_features[:, :, 2] * 10.0  # Multiply by 10.0 to reverse scaling
    
    # Reshape failed IDs for broadcasting
    failed_pr_id_expanded = failed_pr_id.unsqueeze(1)  # [batch_size, 1]
    failed_lr_id_expanded = failed_lr_id.unsqueeze(1)  # [batch_size, 1]
    
    # Create mask for nodes in different rings (with proper broadcasting)
    diff_ring_mask = ((pr_ids != failed_pr_id_expanded) | (lr_ids != failed_lr_id_expanded)).float()
    
    # Reshape predictions for batch processing
    batched_preds = torch.sigmoid(predictions_flat).view(batch_size, nodes_per_batch)
    
    # Calculate penalty for each batch and sum
    batch_penalties = torch.sum(batched_preds * diff_ring_mask, dim=1)
    penalty = torch.mean(batch_penalties) * penalty_weight / nodes_per_batch
    
    # Combine base loss and penalty
    total_loss = base_loss + penalty
    
    return total_loss, base_loss, penalty

In [86]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, DataLoader
import networkx as nx
import pandas as pd
import numpy as np
import random
from sklearn.model_selection import train_test_split

class EnhancedGNN(nn.Module):
    def __init__(self, num_node_features):
        super(EnhancedGNN, self).__init__()
        # First layer - extract neighborhood features
        self.conv1 = GCNConv(num_node_features, 64)
        
        # Attention layer to focus on important connections
        self.attention = GATConv(64, 64, heads=2, dropout=0.2)
        
        # Position-aware layer (crucial for your task)
        self.position_encoder = nn.Sequential(
            nn.Linear(64*2, 64), 
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 64)
        )
        
        # Final prediction
        self.classifier = nn.Linear(64, 1)
        
    def forward(self, x, edge_index, failed_nodes_mask=None):
        # Initial feature processing
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.2,training=self.training)
        
        # Attention mechanism
        x = self.attention(x, edge_index)
        x = F.relu(x)
        
        x = self.position_encoder(x)

        
        # Final prediction
        x = self.classifier(x)
        return x

def build_graph_with_position_features(topology_df):
    """Build NetworkX graph with enhanced position features"""
    G = nx.Graph()
    
    # Track ring membership and positions
    ring_positions = {}  # (pr_id, lr_id) -> list of positions
    
    # First pass: identify all rings and node positions
    for _, row in topology_df.iterrows():
        for node_col in ['aendname', 'bendname']:
            node = row[node_col]
            if not isinstance(node, str) or 'NODE_' not in node:
                continue
                
            parts = node.split('_')
            if len(parts) >= 4:
                try:
                    pr_id = int(parts[1])
                    lr_id = int(parts[2])
                    pos = int(parts[3])
                    
                    key = (pr_id, lr_id)
                    if key not in ring_positions:
                        ring_positions[key] = []
                    
                    if pos not in ring_positions[key]:
                        ring_positions[key].append(pos)
                except ValueError:
                    continue
    
    # Sort positions within each ring
    for key in ring_positions:
        ring_positions[key].sort()
    
    # Add nodes and edges with position-aware features
    for _, row in topology_df.iterrows():
        aend = row['aendname']
        bend = row['bendname']
        
        # Add nodes with enhanced features
        for node in [aend, bend]:
            if node in G:
                continue  # Skip if already added
                
            # Default features
            features = {
                'is_block': 'BLOCK' in str(node),
                'pr_id': -1,
                'lr_id': -1,
                'position': -1,
  
            }
            
            # Extract position information
            if isinstance(node, str) and 'NODE_' in node:
                parts = node.split('_')
                if len(parts) >= 4:
                    try:
                        pr_id = int(parts[1])
                        lr_id = int(parts[2])
                        pos = int(parts[3])
                        
                        # Get normalized position (crucial for learning the pattern)


                        
                        features.update({
                            'pr_id': pr_id,
                            'lr_id': lr_id,
                            'position': pos,

                        })
                    except ValueError:
                        pass
            
            G.add_node(node, **features)
        
        # Add edge
        G.add_edge(aend, bend)
    
    return G

def create_enhanced_dataset(G, num_simulations=5000):
    """Create GNN dataset with enhanced position-aware features"""
    # Group nodes by ring
    ring_nodes = {}
    for node, attrs in G.nodes(data=True):
        pr_id = attrs.get('pr_id', -1)
        lr_id = attrs.get('lr_id', -1)
        
        if pr_id >= 0 and lr_id >= 0:
            key = (pr_id, lr_id)
            if key not in ring_nodes:
                ring_nodes[key] = []
            ring_nodes[key].append((node, attrs.get('position', -1)))
    
    # Sort nodes by position within each ring
    for key in ring_nodes:
        ring_nodes[key] = sorted([n for n in ring_nodes[key] if n[1] >= 0], key=lambda x: x[1])
    
    # Create node mapping and feature matrix
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Create edge index
    edge_index = []
    for u, v in G.edges():
        edge_index.append([node_to_idx[u], node_to_idx[v]])
        edge_index.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()
    
    # Create node features - critical for capturing position patterns
    x = []
    for node in node_list:
        attrs = G.nodes[node]
        
        # Core features that help predict isolations
        features = [
            float(attrs.get('is_block', False)),
            attrs.get('pr_id', -1) / 10.0,  # Normalize ring IDs
            attrs.get('lr_id', -1) / 10.0,
            attrs.get('position', -1) / 10.0,  # Normalize position
        ]
        
        x.append(features)
    
    x = torch.tensor(x, dtype=torch.float)
    
    # Create dataset with balanced sampling
    data_list = []
    positive_count = 0
    
    for _ in range(num_simulations):
        # Select a random ring with enough nodes
        valid_rings = [k for k, nodes in ring_nodes.items() if len(nodes) >= 4]
        if not valid_rings:
            continue
            
        ring_key = random.choice(valid_rings)
        nodes = ring_nodes[ring_key]
        
        if len(nodes) < 4:  # Need at least 4 nodes to have isolated nodes
            continue
        
        # Select two positions with nodes between them
        idx1 = random.randint(0, len(nodes) - 3)
        idx2 = random.randint(idx1 + 2, len(nodes) - 1)
        
        fail_node1, pos1 = nodes[idx1]
        fail_node2, pos2 = nodes[idx2]
        pos1 = G.nodes[fail_node1]['position']
        pos2 = G.nodes[fail_node2]['position']
        min_pos = min(pos1, pos2)
        max_pos = max(pos1, pos2)        
        # Create mask for failed nodes
        failed_mask = torch.zeros(len(node_list), dtype=torch.bool)
        failed_mask[node_to_idx[fail_node1]] = True
        failed_mask[node_to_idx[fail_node2]] = True
        
        # Create target: nodes between failed positions are isolated
        y = torch.zeros(len(node_list), dtype=torch.float)
        
        # Nodes between the two failed nodes in the same ring are isolated
        for node, pos in nodes:
            if pos1 < pos < pos2:  # Position is between failed nodes
                y[node_to_idx[node]] = 1.0
                positive_count += 1
        
        # Create additional feature: distance from failed nodes
        dist_features = torch.ones((len(node_list),2),dtype=torch.float)
        
        for node, attrs in G.nodes(data=True):
            node_idx = node_to_idx[node]
            node_pos = attrs.get('position', -1)
            
            if attrs.get('pr_id', -1) == ring_key[0] and attrs.get('lr_id', -1) == ring_key[1] and node_pos >= 0:
                # Normalized distances to both failed nodes
                dist_to_fail1 = abs(node_pos - pos1) / 100.0
                dist_to_fail2 = abs(node_pos - pos2) / 100.0
                dist_features[node_idx, 0] = dist_to_fail1
                dist_features[node_idx, 1] = dist_to_fail2
                is_between = 1.0 if (min_pos < node_pos < max_pos) else 0.0

        
        # Combine all features
        node_features = torch.cat([x, dist_features,is_between], dim=1)
        failed_pr_id = G.nodes[fail_node1].get('pr_id', -1)
        failed_lr_id = G.nodes[fail_node1].get('lr_id', -1)
        # Create Data object
        data = Data(
            x=node_features,
            edge_index=edge_index,
            y=y,
            failed_nodes=failed_mask,
            failed_pr_id=torch.tensor([failed_pr_id], dtype=torch.long),
            failed_lr_id=torch.tensor([failed_lr_id], dtype=torch.long),
            )
        
        data_list.append(data)
    
    print(f"Created {len(data_list)} examples with {positive_count} positive instances")
    
    return data_list, node_list, node_to_idx

def train_isolation_gnn(data_list, num_epochs=100):
    """Train the enhanced GNN model with strategies to address class imbalance"""
    # Split data
    train_data, val_data = train_test_split(data_list, test_size=0.2, random_state=42)
    
    # Calculate class weights based on imbalance
    all_y = torch.cat([data.y for data in train_data])
    pos_weight = (len(all_y) - all_y.sum()) / all_y.sum() if all_y.sum() > 0 else 10.0
    
    print(f"Using positive class weight: {pos_weight:.2f}")
    
    # Create DataLoaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    
    # Create model
    num_node_features = train_data[0].x.size(1)
    model = EnhancedGNN(num_node_features)
    
    # Use weighted loss
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
    
    # Optimizer with weight decay
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
    
    # Tracking best performance
    best_f1 = 0
    best_model_state = None
    patience = 10
    patience_counter = 0
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
 
        
        for data in train_loader:
            optimizer.zero_grad()
            
            # Forward pass
            out = model(data.x, data.edge_index, data.failed_nodes)

            loss = F.binary_cross_entropy_with_logits(
                out.squeeze(), 
                data.y,
                pos_weight=pos_weight  # Handle class imbalance
            )            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        # Validation
        model.eval()
        y_true = []
        y_pred = []
        
        with torch.no_grad():
            for data in val_loader:
                out = model(data.x, data.edge_index, data.failed_nodes)
                pred = (torch.sigmoid(out) > 0.95).float().squeeze()
                
                y_true.append(data.y)
                y_pred.append(pred)
        
        # Concatenate results
        y_true = torch.cat(y_true)
        y_pred = torch.cat(y_pred)
        
        # Calculate metrics
        true_pos = ((y_pred == 1) & (y_true == 1)).sum().item()
        pred_pos = y_pred.sum().item()
        real_pos = y_true.sum().item()
        
        precision = true_pos / max(1, pred_pos)
        recall = true_pos / max(1, real_pos)
        f1 = 2 * precision * recall / max(1e-8, precision + recall)
        accuracy = (y_pred == y_true).float().mean().item()
        
        print(f"Epoch {epoch+1}/{num_epochs}: Loss: {total_loss/len(train_loader):.4f}, "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Acc: {accuracy:.4f}")
        
        # Early stopping
        if f1 > best_f1:
            best_f1 = f1
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered. Best F1: {best_f1:.4f}")
                break
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    return model

# Example usage:
def run_enhanced_gnn_pipeline(topology_df):
    print("Building graph with position features...")
    G = build_graph_with_position_features(topology_df)
    
    print("Creating enhanced dataset...")
    data_list, node_list, node_to_idx = create_enhanced_dataset(G, num_simulations=1000)
    
    print("Training isolation prediction model...")
    model = train_isolation_gnn(data_list)
    
    # Save the trained model
    torch.save({
        'model_state_dict': model.state_dict(),
        'node_list': node_list,
        'node_to_idx': node_to_idx
    }, 'isolation_gnn_model.pt')
    
    print("Model training complete and saved to isolation_gnn_model.pt")
    return model, G, node_list, node_to_idx

In [87]:



# Example usage:
def run_enhanced_gnn_pipeline(topology_df):
    print("Building graph with position features...")
    G = build_graph_with_position_features(topology_df)
    
    print("Creating enhanced dataset...")
    data_list, node_list, node_to_idx = create_enhanced_dataset(G, num_simulations=1000)
    
    print("Training isolation prediction model...")
    model = train_isolation_gnn(data_list)
    
    # Save the trained model
    torch.save({
        'model_state_dict': model.state_dict(),
        'node_list': node_list,
        'node_to_idx': node_to_idx
    }, 'isolation_gnn_model.pt')
    
    print("Model training complete and saved to isolation_gnn_model.pt")
    return model, G, node_list, node_to_idx

In [88]:
run_enhanced_gnn_pipeline(topo_data)

Building graph with position features...
Creating enhanced dataset...


TypeError: expected Tensor as element 2 in argument 0, but got float

In [None]:
def predict_isolations(model, G, node_list, node_to_idx, failed_node1, failed_node2):
    """
    Predict which nodes will be isolated when two specific nodes fail
    
    Args:
        model: Trained EnhancedGNN model
        G: NetworkX graph with node features
        node_list: List of all node names
        node_to_idx: Dictionary mapping node names to indices
        failed_node1: First node that will fail
        failed_node2: Second node that will fail
        
    Returns:
        List of nodes predicted to be isolated
    """
    # Check if failed nodes exist in the graph
    if failed_node1 not in G or failed_node2 not in G:
        missing = []
        if failed_node1 not in G:
            missing.append(failed_node1)
        if failed_node2 not in G:
            missing.append(failed_node2)
        print(f"Error: These nodes don't exist in the graph: {', '.join(missing)}")
        return []
    
    # Get ring information for failed nodes
    node1_attrs = G.nodes[failed_node1]
    node2_attrs = G.nodes[failed_node2]
    
    pr_id1, lr_id1 = node1_attrs.get('pr_id', -1), node1_attrs.get('lr_id', -1)
    pr_id2, lr_id2 = node2_attrs.get('pr_id', -1), node2_attrs.get('lr_id', -1)
    pos1, pos2 = node1_attrs.get('position', -1), node2_attrs.get('position', -1)
    
    # Verify they're in the same ring
    if pr_id1 != pr_id2 or lr_id1 != lr_id2 or pr_id1 < 0 or lr_id1 < 0:
        print(f"Warning: Failed nodes are not in the same logical ring")
        print(f"  {failed_node1}: PR={pr_id1}, LR={lr_id1}, Pos={pos1}")
        print(f"  {failed_node2}: PR={pr_id2}, LR={lr_id2}, Pos={pos2}")
        return []
    
    # Ensure pos1 < pos2
    if pos1 > pos2:
        failed_node1, failed_node2 = failed_node2, failed_node1
        pos1, pos2 = pos2, pos1
    
    # Create edge index tensor from the graph
    edge_index = []
    for u, v in G.edges():
        edge_index.append([node_to_idx[u], node_to_idx[v]])
        edge_index.append([node_to_idx[v], node_to_idx[u]])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()
    
    # Create feature matrix with the current state
    x = []
    for node in node_list:
        attrs = G.nodes[node]
        
        # Basic features
        features = [
            float(attrs.get('is_block', False)),
            attrs.get('pr_id', -1) / 10.0,
            attrs.get('lr_id', -1) / 10.0,
            attrs.get('position', -1) / 100.0,
            attrs.get('position_normalized', 0.0),
            attrs.get('ring_size', 0) / 20.0
        ]
        
        # Add distance features to failed nodes
        node_pos = attrs.get('position', -1)
        node_pr = attrs.get('pr_id', -1)
        node_lr = attrs.get('lr_id', -1)
        
        # Calculate distances only for nodes in the same ring
        if node_pr == pr_id1 and node_lr == lr_id1 and node_pos >= 0:
            dist_to_fail1 = abs(node_pos - pos1) / 100.0
            dist_to_fail2 = abs(node_pos - pos2) / 100.0
        else:
            dist_to_fail1 = 1.0  # Max normalized distance
            dist_to_fail2 = 1.0
        
        features.extend([dist_to_fail1, dist_to_fail2])
        x.append(features)
    
    x = torch.tensor(x, dtype=torch.float)
    
    # Create mask for failed nodes
    failed_mask = torch.zeros(len(node_list), dtype=torch.bool)
    failed_mask[node_to_idx[failed_node1]] = True
    failed_mask[node_to_idx[failed_node2]] = True
    
    # Make predictions
    model.eval()
    with torch.no_grad():
        # Create a Data object
        data = Data(x=x, edge_index=edge_index, failed_nodes=failed_mask)
        
        # Run inference
        out = model(data.x, data.edge_index, data.failed_nodes)
        predictions = torch.sigmoid(out).squeeze()
        
        # Get isolated nodes (probability > 0.5)
        isolated_indices = torch.where(predictions > 0.5)[0].tolist()
        isolated_nodes = [node_list[idx] for idx in isolated_indices 
                         if node_list[idx] not in [failed_node1, failed_node2]]
    
    # Print results
    print(f"\nPrediction results when {failed_node1} and {failed_node2} fail:")
    print(f"  Failed node 1: {failed_node1} (Position {pos1})")
    print(f"  Failed node 2: {failed_node2} (Position {pos2})")
    print(f"  Physical Ring: {pr_id1}, Logical Ring: {lr_id1}")
    print(f"  Predicted isolated nodes: {len(isolated_nodes)}")
    
    # Optional: Calculate expected isolations based on position rule
    expected_isolated = []
    for node in node_list:
        attrs = G.nodes[node]
        if (attrs.get('pr_id', -1) == pr_id1 and 
            attrs.get('lr_id', -1) == lr_id1 and 
            pos1 < attrs.get('position', -1) < pos2):
            expected_isolated.append(node)
    
    # Compare with rule-based expectation
    print("\nIsolated nodes based on position rule:")
    for node in expected_isolated:
        node_pos = G.nodes[node].get('position', -1)
        is_predicted = node in isolated_nodes
        print(f"  {node} (Position {node_pos}): {'✓' if is_predicted else '✗'}")
    
    print("\nModel predictions:")
    for node in isolated_nodes:
        node_pos = G.nodes[node].get('position', -1)
        node_pr = G.nodes[node].get('pr_id', -1)
        node_lr = G.nodes[node].get('lr_id', -1)
        is_expected = pos1 < node_pos < pos2 and node_pr == pr_id1 and node_lr == lr_id1
        print(f"  {node} (Position {node_pos}, PR={node_pr}, LR={node_lr}): {'✓' if is_expected else '✗'}")
    
    # Calculate prediction accuracy
    correct = sum(1 for node in isolated_nodes if node in expected_isolated)
    total_expected = len(expected_isolated)
    if total_expected > 0:
        accuracy = correct / total_expected * 100
        print(f"\nAccuracy: {correct}/{total_expected} correct predictions ({accuracy:.2f}%)")
    
    return isolated_nodes

In [None]:
# Load a trained model (if you've saved it)
def load_model_and_predict(model_path, topology_df, failed_node1, failed_node2):
    # Load the model
    checkpoint = torch.load(model_path)
    
    # Create a new model instance
    num_features = 6 # 6 basic features + 2 distance features
    model = EnhancedGNN(num_features)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Get node list and mapping
    node_list = checkpoint.get('node_list')
    node_to_idx = checkpoint.get('node_to_idx')
    
    # If node mapping wasn't saved, rebuild the graph
    if not node_list or not node_to_idx:
        print("Rebuilding graph from topology data...")
        G = build_graph_with_position_features(topology_df)
        node_list = list(G.nodes())
        node_to_idx = {node: i for i, node in enumerate(node_list)}
    else:
        G = build_graph_with_position_features(topology_df)
    
    # Predict isolations
    isolated_nodes = predict_isolations(
        model, G, node_list, node_to_idx, failed_node1, failed_node2
    )
    
    return isolated_nodes


isolated_nodes = load_model_and_predict(
    'isolation_gnn_model.pt',
    topo_data,
    'NODE_4_1_2',
    'NODE_4_1_5'
)


Prediction results when NODE_4_1_2 and NODE_4_1_5 fail:
  Failed node 1: NODE_4_1_2 (Position 2)
  Failed node 2: NODE_4_1_5 (Position 5)
  Physical Ring: 4, Logical Ring: 1
  Predicted isolated nodes: 6

Isolated nodes based on position rule:
  NODE_4_1_3 (Position 3): ✓
  NODE_4_1_4 (Position 4): ✓

Model predictions:
  NODE_4_1_1 (Position 1, PR=4, LR=1): ✗
  NODE_4_1_3 (Position 3, PR=4, LR=1): ✓
  NODE_4_1_4 (Position 4, PR=4, LR=1): ✓
  NODE_4_1_6 (Position 6, PR=4, LR=1): ✗
  NODE_4_1_7 (Position 7, PR=4, LR=1): ✗
  NODE_4_1_8 (Position 8, PR=4, LR=1): ✗

Accuracy: 2/2 correct predictions (100.00%)


In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, DataLoader

class PrecisionIsolationGNN(nn.Module):
    def __init__(self, in_features):
        super(PrecisionIsolationGNN, self).__init__()
        # Print the input features dimension for debugging
        print(f"Model initialized with {in_features} input features")
        
        # Graph convolution layers
        self.conv1 = GCNConv(in_features, 32)
        self.conv2 = GATConv(32, 32, heads=2)
        
        # Prediction layers
        self.fc1 = nn.Linear(32*2, 32)
        self.fc2 = nn.Linear(32, 1)
    
    def forward(self, x, edge_index):
        # Basic GNN architecture with attention
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        # Final prediction
        x = F.dropout(x, p=0.2, training=self.training)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

def train_precision_gnn(data_list, num_epochs=100):
    """Train with focus on precision"""
    # Check feature dimensions
    print(f"Input data has {data_list[0].x.shape[1]} features per node")
    
    # Split training and validation
    train_size = int(0.8 * len(data_list))
    train_data = data_list[:train_size]
    val_data = data_list[train_size:]
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)
    
    # Create model with correct input dimension
    in_features = data_list[0].x.shape[1]
    model = PrecisionIsolationGNN(in_features)
    
    # Weighted loss - higher penalty for false positives
    pos_samples = sum(data.y.sum() for data in train_data)
    total_samples = sum(len(data.y) for data in train_data)
    pos_weight = (total_samples - pos_samples) / max(1, pos_samples)
    
    print(f"Using positive weight: {pos_weight:.2f}")
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]))
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    
    # Tracking best model
    best_precision = 0
    best_model = None
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        
        for data in train_loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = criterion(out.squeeze(), data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # Validation with higher threshold for precision
        model.eval()
        with torch.no_grad():
            all_preds = []
            all_targets = []
            
            for data in val_loader:
                out = model(data.x, data.edge_index)
                # Higher threshold increases precision
                pred = (torch.sigmoid(out) > 0.7).float().squeeze()
                all_preds.append(pred)
                all_targets.append(data.y)
            
            all_preds = torch.cat(all_preds)
            all_targets = torch.cat(all_targets)
            
            # Calculate metrics
            true_pos = ((all_preds == 1) & (all_targets == 1)).sum().item()
            pred_pos = all_preds.sum().item()
            actual_pos = all_targets.sum().item()
            
            # Handle division by zero
            precision = true_pos / max(1, pred_pos)
            recall = true_pos / max(1, actual_pos)
            f1 = 2 * precision * recall / max(0.001, precision + recall)
            accuracy = (all_preds == all_targets).float().mean().item()
        
        print(f"Epoch {epoch+1}/{num_epochs}: Loss: {total_loss/len(train_loader):.4f}, "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Acc: {accuracy:.4f}")
        
        # Save best model based on precision
        if precision > best_precision and recall > 0.2:
            best_precision = precision
            best_model = model.state_dict().copy()
            print(f"New best model (precision: {precision:.4f})")
    
    # Load best model
    if best_model:
        model.load_state_dict(best_model)
    
    return model

# Usage:


In [43]:
model = train_precision_gnn(data_list)

Input data has 6 features per node
Model initialized with 6 input features
Using positive weight: 111.90
Epoch 1/100: Loss: 0.9874, Precision: 0.3163, Recall: 0.8765, F1: 0.4648, Acc: 0.9823
New best model (precision: 0.3163)
Epoch 2/100: Loss: 0.4849, Precision: 0.3084, Recall: 0.8765, F1: 0.4563, Acc: 0.9817
Epoch 3/100: Loss: 0.4752, Precision: 0.3232, Recall: 0.8765, F1: 0.4723, Acc: 0.9828
New best model (precision: 0.3232)
Epoch 4/100: Loss: 0.4726, Precision: 0.3171, Recall: 0.8765, F1: 0.4658, Acc: 0.9824
Epoch 5/100: Loss: 0.4680, Precision: 0.3294, Recall: 0.8765, F1: 0.4788, Acc: 0.9833
New best model (precision: 0.3294)
Epoch 6/100: Loss: 0.4601, Precision: 0.3296, Recall: 0.8765, F1: 0.4790, Acc: 0.9833
New best model (precision: 0.3296)
Epoch 7/100: Loss: 0.4567, Precision: 0.3294, Recall: 0.8765, F1: 0.4788, Acc: 0.9833
Epoch 8/100: Loss: 0.4587, Precision: 0.3296, Recall: 0.8765, F1: 0.4791, Acc: 0.9833
New best model (precision: 0.3296)
Epoch 9/100: Loss: 0.4515, Preci

In [51]:
# Save the trained model
torch.save({
    'model_state_dict': model.state_dict(),
    'node_list': node_list,
    'node_to_idx': node_to_idx
}, 'precision_isolation_gnn_model.pt')

print("Model saved to precision_isolation_gnn_model.pt")

Model saved to precision_isolation_gnn_model.pt


In [52]:
def predict_with_trained_gnn(model, G, failed_node1, failed_node2):
    """
    Make predictions with a trained PrecisionIsolationGNN model
    
    Args:
        model: Trained PrecisionIsolationGNN model
        G: NetworkX graph with node features
        failed_node1, failed_node2: Names of nodes that will fail
        
    Returns:
        List of nodes predicted to be isolated
    """
    # Extract input feature dimension from model
    in_features = model.conv1.lin.weight.size(1)
    print(f"Model expects {in_features} input features")
    
    # Get node information
    if failed_node1 not in G or failed_node2 not in G:
        print(f"Error: Failed nodes not found in graph")
        return []
    
    # Get position information
    node1_attrs = G.nodes[failed_node1]
    node2_attrs = G.nodes[failed_node2]
    
    pr_id1 = node1_attrs.get('pr_id', -1)
    lr_id1 = node1_attrs.get('lr_id', -1)
    pos1 = node1_attrs.get('position', -1)
    
    pr_id2 = node2_attrs.get('pr_id', -1)
    lr_id2 = node2_attrs.get('lr_id', -1)
    pos2 = node2_attrs.get('position', -1)
    
    # Ensure position order
    if pos1 > pos2:
        pos1, pos2 = pos2, pos1
    
    print(f"\nPredicting isolations when nodes fail:")
    print(f"  {failed_node1} (PR={pr_id1}, LR={lr_id1}, Pos={pos1})")
    print(f"  {failed_node2} (PR={pr_id2}, LR={lr_id2}, Pos={pos2})")
    
    # Create node list and mapping
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Create features with EXACT dimension matching
    x = []
    for node in node_list:
        attrs = G.nodes[node]
        
        # Extract basic features
        node_pr = attrs.get('pr_id', -1)
        node_lr = attrs.get('lr_id', -1)
        node_pos = attrs.get('position', -1)
        
        # Create appropriate number of features based on model's expected input
        if in_features == 6:
            features = [
                float(attrs.get('is_block', False)),
                node_pr / 10.0,
                node_lr / 10.0,
                node_pos / 100.0,
                attrs.get('position_normalized', 0.0),
                attrs.get('ring_size', 0) / 20.0
            ]
        elif in_features == 8:
            # Calculate distances to failed nodes
            if node_pr == pr_id1 and node_lr == lr_id1 and node_pos >= 0:
                dist_to_fail1 = abs(node_pos - pos1) / 100.0
                dist_to_fail2 = abs(node_pos - pos2) / 100.0
            else:
                dist_to_fail1 = 1.0
                dist_to_fail2 = 1.0
                
            features = [
                float(attrs.get('is_block', False)),
                node_pr / 10.0,
                node_lr / 10.0,
                node_pos / 100.0,
                attrs.get('position_normalized', 0.0),
                attrs.get('ring_size', 0) / 20.0,
                dist_to_fail1,
                dist_to_fail2
            ]
        else:
            print(f"Error: Unknown feature dimension: {in_features}")
            return []
            
        x.append(features)
    
    # Create edge index
    edge_index = []
    for u, v in G.edges():
        if u in node_to_idx and v in node_to_idx:  # Safety check
            edge_index.append([node_to_idx[u], node_to_idx[v]])
            edge_index.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge
        
    # Convert to tensors
    edge_index = torch.tensor(edge_index, dtype=torch.long).t()
    x = torch.tensor(x, dtype=torch.float)
    
    print(f"Feature shape: {x.shape}, Edge index shape: {edge_index.shape}")
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        output = model(x, edge_index)
        probs = torch.sigmoid(output).squeeze()
        
        # Use same threshold as in training (0.7)
        predictions = (probs > 0.7).float()
    
    # Get isolated nodes
    isolated_nodes = []
    for i, pred in enumerate(predictions):
        if pred == 1 and node_list[i] not in [failed_node1, failed_node2]:
            isolated_nodes.append(node_list[i])
    
    # Print results
    print(f"\nPredicted {len(isolated_nodes)} isolated nodes:")
    for node in sorted(isolated_nodes, key=lambda n: G.nodes[n].get('position', -1)):
        node_pos = G.nodes[node].get('position', -1)
        node_pr = G.nodes[node].get('pr_id', -1)
        node_lr = G.nodes[node].get('lr_id', -1)
        is_expected = (node_pr == pr_id1 and node_lr == lr_id1 and pos1 < node_pos < pos2)
        status = "✓" if is_expected else "✗"
        print(f"  {node} (PR={node_pr}, LR={node_lr}, Pos={node_pos}) {status}")
    
    # Compare with ground truth
    expected = []
    for node in node_list:
        attrs = G.nodes[node]
        node_pr = attrs.get('pr_id', -1)
        node_lr = attrs.get('lr_id', -1)
        node_pos = attrs.get('position', -1)
        
        if (node_pr == pr_id1 and node_lr == lr_id1 and pos1 < node_pos < pos2):
            expected.append(node)
    
    correct = len([n for n in isolated_nodes if n in expected])
    precision = correct / max(1, len(isolated_nodes))
    recall = correct / max(1, len(expected))
    
    print(f"\nExpected {len(expected)} isolated nodes")
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}")
    
    return isolated_nodes

# Usage:
# isolated_nodes = predict_with_trained_gnn(model, G, "NODE_1_2_5", "NODE_1_2_9")

In [53]:
failed_node1 = "NODE_1_2_5"
failed_node2 = "NODE_1_2_9"
node_to_idx = {node: i for i, node in enumerate(node_list)}

isolated_nodes = predict_with_trained_gnn(model, G, "NODE_1_2_5", "NODE_1_2_9")

Model expects 6 input features

Predicting isolations when nodes fail:
  NODE_1_2_5 (PR=1, LR=2, Pos=5)
  NODE_1_2_9 (PR=1, LR=2, Pos=9)
Feature shape: torch.Size([310, 6]), Edge index shape: torch.Size([2, 660])

Predicted 0 isolated nodes:

Expected 3 isolated nodes
Precision: 0.00, Recall: 0.00
