In [159]:
from src.graph_neural_networks import LightningGNNCustom
from src.criticality_score import criticality_score
from src.accessibility_indices import global_efficiency, number_independent_paths
from src.utils import convert_nx_to_pyg
from src.config import PATH_MODELS
import networkx as nx
from torch_geometric.utils import to_networkx


In [151]:
gnn_ge = LightningGNNCustom.load_from_checkpoint(PATH_MODELS/"model_global_efficiency.ckpt", map_location="cpu")
gnn_ge.eval()

LightningGNNCustom(
  (model): GNNCustom(
    (msg_block): GAT(3, 64, num_layers=3)
    (edge_mlp): MLP(
      (0): Linear(in_features=129, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.28, inplace=False)
      (3): Linear(in_features=64, out_features=1, bias=True)
      (4): Dropout(p=0.28, inplace=False)
    )
  )
)

In [152]:
# Create a simple directed graph
G = nx.DiGraph()

# Add nodes
nodes = ['A', 'B', 'C', 'D', 'E', 'F']
G.add_nodes_from(nodes)

# Add edges with weights (distance) and capacity (# of trips)
edges = [
 ('A', 'B', {'weight': 1, 'capacity': 10}),
 ('A', 'C', {'weight': 2, 'capacity': 5}),
 ('B', 'D', {'weight': 1, 'capacity': 8}),
 ('C', 'E', {'weight': 2, 'capacity': 4}),
 ('D', 'F', {'weight': 1, 'capacity': 10}),
 ('E', 'F', {'weight': 1, 'capacity': 2}),
 ('B', 'E', {'weight': 3, 'capacity': 3}),
]
G.add_edges_from(edges)

# Define sources and terminals
sources = ['A']
terminals = ['F']

# Add node attributes for profile
for node in G.nodes():
    if node in sources:
        G.nodes[node]['profile'] = 'source'
    elif node in terminals:
        G.nodes[node]['profile'] = 'terminal'
    else:
        G.nodes[node]['profile'] = 'regular'  # or None, or 'regular'

single_link_disruptions = [[edge] for edge in G.edges]
link_scores = process_network(G, index_accesibility=global_efficiency, max_links_in_disruption=0.2)["link_scores"]
for edge, criticality_score in link_scores:
    G.edges[edge]["criticality_score"] = criticality_score

Testing with 7 disruption scenarios


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 16256.99it/s]


In [153]:
link_scores

[(('A', 'B'), 0.1333333333333333),
 (('A', 'C'), 0.0),
 (('B', 'D'), 0.1333333333333333),
 (('B', 'E'), 0.0),
 (('C', 'E'), 0.0),
 (('D', 'F'), 0.1333333333333333),
 (('E', 'F'), 0.0)]

In [154]:
G_torch = convert_nx_to_pyg(G, normalize_scores=False)

In [155]:
G_torch.edge_attr

tensor([1., 2., 1., 3., 2., 1., 1.])

In [156]:
gnn_ge(G_torch)

tensor([1.1461, 0.8551, 1.2243, 0.1457, 0.5421, 4.2030, 4.7472],
       grad_fn=<SqueezeBackward1>)

In [157]:
G_torch.edge_y

tensor([0.1333, 0.0000, 0.1333, 0.0000, 0.0000, 0.1333, 0.0000])