# 03. Model Implementation & Training

This notebook covers the **Model Implementation & Training** phase (Chapter 5.3) of our case study. We aim to learn low-dimensional vector representations (embeddings) for every Output Area (OA) in Liverpool using Graph Auto-Encoders (GAE).

## Objectives
1.  **Train Graph Models**: We interpret the urban system as a graph and use GAEs to compress the structural and attribute information.
2.  **Compare Architectures**: We evaluate different encoders:
    *   **GAT (Homogeneous)**: Uses only spatial contiguity.
    *   **HAN (Heterogeneous)**: Incorporates street networks and public transit (bus) to learn richer representations.
3.  **Cluster & Visualize**: We apply K-Means to the learned embeddings to identify functional regions and visualize the results spatially and in latent space (t-SNE).

## Models Overview
*   **Baseline**: K-Means on raw features (no graph learning).
*   **Model 1 (GAT-GAE)**: Homogeneous graph (OA contiguity).
*   **Model 2 (HAN-GAE)**: Heterogeneous graph (OA + Street accessibility).
*   **Model 3 (HAN-GAE)**: Heterogeneous graph (OA + Street + Bus accessibility).

In [None]:
import sys
import os

# Add project root to path to allow imports from src
sys.path.append(os.path.abspath('..'))

# If on Google Colab, install city2graph
#!pip install city2graph

import yaml
import random
import torch
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
#import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import city2graph as c2g
from torch_geometric.data import HeteroData
from src.models import GATGAE, HANGAE
from src.baselines.kmeans import run_kmeans

%matplotlib inline

## 1. Configuration & Setup
We load hyperparameters from `configs/experiment_config.yaml`. This ensures reproducibility and easy tuning of parameters like learning rate, hidden dimensions, and structure loss weights.

In [None]:
def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

config = load_config('../configs/experiment_config.yaml')

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"Random seed set to: {seed}")

set_seeds(config['seeds']['global'])
device = torch.device(config['experiment']['device'] if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Ensure output directories exist
os.makedirs("../outputs/checkpoints", exist_ok=True)
os.makedirs("../outputs/embeddings", exist_ok=True)
os.makedirs("../outputs/clusters", exist_ok=True)
os.makedirs("../outputs/figures", exist_ok=True)

## 2. Training Utility
This general training loop handles:
1.  **Optimization**: Updating model weights using Adam.
2.  **Loss Calculation**: Combining Feature Reconstruction Loss (Smooth L1) and Structure Reconstruction Loss (DistMult).
3.  **Early Stopping**: Preventing overfitting by monitoring loss improvement.
4.  **Embedding Generation**: Extracting the latent `z` vectors from the best model state.
5.  **Clustering**: Automatically clustering the resulting embeddings.

In [None]:
def train_model(model, data, config, model_name):
    print(f"Training {model_name}...")
    optimizer = optim.Adam(model.parameters(), lr=config['training']['lr'], weight_decay=config['training']['weight_decay'])
    best_loss = float('inf')
    patience = config['training']['early_stopping_patience']
    patience_counter = 0
    
    checkpoint_dir = f"../outputs/checkpoints/{model_name}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    losses = []
    
    for epoch in range(config['training']['epochs']):
        model.train()
        optimizer.zero_grad()
        loss, l_feat, l_struct = model.compute_loss(
            data, 
            lambda_struct=config['model']['lambda_struct'],
            neg_sampling_scale=config['training']["negative_sampling_scale"]
        )
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        # Logging
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Total Loss: {loss.item():.4f}, Feat Loss: {l_feat.item():.4f}, Struct Loss: {l_struct.item():.4f}")
            
        # Early Stopping
        if loss.item() < best_loss:
            best_loss = loss.item()
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), f"{checkpoint_dir}/model.pt")
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
            
    # Load best model for embedding generation
    model.load_state_dict(torch.load(f"{checkpoint_dir}/model.pt"))
    model.eval()
    with torch.no_grad():
        if isinstance(model, GATGAE):
            z, _ = model(data)
        else: # HANGAE
            z, _, beta = model(data)
        
        # Save embeddings
        torch.save(z, f"../outputs/embeddings/{model_name}.pt")
        
    return losses

## 3. Visualization Utilities
To interpret the results, we use three key visualizations:
1.  **Training Curve**: Ensures the model is learning (loss should decrease).
2.  **Spatial Map**: Plots the OAs color-coded by cluster assignment to reveal geographic patterns. Uses `city2graph` plotting utility.
3.  **t-SNE Plot**: Projects the high-dimensional embeddings to 2D using t-SNE to visualize the manifold structure and cluster separability.

In [None]:
import geopandas as gpd
oa = gpd.read_file("../data/processed/features/oa_with_features.gpkg")

oa = oa[["OA21CD", "geometry"]].set_index("OA21CD")

In [None]:
def identify_best_cluster_number(z, min_k=1, max_k=15, algorithm=KMeans, seed=config['seeds']['clustering'], **kwargs):
    """Determines optimal cluster number using Silhouette Score."""
    print(f"Determining optimal clusters using {algorithm.__name__}...")
    best_score = -1
    best_k = min_k
    
    X_np = z.cpu().numpy() if isinstance(z, torch.Tensor) else z
    
    for k in range(min_k, max_k + 1):
        model = algorithm(n_clusters=k, random_state=seed, **kwargs)
        lbls = model.fit_predict(X_np)
        if len(set(lbls)) < 2:
             continue
        score = silhouette_score(X_np, lbls)
        # print(f"k={k}, score={score:.4f}")
        if score > best_score:
            best_score = score
            best_k = k
            
    print(f"Optimal k={best_k} (Score: {best_score:.4f})")
    return best_k

def perform_clustering(z, n_clusters, algorithm=KMeans, output_file=None, seed=config['seeds']['clustering'], **kwargs):
    """Runs clustering and optionally saves to CSV."""
    X_np = z.cpu().numpy() if isinstance(z, torch.Tensor) else z
    model = algorithm(n_clusters=n_clusters, random_state=seed, **kwargs)
    labels = model.fit_predict(X_np)
    
    if output_file:
        df = pd.DataFrame({'cluster': labels})
        df.to_csv(output_file, index=False)
        print(f"Clusters saved to {output_file}")
        
    return labels

def plot_training_curve(losses, model_name):
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Total Loss')
    plt.title(f"{model_name} Training Convergence")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(f"../outputs/figures/{model_name}_loss.png")
    plt.show()

def plot_clusters_spatial(data, cluster_labels=None, cluster_file=None, model_name="Model"):
    """Plots clusters using standard GeoPandas plot (Choropleth). Accepts either labels or a file."""
    print(f"\n--- Plotting Clusters for {model_name} ---")
    
    # Load Reference Geometry (OA)
    oa_path = "../data/processed/features/oa_with_features.gpkg"
    if not os.path.exists(oa_path):
        print(f"Geometry file not found at {oa_path}")
        return
    
    # Load only geometry and ID to be safe
    oa_geom = gpd.read_file(oa_path)[["OA21CD", "geometry"]].set_index("OA21CD")

    if cluster_labels is None:
        if cluster_file and os.path.exists(cluster_file):
            clusters_df = pd.read_csv(cluster_file)
            cluster_labels = clusters_df['cluster'].values
            print(f"Loaded clusters from {cluster_file}: {len(cluster_labels)} rows")
        else:
            print("No cluster labels or valid file provided.")
            return

    # Recover PyG Data to get Node IDs / Index
    if isinstance(data, HeteroData):
        data.graph_metadata.edge_types = list(data.edge_types)
        nodes_dict, _ = c2g.pyg_to_gdf(data)
        gdf_nodes = nodes_dict['oa']
    else:
        gdf_nodes, _ = c2g.pyg_to_gdf(data)
        
    # Assign clusters to the inferred node order
    if len(cluster_labels) != len(gdf_nodes):
         print(f"Warning: Length mismatch. Nodes: {len(gdf_nodes)}, Clusters: {len(cluster_labels)}")
         return

    gdf_nodes['cluster'] = cluster_labels
    
    # Merge with geometry
    gdf_nodes = gdf_nodes.drop(columns=['geometry']).join(oa_geom, how='inner', rsuffix='_oa')
    print(f"Merged for Plotting: {len(gdf_nodes)} rows")
    
    # Normalize geometry column
    gdf_nodes = gdf_nodes.rename(columns={'geometry_oa': 'geometry'}).set_geometry('geometry')
    
    # PLOT using GeoPandas with enhanced aesthetics
    fig, ax = plt.subplots(figsize=(12, 10), dpi=150)
    
    gdf_nodes.plot(
        column='cluster', 
        categorical=True,
        cmap='tab20', 
        legend=True, 
        legend_kwds={'loc': 'center left', 'bbox_to_anchor': (1, 0.5), 'title': 'Cluster ID', 'frameon': False},
        ax=ax,
        edgecolor='white',
        linewidth=0.05
    )
    
    ax.set_title(f"{model_name} Spatial Clusters", fontsize=16, fontweight='bold', pad=20)
    ax.set_axis_off()
    
    output_fig = f"../outputs/figures/{model_name}_spatial.png"
    plt.savefig(output_fig, bbox_inches='tight', dpi=300)
    print(f"Plot saved to {output_fig}")
    plt.show()


def plot_embeddings_tsne(embedding_file, cluster_file=None, cluster_labels=None, model_name="Model", seed=config['seeds']['tsne']):
    """Plots t-SNE of the embeddings colored by cluster."""
    if not os.path.exists(embedding_file):
        print(f"Missing embedding file: {embedding_file}")
        return
        
    z = torch.load(embedding_file, map_location=torch.device('cpu'), weights_only=False)
    if isinstance(z, torch.Tensor):
        z = z.cpu().numpy()
        
    if cluster_labels is None:
         if cluster_file and os.path.exists(cluster_file):
             cluster_labels = pd.read_csv(cluster_file)['cluster'].values
         else:
             print("No cluster labels provided for t-SNE.")
             return
    
    print("Running t-SNE...")
    tsne = TSNE(n_components=2, perplexity=30, max_iter=1000, random_state=seed)
    z_tsne = tsne.fit_transform(z)
    
    # Plotting with enhanced aesthetics
    fig, ax = plt.subplots(figsize=(10, 8), dpi=150)
    
    scatter = ax.scatter(
        z_tsne[:, 0], 
        z_tsne[:, 1], 
        c=cluster_labels, 
        cmap='tab20', 
        s=15, 
        alpha=0.7,
        edgecolor='none'
    )
    
    # Create a legend for clusters
    legend1 = ax.legend(*scatter.legend_elements(), title="Cluster", loc="upper right", bbox_to_anchor=(1.15, 1), frameon=False)
    ax.add_artist(legend1)
    
    ax.set_title(f"{model_name} Embeddings (t-SNE)", fontsize=14, fontweight='bold')
    ax.set_xlabel("Dimension 1", fontsize=10)
    ax.set_ylabel("Dimension 2", fontsize=10)
    
    # Remove top

## 4. Workflows

### 4.1. Baseline: K-Means on Raw Features
This establishes a benchmark. We cluster OAs based purely on their feature vectors (POI counts vs Land Use) without any relationship information.

In [None]:
print("Running Baseline 1: K-Means on Raw Features...")
homo_path = os.path.join("../", config['data']['root'], config['data']['homo'])

homo_data = torch.load(homo_path, map_location=device, weights_only=False)

# Baseline Clustering
baseline_k = identify_best_cluster_number(homo_data.x, algorithm=KMeans)
baseline_labels = perform_clustering(homo_data.x, baseline_k, algorithm=KMeans, output_file='../outputs/clusters/baseline1_kmeans.csv')
print("Baseline clustering completed.")

# Visualizations
plot_clusters_spatial(homo_data, cluster_labels=baseline_labels, model_name='Baseline (K-Means)')


### 4.2. Model 1: GAT-GAE (Homogeneous)
This model learns embeddings by aggregating features from spatially contiguous neighbors. It assumes that functional regions are spatially smooth.

In [None]:
set_seeds(config['seeds']['model']) # Ensure deterministic model training initialization
homo_data = homo_data.to(device)
gat_model = GATGAE(
    in_dim=config['model']['in_dim'],
    hidden_dim=config['model']['hidden_dim'],
    out_dim=config['model']['out_dim'],
    heads=config['model']['heads'],
    dropout=config['model']['dropout']
).to(device)

losses_model1 = train_model(gat_model, homo_data, config, "model1_gat_gae")

In [None]:
# Load Embeddings
z1 = torch.load('../outputs/embeddings/model1_gat_gae.pt', weights_only=False)

#z1 = F.normalize(z1, p=2, dim=-1)

# Clustering Workflow
best_k1 = identify_best_cluster_number(z1, algorithm=KMeans)
labels1 = perform_clustering(z1, best_k1, algorithm=KMeans, output_file='../outputs/clusters/model1_gat_gae.csv')

# Visualizations
plot_training_curve(losses_model1, "Model 1")
plot_clusters_spatial(homo_data, cluster_labels=labels1, model_name='Model 1 (GAT-GAE)')
plot_embeddings_tsne('../outputs/embeddings/model1_gat_gae.pt', cluster_labels=labels1, model_name='Model 1 (GAT-GAE)')

### 4.3. Model 2: HAN-GAE (Street)
This model introduces heterogeneity. It learns to aggregate information not just from contiguous OAs, but also from OAs connected via the street network (15-min walk). This captures functional connectivity that might skip boundaries (e.g., across a park or river connected by a bridge).

In [None]:
set_seeds(config['seeds']['model']) # Ensure deterministic model training initialization
print("\nRunning Model 2: HAN-GAE (Street)...")
street_path = os.path.join("../", config['data']['root'], config['data']['hetero_street'])

street_data = torch.load(street_path, map_location=device, weights_only=False)

# Define metapaths for Model 2
metapaths_street = ['is_contiguous_to', 'M_15min_walk']

han_street = HANGAE(
    in_dim=config['model']['in_dim'],
    hidden_dim=config['model']['hidden_dim'],
    out_dim=config['model']['out_dim'],
    metapaths=metapaths_street,
    heads=config['model']['heads'],
    dropout=config['model']['dropout']
).to(device)

losses_model2 = train_model(han_street, street_data, config, "model2_han_gae_street")

In [None]:
# Load Embeddings
z2 = torch.load('../outputs/embeddings/model2_han_gae_street.pt', weights_only=False)

#z2 = F.normalize(z2, p=2, dim=-1)

# Clustering Workflow
best_k2 = identify_best_cluster_number(z2, algorithm=KMeans)
labels2 = perform_clustering(z2, best_k2, algorithm=KMeans, output_file='../outputs/clusters/model2_han_gae_street.csv')

# Visualizations
plot_training_curve(losses_model2, "Model 2")
plot_clusters_spatial(street_data, cluster_labels=labels2, model_name='Model 2 (Street)')
plot_embeddings_tsne('../outputs/embeddings/model2_han_gae_street.pt', cluster_labels=labels2, model_name='Model 2 (Street)')

### 4.4. Model 3: HAN-GAE (Multi-modal)
This is the most advanced model. It incorporates public transit (Bus) edges into the accessibility layer. OAs are connected if reachable within 15 minutes by *walking OR bus*. This is expected to cluster disjoint areas that are functionally linked by rapid transit.

In [None]:
set_seeds(config['seeds']['model']) # Ensure deterministic model training initialization
print("\nRunning Model 3: HAN-GAE (Multi-modal)...")
multi_path = os.path.join("../", config['data']['root'], config['data']['hetero_multi'])

multi_data = torch.load(multi_path, map_location=device, weights_only=False)

# Metapaths for Model 3
metapaths_multi = ['is_contiguous_to', 'M_15min_walk', 'M_15min_multi']

han_multi = HANGAE(
    in_dim=config['model']['in_dim'],
    hidden_dim=config['model']['hidden_dim'],
    out_dim=config['model']['out_dim'],
    metapaths=metapaths_multi,
    heads=config['model']['heads'],
    dropout=config['model']['dropout']
).to(device)

losses_model3 = train_model(han_multi, multi_data, config, "model3_han_gae_multi")

In [None]:
# Load Embeddings
z3 = torch.load('../outputs/embeddings/model3_han_gae_multi.pt', weights_only=False)

#z3 = F.normalize(z3, p=2, dim=-1)
 
# Clustering Workflow
best_k3 = identify_best_cluster_number(z3, algorithm=KMeans)
labels3 = perform_clustering(z3, 5, algorithm=KMeans, output_file='../outputs/clusters/model3_han_gae_multi.csv')

# Visualizations
plot_training_curve(losses_model3, "Model 3")
plot_clusters_spatial(multi_data, cluster_labels=labels3, model_name='Model 3 (Multi)')
plot_embeddings_tsne('../outputs/embeddings/model3_han_gae_multi.pt', cluster_labels=labels3, model_name='Model 3 (Multi)')