In [None]:
import sys
sys.path.append("../scripts")

import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
GCNConv._orig_propagate = GCNConv.propagate
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch_geometric.explain import GNNExplainer, Explainer
from models import *
from tg_functions import *
from bike_functions import *
import pandas as pd
import seaborn as sns
import random
import numpy 

dropout_p = 0.5
use_gat = True
bins = [int(i) for i in os.getenv("BINS", "400 800 1300 2100 3000 3700 4700 7020 9660").split(' ')]

bins = torch.tensor(bins, device='cuda' if torch.cuda.is_available() else 'cpu')
hidden_c = 200
num_layers = 0
random_seed = 100
nh = 1
torch.manual_seed(random_seed)
random.seed(random_seed)
numpy.random.seed(random_seed)

graph_num = 17  # Replace with your graph number

model_name = 'rich-sun-141' # Replace with your model name

weight_prefix = 'best_accuracy'  # Replace with your weight prefix


if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)

# device = 'cpu'

with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

data = stratified_split(data = data , random_seed = random_seed)

# --- Model Instantiation ---
model = GAT(hidden_c, num_layers, random_seed, bins, data, nh).to(device) if use_gat else GCN(hidden_c, num_layers, random_seed, bins, data).to(device)

if use_gat == 'MLP':
    model = MLP(hidden_c, num_layers, random_seed, bins, data, nh).to(device)

# Load the model with the GCN class
model = torch.load(f'../data/graphs/{graph_num}/models/{model_name}.pt', map_location=device)
model = model.to(device)

model.load_state_dict(torch.load(f'../data/graphs/{graph_num}/models/{model_name}_{weight_prefix}.pt', map_location=device))
model.eval()
criterion = torch.nn.CrossEntropyLoss()


In [None]:
data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()
print(data.x.shape, data.edge_index.shape, data.y.shape, flush = True)
data = stratified_split(data)
criterion = torch.nn.CrossEntropyLoss()

from torch_geometric.explain import GNNExplainer, Explainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=100),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type=None,
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

In [None]:
df_ = pd.read_csv(f'../data/graphs/{graph_num}/node_features.csv')
print(len(df_.columns))
len(df_)

In [None]:
mask = data.val_mask.squeeze() & (data.y > 0).squeeze()
node_idx = 8
node_idx = torch.where(mask)[0][node_idx].item()  # Get the first node index where mask is True
print(f"Node index for explanation: {node_idx}", flush=True)
explanation = explainer(data.x, data.edge_index, index=node_idx)
pred = model(data.x, data.edge_index)
pred = pred.argmax(dim=1)
print(f"Node {node_idx} prediction: {pred[node_idx].item()}", flush=True)
print(f"Node {node_idx} true label: {torch.bucketize(data.y[node_idx].item(),bins)}", flush=True)

In [None]:
max_val = max(data.y[mask])
## get index of node with max val 
max_idx = torch.where(data.y[mask] == max_val)[0][0].item()


In [None]:
from torch_geometric.data import Data
from torch_geometric.utils import subgraph
import torch

def get_node_k_hop_subgraph(data: Data, node_idx: int, k_hop: int):
    """
    Returns the k-hop neighborhood subgraph of a node.

    Parameters:
    - data (torch_geometric.data.Data): PyG graph object
    - node_idx (int): Index of the center node
    - k_hop (int): Number of hops to include in the neighborhood

    Returns:
    - sub_nodes (torch.Tensor): Tensor of node indices in the subgraph
    - sub_edges (torch.Tensor): 2 x N tensor of edge indices in the subgraph
    """
    # Initialize with the center node
    visited = set([node_idx])
    current_frontier = set([node_idx])
    edge_index = data.edge_index

    for _ in range(k_hop):
        next_frontier = set()
        for idx in current_frontier:
            mask = (edge_index[0] == idx) | (edge_index[1] == idx)
            neighbors = torch.unique(torch.cat([edge_index[0][mask], edge_index[1][mask]]))
            next_frontier.update(neighbors.tolist())
        current_frontier = next_frontier - visited
        visited.update(current_frontier)

    sub_nodes = torch.tensor(sorted(visited), device=edge_index.device)
    # Get subgraph edges using PyG utility
    sub_edges, _ = subgraph(sub_nodes, edge_index, relabel_nodes=False)

    return sub_nodes, sub_edges


n_nodes, n_edges = get_node_k_hop_subgraph(data, node_idx, 1)

In [None]:
def show_labeled_x_features(data: Data, csv_path: str, node_indices: torch.Tensor):
    """
    Displays selected rows from data.x with proper column names from CSV header.

    Parameters:
    - data (torch_geometric.data.Data): PyG graph object with `x` feature matrix.
    - csv_path (str): Path to the CSV file with correct column names.
    - node_indices (torch.Tensor): Node indices to extract and display features for.
    """
    import pandas as pd

    # Load column names only from CSV
    col_names = pd.read_csv(csv_path, nrows=0).columns.tolist()

    # Sanity check
    assert data.x.size(1) == len(col_names), f"Feature dim mismatch: data.x has {data.x.size(1)} features, but CSV has {len(col_names)} columns."

    # Slice features
    features = data.x[node_indices].cpu().numpy()

    # Create dataframe
    df = pd.DataFrame(features, columns=col_names, index=[f"Node {i}" for i in node_indices.tolist()])

    # drop columns with all 0s
    df = df.loc[:, (df != 0).any(axis=0)]
    # Display the dataframe
    display(df)
    return df

neighborhood = show_labeled_x_features(data, '../data/graphs/17/node_features.csv', n_nodes)
neighborhood.to_latex('../data/graphs/17/neighborhood.tex', index=True, header=True)

In [None]:
feature_names = pd.read_csv(f'../data/graphs/{graph_num}/node_features.csv').columns.tolist()

tick_dict = {i: feature_names[i] for i in range(len(feature_names))}

### plot the feature importance
node_mask = explanation.node_mask.squeeze()
score = node_mask[node_idx].cpu().numpy()
import seaborn as sns
feat_df = pd.DataFrame({
    'Feature': [tick_dict[i] for i in range(len(tick_dict))],
    'Importance': score
})
## drop columns not in neighborhood columns
feat_df = feat_df[feat_df['Feature'].isin(neighborhood.columns)]

### filter out features with importance less than or equal to 0
feat_df.sort_values(by='Importance', ascending=False, inplace=True)
plt.figure(figsize=(6, 6))
sns.barplot(data=feat_df, x = 'Feature', y='Importance')
### add grid lines
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xticks(ticks=range(len(feat_df)), labels=feat_df['Feature'], rotation=90, ha='right')
plt.xticks(rotation=90, ha='right')
plt.title(f'Feature Importance for Node {node_idx}')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.tight_layout()
plt.show()

In [None]:
def get_node_features(data: Data, node_indices: torch.Tensor) -> torch.Tensor:
    """
    Returns feature matrix rows for the given node indices.

    Parameters:
    - data (torch_geometric.data.Data): PyG graph object
    - node_indices (torch.Tensor): Node indices to extract features for

    Returns:
    - torch.Tensor: Features of the selected nodes
    """
    return data.x[node_indices]

node_features = get_node_features(data, n_nodes)



In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import torch
def visualize_edges(edge_index: torch.Tensor, highlight_nodes: torch.Tensor = None, scale_with_size: bool = False):
    """
    Visualizes a simple undirected graph from given edges, without showing node labels.

    Parameters:
    - edge_index (torch.Tensor): 2 x N tensor of edges
    - highlight_nodes (torch.Tensor, optional): Nodes to highlight
    - scale_with_size (bool): If True, scales node/font size with graph size
    """
    G = nx.Graph()
    edges = edge_index.t().tolist()
    G.add_edges_from(edges)

    num_nodes = G.number_of_nodes()

    # Dynamic scaling
    if scale_with_size:
        node_size = max(1000 / (num_nodes**0.5), 100)
    else:
        node_size = 500

    plt.figure(figsize=(10, 10))
    plt.axis('off')
    pos = nx.spring_layout(G, seed=42)

    # Draw graph without labels
    nx.draw(G, pos, node_color='lightblue', edge_color='gray', node_size=node_size)

    if highlight_nodes is not None:
        highlight = [int(i) for i in highlight_nodes]
        nx.draw_networkx_nodes(G, pos, nodelist=highlight, node_color='blue', node_size=node_size)

    # plt.title("Edge-based Subgraph")
    plt.show()


visualize_edges(n_edges, highlight_nodes=n_nodes, scale_with_size=True)

In [None]:
pred = model(data.x, data.edge_index)
pred = pred.argmax(dim=1)
print(f"Node {node_idx} prediction: {pred[node_idx].item()}", flush=True)
print(f"Node {node_idx} true label: {torch.bucketize(data.y[node_idx].item(),bins)}", flush=True)

In [None]:
import torch
from collections import deque

def get_reachable_subgraph(data, edge_index: torch.Tensor, source_idx: int, max_distance: float):
    """
    Traverses the graph from a source node and collects all nodes and edges reachable
    within a given cumulative distance (based on feature index 4 in data.x, assumed to be length in meters).

    Parameters:
    - data (torch_geometric.data.Data): PyG data object with node features
    - edge_index (torch.Tensor): 2 x N edge index tensor
    - source_idx (int): Starting node index
    - max_distance (float): Max cumulative travel distance

    Returns:
    - reachable_nodes (torch.Tensor): List of reachable node indices
    - reachable_edges (torch.Tensor): 2 x M tensor of reachable edges
    """
    # Build adjacency list with lengths as weights
    edge_list = edge_index.t().tolist()
    lengths = data.x[:, 4]  # feature index 4 is 'length'

    graph = {}  # node -> list of (neighbor, edge_idx, length)
    for i, (u, v) in enumerate(edge_list):
        dist = lengths[u].item()  # assuming length is stored at the source node
        graph.setdefault(u, []).append((v, i, dist))
        graph.setdefault(v, []).append((u, i, dist))  # undirected

    visited = set()
    edge_ids = set()
    queue = deque([(source_idx, 0)])

    while queue:
        current_node, total_dist = queue.popleft()
        if current_node in visited:
            continue
        visited.add(current_node)

        for neighbor, edge_idx, dist in graph.get(current_node, []):
            new_dist = total_dist + dist
            if neighbor not in visited and new_dist <= max_distance:
                queue.append((neighbor, new_dist))
                edge_ids.add(edge_idx)

    reachable_nodes = torch.tensor(sorted(visited), device=edge_index.device)
    reachable_edges = edge_index[:, list(edge_ids)]

    return reachable_nodes, reachable_edges

reachable_nodes, reachable_edges = get_reachable_subgraph(data, data.edge_index, node_idx, 1000)


In [None]:
visualize_edges(reachable_edges, highlight_nodes=reachable_nodes, scale_with_size=True)

In [None]:
x = data.x[:, 5].cpu().numpy()  # Assuming feature index 5 is 'BC'

import numpy as np

array = x
value = number = 1.245451e-07 
for i in [1.245451e-07 , 3.09317e-08, 8.9521905e-08, 3.8050902e-07, 1.343647e-07, 2.662745e-07]:
    value = i
    rank = np.sum(array <= value)  # Number of values ≤ value
    quantile = rank / len(array)
    print(f"Value: {value:.8e}, Rank: {rank}, Quantile: {quantile:.4f}")
