In [2]:
import networkx as nx
import pandas as pd
import numpy as np
import random

  match = re.match("^#\s*version\s*([0-9a-z]*)\s*$", line)


In [3]:
def create_dummy_topology_data(num_rings=10, nodes_per_ring=10, logical_rings_per_physical=3):
    """
    Create dummy topology data for testing
    
    Args:
        num_rings: Number of physical rings to create
        nodes_per_ring: Number of nodes per ring
        logical_rings_per_physical: Number of logical rings per physical ring
        
    Returns:
        DataFrame with topology data
    """
    data = []
    
    for pr_idx in range(num_rings):
        physical_ring = f"RING_{pr_idx}"
        block_name = f"BLOCK_{pr_idx}"  # One block per physical ring
        
        # Each physical ring has multiple logical rings from the same block
        for lr_idx in range(logical_rings_per_physical):
            logical_ring = f"LR_{pr_idx}_{lr_idx}"
            
            # Create a linear path for each logical ring (not a complete ring)
            for i in range(nodes_per_ring - 1):  # Connect nodes in a path, not a ring
                node_a = f"NODE_{pr_idx}_{lr_idx}_{i}"
                node_b = f"NODE_{pr_idx}_{lr_idx}_{i+1}"
                
                data.append({
                    'aendname': node_a,
                    'bendname': node_b,
                    'aendip': f"10.{pr_idx}.{lr_idx}.{i}",
                    'bendip': f"10.{pr_idx}.{lr_idx}.{i+1}",
                    'aendifIndex': i,
                    'bendifIndex': i+1,
                    'block_name': block_name,
                    'physicalringname': physical_ring,
                    'lrname': logical_ring
                })
            
            # Connect the first and last nodes to the block
            # First node connects to block
            data.append({
                'aendname': f"NODE_{pr_idx}_{lr_idx}_0",
                'bendname': block_name,
                'aendip': f"10.{pr_idx}.{lr_idx}.0",
                'bendip': f"10.{pr_idx}.99.99",  # Special IP for block
                'aendifIndex': 100 + lr_idx,
                'bendifIndex': 100 + lr_idx,
                'block_name': block_name,
                'physicalringname': physical_ring,
                'lrname': logical_ring
            })
            
            # Last node connects to block
            data.append({
                'aendname': f"NODE_{pr_idx}_{lr_idx}_{nodes_per_ring-1}",
                'bendname': block_name,
                'aendip': f"10.{pr_idx}.{lr_idx}.{nodes_per_ring-1}",
                'bendip': f"10.{pr_idx}.99.99",  # Special IP for block
                'aendifIndex': 200 + lr_idx,
                'bendifIndex': 200 + lr_idx,
                'block_name': block_name,
                'physicalringname': physical_ring,
                'lrname': logical_ring
            })
    
    # Add connections between blocks from different physical rings
    # for pr_idx in range(num_rings):
    #     if pr_idx < num_rings - 1:  # Connect to next physical ring
    #         # Connect this block to the next ring's block
    #         block_a = f"BLOCK_{pr_idx}"
    #         block_b = f"BLOCK_{pr_idx+1}"
            
    #         data.append({
    #             'aendname': block_a,
    #             'bendname': block_b,
    #             'aendip': f"10.{pr_idx}.99.99",
    #             'bendip': f"10.{pr_idx+1}.99.99",
    #             'aendifIndex': 300 + pr_idx,
    #             'bendifIndex': 300 + pr_idx + 1,
    #             'block_name': block_a,
    #             'physicalringname': f"RING_{pr_idx}",
    #             'lrname': "INTER_BLOCK"  # Inter-block connection
    #         })
    
    return pd.DataFrame(data)

In [4]:
topo_data = create_dummy_topology_data()


In [5]:


def build_network_graph(topology_df):
    """Build a graph from topology_data_logical table"""
    
    G = nx.Graph()
    
    # Add nodes and edges with attributes
    for _, row in topology_df.iterrows():
        aend = row['aendname'].upper()
        bend = row['bendname'].upper()
        
        # Add node attributes
        G.add_node(aend, ip=row['aendip'], 
                   physicalringname=row['physicalringname'], 
                   lrname=row['lrname'],
                   block_name=row['block_name'])
        
        G.add_node(bend, ip=row['bendip'], 
                   physicalringname=row['physicalringname'], 
                   lrname=row['lrname'],
                   block_name=row['block_name'])
        
        # Add edge with attributes
        G.add_edge(aend, bend, 
                  physicalringname=row['physicalringname'], 
                  lrname=row['lrname'],
                  aendifIndex=row['aendifIndex'],
                  bendifIndex=row['bendifIndex'])
    
    return G


In [6]:
G =build_network_graph(topo_data)

In [7]:

from pyvis.network import Network
import matplotlib.colors as mcolors

def visualize_network(G, filename='network_graph.html'):
    """
    Visualize the network graph using PyVis
    
    Args:
        G: NetworkX graph
        filename: Output HTML file name
    """
    # Create a PyVis network
    net = Network(notebook=True, height="750px", width="100%")
    
    # Generate a color map for different physical rings and logical rings
    physical_rings = set(nx.get_node_attributes(G, 'physicalringname').values())
    logical_rings = set(nx.get_node_attributes(G, 'lrname').values())
    
    # Create color maps
    pr_colors = {pr: f"#{hash(pr) % 0xffffff:06x}" for pr in physical_rings}
    lr_colors = {lr: f"#{hash(lr) % 0xffffff:06x}" for lr in logical_rings}
    
    # Add nodes with attributes
    for node in G.nodes():
        pr = G.nodes[node]['physicalringname']
        lr = G.nodes[node]['lrname']
        is_block = node == G.nodes[node]['block_name']
        
        # Set node properties
        node_title = f"Node: {node}<br>IP: {G.nodes[node]['ip']}<br>PR: {pr}<br>LR: {lr}"
        node_color = pr_colors[pr]  # Color by physical ring
        node_shape = 'diamond' if is_block else 'dot'
        node_size = 25 if is_block else 15
        
        net.add_node(node, title=node_title, color=node_color, shape=node_shape, size=node_size, label=node)
    
    # Add edges with attributes
    for u, v in G.edges():
        pr = G.edges[u, v]['physicalringname']
        lr = G.edges[u, v]['lrname']
        
        # Set edge properties
        edge_title = f"PR: {pr}<br>LR: {lr}"
        edge_color = lr_colors[lr]  # Color by logical ring
        
        net.add_edge(u, v, title=edge_title, color=edge_color)
    
    # Set physics layout
    net.barnes_hut(spring_length=200)
    
    # Save and show the graph
    net.show(filename)
    
    return net

In [8]:
net = visualize_network(G, 'network_topology.html')

network_topology.html


In [9]:
def simulate_node_failure(G, node_to_fail):
    """
    Simulate a node failure and identify isolated nodes
    
    Returns:
        - isolated_nodes: list of nodes that become isolated
    """
    # Get the node's ring information
    if node_to_fail not in G.nodes():
        return []
    
    node_pr = G.nodes[node_to_fail]['physicalringname']
    node_lr = G.nodes[node_to_fail]['lrname']
    block_name = G.nodes[node_to_fail]['block_name']
    
    # Get all connected nodes in the same ring before failure
    before_graph = G.copy()
    block_component = set(nx.node_connected_component(before_graph, block_name))
    connected_before = {n for n in block_component 
                        if n in before_graph.nodes() and 
                        before_graph.nodes[n].get('physicalringname') == node_pr and 
                        before_graph.nodes[n].get('lrname') == node_lr}
    
    # Remove the failed node
    after_graph = G.copy()
    after_graph.remove_node(node_to_fail)
    
    # Get connected nodes after failure
    if block_name in after_graph.nodes():
        block_component_after = set(nx.node_connected_component(after_graph, block_name))
        connected_after = {n for n in block_component_after 
                          if n in after_graph.nodes() and 
                          after_graph.nodes[n].get('physicalringname') == node_pr and 
                          after_graph.nodes[n].get('lrname') == node_lr}
    else:
        connected_after = set()
    
    # Find isolated nodes (nodes connected before but not after)
    isolated_nodes = connected_before - connected_after - {node_to_fail}
    
    return list(isolated_nodes)


In [10]:

def simulate_edge_failure(G, edge_to_fail):
    """
    Simulate an edge failure and identify isolated nodes
    
    Returns:
        - isolated_nodes: list of nodes that become isolated
    """
    u, v = edge_to_fail
    if not G.has_edge(u, v):
        return []
    
    # Get edge information
    edge_pr = G.edges[u, v]['physicalringname']
    edge_lr = G.edges[u, v]['lrname']
    block_name = G.nodes[u]['block_name']  # Assuming same block for connected nodes
    
    # Get connected nodes before failure
    before_graph = G.copy()
    block_component = set(nx.node_connected_component(before_graph, block_name))
    connected_before = {n for n in block_component 
                        if n in before_graph.nodes() and 
                        before_graph.nodes[n].get('physicalringname') == edge_pr and 
                        before_graph.nodes[n].get('lrname') == edge_lr}
    
    # Remove the edge
    after_graph = G.copy()
    after_graph.remove_edge(u, v)
    
    # Get connected nodes after failure
    block_component_after = set(nx.node_connected_component(after_graph, block_name))
    connected_after = {n for n in block_component_after 
                       if n in after_graph.nodes() and 
                       after_graph.nodes[n].get('physicalringname') == edge_pr and 
                       after_graph.nodes[n].get('lrname') == edge_lr}
    
    # Find isolated nodes
    isolated_nodes = connected_before - connected_after
    
    return list(isolated_nodes)

In [11]:
# Print all edges in the graph
for edge in G.edges():
    print(edge)


('NODE_0_0_0', 'NODE_0_0_1')
('NODE_0_0_0', 'BLOCK_0')
('NODE_0_0_1', 'NODE_0_0_2')
('NODE_0_0_2', 'NODE_0_0_3')
('NODE_0_0_3', 'NODE_0_0_4')
('NODE_0_0_4', 'NODE_0_0_5')
('NODE_0_0_5', 'NODE_0_0_6')
('NODE_0_0_6', 'NODE_0_0_7')
('NODE_0_0_7', 'NODE_0_0_8')
('NODE_0_0_8', 'NODE_0_0_9')
('NODE_0_0_9', 'BLOCK_0')
('BLOCK_0', 'NODE_0_1_0')
('BLOCK_0', 'NODE_0_1_9')
('BLOCK_0', 'NODE_0_2_0')
('BLOCK_0', 'NODE_0_2_9')
('NODE_0_1_0', 'NODE_0_1_1')
('NODE_0_1_1', 'NODE_0_1_2')
('NODE_0_1_2', 'NODE_0_1_3')
('NODE_0_1_3', 'NODE_0_1_4')
('NODE_0_1_4', 'NODE_0_1_5')
('NODE_0_1_5', 'NODE_0_1_6')
('NODE_0_1_6', 'NODE_0_1_7')
('NODE_0_1_7', 'NODE_0_1_8')
('NODE_0_1_8', 'NODE_0_1_9')
('NODE_0_2_0', 'NODE_0_2_1')
('NODE_0_2_1', 'NODE_0_2_2')
('NODE_0_2_2', 'NODE_0_2_3')
('NODE_0_2_3', 'NODE_0_2_4')
('NODE_0_2_4', 'NODE_0_2_5')
('NODE_0_2_5', 'NODE_0_2_6')
('NODE_0_2_6', 'NODE_0_2_7')
('NODE_0_2_7', 'NODE_0_2_8')
('NODE_0_2_8', 'NODE_0_2_9')
('NODE_1_0_0', 'NODE_1_0_1')
('NODE_1_0_0', 'BLOCK_1')
('NODE

In [12]:
isolated_nodes = simulate_node_failure(G,'BLOCK_1')
isolated_nodes

['NODE_1_2_6',
 'NODE_1_2_2',
 'NODE_1_2_4',
 'NODE_1_2_9',
 'NODE_1_2_7',
 'NODE_1_2_1',
 'NODE_1_2_3',
 'NODE_1_2_0',
 'NODE_1_2_5',
 'NODE_1_2_8']

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
import torch_geometric.transforms as T
import numpy as np

In [14]:
def simulate_multiple_failures(G, elements_to_fail, failure_type='node'):
    """
    Simulate multiple node or edge failures and identify isolated nodes
    
    Args:
        G: NetworkX graph
        elements_to_fail: List of nodes or edges to fail
        failure_type: 'node' or 'edge'
    
    Returns:
        List of isolated nodes
    """
    # Create a copy of the graph 
    after_graph = G.copy()
    
    # Remove all failed elements
    if failure_type == 'node':
        for node in elements_to_fail:
            if node in after_graph:
                after_graph.remove_node(node)
    else:  # edge
        for edge in elements_to_fail:
            if after_graph.has_edge(*edge):
                after_graph.remove_edge(*edge)
    
    # Find isolated nodes (nodes that can't reach their block)
    isolated_nodes = []
    
    # Group nodes by physical ring and logical ring
    nodes_by_ring = {}
    for node in G.nodes():
        if 'BLOCK' in node:
            continue
            
        pr = G.nodes[node]['physicalringname']
        lr = G.nodes[node]['lrname']
        block = G.nodes[node]['block_name']
        
        key = (pr, lr, block)
        if key not in nodes_by_ring:
            nodes_by_ring[key] = []
        nodes_by_ring[key].append(node)
    
    # Check connectivity for each group
    for (pr, lr, block), nodes in nodes_by_ring.items():
        if block not in after_graph:
            # If block is gone, all nodes in this group are isolated
            isolated_nodes.extend(nodes)
            continue
            
        for node in nodes:
            if node in elements_to_fail or node not in after_graph:
                continue  # Skip failed nodes
                
            # Check if node can reach its block
            try:
                path = nx.shortest_path(after_graph, node, block)
            except nx.NetworkXNoPath:
                isolated_nodes.append(node)
    
    return isolated_nodes

In [15]:
isolated_nodes = simulate_multiple_failures(G,['BLOCK_1'],'node')
isolated_nodes


['NODE_1_0_0',
 'NODE_1_0_1',
 'NODE_1_0_2',
 'NODE_1_0_3',
 'NODE_1_0_4',
 'NODE_1_0_5',
 'NODE_1_0_6',
 'NODE_1_0_7',
 'NODE_1_0_8',
 'NODE_1_0_9',
 'NODE_1_1_0',
 'NODE_1_1_1',
 'NODE_1_1_2',
 'NODE_1_1_3',
 'NODE_1_1_4',
 'NODE_1_1_5',
 'NODE_1_1_6',
 'NODE_1_1_7',
 'NODE_1_1_8',
 'NODE_1_1_9',
 'NODE_1_2_0',
 'NODE_1_2_1',
 'NODE_1_2_2',
 'NODE_1_2_3',
 'NODE_1_2_4',
 'NODE_1_2_5',
 'NODE_1_2_6',
 'NODE_1_2_7',
 'NODE_1_2_8',
 'NODE_1_2_9']

In [16]:
def create_isolation_prediction_dataset(topology_df, num_simulations=1000, max_failures=3, balance_ratio=2):
    """
    Create an optimized dataset for predicting node isolations after failures
    
    Args:
        topology_df: DataFrame with network topology information
        num_simulations: Total number of examples to generate
        max_failures: Maximum number of simultaneous failures to simulate
        balance_ratio: Maximum ratio of non-isolation to isolation examples
        
    Returns:
        data_list: List of PyG Data objects
        node_list: List of all node names
        node_to_idx: Dictionary mapping node names to indices
    """
    print("Building network graph...")
    G = build_network_graph(topology_df)
    
    # Create node mapping
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Identify regular nodes (non-block nodes)
    regular_nodes = [n for n in node_list if "BLOCK" not in n]
    
    print("Computing network vulnerability metrics...")
    # Pre-compute network properties for isolation prediction
    articulation_points = set(nx.articulation_points(G))
    bridges = list(nx.bridges(G))
    
    # Group nodes by physical ring and logical ring
    nodes_by_ring = {}
    for node in regular_nodes:
        if node not in G.nodes():
            continue
            
        pr = G.nodes[node]['physicalringname']
        lr = G.nodes[node]['lrname']
        block = G.nodes[node]['block_name']
        key = (pr, lr, block)
        
        if key not in nodes_by_ring:
            nodes_by_ring[key] = []
        nodes_by_ring[key].append(node)
    
    # Helper function to calculate path redundancy
    def get_path_redundancy(G, node, block):
        """Return number of edge-disjoint paths from node to block"""
        try:
            return len(list(nx.edge_disjoint_paths(G, node, block)))
        except:
            return 0
    
    # Helper function to calculate vulnerability score
    def calculate_vulnerability_score(G, node, block, articulation_points, bridges):
        """Calculate a node's vulnerability to isolation"""
        vulnerability = 0.0
        
        # Factor 1: Few paths to block
        paths = get_path_redundancy(G, node, block)
        if paths == 0:
            vulnerability += 1.0  # Already isolated
        elif paths == 1:
            vulnerability += 0.6  # Critical - only one path
        elif paths == 2:
            vulnerability += 0.3  # Vulnerable - two paths
        
        # Factor 2: Node is articulation point
        if node in articulation_points:
            vulnerability += 0.2
        
        # Factor 3: Connected to bridges
        bridge_connections = 0
        for u, v in bridges:
            if node == u or node == v:
                bridge_connections += 1
        vulnerability += min(0.3, bridge_connections * 0.1)
        
        # Factor 4: Path to block goes through articulation points
        try:
            path = nx.shortest_path(G, node, block)
            critical_on_path = sum(1 for n in path[1:-1] if n in articulation_points)
            vulnerability += min(0.3, critical_on_path * 0.1)
        except:
            vulnerability += 0.2  # No path exists
        
        return min(1.0, vulnerability)  # Cap at 1.0
    
    print("Creating node features...")
    # Create isolation-focused node features
    node_features = []
    for node in node_list:
        features = []
        
        # Basic features
        pr = G.nodes[node]['physicalringname']
        pr_hash = hash(pr) % 10
        features.append(pr_hash/10.0)
        
        lr = G.nodes[node]['lrname']
        lr_hash = hash(lr) % 10
        features.append(lr_hash/10.0)
        
        is_block = 1.0 if "BLOCK" in node else 0.0
        features.append(is_block)
        
        # Skip advanced features for block nodes
        if is_block:
            # Add placeholder values for block nodes to match feature dimensions
            features.extend([0.0] * 7)  # 7 extra features for non-block nodes
        else:
            # Get block for this node
            block = G.nodes[node]['block_name']
            
            # 1. Path redundancy - key indicator of isolation risk
            path_redundancy = get_path_redundancy(G, node, block) / 5.0
            features.append(path_redundancy)
            
            # 2. Critical point features
            is_articulation = 1.0 if node in articulation_points else 0.0
            has_bridge = any((node == u or node == v) for u, v in bridges)
            features.extend([is_articulation, float(has_bridge)])
            
            # 3. Path to block features
            try:
                path_length = nx.shortest_path_length(G, node, block) / 10.0
                path = nx.shortest_path(G, node, block)
                critical_on_path = len(set(path[1:-1]).intersection(articulation_points)) / 5.0
            except:
                path_length = 1.0
                critical_on_path = 1.0
            features.extend([path_length, critical_on_path])
            
            # 4. Vulnerability score (combined metric)
            vulnerability = calculate_vulnerability_score(G, node, block, articulation_points, bridges)
            features.append(vulnerability)
            
            # 5. Network centrality
            degree = G.degree(node) / len(G)
            features.append(degree)
        
        node_features.append(features)
    
    # Convert to tensor
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge for undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Track dataset statistics
    examples_with_isolations = []
    examples_without_isolations = []
    isolation_counts = []
    failure_counts = []
    
    print("Generating dataset examples...")
    # PHASE 1: Generate examples that are likely to have isolations (50% of dataset)
    target_isolation_examples = num_simulations // 2
    phase1_attempts = 0
    
    while len(examples_with_isolations) < target_isolation_examples and phase1_attempts < num_simulations * 2:
        phase1_attempts += 1
        
        # Choose failure type (node or edge)
        failure_type = random.choice(['node', 'edge'])
        
        if failure_type == 'node':
            # STRATEGY: Target critical nodes or nodes in the same ring
            strategy = random.choice(['critical', 'same_ring', 'random'])
            
            if strategy == 'critical':
                # Target articulation points and their neighbors
                potential_targets = list(articulation_points.intersection(set(regular_nodes)))
                if len(potential_targets) == 0:
                    potential_targets = regular_nodes
                
                # Select 1-3 critical nodes
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                nodes_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Target multiple nodes in the same ring
                if not nodes_by_ring:
                    # Fallback to random
                    nodes_to_fail = random.sample(regular_nodes, random.randint(1, min(max_failures, len(regular_nodes))))
                else:
                    # Pick a random ring
                    ring_key = random.choice(list(nodes_by_ring.keys()))
                    ring_nodes = nodes_by_ring[ring_key]
                    
                    # Select 1-3 nodes from this ring
                    num_failures = random.randint(1, min(max_failures, len(ring_nodes)))
                    nodes_to_fail = random.sample(ring_nodes, num_failures)
            
            else:  # random strategy
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            for node in nodes_to_fail:
                node_idx = node_to_idx[node]
                invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                valid_edge_mask[invalid_edges] = False
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([0]),
                num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
            )
            
        else:  # edge failure
            # STRATEGY: Target bridge edges or edges in the same ring
            strategy = random.choice(['bridges', 'same_ring', 'random'])
            
            if strategy == 'bridges':
                # Target bridge edges
                potential_targets = [(u, v) for u, v in bridges if "BLOCK" not in u and "BLOCK" not in v]
                if len(potential_targets) == 0:
                    # Fallback to random edges
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    potential_targets = all_edges
                
                # Select 1-3 bridge edges
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                edges_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Target multiple edges in the same ring
                ring_edges = {}
                for u, v in G.edges():
                    if "BLOCK" in u or "BLOCK" in v:
                        continue
                        
                    try:
                        pr_u = G.nodes[u]['physicalringname']
                        lr_u = G.nodes[u]['lrname']
                        pr_v = G.nodes[v]['physicalringname']
                        lr_v = G.nodes[v]['lrname']
                        
                        if pr_u == pr_v and lr_u == lr_v:
                            key = (pr_u, lr_u)
                            if key not in ring_edges:
                                ring_edges[key] = []
                            ring_edges[key].append((u, v))
                    except:
                        continue
                
                if not ring_edges:
                    # Fallback to random
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    num_failures = random.randint(1, min(max_failures, len(all_edges)))
                    edges_to_fail = random.sample(all_edges, num_failures)
                else:
                    # Pick a random ring
                    ring_key = random.choice(list(ring_edges.keys()))
                    ring_edge_list = ring_edges[ring_key]
                    
                    # Select 1-3 edges from this ring
                    num_failures = random.randint(1, min(max_failures, len(ring_edge_list)))
                    edges_to_fail = random.sample(ring_edge_list, num_failures)
            
            else:  # random strategy
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            for u, v in edges_to_fail:
                u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                
                for i in range(edge_index.size(1)):
                    e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                    if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                        valid_edge_mask[i] = False
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([1]),
                num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
            )
        
        # Track isolation information
        isolation_count = y.sum().item()
        isolation_counts.append(isolation_count)
        failure_counts.append(len(nodes_to_fail if failure_type == 'node' else edges_to_fail))
        
        # Add to appropriate list
        if isolation_count > 0:
            examples_with_isolations.append(data)
        else:
            examples_without_isolations.append(data)
            
        # Report progress
        if phase1_attempts % 100 == 0:
            print(f"Phase 1: Generated {len(examples_with_isolations)} examples with isolations after {phase1_attempts} attempts")
    
    # PHASE 2: Generate remaining examples randomly to complete the dataset
    remaining = num_simulations - len(examples_with_isolations) - len(examples_without_isolations)
    
    if remaining > 0:
        print(f"Generating {remaining} additional examples for Phase 2...")
        
        for _ in range(remaining):
            # Choose failure type (node or edge)
            failure_type = random.choice(['node', 'edge'])
            
            if failure_type == 'node':
                # Choose random nodes
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                for node in nodes_to_fail:
                    node_idx = node_to_idx[node]
                    invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                    valid_edge_mask[invalid_edges] = False
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([0]),
                    num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
                )
                
            else:  # edge failure
                # Choose random edges
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                for u, v in edges_to_fail:
                    u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                    
                    for i in range(edge_index.size(1)):
                        e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                        if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                            valid_edge_mask[i] = False
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([1]),
                    num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
                )
            
            # Track isolation information
            isolation_count = y.sum().item()
            isolation_counts.append(isolation_count)
            failure_counts.append(len(nodes_to_fail if failure_type == 'node' else edges_to_fail))
            
            # Add to appropriate list
            if isolation_count > 0:
                examples_with_isolations.append(data)
            else:
                examples_without_isolations.append(data)
    
    # Balance the dataset using the specified ratio
    print(f"Balancing dataset with isolation:non-isolation ratio of 1:{balance_ratio}")
    if len(examples_without_isolations) > len(examples_with_isolations) * balance_ratio:
        # Downsample negative examples
        target_non_isolations = len(examples_with_isolations) * balance_ratio
        examples_without_isolations = random.sample(examples_without_isolations, int(target_non_isolations))
    
    # Combine and shuffle the final dataset
    final_data_list = examples_with_isolations + examples_without_isolations
    random.shuffle(final_data_list)
    
    # Print dataset statistics
    total_examples = len(final_data_list)
    total_isolations = len(examples_with_isolations)
    total_non_isolations = len(examples_without_isolations)
    
    print("\n=== Dataset Statistics ===")
    print(f"Total examples: {total_examples}")
    print(f"Examples with isolations: {total_isolations} ({total_isolations/total_examples*100:.1f}%)")
    print(f"Examples without isolations: {total_non_isolations} ({total_non_isolations/total_examples*100:.1f}%)")
    print(f"Average isolations per positive example: {sum(c for c in isolation_counts if c > 0)/max(1, total_isolations):.2f}")
    print(f"Average failures per example: {sum(failure_counts)/len(failure_counts):.2f}")
    
    # Detailed breakdown by failure type
    node_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 0)
    edge_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 1)
    
    print("\nIsolation examples by failure type:")
    print(f"  Node failures: {node_isolations} ({node_isolations/max(1, total_isolations)*100:.1f}%)")
    print(f"  Edge failures: {edge_isolations} ({edge_isolations/max(1, total_isolations)*100:.1f}%)")
    
    # Breakdown by number of failures
    failure_counts_dist = {}
    for data in final_data_list:
        num_fails = data.num_failures.item()
        has_isolations = data.y.sum().item() > 0
        
        key = f"{num_fails}_{'iso' if has_isolations else 'non'}"
        if key not in failure_counts_dist:
            failure_counts_dist[key] = 0
        failure_counts_dist[key] += 1
    
    print("\nExamples by number of failures:")
    for k in sorted(failure_counts_dist.keys()):
        num, iso_status = k.split('_')
        status = "with isolations" if iso_status == 'iso' else "without isolations"
        print(f"  {num} failures {status}: {failure_counts_dist[k]}")
    
    return final_data_list, node_list, node_to_idx

In [17]:
class IsolationPredictionGNN(torch.nn.Module):
    def __init__(self, num_node_features, edge_dim=None):
        super(IsolationPredictionGNN, self).__init__()
        # Increase hidden dimensions
        hidden_dim = 128
        
        # Pre-failure feature extraction
        self.node_encoder = nn.Sequential(
            nn.Linear(num_node_features, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.3)
        )
        
        # Edge features processing if available
        self.use_edge_features = edge_dim is not None
        if self.use_edge_features:
            self.edge_encoder = nn.Sequential(
                nn.Linear(edge_dim, hidden_dim),
                nn.ReLU()
            )
        
        # Initial structure learning (before failure)
        self.conv1 = GATConv(hidden_dim, hidden_dim//4, heads=4, concat=True, dropout=0.3)
        self.conv2 = GATConv(hidden_dim, hidden_dim//4, heads=4, concat=True, dropout=0.3)
        
        # Post-failure processing with separate convolutions
        self.post_failure_conv1 = GATConv(hidden_dim, hidden_dim//4, heads=4, concat=True, dropout=0.3)
        self.post_failure_conv2 = GATConv(hidden_dim, hidden_dim//4, heads=4, concat=True, dropout=0.3)
        
        # Explicit isolation prediction with deeper network
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim//2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim//2, 1)
        )
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Get node features
        x = self.node_encoder(x)
        
        # Process edge features if available
        edge_attr = None
        if self.use_edge_features and hasattr(data, 'edge_attr'):
            edge_attr = self.edge_encoder(data.edge_attr)
        
        # Learn network structure BEFORE failure
        h1 = F.relu(self.conv1(x, edge_index, edge_attr))
        h1 = F.dropout(h1, p=0.3, training=self.training)
        h_before = F.relu(self.conv2(h1, edge_index, edge_attr))
        
        # Apply failure mask and re-process
        if hasattr(data, 'valid_edge_mask') and data.valid_edge_mask is not None:
            # Use only valid edges (those not failed)
            valid_edges = edge_index[:, data.valid_edge_mask]
            
            if self.use_edge_features and edge_attr is not None:
                valid_edge_attr = edge_attr[data.valid_edge_mask]
            else:
                valid_edge_attr = None
                
            # Process AFTER failure with the post-failure convolutions
            if valid_edges.size(1) > 0:  # Only if we have valid edges
                h2 = F.relu(self.post_failure_conv1(h_before, valid_edges, valid_edge_attr))
                h2 = F.dropout(h2, p=0.3, training=self.training)
                h_after = F.relu(self.post_failure_conv2(h2, valid_edges, valid_edge_attr))
            else:
                # If all edges failed, just use the node features
                h_after = h_before  # Cannot propagate anything
        else:
            # No failures applied, use the same edges
            h_after = h_before
        
        # Concatenate before and after state for each node
        # This explicitly models the change from pre-failure to post-failure
        combined = torch.cat([h_before, h_after], dim=1)
        
        # Predict isolation probability
        return self.predictor(combined)

In [18]:
def create_isolation_prediction_dataset(topology_df, num_simulations=5000, max_failures=5, min_isolation_ratio=0.3):
    """
    Create a high-quality dataset for training a GNN to predict node isolations after failures.
    
    Args:
        topology_df: DataFrame with network topology information
        num_simulations: Total number of examples to generate
        max_failures: Maximum number of simultaneous failures to simulate
        min_isolation_ratio: Minimum ratio of examples that must contain isolations
        
    Returns:
        data_list: List of PyG Data objects
        node_list: List of all node names
        node_to_idx: Dictionary mapping node names to indices
    """
    print("Building network graph...")
    G = build_network_graph(topology_df)
    
    # Create node mapping
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Identify regular nodes (non-block nodes)
    regular_nodes = [n for n in node_list if "BLOCK" not in n]
    
    print("Computing network vulnerability metrics...")
    # Pre-compute network properties for better feature engineering
    articulation_points = set(nx.articulation_points(G))
    bridges = list(nx.bridges(G))
    
    # Pre-compute edge betweenness - identifying critical edges
    edge_betweenness = nx.edge_betweenness_centrality(G)
    
    # Group nodes by physical ring and logical ring
    nodes_by_ring = {}
    for node in regular_nodes:
        if node not in G.nodes():
            continue
            
        pr = G.nodes[node].get('physicalringname', 'unknown')
        lr = G.nodes[node].get('lrname', 'unknown')
        block = G.nodes[node].get('block_name', 'unknown')
        key = (pr, lr, block)
        
        if key not in nodes_by_ring:
            nodes_by_ring[key] = []
        nodes_by_ring[key].append(node)
    
    print("Calculating node vulnerability scores...")
    # Calculate vulnerability score for each node
    vulnerability_scores = {}
    for node in regular_nodes:
        if "BLOCK" in node:
            vulnerability_scores[node] = 0.0
            continue
            
        try:
            block = G.nodes[node].get('block_name', None)
            if not block:
                vulnerability_scores[node] = 0.5
                continue
                
            # Factor 1: Path redundancy to block
            try:
                paths = len(list(nx.edge_disjoint_paths(G, node, block)))
                if paths == 0:
                    path_score = 1.0  # Already isolated
                elif paths == 1: 
                    path_score = 0.8  # Critical - only one path
                elif paths == 2:
                    path_score = 0.5  # Vulnerable - two paths
                else:
                    path_score = 0.1  # Low risk
            except:
                path_score = 1.0  # Unable to calculate paths
                
            # Factor 2: Is it an articulation point?
            is_articulation = 1.0 if node in articulation_points else 0.0
            
            # Factor 3: Connected to bridges
            bridge_connections = sum(1 for u, v in bridges if node in (u, v))
            bridge_score = min(1.0, bridge_connections * 0.2)
            
            # Factor 4: Betweenness centrality - how important is it for connecting other nodes
            betweenness = nx.betweenness_centrality(G, k=min(100, len(G))).get(node, 0)
            
            # Factor 5: Closeness to block
            try:
                closeness = 1.0 / max(1.0, nx.shortest_path_length(G, node, block))
            except:
                closeness = 0.0
                
            # Combined score (weighted average)
            vulnerability = (
                0.4 * path_score + 
                0.2 * is_articulation + 
                0.2 * bridge_score + 
                0.1 * betweenness +
                0.1 * closeness
            )
            
            vulnerability_scores[node] = min(1.0, vulnerability)
        except Exception as e:
            print(f"Error calculating vulnerability for {node}: {e}")
            vulnerability_scores[node] = 0.5
    
    print("Creating node features...")
    # Create more informative node features
    node_features = []
    for node in node_list:
        features = []
        
        # Basic features
        pr = G.nodes[node].get('physicalringname', 'unknown')
        pr_hash = int(hashlib.md5(pr.encode()).hexdigest(), 16) % 10
        features.append(pr_hash/10.0)
        
        lr = G.nodes[node].get('lrname', 'unknown')
        lr_hash = int(hashlib.md5(lr.encode()).hexdigest(), 16) % 10
        features.append(lr_hash/10.0)
        
        is_block = 1.0 if "BLOCK" in node else 0.0
        features.append(is_block)
        
        # Skip advanced features for block nodes
        if is_block:
            # Add placeholder values for block nodes
            features.extend([0.0] * 8)  # 8 extra features 
        else:
            # Topological importance features
            degree = G.degree(node) / max(1, max(dict(G.degree()).values()))
            features.append(degree)
            
            clustering = nx.clustering(G, node)
            features.append(clustering)
            
            is_articulation = 1.0 if node in articulation_points else 0.0
            features.append(is_articulation)
            
            # Edge connectivity features
            bridges_connected = min(1.0, sum(1 for u, v in bridges if node in (u, v)) / 5.0)
            features.append(bridges_connected)
            
            # Block connectivity
            block = G.nodes[node].get('block_name', None)
            if block:
                try:
                    path_length = nx.shortest_path_length(G, node, block) / 10.0
                    disjoint_paths = len(list(nx.edge_disjoint_paths(G, node, block))) / 5.0
                except:
                    path_length = 1.0
                    disjoint_paths = 0.0
            else:
                path_length = 1.0
                disjoint_paths = 0.0
            features.append(path_length)
            features.append(disjoint_paths)
            
            # Centrality metrics
            betweenness = nx.betweenness_centrality(G, k=min(100, len(G))).get(node, 0)
            features.append(betweenness)
            
            # Pre-computed vulnerability score
            vulnerability = vulnerability_scores.get(node, 0.5)
            features.append(vulnerability)
        
        node_features.append(features)
    
    # Convert to tensor
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Create edge features
    edge_features = []
    for u, v in G.edges():
        # Compute various edge features
        # 1. Edge betweenness
        betweenness = edge_betweenness.get((u, v), edge_betweenness.get((v, u), 0))
        
        # 2. Same physical ring
        same_pr = int(G.nodes[u].get('physicalringname') == G.nodes[v].get('physicalringname'))
        
        # 3. Same logical ring
        same_lr = int(G.nodes[u].get('lrname') == G.nodes[v].get('lrname'))
        
        # 4. Is this a bridge?
        is_bridge = 1.0 if (u, v) in bridges or (v, u) in bridges else 0.0
        
        # Duplicate for reverse edge
        edge_features.append([betweenness, same_pr, same_lr, is_bridge])
        edge_features.append([betweenness, same_pr, same_lr, is_bridge])
    
    edge_attr = torch.tensor(edge_features, dtype=torch.float)
    
    # Track dataset statistics
    examples_with_isolations = []
    examples_without_isolations = []
    
    # Generate dataset with strategic failure selection
    print("Generating dataset examples...")
    
    # PHASE 1: Generate examples targeting vulnerabilities
    phase1_target = int(num_simulations * 0.7)  # 70% from targeted strategies
    
    while len(examples_with_isolations) + len(examples_without_isolations) < phase1_target:
        # Choose failure type (node or edge with 70/30 split)
        failure_type = 'node' if random.random() < 0.7 else 'edge'
        
        # Strategies weighted by effectiveness
        if failure_type == 'node':
            # Node failure strategies
            strategies = [
                ('high_vulnerability', 0.4),  # Target vulnerable nodes
                ('articulation', 0.3),        # Target articulation points
                ('same_ring', 0.2),           # Target nodes in same ring
                ('random', 0.1)               # Random targets
            ]
            
            strategy = random.choices([s[0] for s in strategies], 
                                     [s[1] for s in strategies])[0]
            
            if strategy == 'high_vulnerability':
                # Target high vulnerability nodes
                sorted_nodes = sorted([(n, vulnerability_scores.get(n, 0)) 
                                      for n in regular_nodes if "BLOCK" not in n],
                                     key=lambda x: x[1], reverse=True)
                
                # Take from top 20% vulnerable nodes
                top_vulnerable = sorted_nodes[:max(1, int(len(sorted_nodes) * 0.2))]
                potential_targets = [n[0] for n in top_vulnerable]
                
                if not potential_targets:
                    potential_targets = regular_nodes
                
                # Select 1-N failures
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                nodes_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'articulation':
                # Target articulation points
                potential_targets = list(articulation_points & set(regular_nodes))
                if len(potential_targets) < 2:
                    potential_targets = regular_nodes
                    
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                nodes_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Select nodes in the same ring 
                if not nodes_by_ring:
                    # Fallback to random
                    nodes_to_fail = random.sample(regular_nodes, 
                                                 random.randint(1, min(max_failures, len(regular_nodes))))
                else:
                    # Pick a random ring group
                    rings_with_enough_nodes = [key for key, nodes in nodes_by_ring.items() 
                                              if len(nodes) >= 2]
                    
                    if not rings_with_enough_nodes:
                        # Fallback to random  
                        nodes_to_fail = random.sample(regular_nodes,
                                                     random.randint(1, min(max_failures, len(regular_nodes))))
                    else:
                        ring_key = random.choice(rings_with_enough_nodes)
                        ring_nodes = nodes_by_ring[ring_key]
                        
                        # Select multiple nodes from this ring
                        num_failures = random.randint(1, min(max_failures, len(ring_nodes)))
                        nodes_to_fail = random.sample(ring_nodes, num_failures)
            
            else:  # random strategy
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            for node in nodes_to_fail:
                node_idx = node_to_idx[node]
                invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                valid_edge_mask[invalid_edges] = False
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([0]),
                failed_elements=[node_to_idx[n] for n in nodes_to_fail],
                num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
            )
            
        else:  # edge failure
            # Edge failure strategies
            strategies = [
                ('bridges', 0.4),       # Target bridge edges 
                ('high_betweenness', 0.3),  # Target high betweenness edges
                ('same_ring', 0.2),     # Target edges in the same ring
                ('random', 0.1)         # Random edges
            ]
            
            strategy = random.choices([s[0] for s in strategies], 
                                     [s[1] for s in strategies])[0]
            
            if strategy == 'bridges':
                # Target bridge edges
                potential_targets = [(u, v) for u, v in bridges if "BLOCK" not in u and "BLOCK" not in v]
                if len(potential_targets) == 0:
                    # Fallback to random edges
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    potential_targets = all_edges
                
                # Select multiple bridge edges
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                edges_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'high_betweenness':
                # Target high betweenness edges
                sorted_edges = sorted([(u, v) for u, v in edge_betweenness.keys() 
                                     if "BLOCK" not in u and "BLOCK" not in v],
                                    key=lambda e: edge_betweenness.get(e, 0), 
                                    reverse=True)
                
                # Take top 20% high betweenness edges
                potential_targets = sorted_edges[:max(1, int(len(sorted_edges) * 0.2))]
                
                if not potential_targets:
                    # Fallback to random
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    potential_targets = all_edges
                    
                num_failures = random.randint(1, min(max_failures, len(potential_targets)))
                edges_to_fail = random.sample(potential_targets, num_failures)
                
            elif strategy == 'same_ring':
                # Target multiple edges in the same ring
                ring_edges = {}
                for u, v in G.edges():
                    if "BLOCK" in u or "BLOCK" in v:
                        continue
                        
                    try:
                        pr_u = G.nodes[u].get('physicalringname')
                        lr_u = G.nodes[u].get('lrname')
                        pr_v = G.nodes[v].get('physicalringname')
                        lr_v = G.nodes[v].get('lrname')
                        
                        if pr_u == pr_v and lr_u == lr_v:
                            key = (pr_u, lr_u)
                            if key not in ring_edges:
                                ring_edges[key] = []
                            ring_edges[key].append((u, v))
                    except:
                        continue
                
                if not ring_edges:
                    # Fallback to random
                    all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                    num_failures = random.randint(1, min(max_failures, len(all_edges)))
                    edges_to_fail = random.sample(all_edges, num_failures)
                else:
                    # Pick a random ring
                    ring_key = random.choice(list(ring_edges.keys()))
                    ring_edge_list = ring_edges[ring_key]
                    
                    # Select multiple edges from this ring
                    num_failures = random.randint(1, min(max_failures, len(ring_edge_list)))
                    edges_to_fail = random.sample(ring_edge_list, num_failures)
            
            else:  # random strategy
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
            
            # Simulate failures
            isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
            
            # Create target tensor
            y = torch.zeros(len(node_list), dtype=torch.float)
            for node in isolated_nodes:
                if node in node_to_idx:
                    y[node_to_idx[node]] = 1.0
            
            # Create edge mask
            valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
            failed_edge_indices = []
            
            for u, v in edges_to_fail:
                u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                
                for i in range(edge_index.size(1)):
                    e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                    if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                        valid_edge_mask[i] = False
                        failed_edge_indices.append(i)
            
            # Create data object
            data = Data(
                x=node_features,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                valid_edge_mask=valid_edge_mask,
                failure_type=torch.tensor([1]),
                failed_elements=failed_edge_indices,
                num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
            )
        
        # Add to appropriate list
        if y.sum().item() > 0:
            examples_with_isolations.append(data)
        else:
            examples_without_isolations.append(data)
            
        # Report progress
        total_examples = len(examples_with_isolations) + len(examples_without_isolations)
        if total_examples % 100 == 0:
            isolation_ratio = len(examples_with_isolations) / max(1, total_examples)
            print(f"Generated {total_examples} examples, {len(examples_with_isolations)} with isolations ({isolation_ratio:.2f})")
    
    # PHASE 2: Generate remaining examples with random approach
    phase2_target = num_simulations - len(examples_with_isolations) - len(examples_without_isolations)
    
    if phase2_target > 0:
        print(f"Generating {phase2_target} additional examples for Phase 2...")
        
        for _ in range(phase2_target):
            # Choose failure type (node or edge)
            failure_type = random.choice(['node', 'edge'])
            
            if failure_type == 'node':
                # Choose random nodes
                num_failures = random.randint(1, min(max_failures, len(regular_nodes)))
                nodes_to_fail = random.sample(regular_nodes, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, nodes_to_fail, 'node')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                for node in nodes_to_fail:
                    node_idx = node_to_idx[node]
                    invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
                    valid_edge_mask[invalid_edges] = False
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([0]),
                    failed_elements=[node_to_idx[n] for n in nodes_to_fail],
                    num_failures=torch.tensor([len(nodes_to_fail)], dtype=torch.long)
                )
                
            else:  # edge failure
                # Choose random edges
                all_edges = [(u, v) for u, v in G.edges() if "BLOCK" not in u and "BLOCK" not in v]
                num_failures = random.randint(1, min(max_failures, len(all_edges)))
                edges_to_fail = random.sample(all_edges, num_failures)
                
                # Simulate failures
                isolated_nodes = simulate_multiple_failures(G, edges_to_fail, 'edge')
                
                # Create target tensor
                y = torch.zeros(len(node_list), dtype=torch.float)
                for node in isolated_nodes:
                    if node in node_to_idx:
                        y[node_to_idx[node]] = 1.0
                
                # Create edge mask
                valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
                failed_edge_indices = []
                
                for u, v in edges_to_fail:
                    u_idx, v_idx = node_to_idx[u], node_to_idx[v]
                    
                    for i in range(edge_index.size(1)):
                        e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                        if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                            valid_edge_mask[i] = False
                            failed_edge_indices.append(i)
                
                # Create data object
                data = Data(
                    x=node_features,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    y=y,
                    valid_edge_mask=valid_edge_mask,
                    failure_type=torch.tensor([1]),
                    failed_elements=failed_edge_indices,
                    num_failures=torch.tensor([len(edges_to_fail)], dtype=torch.long)
                )
            
            # Add to appropriate list
            if y.sum().item() > 0:
                examples_with_isolations.append(data)
            else:
                examples_without_isolations.append(data)
    
    # Balance the dataset based on min_isolation_ratio
    isolation_count = len(examples_with_isolations)
    total_count = isolation_count + len(examples_without_isolations)
    current_ratio = isolation_count / total_count
    
    print(f"Current isolation ratio: {current_ratio:.2f}, target: {min_isolation_ratio:.2f}")
    
    if current_ratio < min_isolation_ratio and len(examples_without_isolations) > 0:
        # Need to remove some non-isolation examples
        target_non_isolation = int(isolation_count * (1 - min_isolation_ratio) / min_isolation_ratio)
        if target_non_isolation < len(examples_without_isolations):
            print(f"Downsampling non-isolation examples from {len(examples_without_isolations)} to {target_non_isolation}")
            examples_without_isolations = random.sample(examples_without_isolations, target_non_isolation)
    
    # Combine and shuffle the final dataset
    final_data_list = examples_with_isolations + examples_without_isolations
    random.shuffle(final_data_list)
    
    # Print dataset statistics
    total_examples = len(final_data_list)
    total_isolations = len(examples_with_isolations)
    total_non_isolations = len(examples_without_isolations)
    
    print("\n=== Dataset Statistics ===")
    print(f"Total examples: {total_examples}")
    print(f"Examples with isolations: {total_isolations} ({total_isolations/total_examples*100:.1f}%)")
    print(f"Examples without isolations: {total_non_isolations} ({total_non_isolations/total_examples*100:.1f}%)")
    
    # Calculate average isolations per positive example
    avg_isolations = sum(data.y.sum().item() for data in examples_with_isolations) / max(1, total_isolations)
    print(f"Average isolated nodes per positive example: {avg_isolations:.2f}")
    
    # Detailed breakdown by failure type
    node_failures = sum(1 for data in final_data_list if data.failure_type.item() == 0)
    edge_failures = sum(1 for data in final_data_list if data.failure_type.item() == 1)
    
    node_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 0)
    edge_isolations = sum(1 for data in examples_with_isolations if data.failure_type.item() == 1)
    
    print("\nExamples by failure type:")
    print(f"  Node failures: {node_failures} ({node_failures/total_examples*100:.1f}%)")
    print(f"  Edge failures: {edge_failures} ({edge_failures/total_examples*100:.1f}%)")
    
    print("\nIsolation examples by failure type:")
    print(f"  Node failures with isolations: {node_isolations} ({node_isolations/total_isolations*100:.1f}%)")
    print(f"  Edge failures with isolations: {edge_isolations} ({edge_isolations/total_isolations*100:.1f}%)")
    
    # Breakdown by number of failures
    failure_counts = {}
    for data in final_data_list:
        num_fails = data.num_failures.item()
        has_isolations = data.y.sum().item() > 0
        
        key = f"{num_fails}_{'iso' if has_isolations else 'non'}"
        if key not in failure_counts:
            failure_counts[key] = 0
        failure_counts[key] += 1
    
    print("\nExamples by number of failures:")
    for k in sorted(failure_counts.keys()):
        num, iso_status = k.split('_')
        status = "with isolations" if iso_status == 'iso' else "without isolations"
        print(f"  {num} failures {status}: {failure_counts[k]}")
    
    return final_data_list, node_list, node_to_idx

In [19]:
import hashlib


def train_isolation_model(data_list, epochs=50):
    # Split data
    total = len(data_list)
    train_size = int(0.7 * total)
    val_size = int(0.15 * total)
    test_size = total - train_size - val_size
    
    # Shuffle and split
    random.shuffle(data_list)
    train_data = data_list[:train_size]
    val_data = data_list[train_size:train_size+val_size]
    test_data = data_list[train_size+val_size:]
    
    # Count positives in training set
    positive_count = sum(data.y.sum().item() > 0 for data in train_data)
    positive_rate = positive_count / len(train_data)
    print(f"Training set: {len(train_data)} examples, {positive_count} with isolations ({positive_rate:.2%})")
    
    # Create data loaders
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
    
    # Initialize model with edge features if available
    edge_dim = train_data[0].edge_attr.size(1) if hasattr(train_data[0], 'edge_attr') else None
    model = IsolationPredictionGNN(num_node_features=train_data[0].x.size(1), edge_dim=edge_dim)
    
    # Calculate positive weight based on class imbalance
    # This is crucial! Use a large value (50.0) to force the model to predict positives
    pos_weight = torch.tensor(30.0)  # Very high weight for rare isolation cases
    
    # You could also calculate from your data:
    # pos_samples = sum(d.y.sum().item() for d in train_data)
    # neg_samples = sum((~d.y.bool()).sum().item() for d in train_data)
    # pos_weight = torch.tensor(neg_samples / max(1, pos_samples))
    
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    
    # Use Adam with weight decay and a lower learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Training loop
    best_f1 = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            
            # Flatten predictions and targets
            pred = out.squeeze()
            target = batch.y
            
            loss = criterion(pred, target)
            loss.backward()
            
            # Clip gradients to prevent explosion
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs
        
        train_loss = total_loss / len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch)
                pred = out.squeeze()
                target = batch.y
                
                loss = criterion(pred, target)
                val_loss += loss.item() * batch.num_graphs
                
                # Use a MUCH LOWER threshold (0.1 or 0.2) to detect isolations
                # This counteracts the model's tendency to predict all negatives
                binary_pred = (torch.sigmoid(pred) > 0.2).float()
                
                all_preds.append(binary_pred)
                all_targets.append(target)
        
        val_loss /= len(val_loader.dataset)
        
        # Concatenate predictions and targets
        all_preds = torch.cat(all_preds, dim=0)
        all_targets = torch.cat(all_targets, dim=0)
        
        # Calculate metrics - but handle the case of no positive predictions
        tp = ((all_preds == 1) & (all_targets == 1)).sum().item()
        fp = ((all_preds == 1) & (all_targets == 0)).sum().item()
        fn = ((all_preds == 0) & (all_targets == 1)).sum().item()
        tn = ((all_preds == 0) & (all_targets == 0)).sum().item()
        
        precision = tp / max(tp + fp, 1)
        recall = tp / max(tp + fn, 1)
        f1 = 2 * precision * recall / max(precision + recall, 1e-8)
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        
        # Update learning rate based on validation loss
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}")
        
        # Early stopping based on F1 score
        if f1 > best_f1:
            best_f1 = f1
            patience_counter = 0
            torch.save(model.state_dict(), "best_isolation_model.pt")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}. Best F1: {best_f1:.4f}")
                model.load_state_dict(torch.load("best_isolation_model.pt"))
                break
    
    return model

In [20]:
def create_test_example(G, node_list, node_to_idx, failure_type, elements_to_fail):
    """
    Create a PyG Data object for testing the model with specific failures.
    
    Args:
        G: NetworkX graph
        node_list: List of all nodes in the graph
        node_to_idx: Mapping from node names to indices
        failure_type: 'node' or 'edge'
        elements_to_fail: List of nodes or edges to fail
        
    Returns:
        PyG Data object ready for model inference
    """
    # Create node features - must match the exact features used in training
    node_features = []
    for node in node_list:
        features = []
        
        # Basic features
        pr = G.nodes[node].get('physicalringname', 'unknown')
        pr_hash = int(hashlib.md5(pr.encode()).hexdigest(), 16) % 10
        features.append(pr_hash/10.0)
        
        lr = G.nodes[node].get('lrname', 'unknown')
        lr_hash = int(hashlib.md5(lr.encode()).hexdigest(), 16) % 10
        features.append(lr_hash/10.0)
        
        is_block = 1.0 if "BLOCK" in node else 0.0
        features.append(is_block)
        
        # Add the same advanced features used in your training dataset
        if is_block:
            # Add placeholder values for block nodes
            features.extend([0.0] * 8)  # Must match the number used in training
        else:
            # Calculate same advanced features
            degree = G.degree(node) / max(1, max(dict(G.degree()).values()))
            features.append(degree)
            
            clustering = nx.clustering(G, node)
            features.append(clustering)
            
            # Precompute these if needed for efficiency
            articulation_points = set(nx.articulation_points(G))
            bridges = list(nx.bridges(G))
            
            is_articulation = 1.0 if node in articulation_points else 0.0
            features.append(is_articulation)
            
            bridges_connected = min(1.0, sum(1 for u, v in bridges if node in (u, v)) / 5.0)
            features.append(bridges_connected)
            
            # Add the rest of your advanced features here
            # For example:
            block = G.nodes[node].get('block_name', None)
            if block:
                try:
                    path_length = nx.shortest_path_length(G, node, block) / 10.0
                    disjoint_paths = len(list(nx.edge_disjoint_paths(G, node, block))) / 5.0
                    features.append(path_length)
                    features.append(disjoint_paths)
                except:
                    features.append(1.0)  # Max path length
                    features.append(0.0)  # No paths
            else:
                features.append(1.0)
                features.append(0.0)
                
            # Isolation risk features
            single_path_vulnerability = 1.0 if disjoint_paths <= 1 else 0.0
            features.append(single_path_vulnerability)
            
            # Add other isolation risk features
            neighbors = list(G.neighbors(node))
            connects_to_articulation = 1.0 if any(n in articulation_points for n in neighbors) else 0.0
            features.append(connects_to_articulation)
            
            isolation_risk = max(single_path_vulnerability, is_articulation, bridges_connected)
            features.append(isolation_risk)
            
        node_features.append(features)
    
    # Convert to tensor
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edge_list = []
    for u, v in G.edges():
        u_idx, v_idx = node_to_idx[u], node_to_idx[v]
        edge_list.append([u_idx, v_idx])
        edge_list.append([v_idx, u_idx])  # Add both directions
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    
    # Create edge features if used in your model
    edge_attr = None
    if hasattr(G.edges[list(G.edges())[0]], 'features'):
        edge_attr = []
        for i, (u, v) in enumerate(edge_list):
            u_name, v_name = node_list[u], node_list[v]
            if G.has_edge(u_name, v_name):
                edge_attr.append(G.edges[u_name, v_name].get('features', [0.0]))
            else:
                edge_attr.append([0.0])
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Create edge mask for failure simulation
    valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
    
    if failure_type == 'node':
        for node in elements_to_fail:
            node_idx = node_to_idx[node]
            invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
            valid_edge_mask[invalid_edges] = False
    else:  # edge failure
        for u, v in elements_to_fail:
            u_idx, v_idx = node_to_idx[u], node_to_idx[v]
            for i in range(edge_index.size(1)):
                e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                    valid_edge_mask[i] = False
    
    # Create Data object
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        valid_edge_mask=valid_edge_mask,
        failure_type=torch.tensor([0 if failure_type == 'node' else 1]),
        failed_elements=[node_to_idx[n] if failure_type == 'node' 
                        else [node_to_idx[u], node_to_idx[v]] for n in elements_to_fail]
    )
    
    return data

In [21]:
import random
import torch
import copy
from torch_geometric.loader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv

# 1. Create better dataset with proper balance
data_list, node_list, node_to_idx = create_isolation_prediction_dataset(
    topo_data,
    num_simulations=5000,
    max_failures=10,
    min_isolation_ratio=0.4 # Force 50% of examples to have isolations
)


Building network graph...
Computing network vulnerability metrics...
Calculating node vulnerability scores...
Creating node features...
Generating dataset examples...
Generated 100 examples, 41 with isolations (0.41)
Generated 200 examples, 67 with isolations (0.34)
Generated 300 examples, 104 with isolations (0.35)
Generated 400 examples, 142 with isolations (0.35)
Generated 500 examples, 192 with isolations (0.38)
Generated 600 examples, 230 with isolations (0.38)
Generated 700 examples, 267 with isolations (0.38)
Generated 800 examples, 294 with isolations (0.37)
Generated 900 examples, 329 with isolations (0.37)
Generated 1000 examples, 366 with isolations (0.37)
Generated 1100 examples, 408 with isolations (0.37)
Generated 1200 examples, 443 with isolations (0.37)
Generated 1300 examples, 483 with isolations (0.37)
Generated 1400 examples, 522 with isolations (0.37)
Generated 1500 examples, 563 with isolations (0.38)
Generated 1600 examples, 605 with isolations (0.38)
Generated 17

In [None]:

# 2. Train with the improved loss function
model = train_isolation_model(data_list, epochs=100)

# 3. Test with a known isolation case (lower threshold!)
def test_model_with_known_case():
    # Find a node that causes isolations when it fails
    for node in G.nodes():
        if "BLOCK" not in node:
            isolated = simulate_multiple_failures(G, [node], 'node')
            if len(isolated) > 0:
                print(f"Testing with node {node} which should cause {len(isolated)} isolations")
                
                # Create test data
                test_data = create_test_example(G, node_list, node_to_idx, 'node', [node])
                
                # Predict
                model.eval()
                with torch.no_grad():
                    pred = model(test_data)
                    # Use lower threshold here!
                    binary_pred = (torch.sigmoid(pred.squeeze()) > 0.2).float()
                
                # Check results
                predicted_nodes = [node_list[i] for i in range(len(node_list)) if binary_pred[i] > 0.5]
                print(f"Actual isolated: {isolated}")
                print(f"Predicted isolated: {predicted_nodes}")
                return

Training set: 3128 examples, 1284 with isolations (41.05%)




Epoch 0: Train Loss: 0.8430, Val Loss: 0.7388, Precision: 0.0059, Recall: 1.0000, F1: 0.0118, Accuracy: 0.0059
Epoch 1: Train Loss: 0.7217, Val Loss: 0.6036, Precision: 0.0059, Recall: 1.0000, F1: 0.0118, Accuracy: 0.0091
Epoch 2: Train Loss: 0.5831, Val Loss: 0.4293, Precision: 0.0113, Recall: 1.0000, F1: 0.0224, Accuracy: 0.4846
Epoch 3: Train Loss: 0.4686, Val Loss: 0.3550, Precision: 0.0148, Recall: 0.9829, F1: 0.0291, Accuracy: 0.6123
Epoch 4: Train Loss: 0.4037, Val Loss: 0.2935, Precision: 0.1104, Recall: 0.6979, F1: 0.1907, Accuracy: 0.9650
Epoch 5: Train Loss: 0.3554, Val Loss: 0.2766, Precision: 0.0821, Recall: 0.6783, F1: 0.1464, Accuracy: 0.9532
Epoch 6: Train Loss: 0.3345, Val Loss: 0.2672, Precision: 0.1303, Recall: 0.6498, F1: 0.2171, Accuracy: 0.9723
Epoch 7: Train Loss: 0.3201, Val Loss: 0.2712, Precision: 0.1302, Recall: 0.6466, F1: 0.2168, Accuracy: 0.9724
Epoch 8: Train Loss: 0.3136, Val Loss: 0.2677, Precision: 0.1302, Recall: 0.6466, F1: 0.2168, Accuracy: 0.9724
E

In [49]:
def predict_isolated_nodes(model, G, node_list, node_to_idx, failure_type, elements_to_fail, threshold=0.2):
    """
    Predict which nodes will become isolated after failures using a trained GNN model.
    
    Args:
        model: Trained GNN model
        G: NetworkX graph
        node_list: List of all nodes in the graph
        node_to_idx: Mapping from node names to indices
        failure_type: 'node' or 'edge'
        elements_to_fail: Node(s) or edge(s) to fail
        threshold: Probability threshold for considering a node isolated (lower is more sensitive)
    
    Returns:
        List of node names predicted to be isolated
    """
    # Set model to evaluation mode
    model.eval()
    
    # Pre-compute network properties for better feature engineering
    articulation_points = set(nx.articulation_points(G))
    bridges = list(nx.bridges(G))
    
    # Create node features
    node_features = []
    for node in node_list:
        # Feature 1-2: Physical and logical ring
        pr = G.nodes[node].get('physicalringname', 'unknown')
        pr_hash = hash(pr) % 10
        
        lr = G.nodes[node].get('lrname', 'unknown')
        lr_hash = hash(lr) % 10
        
        # Feature 3: Is it a block node?
        is_block = 1.0 if "BLOCK" in node else 0.0
        
        # Feature 4-5: Network properties
        degree = G.degree(node) / max(1, len(G))
        clustering = nx.clustering(G, node)
        
        # Feature 6: Is articulation point?
        is_articulation = 1.0 if node in articulation_points else 0.0
        
        # Feature 7: Connected to bridges
        bridges_connected = min(1.0, sum(1 for u, v in bridges if node in (u, v)) / 5.0)
        
        # Features 8-9: Path to block
        block = G.nodes[node].get('block_name', None)
        if block and not is_block:
            try:
                path_length = nx.shortest_path_length(G, node, block) / 10.0
                disjoint_paths = len(list(nx.edge_disjoint_paths(G, node, block))) / 5.0
            except:
                path_length = 1.0
                disjoint_paths = 0.0
        else:
            path_length = 0.0 if is_block else 1.0
            disjoint_paths = 1.0 if is_block else 0.0
        
        # Feature 10: Single path vulnerability
        single_path_vulnerability = 1.0 if not is_block and disjoint_paths <= 0.2 else 0.0
        
        # Feature 11: Isolation risk
        isolation_risk = max(single_path_vulnerability, is_articulation, bridges_connected) if not is_block else 0.0
        
        # Combine all features
        node_features.append([
            pr_hash/10.0, 
            lr_hash/10.0, 
            is_block, 
            degree, 
            clustering, 
            is_articulation, 
            bridges_connected, 
            path_length, 
            disjoint_paths, 
            single_path_vulnerability, 
            isolation_risk
        ])
    
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge for undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Create edge attributes - FIXED to have exactly 4 features per edge
    edge_attr = []
    for i in range(edge_index.size(1)):
        u_idx, v_idx = edge_index[0, i].item(), edge_index[1, i].item()
        u_node, v_node = node_list[u_idx], node_list[v_idx]
        
        # Edge features
        # 1. Same physical ring
        same_pr = 1.0 if G.nodes[u_node].get('physicalringname') == G.nodes[v_node].get('physicalringname') else 0.0
        
        # 2. Same logical ring
        same_lr = 1.0 if G.nodes[u_node].get('lrname') == G.nodes[v_node].get('lrname') else 0.0
        
        # 3. Is this a bridge?
        is_bridge = 1.0 if (u_node, v_node) in bridges or (v_node, u_node) in bridges else 0.0
        
        # 4. Block connection
        block_connection = 1.0 if ("BLOCK" in u_node) != ("BLOCK" in v_node) else 0.0
        
        # IMPORTANT: Use exactly 4 features for edge_attr
        edge_attr.append([same_pr, same_lr, is_bridge, block_connection])
    
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    
    # Convert elements_to_fail to list if it's a single element
    if not isinstance(elements_to_fail, list):
        elements_to_fail = [elements_to_fail]
    
    # Create edge mask for failure simulation
    valid_edge_mask = torch.ones(edge_index.size(1), dtype=torch.bool)
    
    if failure_type == 'node':
        # Mark all edges connected to failed nodes as invalid
        for node in elements_to_fail:
            node_idx = node_to_idx[node]
            invalid_edges = (edge_index[0] == node_idx) | (edge_index[1] == node_idx)
            valid_edge_mask[invalid_edges] = False
    else:  # Edge failure
        for failed_edge in elements_to_fail:
            u_idx, v_idx = node_to_idx[failed_edge[0]], node_to_idx[failed_edge[1]]
            
            for i in range(edge_index.size(1)):
                e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
                if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                    valid_edge_mask[i] = False
    
    # Create PyG Data object for the model
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_attr,
        valid_edge_mask=valid_edge_mask,
        failure_type=torch.tensor([0 if failure_type == 'node' else 1]),
        num_failures=torch.tensor([len(elements_to_fail)], dtype=torch.long)
    )
    
    # Get model predictions
    with torch.no_grad():
        out = model(data)
        pred_probs = torch.sigmoid(out.squeeze())
        
        # Handle different output shapes (batch vs single)
        if len(pred_probs.shape) > 0:
            pred_probs = pred_probs.numpy()
        else:
            pred_probs = pred_probs.item()
    
    # Find nodes predicted to be isolated
    threshold = 0.7
    predicted_isolated_indices = np.where(pred_probs > threshold)[0]
    predicted_isolated_nodes = [node_list[idx] for idx in predicted_isolated_indices]
    
    return predicted_isolated_nodes

In [50]:
elements_to_fail = ['NODE_2_2_0','NODE_2_2_4']
actual_isolated_nodes = simulate_multiple_failures(G, elements_to_fail, failure_type='node')
actual_isolated_nodes


['NODE_2_2_1', 'NODE_2_2_2', 'NODE_2_2_3']

In [51]:
predicted_isolated_nodes = predict_isolated_nodes(model, G, node_list, node_to_idx, 'node', elements_to_fail)
predicted_isolated_nodes



[]

In [105]:
def evaluate_prediction(G, model, node_list, node_to_idx, failure_type, elements_to_fail):
    """
    Evaluate how well the model predicts isolated nodes compared to actual simulation
    
    Args:
        G: NetworkX graph
        model: Trained GNN model
        node_list: List of all nodes
        node_to_idx: Mapping from node names to indices
        failure_type: 'node' or 'edge'
        elements_to_fail: List of nodes or edges to fail
        
    Returns:
        Dict with evaluation metrics
    """
    # Get actual isolated nodes via simulation
    if failure_type == 'node':
        actual_isolated = simulate_multiple_failures(G, elements_to_fail, 'node')
    else:
        actual_isolated = simulate_multiple_failures(G, elements_to_fail, 'edge')
    
    # Get predicted isolated nodes from model
    predicted_isolated = predict_isolated_nodes(model, G, node_list, node_to_idx, 
                                               failure_type, elements_to_fail)
    
    # Calculate evaluation metrics
    true_positives = len(set(actual_isolated) & set(predicted_isolated))
    false_positives = len(set(predicted_isolated) - set(actual_isolated))
    false_negatives = len(set(actual_isolated) - set(predicted_isolated))
    
    # Handle division by zero
    if true_positives + false_positives > 0:
        precision = true_positives / (true_positives + false_positives)
    else:
        precision = 1.0 if len(actual_isolated) == 0 else 0.0
        
    if true_positives + false_negatives > 0:
        recall = true_positives / (true_positives + false_negatives)
    else:
        recall = 1.0 if len(predicted_isolated) == 0 else 0.0
    
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0.0
    
    accuracy = len(set(actual_isolated) & set(predicted_isolated)) / max(1, len(set(actual_isolated) | set(predicted_isolated)))
    
    # Print results
    if failure_type == 'node':
        print(f"Node failure: {', '.join(elements_to_fail)}")
    else:
        print(f"Edge failure: {', '.join([f'({e[0]}-{e[1]})' for e in elements_to_fail])}")
    
    print(f"Actual isolated nodes: {', '.join(actual_isolated) if actual_isolated else 'None'}")
    print(f"Predicted isolated nodes: {', '.join(predicted_isolated) if predicted_isolated else 'None'}")
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1: {f1:.2f}, Accuracy: {accuracy:.2f}")
    
    return {
        'actual_isolated': actual_isolated,
        'predicted_isolated': predicted_isolated,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'accuracy': accuracy
    }

In [106]:
import random


In [107]:
# Test single node failures
print("\n=== TESTING SINGLE NODE FAILURES ===")
for node in random.sample(node_list, 3):  # Test 3 random nodes
    if "BLOCK" not in node:  # Skip block nodes
        evaluate_prediction(G, model, node_list, node_to_idx, 'node', [node])

# Test double node failures (more likely to cause isolations)
print("\n=== TESTING DOUBLE NODE FAILURES ===")
# Find nodes in the same logical ring
for pr_idx in range(2):
    for lr_idx in range(2):
        ring_nodes = [n for n in G.nodes() 
                     if 'physicalringname' in G.nodes[n] and 
                        G.nodes[n]['physicalringname'] == f"RING_{pr_idx}" and
                        'lrname' in G.nodes[n] and
                        G.nodes[n]['lrname'] == f"LR_{pr_idx}_{lr_idx}" and
                        "BLOCK" not in n]
        
        if len(ring_nodes) >= 3:
            # Pick two adjacent nodes in the ring
            nodes_to_fail = [ring_nodes[1], ring_nodes[2]]
            evaluate_prediction(G, model, node_list, node_to_idx, 'node', nodes_to_fail)
            break

# Test single edge failures
print("\n=== TESTING SINGLE EDGE FAILURES ===")
for edge in random.sample(list(G.edges()), 3):
    evaluate_prediction(G, model, node_list, node_to_idx, 'edge', [edge])

# Test double edge failures
print("\n=== TESTING DOUBLE EDGE FAILURES ===")
# Find edges in the same logical ring
edge_pairs = []
for pr_idx in range(2):
    for lr_idx in range(2):
        ring_edges = [(u, v) for u, v in G.edges() 
                      if 'physicalringname' in G.nodes[u] and 
                         G.nodes[u]['physicalringname'] == f"RING_{pr_idx}" and
                         'lrname' in G.nodes[u] and 
                         G.nodes[u]['lrname'] == f"LR_{pr_idx}_{lr_idx}"]
        
        if len(ring_edges) >= 2:
            edges_to_fail = [ring_edges[0], ring_edges[1]]
            evaluate_prediction(G, model, node_list, node_to_idx, 'edge', edges_to_fail)
            break


=== TESTING SINGLE NODE FAILURES ===
Node failure: NODE_0_0_4
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Node failure: NODE_2_2_3
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Node failure: NODE_1_2_1
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00

=== TESTING DOUBLE NODE FAILURES ===
Node failure: NODE_0_0_1, NODE_0_0_2
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Node failure: NODE_1_0_1, NODE_1_0_2
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00

=== TESTING SINGLE EDGE FAILURES ===
Edge failure: (NODE_0_1_3-NODE_0_1_4)
Actual isolated nodes: None
Predicted isolated nodes: None
Precision: 1.00, Recall: 1.00, F1: 1.00, Accuracy: 0.00
Edge fail

In [92]:
model.train()
for epoch in range(100):
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out.squeeze(), batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Validation
    if epoch % 5 == 0:
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                out = model(batch)
                val_loss += criterion(out.squeeze(), batch.y).item()
        
        print(f"Epoch {epoch}: Train Loss: {total_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}")
        model.train()

KeyError: 'failed_edge_mask'

In [5]:
def create_gnn_dataset(topology_df, num_simulations=100):
    """
    Create a dataset for training a GNN to predict isolated nodes
    
    Returns:
        - List of PyTorch Geometric Data objects
    """
    # Build the graph
    G = build_network_graph(topology_df)
    
    # Node mapping (for creating numerical indices)
    node_list = list(G.nodes())
    node_to_idx = {node: i for i, node in enumerate(node_list)}
    
    # Create node features
    node_features = []
    for node in node_list:
        # Feature 1: One-hot encoded physical ring
        pr = G.nodes[node]['physicalringname']
        pr_hash = hash(pr) % 10  # Simple encoding
        
        # Feature 2: One-hot encoded logical ring
        lr = G.nodes[node]['lrname']
        lr_hash = hash(lr) % 10  # Simple encoding
        
        # Feature 3: Is it a block node?
        is_block = 1.0 if node == G.nodes[node]['block_name'] else 0.0
        
        # Feature 4-5: Degree centrality and clustering coefficient
        degree = G.degree(node) / len(G)
        clustering = nx.clustering(G, node)
        
        node_features.append([pr_hash/10.0, lr_hash/10.0, is_block, degree, clustering])
    
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    # Create edge index for PyG
    edges = []
    for u, v in G.edges():
        edges.append([node_to_idx[u], node_to_idx[v]])
        edges.append([node_to_idx[v], node_to_idx[u]])  # Add reverse edge for undirected graph
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    # Create dataset entries
    data_list = []
    
    # Simulate node failures
    for _ in range(num_simulations // 2):  # Half node failures, half edge failures
        # Randomly select a node to fail
        node_to_fail = np.random.choice(node_list)
        
        # Find isolated nodes
        isolated_nodes = simulate_node_failure(G, node_to_fail)
        
        # Create target tensor (1 for isolated nodes, 0 for others)
        y = torch.zeros(len(node_list), dtype=torch.float)
        for node in isolated_nodes:
            y[node_to_idx[node]] = 1.0
        
        # Create node failure mask
        node_mask = torch.zeros(len(node_list), dtype=torch.bool)
        node_mask[node_to_idx[node_to_fail]] = True
        
        # Create PyG Data object
        data = Data(
            x=node_features,
            edge_index=edge_index,
            y=y,
            failed_node_mask=node_mask,
            failed_edge_mask=None
        )
        
        data_list.append(data)
    
    # Simulate edge failures
    edge_list = list(G.edges())
    for _ in range(num_simulations // 2):
        # Randomly select an edge to fail
        edge_to_fail = edge_list[np.random.randint(0, len(edge_list))]
        
        # Find isolated nodes
        isolated_nodes = simulate_edge_failure(G, edge_to_fail)
        
        # Create target tensor
        y = torch.zeros(len(node_list), dtype=torch.float)
        for node in isolated_nodes:
            y[node_to_idx[node]] = 1.0
        
        # Create edge failure mask
        edge_mask = torch.zeros(edge_index.size(1), dtype=torch.bool)
        u_idx, v_idx = node_to_idx[edge_to_fail[0]], node_to_idx[edge_to_fail[1]]
        
        for i in range(edge_index.size(1)):
            e_u, e_v = edge_index[0, i].item(), edge_index[1, i].item()
            if (e_u == u_idx and e_v == v_idx) or (e_u == v_idx and e_v == u_idx):
                edge_mask[i] = True
        
        # Create PyG Data object
        data = Data(
            x=node_features,
            edge_index=edge_index,
            y=y,
            failed_node_mask=None,
            failed_edge_mask=edge_mask
        )
        
        data_list.append(data)
    
    return data_list


In [None]:
import mysql.connector
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

def load_topology_data(config):
    """Load topology data from MySQL database"""
    connection = mysql.connector.connect(**config)
    cursor = connection.cursor(dictionary=True)
    
    query = """
        SELECT 
           aendname, 
           bendname, 
           aendip, 
           bendip, 
           aendifIndex,
           bendifIndex,
           block_name, 
           physicalringname, 
           lrname 
        FROM topology_data_logical
    """
    cursor.execute(query)
    rows = cursor.fetchall()
    connection.close()
    
    return pd.DataFrame(rows)


In [11]:


# Database configuration
db_config = {
    "host": "192.168.30.15",
    "user": "nms",
    "password": "Nms@1234",
    "database": "cnmsip"
}

# Load topology data
topology_df = load_topology_data(db_config)

# Create dataset for GNN training
data_list = create_gnn_dataset(topology_df, num_simulations=200)
print(f"Created {len(data_list)} data entries")

KeyError: 'Pakhanjur'