In this notebook, the trained GNN predictive model is retrieved. GNNExplainer is set up to analyse its predictions, and the graphs that received the highest and lowest amenity count predictions run through GNNExplainer. This outputs the explained importances of each node and edge in the graph, which is projected onto a map for visual understanding.

# Import Libraries

In [1]:
%matplotlib inline

In [1]:
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import TUDataset, Planetoid
from torch_geometric.nn import GCNConv, GINConv, SAGEConv, GATv2Conv, Set2Set, global_max_pool, MLP, global_mean_pool
from torch.nn import Linear, Sequential, ReLU, Module, ModuleList, BatchNorm1d as BN
import torch_geometric.transforms as T
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm, trange

In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
import numpy as np
import contextily as cx
from shapely.wkt import loads
import folium
from folium.plugins import MeasureControl
from branca.colormap import linear

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

In [4]:
from torch_geometric.explain import Explainer, GNNExplainer, PGExplainer

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Import data

In [5]:
# function to import pytorch geometric graph dataset
class EW_msoa_graphs(Dataset):
    def __init__(self, root):
        super().__init__()
        self.root = root
        self.file_list = [f for f in os.listdir(root) if f.endswith('.pt')]  # Cache the file list once

    def __len__(self):
        return len(self.file_list)

    # iterate through index, get each file name to call from, call graph
    def __getitem__(self, idx):
        file_name = self.file_list[idx]  # automatically raise IndexError if idx is out of range
        return self.load_graph(os.path.join(self.root, file_name))

    def load_graph(self, file_path):
        try:
            return torch.load(file_path)
        except FileNotFoundError:
            raise FileNotFoundError(f"File {file_path} not found.")

    def get_by_filename(self, filename):
        # Method to load a graph directly by filename
        if filename in self.file_list:
            return self.load_graph(os.path.join(self.root, filename))
        else:
            raise FileNotFoundError(f"File {filename} not found in dataset directory.")

# Create node/edge mapping csvs for accurate visualization
When graphs were being created earlier, the node/edge indexes needed to be reset, so the current graphs cannot be linked with the OSMid for plotting the explanations. Here the node/edge indexes of the graphs are linked with the node/edge OSMids within new CSVs.

In [7]:
# function to create CSV mapping new node / edge index to OSM node / edge index
# using the same index reset function from graph creation process
def read_data_to_csv(MSOA, node_file, edge_file):
    node_df = pd.read_csv(node_file)
    edge_df = pd.read_csv(edge_file)

    # Node index and coordinates
    node_df['new_index'] = range(len(node_df))
    node_mapping = node_df[['osmid', 'new_index', 'x', 'y', 'geometry']]

    # Edge index
    edge_df = edge_df.rename(columns={'osmid': 'edge_osmid'})
    edge_df = edge_df.merge(node_df[['osmid', 'new_index']], how='left', left_on='u', right_on='osmid')
    edge_df = edge_df.rename(columns={'new_index': 'new_source'}).drop(columns=['osmid'])
    edge_df = edge_df.merge(node_df[['osmid', 'new_index']], how='left', left_on='v', right_on='osmid')
    edge_df = edge_df.rename(columns={'new_index': 'new_dest'}).drop(columns=['osmid'])
    edge_mapping = edge_df[['edge_osmid', 'u', 'v', 'new_source', 'new_dest', 'geometry']]

    return node_mapping, edge_mapping

In [8]:
# save node and edge mappings by iterating through original node / edge files
def save_data(node_edge_source, destination):
    # Ensure the base destination directory exists
    if not os.path.exists(destination):
        os.makedirs(destination, exist_ok=True)

    # Iterate through each MSOA directory in the source directory
    for MSOA in os.listdir(node_edge_source):
        if MSOA.startswith("W02"):  # Skip Welsh MSOAs
            continue

        MSOA_code = MSOA
        MSOA_path = os.path.join(node_edge_source, MSOA)
        node_file = os.path.join(MSOA_path, 'node_list.csv')
        edge_file = os.path.join(MSOA_path, 'edge_list.csv')

        # Check if node and edge files exist
        if os.path.exists(node_file) and os.path.exists(edge_file):
            node_df, edge_df = read_data_to_csv(MSOA_code, node_file, edge_file)

            # Create MSOA specific directory in the destination
            msoa_dest_path = os.path.join(destination, MSOA_code)
            if not os.path.exists(msoa_dest_path):
                os.makedirs(msoa_dest_path, exist_ok=True)

            # Save the node and edge mappings
            node_df.to_csv(os.path.join(msoa_dest_path, 'node_mapping.csv'), index=False)
            edge_df.to_csv(os.path.join(msoa_dest_path, 'edge_mapping.csv'), index=False)

In [9]:
#save_data('EW_msoa_node_edge_drive', 'graph_osmid_mappings')

# Import trained model

In [6]:
# same model and architecture as was originally trained for the Explainer to use
class GIN(torch.nn.Module):
    def __init__(self, num_features, num_layers, hidden_channels):
        super().__init__()
        
        # Initial linear transformation layer
        self.initial_layer = torch.nn.Linear(num_features, hidden_channels)
                
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            mlp = MLP([hidden_channels, hidden_channels, hidden_channels])
            self.convs.append(GINConv(nn=mlp, train_eps=False))
            in_channels = hidden_channels
            out_channels = hidden_channels
        
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)
        self.return_embeddings = False  # Initialize flag

    
    def forward(self, x, edge_index, batch=None):
        batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        x = self.initial_layer(x)
        for conv in self.convs:
            x = conv(x, edge_index)
        embeddings = global_mean_pool(x, batch)
	# Classification Head
        x = F.relu(self.lin1(embeddings))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)

        if self.return_embeddings:
            return embeddings
        
        return x

    def enable_embedding_output(self):
        self.return_embeddings = True

    def disable_embedding_output(self):
        self.return_embeddings = False

In [7]:
# call model from where it was saved, use same hyperparameters as it was saved with
def import_model(model_class, num_layers, data, hidden_channels):
    # Get dataset
    data_folder = f'graphs/graphs_{data}'
    dataset = EW_msoa_graphs(root=data_folder)
    loader = DataLoader(dataset, batch_size=1)

    # set features determined by dataset
    num_node_features = dataset.num_node_features
    
    criterion=torch.nn.MSELoss()
    
    # call model path
    model_path = os.path.join('trained_models', f'{data}',f'{model_class}', f'{num_layers}_layers', f"{model_class}_{num_layers}layer.pth")     
    model=GIN(num_node_features, num_layers, hidden_channels)
    model_state = torch.load(model_path)
    model.load_state_dict(model_state)
    model.eval()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    model.to(device)
    return model, loader, dataset

In [10]:
# Use same hyperparameters as it was saved with, call model, dataloader, and dataset
model_class = 'GIN'
num_layers = 3
data = 'n_amenities_15min'
hidden_channels = 64

model, loader, dataset = import_model(model_class, num_layers, data, hidden_channels)

# Initialize Explainer

In [11]:
# Call explainer to explain the model
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=50),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    ),
    # Include the top 25 most important edges
    threshold_config=dict(threshold_type='topk', value=30),
)

In [12]:
# Run explainer to get explanation
def get_explanation_for_msoa(msoa):
    # Load the graph by filename
    try:
        data = dataset.get_by_filename(f'{msoa}.pt')
    except FileNotFoundError as e:
        print(f"Error loading file: {e}")
        return

    # Check and prepare data
    if not hasattr(data, 'batch'):
        data.batch = torch.zeros(data.num_nodes, dtype=torch.long)

    # Run the explainer
    try:
        node_index = 0  # You can specify or calculate a specific node index if needed
        explanation = explainer(data.x, data.edge_index)
        return explanation, data.edge_index
    except Exception as e:
        print(f"Error during explanation: {e}")

In [13]:
# Transform explanation to dataframe listing importances of each node and edge
def explanation_to_df(explanation, edge_index):
    # Node DataFrame
    if hasattr(explanation, 'node_mask') and explanation.node_mask is not None:
        node_importances = explanation.node_mask.detach().cpu().numpy().flatten()  # Flatten to ensure 1D
        nodes = np.arange(len(node_importances))
        df_nodes = pd.DataFrame({
            'Node': nodes,
            'Importance': node_importances
        })
    else:
        df_nodes = pd.DataFrame(columns=['Node', 'Importance'])
    
    # Edge DataFrame
    if hasattr(explanation, 'edge_mask') and explanation.edge_mask is not None:
        edge_importances = explanation.edge_mask.detach().cpu().numpy().flatten()  # Flatten to ensure 1D
        # Ensure edge_index is on CPU and converted to numpy for processing
        edge_index_np = edge_index.cpu().numpy()
        df_edges = pd.DataFrame({
            'Source': edge_index_np[0],  # First row of edge_index for source nodes
            'Target': edge_index_np[1],  # Second row of edge_index for target nodes
            'Importance': edge_importances
        })
    else:
        df_edges = pd.DataFrame(columns=['Source', 'Target', 'Importance'])
    
    return df_nodes, df_edges

In [14]:
# map importances to OSMnx node and edge indexes for plotting
def retrieve_mappings(msoa):
    # Construct file paths
    node_mapping_path = os.path.join('graph_osmid_mappings', f'{msoa}', 'node_mapping.csv')
    edge_mapping_path = os.path.join('graph_osmid_mappings', f'{msoa}', 'edge_mapping.csv')

    # Check if files exist
    if not os.path.exists(node_mapping_path) or not os.path.exists(edge_mapping_path):
        raise FileNotFoundError(f"Mapping files for {msoa} not found")
    
    # Load the CSV files into DataFrames
    node_mapping = pd.read_csv(node_mapping_path)
    edge_mapping = pd.read_csv(edge_mapping_path)
    
    return node_mapping, edge_mapping

In [15]:
# Plot the dataframe with folium in order to be able to zoom in, examine different parts of explanation
def plot_folium_gdf(gdf_lines, gdf_points):
    # Calculate the centroid of the entire dataset for map centering
    center = gdf_points.geometry.unary_union.centroid

    # Create a Folium map centered around the calculated centroid
    m = folium.Map(location=[center.y, center.x], zoom_start=15, 
                   tiles = 'https://tiles.stadiamaps.com/tiles/stamen_toner_lines/{z}/{x}/{y}{r}.png',
                  attr='&copy; <a href="https://stadiamaps.com/" target="_blank">Stadia Maps</a> <a href="https://stamen.com/" target="_blank">&copy; Stamen Design</a> &copy; <a href="https://openmaptiles.org/" target="_blank">OpenMapTiles</a> &copy; <a href="https://www.openstreetmap.org/copyright" target="_blank">OpenStreetMap</a>',
                  control_scale=True)

    # Define and scale the colormap for 'Importance'
    importance_min = min(gdf_lines['Importance'].min(), gdf_points['Importance'].min())
    importance_max = max(gdf_lines['Importance'].max(), gdf_points['Importance'].max())
    colormap = linear.YlOrRd_09.scale(importance_min, importance_max)
    colormap.caption = 'Importance'
    m.add_child(colormap)

    # Function to normalize 'Importance' to color values
    def get_color(importance):
        return colormap(importance)

    # Plot lines with colors based on 'Importance'
    for idx, line in gdf_lines.iterrows():
        line_coords = [[point[1], point[0]] for point in line.geometry.coords]  # Correctly access the coords from the geometry
        folium.PolyLine(
            line_coords,
            color=get_color(line['Importance']),
            weight=4
        ).add_to(m)

    # Plot points with colors based on 'Importance'
    for idx, point in gdf_points.iterrows():
        folium.Circle(
            location=[point.geometry.y, point.geometry.x],
            radius=4,
            color=get_color(point['Importance']),
            fill=True,
            fill_color=colormap(point['Importance']),
            fill_opacity=0.5
        ).add_to(m)

    # Add the colormap to the map for reference
    m.add_child(colormap)

    # Display the map
    display(m)

In [16]:
def call_all(msoa):
    # Get explanation
    explanation, edge_index = get_explanation_for_msoa(msoa)

    # Create dfs with node and edge importances
    node_importance, edge_importance = explanation_to_df(explanation, edge_index)

    # Import node and edge data
    node_mapping, edge_mapping = retrieve_mappings(msoa)

    # Merge importances onto node and edge data
    node_df = pd.merge(node_mapping, node_importance, left_on='new_index', right_on='Node')
    edge_df = pd.merge(edge_mapping, edge_importance, 
                       how='inner', left_on=['new_source', 'new_dest'], right_on=['Source', 'Target'])

    # Join node and edge dfs
    edge_df = edge_df[['Importance', 'geometry']]
    edge_importance_max = edge_df['Importance'].max()
    #print(f"edge importance max = {edge_importance_max}")
    edge_df['geometry'] = edge_df['geometry'].apply(loads)
    edge_gdf = gpd.GeoDataFrame(edge_df, geometry='geometry')
    edge_gdf.set_crs(epsg=4326, inplace=True)
    
    node_df = node_df[['Importance', 'geometry']]
    node_importance_max = node_df['Importance'].max()
    #print(f"node importance max = {node_importance_max}")
    node_df['geometry'] = node_df['geometry'].apply(loads)
    node_gdf = gpd.GeoDataFrame(node_df, geometry='geometry')
    node_gdf.set_crs(epsg=4326, inplace=True)

    # Plot importances
    plot_folium_gdf(edge_gdf, node_gdf)

# Project Explanations

In [35]:
# Get list of predictions in order to find highest and lowest predictions
predictions = pd.read_csv("trained_models/n_amenities_15min/GIN/3_layers/GIN_3layer_Predictions.csv")
predictions = predictions.dropna()
sorted_predictions = predictions.sort_values(by='Predicted_Value', ascending=True)

In [36]:
sorted_predictions.head(10)

Unnamed: 0,MSOA11CD,MSOA11NM,geometry,spatial_signature,n_amenities_15min,Predicted_Value
3659,E02005675,Northampton 026,"POLYGON ((473861.4000813385 260012.400306092, ...",Open sprawl,0.249746,4.511558
1595,E02001762,North Tyneside 025,"POLYGON ((430166.0308246439 568779.4370096538,...",Open sprawl,0.254864,4.863844
1141,E02004290,County Durham 002,"POLYGON ((426130.5021168856 556430.8940221532,...",Warehouse/Park land,0.465188,3.840145
2611,E02003139,Plymouth 018,POLYGON ((255092.70253359422 57287.29957328958...,Open sprawl,0.475099,3.507073
3919,E02006414,Spelthorne 012,POLYGON ((506722.39569594257 170686.2618549006...,Open sprawl,0.479671,4.620358
2253,E02002589,Halton 016,POLYGON ((351384.75049527327 381968.7503716727...,Urban buffer,0.615762,4.336823
2749,E02003310,Thurrock 015,POLYGON ((556464.4419502685 180246.68221580447...,Warehouse/Park land,0.679734,5.973714
2046,E02002278,Kirklees 008,"POLYGON ((417770.3049890996 425428.1070841474,...",Open sprawl,0.69486,5.135843
2566,E02003066,North Somerset 002,POLYGON ((346046.0455943012 176822.85219661577...,Open sprawl,0.726978,3.76555
1917,E02002104,Solihull 024,POLYGON ((416193.53647563106 279297.5162280527...,Open sprawl,0.744171,3.501519


#### High scoring MSOAs

In [38]:
call_all('E02000730')

In [21]:
call_all('E02000120')

In [22]:
call_all('E02000393')

In [24]:
call_all('E02000125')

#### Low scoring MSOAs

In [27]:
call_all('E02006776')

In [31]:
call_all('E02001353')

In [28]:
call_all('E02004649')

In [29]:
call_all('E02004680')