In [1]:
import sys
sys.path.append("../scripts")

import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
GCNConv._orig_propagate = GCNConv.propagate
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch_geometric.explain import GNNExplainer, Explainer
from models import *
from tg_functions import *
from bike_functions import *
import pandas as pd

dropout_p = 0.5
use_gat = True
bins = [int(i) for i in os.getenv("BINS", "400 800 1300 2100 3000 3700 4700 7020 9660").split(' ')]
# bins = [int(i) for i in os.getenv("BINS", "3000").split(' ')]

bins = torch.tensor(bins, device='cuda' if torch.cuda.is_available() else 'cpu')
# bins = 'regression'
hidden_c = 200
num_layers = 0
random_seed = 100
nh = 1

graph_num = 17  # Replace with your graph number

model_name = 'rich-sun-141' # Replace with your model name

weight_prefix = 'best_accuracy'  # Replace with your weight prefix


if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)

# device = 'cpu'

with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

data = stratified_split(data = data , random_seed = random_seed)

# --- Model Instantiation ---
model = GAT(hidden_c, num_layers, random_seed, bins, data, nh).to(device) if use_gat else GCN(hidden_c, num_layers, random_seed, bins, data).to(device)

if use_gat == 'MLP':
    model = MLP(hidden_c, num_layers, random_seed, bins, data, nh).to(device)

# Load the model with the GCN class
model = torch.load(f'../data/graphs/{graph_num}/models/{model_name}.pt', map_location=device)
model = model.to(device)

model.load_state_dict(torch.load(f'../data/graphs/{graph_num}/models/{model_name}_{weight_prefix}.pt', map_location=device))

criterion = torch.nn.CrossEntropyLoss()


Using CPU


In [2]:
data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()
print(data.x.shape, data.edge_index.shape, data.y.shape, flush = True)
data = stratified_split(data)
criterion = torch.nn.CrossEntropyLoss()

test(model, data, criterion, device, bins)[0]
from torch_geometric.explain import GNNExplainer, Explainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=1),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type=None,
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

torch.Size([78168, 31]) torch.Size([2, 152596]) torch.Size([78168])


In [55]:
mask = data.val_mask.squeeze() & (data.y > 0).squeeze()
node_idx = 10
node_idx = torch.where(mask)[0][node_idx].item()  # Get the first node index where mask is True
print(f"Node index for explanation: {node_idx}", flush=True)
explanation = explainer(data.x, data.edge_index, index=node_idx)


Node index for explanation: 18912


In [None]:
explainer.vis

In [66]:
feature_names = pd.read_csv(f'../data/graphs/{graph_num}/node_features.csv').columns.tolist()

tick_dict = {i: feature_names[i] for i in range(len(feature_names))}
# fig , ax = plt.subplots()
### plot the feature importance
node_mask = explanation.node_mask.squeeze()


In [67]:
node_mask

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [24]:
def get_node_neighborhood_subgraph(data: Data, node_idx: int):
    """
    Returns the 1-hop neighborhood subgraph of a node, including:
    - all neighbor nodes (including the center)
    - all edges between these nodes

    Parameters:
    - data (torch_geometric.data.Data): PyG graph object
    - node_idx (int): Index of the center node

    Returns:
    - sub_nodes (torch.Tensor): Tensor of node indices in the subgraph
    - sub_edges (torch.Tensor): 2 x N tensor of edge indices in the subgraph
    """
    src, dst = data.edge_index

    # Step 1: Get direct neighbors
    mask = (src == node_idx) | (dst == node_idx)
    direct_edges = mask.nonzero(as_tuple=False).view(-1)
    neighbor_nodes = torch.unique(torch.cat([src[direct_edges], dst[direct_edges], torch.tensor([node_idx], device=src.device)]))

    # Step 2: Get all edges where both nodes are in neighbor_nodes
    node_mask = torch.zeros(data.num_nodes, dtype=torch.bool, device=src.device)
    node_mask[neighbor_nodes] = True
    edge_in_subgraph = node_mask[src] & node_mask[dst]
    sub_edges = data.edge_index[:, edge_in_subgraph]

    return neighbor_nodes, sub_edges


n_nodes, n_edges = get_node_neighborhood(data, node_idx)

In [33]:
def show_labeled_x_features(data: Data, csv_path: str, node_indices: torch.Tensor):
    """
    Displays selected rows from data.x with proper column names from CSV header.

    Parameters:
    - data (torch_geometric.data.Data): PyG graph object with `x` feature matrix.
    - csv_path (str): Path to the CSV file with correct column names.
    - node_indices (torch.Tensor): Node indices to extract and display features for.
    """
    import pandas as pd

    # Load column names only from CSV
    col_names = pd.read_csv(csv_path, nrows=0).columns.tolist()

    # Sanity check
    assert data.x.size(1) == len(col_names), f"Feature dim mismatch: data.x has {data.x.size(1)} features, but CSV has {len(col_names)} columns."

    # Slice features
    features = data.x[node_indices].cpu().numpy()

    # Create dataframe
    df = pd.DataFrame(features, columns=col_names, index=[f"Node {i}" for i in node_indices.tolist()])

    # drop columns with all 0s
    df = df.loc[:, (df != 0).any(axis=0)]
    # Display the dataframe
    display(df)

show_labeled_x_features(data, '../data/graphs/17/node_features.csv', n_nodes)

Unnamed: 0,lanes,maxspeed,oneway,length,bc,public_transport,place,military,healthcare,power,water
Node 18911,2.0,40.0,0.0,47.932362,1.225812e-07,0.0,3.0,4.0,0.0,2.0,0.0
Node 18912,0.0,40.0,1.0,94.692169,1.621869e-07,1.0,8.0,4.0,1.0,0.0,0.0
Node 25232,0.0,40.0,1.0,84.394104,1.26509e-07,0.0,3.0,10.0,0.0,0.0,1.0
Node 25233,0.0,50.0,1.0,128.800964,7.610181e-08,0.0,3.0,11.0,0.0,0.0,0.0
Node 77716,2.0,40.0,0.0,16.672379,1.877178e-07,0.0,0.0,0.0,0.0,0.0,0.0


In [30]:
n_nodes

tensor([18911, 18912, 25232, 25233, 77716])

In [None]:
def get_node_features(data: Data, node_indices: torch.Tensor) -> torch.Tensor:
    """
    Returns feature matrix rows for the given node indices.

    Parameters:
    - data (torch_geometric.data.Data): PyG graph object
    - node_indices (torch.Tensor): Node indices to extract features for

    Returns:
    - torch.Tensor: Features of the selected nodes
    """
    return data.x[node_indices]

node_features = get_node_features(data, n_nodes)



In [32]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

def visualize_subgraph(data: Data, edge_index: torch.Tensor, highlight_nodes: torch.Tensor = None):
    """
    Visualizes a subgraph defined by given edges.

    Parameters:
    - data (torch_geometric.data.Data): Full graph data
    - edge_index (torch.Tensor): 2 x N tensor of edges in subgraph
    - highlight_nodes (torch.Tensor, optional): Nodes to highlight (e.g., the center + its neighbors)
    """
    # Create a shallow copy of data with only the subgraph edges
    sub_data = Data(x=data.x, edge_index=edge_index)

    # Convert to NetworkX for plotting
    G = to_networkx(sub_data, to_undirected=True)

    # Draw with optional node highlighting
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500, font_size=10)

    if highlight_nodes is not None:
        highlight = [int(i) for i in highlight_nodes]
        nx.draw_networkx_nodes(G, pos, nodelist=highlight, node_color='orange')

    plt.title("Subgraph Visualization")
    plt.show()

visualize_subgraph(data, data.edge_index, highlight_nodes=n_nodes)

KeyboardInterrupt: 