In [None]:
import os, torch
from sklearn.model_selection import train_test_split
import pickle
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.nn.models import Node2Vec
from torch_geometric.data import DataLoader
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
import networkx as nx
import pandas as pd

model_name = 'fluent-lake-17'  # Replace with your model name
graph_num = 26 
weights_prefix = 'best_accuracy'  # Replace with your desired weights prefix
random_seed =  100
bins = [int(i) for i in "400 800 1300 2100 3000 3700 4700 7020 9660".split(' ')] 
dropout_p =  0.5
epochs = 100

if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(0)}", flush = True)
else:
    device = torch.device('cpu')
    print("Using CPU", flush = True)

device = 'cpu'
bins = torch.tensor(bins, device=device)

### load graph data

with open(f'../data/graphs/{graph_num}/linegraph_tg.pkl', 'rb') as f:
    data = pickle.load(f)

def stratified_split(data, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    """Splits data into train, validation, and test sets, stratifying by y > 0."""

    # Create a boolean mask for nodes where y > 0
    positive_mask = data.y > 0
    print(f"Positive nodes: {positive_mask.sum().item()}, Total nodes: {data.num_nodes}", flush = True)

    # Get indices of positive and negative nodes
    positive_indices = positive_mask.nonzero(as_tuple=False).squeeze()
    negative_indices = (~positive_mask).nonzero(as_tuple=False).squeeze()

    # Split positive indices
    pos_train_idx, pos_temp_idx = train_test_split(positive_indices, train_size=train_ratio, random_state=random_seed)  # Adjust random_state for consistent splits
    pos_val_idx, pos_test_idx = train_test_split(pos_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Split negative indices
    neg_train_idx, neg_temp_idx = train_test_split(negative_indices, train_size=train_ratio, random_state=random_seed)
    neg_val_idx, neg_test_idx = train_test_split(neg_temp_idx, test_size=(test_ratio / (val_ratio + test_ratio)), random_state=random_seed)

    # Combine indices
    train_idx = torch.cat([pos_train_idx, neg_train_idx])
    val_idx = torch.cat([pos_val_idx, neg_val_idx])
    test_idx = torch.cat([pos_test_idx, neg_test_idx])

    # Create masks
    train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask

    ## print number of nodes in each set with y > 0
    print(f"Train nodes with y > 0: {(data.y[data.train_mask] > 0.0).sum()}", flush = True)
    print(f"Validation nodes with y > 0: {(data.y[data.val_mask] > 0.0).sum()}", flush = True)
    print(f"Test nodes with y > 0: {(data.y[data.test_mask] > 0.0).sum()}", flush = True)
    return data

# remove fist feature of x
# data.x = data.x[:, 1:]

data.edge_index = data.edge_index.contiguous()
data.x = data.x.contiguous()
data.y = data.y.contiguous()

print(data.x.shape, data.edge_index.shape, data.y.shape, flush = True)

data = stratified_split(data)


In [None]:
with open(f'../data/graphs/{graph_num}/linegraph_nx.pkl', 'rb') as f:
    H = pickle.load(f)


In [None]:
import networkx as nx

# Create a MultiDiGraph with multiple edges between nodes
MDG = nx.MultiDiGraph()
MDG.add_edge(1, 2, key='a', weight=3)
MDG.add_edge(1, 2, key='b', weight=7)
MDG.add_edge(2, 3, weight=5)

# Convert to Graph
G = nx.Graph(MDG)
print(G.edges(data=True))


In [None]:
# --- Model Definitions ---
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers):
        super().__init__()
        torch.manual_seed(random_seed)

        self.input_layer = GCNConv(data.num_features, hidden_channels, improved=True, cached=True)

        # Create intermediate hidden layers (optional)
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(num_layers):
            self.hidden_layers.append(GCNConv(hidden_channels, hidden_channels, improved=True, cached=True))

        self.output_layer = GCNConv(hidden_channels, len(bins) + 1, cached=True)

    def forward(self, x, edge_index):
        x = self.input_layer(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=dropout_p, training=self.training)

        for layer in self.hidden_layers:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=dropout_p, training=self.training)

        x = self.output_layer(x, edge_index)
        return x


class GAT(torch.nn.Module):
    def __init__(self,hidden_channels, num_layers, num_heads):
        super().__init__()
        torch.manual_seed(42)  # Replace with your desired seed

        self.convs = torch.nn.ModuleList()

        # Input layer
        self.convs.append(GATConv(data.num_features, hidden_channels, heads=num_heads, concat=True))

        # Hidden layers
        for _ in range(num_layers):
            self.convs.append(GATConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, concat=True))

        # Output layer
        self.convs.append(GATConv(hidden_channels * num_heads, len(bins) + 1, heads=1, concat=False))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=0.5, training=self.training)  # Adjust dropout probability as needed

        x = self.convs[-1](x, edge_index)
        return x


In [None]:
weights_path = f'../data/graphs/{graph_num}/models/{model_name}_{weights_prefix}.pt'

model = torch.load(f'../data/graphs/{graph_num}/models/{model_name}.pt', map_location=device)
### load weights onto model
model.load_state_dict(torch.load(weights_path, map_location=device))
model = model.to(device)
model.eval()

In [None]:
preds  = model(data.x.to(device), data.edge_index.to(device))
# assign predicted aadt bin to test nodes
preds = preds.argmax(dim=1)

for node, pred in zip(list(H.nodes()), preds):
    H.nodes[node]['pred'] = pred.item()

In [None]:
# get node with  highest aadt in H
highest_idx = np.argmax(data.y)
highest_node = list(H.nodes(data=True))[highest_idx]

In [None]:
from copy import deepcopy
g = data.old_graph

In [None]:
lengths = nx.single_source_dijkstra_path_length(H, highest_node[0], cutoff=2000, weight='weight')
### create subgraph with all nodes within 1000m of highest node
subgraph_nodes = list(lengths.keys())
subgraph_edges = [(u, v) for u, v in H.edges() if u in subgraph_nodes and v in subgraph_nodes]
subgraph = H.edge_subgraph(subgraph_edges)
subgraph = subgraph.to_undirected()

In [None]:
node_preds = {}
node_aadt = {}
node_ebc = {}
for node in H.nodes(data=True):
    node_preds[node[0]] = node[1]['pred']
    node_aadt[node[0]] = node[1]['aadt']
    node_ebc[node[0]] = node[1]['bc']

In [None]:
edge_preds = {}
edge_aadt = {}
edge_ebc = {}
for u, v, k in g.edges(keys=True):
    if (u, v) in node_preds:
        edge_preds[(u, v, k)] = node_preds[(u, v)]
    if (u, v) in node_aadt:
        edge_aadt[(u, v, k)] = node_aadt[(u, v)]
    if (u, v) in node_ebc:
        edge_ebc[(u, v, k)] = node_ebc[(u, v)]


nx.set_edge_attributes(g, edge_preds, name='pred')
nx.set_edge_attributes(g, edge_aadt, name='aadt')
nx.set_edge_attributes(g, edge_ebc, name='ebc')

In [None]:
# Remove edges from g that are not in subgraph_edges
edges_to_remove = []
for edge in g.edges():
    if edge not in subgraph_nodes:
        edges_to_remove.append(edge)


for edge in edges_to_remove:
    g.remove_edge(edge[0], edge[1])


In [None]:
import momepy as mp 
gdf = mp.nx_to_gdf(g)
# nx.draw(subgraph, with_labels=False)
gdf[1].plot(column='pred', cmap='turbo', legend=True, figsize=(10, 10))

In [None]:
import momepy as mp 
gdf = mp.nx_to_gdf(g)
# nx.draw(subgraph, with_labels=False)
gdf[1].query('aadt > 0').plot(column='aadt', cmap='turbo', legend=True, figsize=(10, 10))

In [None]:
import matplotlib.pyplot as plt

def plot_edge_gdf_columns(gdf, columns=['pred', 'ebc', 'aadt'], cmap='viridis', k=10):
    """
    Plots multiple edge attributes from a GeoDataFrame using geopandas plot.

    Parameters:
    - gdf: GeoDataFrame containing LineString geometries and attribute columns
    - columns: list of attribute names to plot
    - cmap: colormap to use
    """
    n = len(columns)
    fig, axes = plt.subplots(1, n, figsize=(5 * n, 5))

    if n == 1:
        axes = [axes]  # Ensure axes is always iterable

    for col, ax in zip(columns, axes):
        gdf.plot(column=col, cmap=cmap, legend=True, ax=ax, linewidth=2, k = k)
        ax.set_title(f"{col} values")
        ax.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
gdf[1]

In [None]:
# import torch

# aadt_tensor = torch.tensor(gdf[1]['aadt'].values)
# bins_tensor = torch.tensor(bins)

# gdf[1]['binned_aadt'] = torch.bucketize(aadt_tensor, bins_tensor, right=False).numpy() - 1
import torch
import numpy as np

# Bucketize
aadt_tensor = torch.tensor(gdf[1]['aadt'].values)
bins_tensor = torch.tensor(bins)

# right=False matches pd.cut(left-inclusive, right-exclusive)
bucket_indices = torch.bucketize(aadt_tensor, bins_tensor, right=False).numpy()

# Clamp to 0–9
bucket_indices = np.clip(bucket_indices - 1, 0, len(bins) - 2)  # gives labels 0–9

# Assign
gdf[1]['binned_aadt'] = bucket_indices



gdf[1]['pred_diff'] = gdf[1]['pred'] - gdf[1]['binned_aadt']


In [None]:
def max_bin_cutoff_diff(row, bins_array):
    pred_idx = row['pred']
    true_idx = row['binned_aadt']

    if pd.isna(pred_idx) or pd.isna(true_idx):
        return float('nan')

    i = int(pred_idx)
    j = int(true_idx)

    lower_bin = min(i, j)
    upper_bin = max(i, j)

    # Clamp upper_bin + 1 to len(bins_array) - 1
    upper_cutoff_idx = min(upper_bin + 1, len(bins_array) - 1)

    lower_cutoff = bins_array[lower_bin]
    upper_cutoff = bins_array[upper_cutoff_idx]

    return abs(upper_cutoff - lower_cutoff)

# Apply safely
bins_array = list(bins)
gdf[1]['aadt_diff_est'] = gdf[1].apply(lambda row: max_bin_cutoff_diff(row, bins_array), axis=1)


In [None]:
# plot_edge_gdf_columns(gdf[1].query('aadt > 0'), columns=['pred', 'binned_aadt'], cmap= 'turbo')


In [None]:
gdf[1].query('aadt > 0').explore(
    column='binned_aadt',
    cmap='RdBu',
    legend=True,
    tooltip=['pred', 'ebc', 'aadt', 'binned_aadt', 'pred_diff', 'name', 'aadt_diff_est', 'highway'],
    name='Predicted AADT',
    style_kwds={
        'fillOpacity': 0.7,
        'weight': 5
    },
    # tiles='Esri.WorldImagery',  #Satellite background
    tiles='cartodb positron',
    # vmin=-max_abs,
    # vmax= max_abs
)

In [None]:
gdf[1].query('aadt > 0').explore(
    column='pred',
    cmap='RdBu',
    legend=True,
    tooltip=['pred', 'ebc', 'aadt', 'binned_aadt', 'pred_diff', 'name', 'aadt_diff_est', 'highway'],
    name='Predicted AADT',
    style_kwds={
        'fillOpacity': 0.7,
        'weight': 5  # 🔥 Thicker lines
    },
    # tiles='Esri.WorldImagery',  #Satellite background
    tiles='cartodb positron',
    # vmin=-max_abs,
    # vmax= max_abs
)

In [None]:
gdf[1].query('aadt > 0').plot(column='pred_diff', cmap='turbo', legend=True, figsize=(10, 10))

In [None]:
from torch_geometric.explain import GNNExplainer, Explainer

# Move data to device
x = data.x.to(device)
edge_index = data.edge_index.to(device)

# Create explainer
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=5),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type=None,  # No edge mask in this case
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

# Run explanation for a single node
explanation = explainer(x=x, edge_index=edge_index, index=highest_idx)

In [None]:
# replace plot y values with feature names
import matplotlib.pyplot as plt
features = pd.read_csv(f'../data/graphs/{graph_num}/node_features.csv').columns[1:]


In [None]:
# center_node is an edge from g (since H is the line graph of g)
center_edge = list(H.nodes())[highest_idx]

# Get neighboring edges in H (i.e., edges in g that share a node with center_edge)
neighbor_edges = list(H.neighbors(center_edge))
neighbor_edges.append(center_edge)  # Include the center edge itself

h_subgraph = H.subgraph(neighbor_edges)

new_g = nx.Graph()
for u, v in h_subgraph.nodes():
    new_g.add_edge(u, v)

In [None]:
edge_labels = {}
for s, t in new_g.edges():
    for i in g.edges((s, t), data=True):
        if 'name' in i[2]:
            edge_labels[(s, t)] = i[2]['name']
            print(s, t, i[2]['name'])
            ### copy all attributes from g to new_g
            nx.set_edge_attributes(new_g, {(s, t) : i[2]})
            break
        else:
            print('NO NAME')
    nx.set_node_attributes(new_g, {s: g.nodes[s]})
    nx.set_node_attributes(new_g, {t: g.nodes[t]})

In [None]:
new_g.edges(data=True)

In [None]:
gdf_sub = mp.nx_to_gdf(new_g, lines = True, points=True)

In [None]:
sorted(new_g.edges()) == sorted(edge_labels.keys())

In [None]:
len(new_g.edges())

In [None]:
pos = nx.spring_layout(new_g)
plt.figure(figsize=(7, 7))
nx.draw(new_g, with_labels=True, pos=pos)
nx.draw_networkx_edge_labels(new_g, edge_labels=edge_labels, pos=pos)
plt.show()

In [None]:
a = 'amenity shop building aerialway aeroway barrier boundary craft emergency highway historic landuse leisure healthcare military natural office power public_transport railway place service tourism waterway route water'
b = 'aerialway aeroway amenity barrier boundary building craft emergency healthcare highway historic landuse leisure man_made military natural office place power public_transport railway route shop telecom tourism water waterway'

# print element in b not in a
for i in b.split(' '):
    if i not in a.split(' '):
        print(i)

# print element in a not in b
for i in a.split(' '):
    if i not in b.split(' '):
      print(i)