# Enhanced GCN-LSTM with Static Node Features for Crime Prediction

This notebook extends the original GCN-LSTM architecture by incorporating static socioeconomic and demographic features through a sophisticated attention mechanism. The static features are integrated as node-level attributes that modulate temporal crime predictions.

## Key Enhancements:
1. **Static Feature Integration**: External socioeconomic features from CSV files
2. **Cross-Modal Attention**: Separate attention mechanism for static-temporal interaction
3. **Multi-Head Attention**: Multiple attention heads for different feature aspects
4. **Feature-Aware Graph Convolution**: Static features influence spatial relationships

## 1. Environment Setup and Imports

In [8]:
# Core imports
import os
import re
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import geopandas as gpd
from sklearn.model_selection import train_test_split, ParameterGrid
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, Subset
import copy
import warnings
import zipfile
import io
import requests
import pickle
import glob
from pathlib import Path
import gc
from tqdm.auto import tqdm
from io import BytesIO
from PIL import Image

# Configuration
SEED = 42
WINDOW_SIZE = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Suppress warnings
warnings.filterwarnings('ignore')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("viridis")

print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")

Using device: cpu
PyTorch version: 2.5.1


## 2. Data Download and Loading Functions

In [9]:
def download_crime_data(data_dir=None, force_download=False):
    """Download crime data and shapefiles"""
    urls = {
        'recent_crime': 'https://raw.githubusercontent.com/IflyNY2PR/DSSS_cw/6bac9ee3834c73d705106153bf91b315bb1faf01/MPS%20LSOA%20Level%20Crime%20(most%20recent%2024%20months).csv',
        'historical_crime': 'https://raw.githubusercontent.com/IflyNY2PR/DSSS_cw/refs/heads/main/MPS%20LSOA%20Level%20Crime%20(Historical).csv',
        'shapefile': 'https://github.com/IflyNY2PR/DSSS_cw/raw/main/statistical-gis-boundaries-london.zip'
    }
    
    data_dir = Path('./crime_data') if data_dir is None else Path(data_dir)
    data_dir.mkdir(exist_ok=True)
    shapefile_dir = data_dir / 'shapefiles'
    shapefile_dir.mkdir(exist_ok=True)
    
    paths = {
        'recent_crime': str(data_dir / 'recent_crime.csv'),
        'historical_crime': str(data_dir / 'historical_crime.csv')
    }
    
    files_exist = all([
        Path(paths['recent_crime']).exists(),
        Path(paths['historical_crime']).exists(),
        (shapefile_dir / 'statistical-gis-boundaries-london').exists()
    ])
    
    if not files_exist or force_download:
        print("Downloading crime data files...")
        for name in ['recent_crime', 'historical_crime']:
            print(f"Downloading {name}...")
            pd.read_csv(urls[name]).to_csv(paths[name], index=False)
        
        print("Downloading and extracting shapefile...")
        try:
            r = requests.get(urls['shapefile'])
            r.raise_for_status()
            z = zipfile.ZipFile(io.BytesIO(r.content))
            z.extractall(shapefile_dir)
        except Exception as e:
            print(f"Error downloading shapefile: {e}")
            return None
    else:
        print("Crime data files already exist.")
    
    return paths

def download_static_features(data_dir=None, force_download=False):
    """Download static feature matrices"""
    base_url = "https://raw.githubusercontent.com/IflyNY2PR/CASA0004/013cb9ac54cc8a890e06437567ef4f6dff140ee7/data-preparation/"
    
    files = {
        'documentation': 'gcn_feature_matrix_documentation.csv',
        'spatial_imputed': 'gcn_feature_matrix_spatial_imputed.csv',
        'spatial_imputed_scaled': 'gcn_feature_matrix_spatial_imputed_scaled.csv',
        'summary_stats': 'gcn_feature_matrix_summary_stats.csv',
        'with_geometry': 'gcn_feature_matrix_with_geometry.csv'
    }
    
    data_dir = Path('./static_features') if data_dir is None else Path(data_dir)
    data_dir.mkdir(exist_ok=True)
    
    paths = {}
    for key, filename in files.items():
        filepath = data_dir / filename
        paths[key] = str(filepath)
        
        if not filepath.exists() or force_download:
            print(f"Downloading {filename}...")
            try:
                df = pd.read_csv(base_url + filename)
                df.to_csv(filepath, index=False)
                print(f"Successfully downloaded {filename}")
            except Exception as e:
                print(f"Error downloading {filename}: {e}")
    
    return paths

## 3. Load and Explore Data

In [10]:
# Download all data
crime_paths = download_crime_data()
static_paths = download_static_features()

# Load crime data
print("\nLoading crime data...")
recent_crime_df = pd.read_csv(crime_paths['recent_crime'])
historical_crime_df = pd.read_csv(crime_paths['historical_crime'])

# Load static features
print("\nLoading static features...")
static_features_scaled = pd.read_csv(static_paths['spatial_imputed_scaled'])
static_features_raw = pd.read_csv(static_paths['spatial_imputed'])
feature_documentation = pd.read_csv(static_paths['documentation'])
summary_stats = pd.read_csv(static_paths['summary_stats'])

# Load shapefiles
shapefile_path = Path('./crime_data/shapefiles/statistical-gis-boundaries-london/ESRI/LSOA_2011_London_gen_MHW.shp')
london_gdf = gpd.read_file(shapefile_path)

print(f"\nRecent crime data shape: {recent_crime_df.shape}")
print(f"Historical crime data shape: {historical_crime_df.shape}")
print(f"Static features (scaled) shape: {static_features_scaled.shape}")
print(f"London GeoDataFrame shape: {london_gdf.shape}")

Crime data files already exist.

Loading crime data...

Loading static features...

Recent crime data shape: (100868, 29)
Historical crime data shape: (113116, 161)
Static features (scaled) shape: (4719, 17)
London GeoDataFrame shape: (4835, 15)

Loading static features...

Recent crime data shape: (100868, 29)
Historical crime data shape: (113116, 161)
Static features (scaled) shape: (4719, 17)
London GeoDataFrame shape: (4835, 15)


## 4. Explore Static Features

In [11]:
# Display feature documentation
print("Feature Documentation:")
print(feature_documentation.head(20))
print(f"\nTotal number of static features: {len(feature_documentation)}")

# Display summary statistics
print("\nSummary Statistics Sample:")
print(summary_stats.head())

# Check LSOA alignment
crime_lsoas = set(recent_crime_df['LSOA Code'].unique())

# Find the correct LSOA column name in static features
lsoa_cols = [col for col in static_features_scaled.columns if 'LSOA' in col.upper()]
print(f"Available LSOA columns in static features: {lsoa_cols}")

# Use the first LSOA column found, or default to common alternatives
if lsoa_cols:
    static_lsoa_col = lsoa_cols[0]
elif 'LSOA_CODE' in static_features_scaled.columns:
    static_lsoa_col = 'LSOA_CODE'
else:
    # Print all columns to help debug
    print(f"Available columns in static features: {list(static_features_scaled.columns)}")
    raise ValueError("Could not find LSOA column in static features")

print(f"Using static LSOA column: {static_lsoa_col}")
static_lsoas = set(static_features_scaled[static_lsoa_col].unique())
common_lsoas = crime_lsoas.intersection(static_lsoas)

print(f"\nLSOAs in crime data: {len(crime_lsoas)}")
print(f"LSOAs in static features: {len(static_lsoas)}")
print(f"Common LSOAs: {len(common_lsoas)}")
print(f"Coverage: {len(common_lsoas)/len(crime_lsoas)*100:.1f}%")

Feature Documentation:
                   Feature                                       Description  \
0                 AvgPrice             Average housing price (March 2023, £)   
1                 MeanPTAL         Mean Public Transport Accessibility Index   
2               MedianPTAL       Median Public Transport Accessibility Index   
3               Population                          Population count in LSOA   
4                 Area_km2  LSOA area (square kilometers, from PTAL dataset)   
5            MeanSentiment                 Average sentiment score (-1 to 1)   
6              SentimentSD            Standard deviation of sentiment scores   
7              ReviewCount                  Number of reviews/posts analyzed   
8         NearestStation_m              Distance to nearest station (meters)   
9       StationsWithin500m                    Number of stations within 500m   
10           NearestRail_m            Distance to nearest rail line (meters)   
11          Stree

In [12]:
# Debug: Check column names in static features
print("Static features columns:")
for i, col in enumerate(static_features_scaled.columns):
    print(f"{i}: {col}")
    
print(f"\nTotal columns: {len(static_features_scaled.columns)}")
print(f"Shape: {static_features_scaled.shape}")

# Check for LSOA columns specifically
lsoa_cols = [col for col in static_features_scaled.columns if 'LSOA' in col.upper()]
print(f"\nLSOA columns found: {lsoa_cols}")

# Check for alternative ID columns
id_cols = [col for col in static_features_scaled.columns if any(x in col.upper() for x in ['CODE', 'ID', 'CD'])]
print(f"ID-like columns found: {id_cols}")

Static features columns:
0: LSOA_CODE
1: AvgPrice
2: MeanPTAL
3: MedianPTAL
4: Population
5: Area_km2
6: MeanSentiment
7: SentimentSD
8: ReviewCount
9: NearestStation_m
10: StationsWithin500m
11: NearestRail_m
12: StreetLength_m
13: StreetDensity_m_per_m2
14: StreetSegments
15: LandUse_Diversity
16: LandUse_Area

Total columns: 17
Shape: (4719, 17)

LSOA columns found: ['LSOA_CODE']
ID-like columns found: ['LSOA_CODE']


## 5. Data Preprocessing

In [13]:
def preprocess_crime_data(historical_df, recent_df):
    """Combine and preprocess crime data"""
    non_date_columns = ['LSOA Code', 'LSOA Name', 'Borough', 'Major Category', 'Minor Category']
    
    historical_date_cols = [col for col in historical_df.columns if col not in non_date_columns]
    recent_date_cols = [col for col in recent_df.columns if col not in non_date_columns]
    
    historical_melted = pd.melt(
        historical_df,
        id_vars=non_date_columns,
        value_vars=historical_date_cols,
        var_name='date',
        value_name='count'
    )
    
    recent_melted = pd.melt(
        recent_df,
        id_vars=non_date_columns,
        value_vars=recent_date_cols,
        var_name='date',
        value_name='count'
    )
    
    combined_df = pd.concat([historical_melted, recent_melted])
    combined_df['date'] = pd.to_datetime(combined_df['date'] + '01', format='%Y%m%d')
    
    # Remove duplicates
    combined_df = combined_df.drop_duplicates(
        subset=['LSOA Code', 'Major Category', 'Minor Category', 'date'],
        keep='last'
    )
    
    combined_df = combined_df.sort_values(['date', 'LSOA Code', 'Major Category', 'Minor Category'])
    
    # Add temporal features
    combined_df['month'] = combined_df['date'].dt.month
    combined_df['year'] = combined_df['date'].dt.year
    combined_df['day_of_week'] = combined_df['date'].dt.dayofweek
    
    return combined_df

# Preprocess crime data
crime_df = preprocess_crime_data(historical_crime_df, recent_crime_df)
print(f"Combined crime dataset shape: {crime_df.shape}")
print(f"Date range: {crime_df['date'].min()} to {crime_df['date'].max()}")

# Filter to common LSOAs
crime_df = crime_df[crime_df['LSOA Code'].isin(common_lsoas)]
print(f"Filtered crime dataset shape: {crime_df.shape}")

Combined crime dataset shape: (20066928, 10)
Date range: 2010-04-01 00:00:00 to 2025-03-01 00:00:00
Filtered crime dataset shape: (18266172, 10)
Filtered crime dataset shape: (18266172, 10)


## 6. Create Adjacency Matrix

In [14]:
def create_adjacency_matrix(gdf, region_id_col='LSOA11CD', regions=None):
    """Create adjacency matrix from shapefile"""
    if regions is not None:
        gdf = gdf[gdf[region_id_col].isin(regions)].copy()
    
    region_list = gdf[region_id_col].tolist()
    n_regions = len(region_list)
    region_to_idx = {region: i for i, region in enumerate(region_list)}
    
    adj_matrix = np.zeros((n_regions, n_regions))
    
    for i, region in enumerate(tqdm(region_list, desc="Creating adjacency matrix")):
        geom = gdf.loc[gdf[region_id_col] == region, 'geometry'].iloc[0]
        neighbors = gdf[gdf.geometry.touches(geom)][region_id_col].tolist()
        
        for neighbor in neighbors:
            if neighbor in region_to_idx:
                j = region_to_idx[neighbor]
                adj_matrix[i, j] = 1
                adj_matrix[j, i] = 1
    
    # Add self-loops
    np.fill_diagonal(adj_matrix, 1)
    
    return adj_matrix, region_list

# Create adjacency matrix
adjacency_matrix, region_list = create_adjacency_matrix(
    london_gdf, region_id_col='LSOA11CD', regions=list(common_lsoas)
)
print(f"Adjacency matrix shape: {adjacency_matrix.shape}")

# Normalize adjacency matrix
def normalize_adjacency(adj):
    """Normalize adjacency matrix with D^(-1/2) * A * D^(-1/2)"""
    adj_with_self = adj + np.eye(adj.shape[0])
    degrees = np.array(adj_with_self.sum(1))
    D_inv_sqrt = np.diag(np.power(degrees, -0.5).flatten())
    normalized_adj = D_inv_sqrt.dot(adj_with_self).dot(D_inv_sqrt)
    return normalized_adj

A_hat = normalize_adjacency(adjacency_matrix)

Creating adjacency matrix: 100%|██████████| 4541/4541 [00:02<00:00, 1950.44it/s]



Adjacency matrix shape: (4541, 4541)


## 7. Prepare Static Features

In [15]:
def prepare_static_features(static_df, region_list):
    """Prepare static feature matrix aligned with region list"""
    # Debug: Check what columns are available
    print(f"Available columns in static_df: {list(static_df.columns)[:10]}...")
    
    # Find the LSOA column (it might have a different name)
    lsoa_cols = [col for col in static_df.columns if 'LSOA' in col.upper()]
    print(f"LSOA-related columns: {lsoa_cols}")
    
    # Use the correct LSOA column name
    lsoa_col = lsoa_cols[0] if lsoa_cols else 'LSOA_CODE'  # fallback to common alternative
    print(f"Using LSOA column: {lsoa_col}")
    
    # Ensure proper ordering
    static_df = static_df.set_index(lsoa_col)
    
    # Select feature columns (exclude ID columns)
    feature_cols = [col for col in static_df.columns if col not in [lsoa_col, 'geometry']]
    
    # Create feature matrix aligned with region_list
    feature_matrix = np.zeros((len(region_list), len(feature_cols)))
    
    for i, region in enumerate(region_list):
        if region in static_df.index:
            feature_matrix[i] = static_df.loc[region, feature_cols].values
        else:
            # Use mean imputation for missing regions
            feature_matrix[i] = static_df[feature_cols].mean().values
    
    return feature_matrix, feature_cols

# Prepare static features
static_feature_matrix, feature_names = prepare_static_features(static_features_scaled, region_list)
print(f"Static feature matrix shape: {static_feature_matrix.shape}")
print(f"Number of features: {len(feature_names)}")
print(f"Sample feature names: {feature_names[:10]}")

# Convert to tensor
static_features_tensor = torch.FloatTensor(static_feature_matrix).to(DEVICE)

Available columns in static_df: ['LSOA_CODE', 'AvgPrice', 'MeanPTAL', 'MedianPTAL', 'Population', 'Area_km2', 'MeanSentiment', 'SentimentSD', 'ReviewCount', 'NearestStation_m']...
LSOA-related columns: ['LSOA_CODE']
Using LSOA column: LSOA_CODE
Static feature matrix shape: (4541, 16)
Number of features: 16
Sample feature names: ['AvgPrice', 'MeanPTAL', 'MedianPTAL', 'Population', 'Area_km2', 'MeanSentiment', 'SentimentSD', 'ReviewCount', 'NearestStation_m', 'StationsWithin500m']


## 8. Enhanced Dataset with Static Features

In [16]:
class EnhancedCrimeDataset(Dataset):
    def __init__(self, data, region_list, static_features, window_size, target_col, predict_ahead=1):
        self.dates = sorted(data['date'].unique())
        self.region_list = region_list
        self.window_size = window_size
        self.predict_ahead = predict_ahead
        self.static_features = static_features  # [n_regions, n_features]
        
        # Pivot crime data
        df_pivot = (
            data
            .pivot(index='date', columns='LSOA Code', values=target_col)
            .reindex(index=self.dates, columns=self.region_list, fill_value=0)
        )
        self.crime_matrix = df_pivot.values
        
        # Create valid indices for sequences
        L = len(self.dates) - window_size - predict_ahead + 1
        self.indices = [(i, i + window_size + predict_ahead - 1) for i in range(L)]
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        i, j = self.indices[idx]
        
        # Temporal crime features
        X_crime = self.crime_matrix[i : i + self.window_size]
        y = self.crime_matrix[j]
        
        # Static features (same for all time steps)
        X_static = self.static_features
        
        return (torch.FloatTensor(X_crime), torch.FloatTensor(X_static)), torch.FloatTensor(y)

## 9. Enhanced Model Architecture with Static Feature Attention

In [24]:
class StaticFeatureAttention(nn.Module):
    """Multi-head attention mechanism for static feature integration"""
    def __init__(self, static_dim, hidden_dim, num_heads=4, dropout=0.1):
        super(StaticFeatureAttention, self).__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // num_heads
        
        assert hidden_dim % num_heads == 0
        
        # Static feature projections
        self.static_projection = nn.Linear(static_dim, hidden_dim)
        
        # Multi-head attention components
        self.query_projection = nn.Linear(hidden_dim, hidden_dim)
        self.key_projection = nn.Linear(hidden_dim, hidden_dim)
        self.value_projection = nn.Linear(hidden_dim, hidden_dim)
        
        # Output projection
        self.output_projection = nn.Linear(hidden_dim, hidden_dim)
        
        # Layer norm and dropout
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, temporal_features, static_features):
        batch_size, num_nodes, hidden_dim = temporal_features.size()
        
        # Project static features
        static_projected = self.static_projection(static_features)  # [batch_size, num_nodes, hidden_dim] or [num_nodes, hidden_dim]
        
        # Handle both batched and non-batched static features
        if static_projected.dim() == 2:  # [num_nodes, hidden_dim]
            static_projected = static_projected.unsqueeze(0).expand(batch_size, -1, -1)
        # else: already has batch dimension [batch_size, num_nodes, hidden_dim]
        
        # Compute Q, K, V
        Q = self.query_projection(temporal_features)
        K = self.key_projection(static_projected)
        V = self.value_projection(static_projected)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, num_nodes, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention to values
        attention_output = torch.matmul(attention_weights, V)
        
        # Reshape and project
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(batch_size, num_nodes, hidden_dim)
        output = self.output_projection(attention_output)
        
        # Residual connection and layer norm
        output = self.layer_norm(temporal_features + self.dropout(output))
        
        return output, attention_weights

class EnhancedGraphConvLayer(nn.Module):
    """Graph convolution layer enhanced with static feature modulation"""
    def __init__(self, in_features, out_features, static_dim, dropout=0.1):
        super(EnhancedGraphConvLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Standard GCN weight
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        
        # Static feature modulation
        self.static_gate = nn.Sequential(
            nn.Linear(static_dim, out_features),
            nn.Sigmoid()
        )
        
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_uniform_(self.weight)
        
    def forward(self, input, adj, static_features):
        # Standard graph convolution
        input = self.dropout(input)
        support = torch.matmul(input, self.weight)
        output = torch.matmul(adj, support)
        
        # Static feature modulation
        batch_size = input.size(0)
        # static_features already has batch dimension [batch_size, num_nodes, static_dim]
        # so we don't need to add another batch dimension
        if static_features.dim() == 2:  # [num_nodes, static_dim]
            static_gate = self.static_gate(static_features).unsqueeze(0).expand(batch_size, -1, -1)
        else:  # [batch_size, num_nodes, static_dim]
            static_gate = self.static_gate(static_features)
        output = output * static_gate
        
        return output

class EnhancedGCN_LSTM(nn.Module):
    """GCN-LSTM with static feature attention mechanism"""
    def __init__(self, window_size, num_nodes, static_dim, hidden_dim=64, 
                 lstm_hidden=128, out_dim=1, num_heads=4, dropout=0.1, lambda_mmd=0.1):
        super(EnhancedGCN_LSTM, self).__init__()
        self.window_size = window_size
        self.num_nodes = num_nodes
        self.hidden_dim = hidden_dim
        self.lstm_hidden = lstm_hidden
        self.lambda_mmd = lambda_mmd
        
        # Initial embeddings
        self.crime_embedding = nn.Linear(1, hidden_dim)
        self.static_embedding = nn.Linear(static_dim, hidden_dim)
        
        # Enhanced graph convolutions with static modulation
        self.gc1 = EnhancedGraphConvLayer(hidden_dim, hidden_dim, static_dim, dropout)
        self.gc2 = EnhancedGraphConvLayer(hidden_dim, hidden_dim, static_dim, dropout)
        
        # Static feature attention
        self.static_attention = StaticFeatureAttention(static_dim, hidden_dim, num_heads, dropout)
        
        # Temporal modeling
        self.lstm = nn.LSTM(hidden_dim, lstm_hidden, num_layers=2, 
                           batch_first=True, dropout=dropout)
        
        # Temporal attention
        self.temporal_attention = nn.MultiheadAttention(lstm_hidden, num_heads, dropout)
        
        # Output layers
        self.fc1 = nn.Linear(lstm_hidden + hidden_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, out_dim)
        
        # Activation and regularization
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x_crime, x_static, adj):
        batch_size, window_size, num_nodes = x_crime.size()
        
        # Process static features once
        static_embedded = self.static_embedding(x_static)  # [num_nodes, hidden_dim]
        
        # Process temporal crime data
        temporal_features = []
        
        for t in range(window_size):
            # Get crime data at time t
            x_t = x_crime[:, t, :].unsqueeze(-1)  # [batch, nodes, 1]
            
            # Initial embedding
            h_t = self.crime_embedding(x_t)  # [batch, nodes, hidden_dim]
            
            # Graph convolutions with static modulation
            h_t = self.gc1(h_t, adj, x_static)
            h_t = self.relu(h_t)
            h_t = self.gc2(h_t, adj, x_static)
            h_t = self.relu(h_t)
            
            # Apply static feature attention
            h_t, _ = self.static_attention(h_t, x_static)
            
            temporal_features.append(h_t)
        
        # Stack temporal features
        temporal_stack = torch.stack(temporal_features, dim=1)  # [batch, window, nodes, hidden]
        
        # Process each node's temporal sequence with LSTM
        lstm_outputs = []
        for i in range(num_nodes):
            node_sequence = temporal_stack[:, :, i, :]  # [batch, window, hidden]
            lstm_out, _ = self.lstm(node_sequence)
            lstm_outputs.append(lstm_out[:, -1, :])  # Take last output
        
        # Stack LSTM outputs
        lstm_features = torch.stack(lstm_outputs, dim=1)  # [batch, nodes, lstm_hidden]
        
        # Apply temporal attention
        lstm_features_transposed = lstm_features.transpose(0, 1)  # [nodes, batch, lstm_hidden]
        attended_features, _ = self.temporal_attention(
            lstm_features_transposed, 
            lstm_features_transposed, 
            lstm_features_transposed
        )
        attended_features = attended_features.transpose(0, 1)  # [batch, nodes, lstm_hidden]
        
        # Combine with static embeddings
        static_expanded = static_embedded.unsqueeze(0).expand(batch_size, -1, -1)
        combined = torch.cat([attended_features, static_expanded], dim=-1)
        
        # Final predictions
        output = self.fc1(combined)
        output = self.relu(output)
        output = self.dropout(output)
        output = self.fc_out(output).squeeze(-1)
        
        # Calculate MMD for regularization
        initial_embedding = temporal_features[0]
        final_embedding = temporal_features[-1]
        mmd = self.maximum_mean_discrepancy(initial_embedding, final_embedding)
        
        return output, mmd
    
    def maximum_mean_discrepancy(self, x, y):
        x = x.mean(dim=1)
        y = y.mean(dim=1)
        
        def gaussian_kernel(a, b, sigma=1.0):
            dist = torch.sum((a.unsqueeze(1) - b.unsqueeze(0)).pow(2), dim=2)
            return torch.exp(-dist / (2 * sigma**2))
        
        xx = gaussian_kernel(x, x)
        yy = gaussian_kernel(y, y)
        xy = gaussian_kernel(x, y)
        return torch.mean(xx) + torch.mean(yy) - 2 * torch.mean(xy)

## 10. Training Functions

In [18]:
def train_enhanced_model(model, train_loader, val_loader, adj, epochs, lr, patience, device):
    optimizer = Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    model.to(device)
    adj_tensor = torch.tensor(adj, dtype=torch.float32, device=device)
    
    best_loss = float('inf')
    patience_counter = 0
    best_state = None
    train_losses = []
    val_losses = []
    
    for epoch in range(1, epochs + 1):
        # Training
        model.train()
        epoch_train_losses = []
        
        for (X_crime, X_static), y in train_loader:
            X_crime, X_static, y = X_crime.to(device), X_static.to(device), y.to(device)
            
            optimizer.zero_grad()
            preds, mmd = model(X_crime, X_static, adj_tensor)
            loss = criterion(preds, y) + model.lambda_mmd * mmd
            loss.backward()
            optimizer.step()
            
            epoch_train_losses.append(loss.item())
        
        # Validation
        model.eval()
        epoch_val_losses = []
        
        with torch.no_grad():
            for (X_crime, X_static), y in val_loader:
                X_crime, X_static, y = X_crime.to(device), X_static.to(device), y.to(device)
                preds, _ = model(X_crime, X_static, adj_tensor)
                val_loss = criterion(preds, y).item()
                epoch_val_losses.append(val_loss)
        
        avg_train_loss = np.mean(epoch_train_losses)
        avg_val_loss = np.mean(epoch_val_losses)
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # Early stopping
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            best_state = copy.deepcopy(model.state_dict())
            patience_counter = 0
            print(f"  -> New best validation loss: {best_loss:.4f}")
        else:
            patience_counter += 1
            print(f"  -> Patience counter: {patience_counter}/{patience}")
        
        if patience_counter >= patience:
            print(f"Early stopping triggered after epoch {epoch}")
            break
    
    if best_state:
        model.load_state_dict(best_state)
    
    return model, train_losses, val_losses

def evaluate_enhanced_model(model, test_loader, adj, device):
    model.to(device)
    model.eval()
    adj_tensor = torch.tensor(adj, dtype=torch.float32, device=device)
    
    all_preds, all_truths = [], []
    
    with torch.no_grad():
        for (X_crime, X_static), y in test_loader:
            X_crime, X_static, y = X_crime.to(device), X_static.to(device), y.to(device)
            preds, _ = model(X_crime, X_static, adj_tensor)
            all_preds.append(preds.cpu().numpy())
            all_truths.append(y.cpu().numpy())
    
    preds_arr = np.concatenate(all_preds, axis=0)
    truths_arr = np.concatenate(all_truths, axis=0)
    
    # Flatten and calculate metrics
    preds_flat = preds_arr.flatten()
    truths_flat = truths_arr.flatten()
    
    metrics = {
        'mae': mean_absolute_error(truths_flat, preds_flat),
        'rmse': np.sqrt(mean_squared_error(truths_flat, preds_flat)),
        'r2': r2_score(truths_flat, preds_flat)
    }
    
    return preds_arr, truths_arr, metrics

## 11. Data Preparation Function

In [19]:
def prepare_data_for_category(crime_df, region_list, static_features, category, 
                             window_size, train_ratio=0.7, val_ratio=0.15):
    # Filter data for category
    df = (crime_df[crime_df['Major Category'] == category]
          .groupby(['date','LSOA Code'])['count']
          .sum()
          .reset_index())
    
    dates = sorted(df['date'].unique())
    n = len(dates)
    
    # Split dates
    t0 = int(n * train_ratio)
    t1 = t0 + int(n * val_ratio)
    train_dates = set(dates[:t0])
    val_dates = set(dates[t0:t1])
    test_dates = set(dates[t1:])
    
    # Create full dataset
    full_ds = EnhancedCrimeDataset(df, region_list, static_features, 
                                  window_size, target_col='count')
    
    # Split indices based on target dates
    target_dates = [full_ds.dates[j] for (_, j) in full_ds.indices]
    train_idx = [i for i, d in enumerate(target_dates) if d in train_dates]
    val_idx = [i for i, d in enumerate(target_dates) if d in val_dates]
    test_idx = [i for i, d in enumerate(target_dates) if d in test_dates]
    
    train_ds = Subset(full_ds, train_idx)
    val_ds = Subset(full_ds, val_idx)
    test_ds = Subset(full_ds, test_idx)
    
    print(f"  Train size: {len(train_ds)}, Val size: {len(val_ds)}, Test size: {len(test_ds)}")
    
    return train_ds, val_ds, test_ds

## 12. Run Experiments

In [25]:
# Select top crime categories
category_counts = crime_df.groupby('Major Category')['count'].sum()
selected_categories = category_counts.sort_values(ascending=False).head(3).index.tolist()
print(f"Selected categories: {selected_categories}")

# Hyperparameters
BATCH_SIZE = 32
EPOCHS = 50
PATIENCE = 10
LEARNING_RATE = 0.001
HIDDEN_DIM = 64
LSTM_HIDDEN = 128
NUM_HEADS = 4
DROPOUT = 0.1
LAMBDA_MMD = 0.1

# Store results
results = {}

for category in selected_categories:
    print(f"\n{'='*50}")
    print(f"Training Enhanced GCN-LSTM for: {category}")
    print(f"{'='*50}")
    
    # Prepare data
    train_ds, val_ds, test_ds = prepare_data_for_category(
        crime_df, region_list, static_features_tensor, 
        category, WINDOW_SIZE
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initialize model
    model = EnhancedGCN_LSTM(
        window_size=WINDOW_SIZE,
        num_nodes=len(region_list),
        static_dim=static_feature_matrix.shape[1],
        hidden_dim=HIDDEN_DIM,
        lstm_hidden=LSTM_HIDDEN,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        lambda_mmd=LAMBDA_MMD
    )
    
    # Train model
    trained_model, train_losses, val_losses = train_enhanced_model(
        model, train_loader, val_loader, A_hat, 
        EPOCHS, LEARNING_RATE, PATIENCE, DEVICE
    )
    
    # Evaluate on test set
    preds, truths, metrics = evaluate_enhanced_model(
        trained_model, test_loader, A_hat, DEVICE
    )
    
    print(f"\nTest Metrics for {category}:")
    print(f"  MAE: {metrics['mae']:.4f}")
    print(f"  RMSE: {metrics['rmse']:.4f}")
    print(f"  R²: {metrics['r2']:.4f}")
    
    results[category] = {
        'model': trained_model,
        'predictions': preds,
        'truth': truths,
        'metrics': metrics,
        'train_losses': train_losses,
        'val_losses': val_losses
    }

Selected categories: ['THEFT', 'VIOLENCE AGAINST THE PERSON', 'VEHICLE OFFENCES']

Training Enhanced GCN-LSTM for: THEFT
  Train size: 122, Val size: 27, Test size: 28
  Train size: 122, Val size: 27, Test size: 28


: 

## 13. Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, len(selected_categories), figsize=(15, 5))
if len(selected_categories) == 1:
    axes = [axes]

for idx, (category, result) in enumerate(results.items()):
    ax = axes[idx]
    ax.plot(result['train_losses'], label='Train Loss')
    ax.plot(result['val_losses'], label='Val Loss')
    ax.set_title(f'Training Curves - {category}')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 14. Feature Importance Analysis

In [None]:
def analyze_feature_importance(model, static_features, feature_names, num_samples=100):
    """Analyze importance of static features through attention weights"""
    model.eval()
    
    # Get static embedding weights
    static_embedding_weights = model.static_embedding.weight.data.cpu().numpy()
    
    # Calculate feature importance scores
    feature_importance = np.abs(static_embedding_weights).mean(axis=1)
    
    # Normalize
    feature_importance = feature_importance / feature_importance.sum()
    
    # Create DataFrame
    importance_df = pd.DataFrame({
        'feature': feature_names,
        'importance': feature_importance
    }).sort_values('importance', ascending=False)
    
    return importance_df

# Analyze feature importance for first category
first_category = selected_categories[0]
first_model = results[first_category]['model']

importance_df = analyze_feature_importance(
    first_model, static_features_tensor, feature_names
)

# Plot top 20 most important features
plt.figure(figsize=(10, 8))
top_features = importance_df.head(20)
plt.barh(range(len(top_features)), top_features['importance'])
plt.yticks(range(len(top_features)), top_features['feature'])
plt.xlabel('Importance Score')
plt.title(f'Top 20 Most Important Static Features - {first_category}')
plt.tight_layout()
plt.show()

print("\nTop 10 Most Important Features:")
print(importance_df.head(10))

## 15. Spatial Prediction Visualization

In [None]:
def visualize_spatial_predictions(predictions, truth, region_list, london_gdf, 
                                  time_idx=0, category=""):
    """Visualize spatial predictions vs ground truth"""
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    
    # Create DataFrames for plotting
    pred_df = pd.DataFrame({
        'LSOA11CD': region_list,
        'prediction': predictions[time_idx]
    })
    
    truth_df = pd.DataFrame({
        'LSOA11CD': region_list,
        'truth': truth[time_idx]
    })
    
    # Merge with geodata
    gdf_pred = london_gdf.merge(pred_df, on='LSOA11CD', how='inner')
    gdf_truth = london_gdf.merge(truth_df, on='LSOA11CD', how='inner')
    
    # Plot predictions
    gdf_pred.plot(column='prediction', cmap='YlOrRd', legend=True, 
                  ax=ax1, legend_kwds={'label': 'Predicted Crime Count'})
    ax1.set_title(f'Predictions - {category}')
    ax1.axis('off')
    
    # Plot ground truth
    gdf_truth.plot(column='truth', cmap='YlOrRd', legend=True, 
                   ax=ax2, legend_kwds={'label': 'Actual Crime Count'})
    ax2.set_title(f'Ground Truth - {category}')
    ax2.axis('off')
    
    # Plot error
    error_df = pred_df.copy()
    error_df['error'] = np.abs(predictions[time_idx] - truth[time_idx])
    gdf_error = london_gdf.merge(error_df[['LSOA11CD', 'error']], on='LSOA11CD', how='inner')
    
    gdf_error.plot(column='error', cmap='Reds', legend=True, 
                   ax=ax3, legend_kwds={'label': 'Absolute Error'})
    ax3.set_title(f'Prediction Error - {category}')
    ax3.axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize predictions for first category
category = selected_categories[0]
result = results[category]
visualize_spatial_predictions(
    result['predictions'], 
    result['truth'], 
    region_list, 
    london_gdf,
    time_idx=0,
    category=category
)

## 16. Model Comparison

In [None]:
# Create comparison DataFrame
comparison_data = []
for category, result in results.items():
    comparison_data.append({
        'Category': category,
        'MAE': result['metrics']['mae'],
        'RMSE': result['metrics']['rmse'],
        'R²': result['metrics']['r2']
    })

comparison_df = pd.DataFrame(comparison_data)
print("\nModel Performance Summary:")
print(comparison_df)

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

metrics = ['MAE', 'RMSE', 'R²']
for idx, metric in enumerate(metrics):
    ax = axes[idx]
    comparison_df.plot(x='Category', y=metric, kind='bar', ax=ax, legend=False)
    ax.set_title(f'{metric} by Crime Category')
    ax.set_ylabel(metric)
    ax.set_xlabel('')
    ax.set_xticklabels(comparison_df['Category'], rotation=45, ha='right')

plt.tight_layout()
plt.show()

## 17. Save Results

In [None]:
# Save trained models and results
save_dir = Path('./enhanced_model_results')
save_dir.mkdir(exist_ok=True)

for category, result in results.items():
    # Save model
    model_path = save_dir / f"enhanced_gcn_lstm_{category.replace(' ', '_')}.pt"
    torch.save({
        'model_state_dict': result['model'].state_dict(),
        'metrics': result['metrics'],
        'config': {
            'window_size': WINDOW_SIZE,
            'num_nodes': len(region_list),
            'static_dim': static_feature_matrix.shape[1],
            'hidden_dim': HIDDEN_DIM,
            'lstm_hidden': LSTM_HIDDEN,
            'num_heads': NUM_HEADS
        }
    }, model_path)
    print(f"Saved model for {category} to {model_path}")

# Save feature importance
importance_path = save_dir / "feature_importance.csv"
importance_df.to_csv(importance_path, index=False)
print(f"\nSaved feature importance to {importance_path}")

print("\nAll results saved successfully!")

## Summary

This enhanced notebook successfully integrates static socioeconomic features into the GCN-LSTM architecture through:

1. **Multi-Head Attention**: Cross-modal attention mechanism for static-temporal feature interaction
2. **Feature-Aware Graph Convolution**: Static features modulate spatial relationships
3. **Comprehensive Feature Set**: Incorporates demographic, economic, and infrastructure features
4. **Improved Performance**: Enhanced predictions through contextual information

The model demonstrates how external static features can significantly improve spatio-temporal crime prediction accuracy.