# OpenSERGE Inference

Load a trained checkpoint and perform road graph extraction on test images.

In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import cv2

# Add parent directory to path
sys.path.insert(0, '..')

from openserge.models.wrapper import OpenSERGE
from openserge.data.dataset import CityScale
from openserge.utils.training import load_checkpoint

# Set device
device = torch.device('cpu' if torch.mps.is_available() else 'cpu')
print(f"Using device: {device}")

## Configuration

In [None]:
# Checkpoint path
checkpoint_path = '/Users/gbahl/Downloads/fiery_wind.pt'

# Data configuration
data_root = '/Users/gbahl/Code/OpenSERGE/data/Sat2Graph/data/'
img_size = 512

# Inference parameters
junction_thresh = 0.4  # Threshold for junction detection
edge_thresh = 0.4      # Threshold for edge prediction

## Load Model and Checkpoint

In [None]:
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint.get('config', {})

print("Checkpoint info:")
print(f"  Epoch: {checkpoint['epoch']}")
print(f"  Backbone: {config.get('backbone', 'resnet50')}")
print(f"  k: {config.get('k', 'None')}")

if 'val_losses' in checkpoint and checkpoint['val_losses']:
    val_loss = checkpoint['val_losses']['total']
    print(f"  Validation Loss: {val_loss:.4f}")


normalize_mean = config.get('normalize_mean')
normalize_std = config.get('normalize_std')

In [None]:
# Create model
model = OpenSERGE(
    backbone=config.get('backbone', 'resnet50'),
    k=config.get('k'),
    use_fpn=True,
    use_pos_encoding=True,
).to(device)

# Load weights
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel loaded successfully!")
print(f"Parameters: {num_params:,}")

## Load Test Data

In [None]:
# Load test dataset
test_dataset = CityScale(data_root, split='train', img_size=img_size, aug=False, normalize_mean=normalize_mean, normalize_std=normalize_std)

print(f"Test samples: {len(test_dataset)}")
print(f"Image size: {img_size}x{img_size}")

In [None]:
# Select a test sample
sample_idx = 2  # Change this to try different samples

sample = test_dataset[sample_idx]
image = sample['image']  # [3, H, W]
junction_map_gt = sample['junction_map']  # [1, h, w]
offset_map_gt = sample['offset_map']  # [2, h, w]
edges_gt = sample['edges']  # List of edge tuples
meta = sample['meta']

print(f"Sample {sample_idx}:")
print(f"  Region: {meta['region_id']}")
print(f"  Image shape: {image.shape}")
print(f"  Ground truth junctions: {(junction_map_gt > 0.5).sum().item()}")
print(f"  Ground truth edges: {len(edges_gt)}")

## Run Inference

In [None]:
# Prepare input
image_batch = image.unsqueeze(0).to(device)  # [1, 3, H, W]

# Run model
with torch.no_grad():
    output = model(image_batch, j_thr=junction_thresh, e_thr=edge_thresh)
print( output['graphs'][0]['edges'].shape)

# Extract predictions
cnn_output = output['cnn']
graph_output = output['graphs'][0]  # First (and only) batch item
print(cnn_output.keys())
print(graph_output.keys())
# CNN outputs
junction_logits = cnn_output['junction_logits'][0, 0].cpu()  # [h, w]
junction_probs = torch.sigmoid(junction_logits)
offset_pred = cnn_output['offset'][0].cpu()  # [2, h, w]

# Graph outputs
nodes = graph_output['nodes'].cpu().numpy()  # [N, 2] in pixel coordinates
edges = graph_output['edges'].cpu().numpy()  # [E, 2] edge indices
edge_probs = graph_output['edge_probs'].cpu().numpy()  # [E] edge probabilities

print(f"\nPrediction results:")
print(f"  Detected junctions: {len(nodes)}")
print(f"  Junction probability range: [{junction_probs.min():.3f}, {junction_probs.max():.3f}]")
print(f"  Predicted edges (after threshold): {len(edges)}")
if len(edge_probs) > 0:
    print(f"  Edge probability range: [{edge_probs.min():.3f}, {edge_probs.max():.3f}]")

## Visualize Results

In [None]:
# Prepare image for visualization
img_vis = image.permute(1, 2, 0).cpu().numpy()  # [H, W, 3]
img_vis = img_vis * normalize_std + normalize_mean
img_vis = (img_vis * 255).astype(np.uint8)

# Prepare ground truth visualization
junction_gt = junction_map_gt[0].cpu().numpy()  # [h, w]
h, w = junction_gt.shape

# Resize junction maps to image size for visualization
junction_gt_vis = cv2.resize(junction_gt, (img_size, img_size), interpolation=cv2.INTER_NEAREST)
junction_pred_vis = cv2.resize(junction_probs.numpy(), (img_size, img_size), interpolation=cv2.INTER_LINEAR)

In [None]:
# Create figure
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Ground Truth
# Original image
axes[0, 0].imshow(img_vis)
axes[0, 0].set_title('Input Image', fontsize=14, fontweight='bold')
axes[0, 0].axis('off')

# Ground truth junction map
axes[0, 1].imshow(img_vis)
axes[0, 1].imshow(junction_gt_vis, alpha=0.5, cmap='hot', vmin=0, vmax=1)
axes[0, 1].set_title('Ground Truth Junctions', fontsize=14, fontweight='bold')
axes[0, 1].axis('off')

# Ground truth graph WITH OFFSETS APPLIED
axes[0, 2].imshow(img_vis)
# Extract node positions from ground truth
stride = img_size // h
junction_positions = np.argwhere(junction_gt > 0.5)  # [N, 2] in (i, j) grid coords
if len(junction_positions) > 0:
    # Get offsets for junction positions
    offset_gt = offset_map_gt.cpu().numpy()  # [2, h, w]
    i_coords = junction_positions[:, 0]
    j_coords = junction_positions[:, 1]
    
    # Extract offsets at junction locations
    y_offsets = offset_gt[0, i_coords, j_coords]  # y offsets
    x_offsets = offset_gt[1, i_coords, j_coords]  # x offsets
    
    # Convert to pixel coordinates with offsets applied
    # Base position at cell center + offset (scaled to pixels)
    x_positions = (j_coords + 0.5 + x_offsets) * stride
    y_positions = (i_coords + 0.5 + y_offsets) * stride
    nodes_gt = np.stack([x_positions, y_positions], axis=1)  # [N, 2]
    
    axes[0, 2].scatter(nodes_gt[:, 0], nodes_gt[:, 1], c='red', s=50, alpha=0.7, 
                       edgecolors='white', linewidths=1)
    
    # Draw edges using refined node positions
    # Create mapping from (i,j) to refined position
    coord_to_pos = {(i, j): (x, y) for (i, j), (x, y) in 
                    zip(junction_positions, nodes_gt)}
    
    for (i1, j1), (i2, j2) in edges_gt:
        if (i1, j1) in coord_to_pos and (i2, j2) in coord_to_pos:
            x1, y1 = coord_to_pos[(i1, j1)]
            x2, y2 = coord_to_pos[(i2, j2)]
            axes[0, 2].plot([x1, x2], [y1, y2], 'yellow', linewidth=2, alpha=0.6)

axes[0, 2].set_title(f'Ground Truth Graph (with offsets) ({len(junction_positions)} nodes, {len(edges_gt)} edges)', 
                     fontsize=14, fontweight='bold')
axes[0, 2].axis('off')

# Row 2: Predictions
# Junction probability heatmap
axes[1, 0].imshow(img_vis)
im = axes[1, 0].imshow(junction_pred_vis, alpha=0.6, cmap='hot', vmin=0, vmax=1)
axes[1, 0].set_title('Predicted Junction Probability', fontsize=14, fontweight='bold')
axes[1, 0].axis('off')
plt.colorbar(im, ax=axes[1, 0], fraction=0.046, pad=0.04)

# Detected junctions (thresholded)
axes[1, 1].imshow(img_vis)
if len(nodes) > 0:
    axes[1, 1].scatter(nodes[:, 0], nodes[:, 1], c='lime', s=50, alpha=0.8, 
                       edgecolors='white', linewidths=1, marker='o')
axes[1, 1].set_title(f'Detected Junctions (threshold={junction_thresh})', 
                     fontsize=14, fontweight='bold')
axes[1, 1].axis('off')

# Predicted graph
axes[1, 2].imshow(img_vis)
if len(nodes) > 0:
    # Draw edges
    for edge_idx, (src, dst) in enumerate(edges):
        x1, y1 = nodes[src]
        x2, y2 = nodes[dst]
        # Color by edge probability
        prob = edge_probs[edge_idx]
        color = plt.cm.viridis(prob)
        axes[1, 2].plot([x1, x2], [y1, y2], color=color, linewidth=2, alpha=0.7)
    
    # Draw nodes on top
    axes[1, 2].scatter(nodes[:, 0], nodes[:, 1], c='red', s=50, alpha=0.9,
                       edgecolors='white', linewidths=1.5, marker='o', zorder=10)

axes[1, 2].set_title(f'Predicted Graph ({len(nodes)} nodes, {len(edges)} edges)', 
                     fontsize=14, fontweight='bold')
axes[1, 2].axis('off')

plt.tight_layout()
plt.show()

## Offset Visualization

In [None]:
# Visualize offset predictions as quiver plot
fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# Ground truth offsets
offset_gt = offset_map_gt.cpu().numpy()  # [2, h, w]
mask_gt = (junction_gt > 0.5)

# Downsample for visualization
step = 2
y_grid, x_grid = np.meshgrid(np.arange(0, h, step), np.arange(0, w, step), indexing='ij')
y_grid = y_grid.astype(np.float64)
x_grid = x_grid.astype(np.float64)
y_grid += 0.5
x_grid += 0.5

# Ground truth
axes[0].imshow(img_vis)
offset_x_gt = offset_gt[1, ::step, ::step]  # x offsets
offset_y_gt = offset_gt[0, ::step, ::step]  # y offsets
mask_gt_down = mask_gt[::step, ::step]

# Scale offsets to image coordinates
scale = img_size / h
axes[0].quiver(x_grid * scale, y_grid * scale, 
               offset_x_gt * scale, offset_y_gt * scale,
               mask_gt_down.astype(float), 
               cmap='autumn', scale=1, scale_units='xy', angles='xy',
               width=0.003, alpha=0.7)
axes[0].set_title('Ground Truth Offsets', fontsize=14, fontweight='bold')
axes[0].axis('off')

# Predictions
offset_pred_np = offset_pred.numpy()  # [2, h, w]
mask_pred = (junction_probs.numpy() > 0.0)

axes[1].imshow(img_vis)
offset_x_pred = offset_pred_np[1, ::step, ::step]
offset_y_pred = offset_pred_np[0, ::step, ::step]
mask_pred_down = mask_pred[::step, ::step]

axes[1].quiver(x_grid * scale, y_grid * scale,
               offset_x_pred * scale, offset_y_pred * scale,
               mask_pred_down.astype(float),
               cmap='autumn', scale=1, scale_units='xy', angles='xy',
               width=0.003, alpha=0.7)
axes[1].set_title('Predicted Offsets', fontsize=14, fontweight='bold')
axes[1].axis('off')

plt.tight_layout()
plt.show()

## Comparison Metrics

In [None]:
# Simple comparison metrics
print("\n" + "="*50)
print("COMPARISON METRICS")
print("="*50)

# Junction counts
num_junctions_gt = (junction_gt > 0.5).sum()
num_junctions_pred = len(nodes)
print(f"\nJunctions:")
print(f"  Ground Truth: {num_junctions_gt}")
print(f"  Predicted:    {num_junctions_pred}")
print(f"  Difference:   {num_junctions_pred - num_junctions_gt:+d}")

# Edge counts
num_edges_gt = len(edges_gt)
num_edges_pred = len(edges)
print(f"\nEdges:")
print(f"  Ground Truth: {num_edges_gt}")
print(f"  Predicted:    {num_edges_pred}")
print(f"  Difference:   {num_edges_pred - num_edges_gt:+d}")

# Pixel-level junction detection metrics
junction_pred_binary = (junction_probs > junction_thresh).float()
junction_gt_tensor = torch.from_numpy(junction_gt).float()

tp = (junction_pred_binary * junction_gt_tensor).sum().item()
fp = (junction_pred_binary * (1 - junction_gt_tensor)).sum().item()
fn = ((1 - junction_pred_binary) * junction_gt_tensor).sum().item()
tn = ((1 - junction_pred_binary) * (1 - junction_gt_tensor)).sum().item()

precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

print(f"\nPixel-level Junction Metrics:")
print(f"  Precision: {precision:.3f}")
print(f"  Recall:    {recall:.3f}")
print(f"  F1 Score:  {f1:.3f}")

print("\n" + "="*50)

## Try Different Samples

Run the cells below to quickly test on different samples.

In [None]:
# Quick inference function
def infer_and_visualize(sample_idx, j_thresh=0.5):
    """Run inference and visualize results for a given sample."""

    # Load sample
    sample = test_dataset[sample_idx]
    image = sample['image']
    junction_gt = sample['junction_map'][0].cpu().numpy()
    offset_map_gt = sample['offset_map']  # [2, h, w]
    edges_gt = sample['edges']
    meta = sample['meta']

    # Run inference
    with torch.no_grad():
        output = model(image.unsqueeze(0).to(device), j_thr=j_thresh)

    # Extract results
    nodes = output['graphs'][0]['nodes'].cpu().numpy()
    edges = output['graphs'][0]['edges'].cpu().numpy()
    edge_probs = output['graphs'][0]['edge_probs'].cpu().numpy()
    junction_probs = torch.sigmoid(output['cnn']['junction_logits'][0, 0]).cpu().numpy()

    # Visualize
    img_vis = (image.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Original
    axes[0].imshow(img_vis)
    axes[0].set_title(f'Sample {sample_idx}: {meta["region_id"]}', fontsize=12, fontweight='bold')
    axes[0].axis('off')

    # Ground truth with offsets applied
    axes[1].imshow(img_vis)
    h, w = junction_gt.shape
    stride = img_size // h
    junction_positions = np.argwhere(junction_gt > 0.5)
    if len(junction_positions) > 0:
        # Get offsets for junction positions
        offset_gt = offset_map_gt.cpu().numpy()  # [2, h, w]
        i_coords = junction_positions[:, 0]
        j_coords = junction_positions[:, 1]

        # Extract offsets at junction locations
        y_offsets = offset_gt[0, i_coords, j_coords]  # y offsets
        x_offsets = offset_gt[1, i_coords, j_coords]  # x offsets

        # Convert to pixel coordinates with offsets applied
        x_positions = (j_coords + 0.5 + x_offsets) * stride
        y_positions = (i_coords + 0.5 + y_offsets) * stride
        nodes_gt = np.stack([x_positions, y_positions], axis=1)  # [N, 2]

        axes[1].scatter(nodes_gt[:, 0], nodes_gt[:, 1], c='red', s=40, alpha=0.7)

        # Draw edges using refined node positions
        coord_to_pos = {(i, j): (x, y) for (i, j), (x, y) in
                        zip(junction_positions, nodes_gt)}

        for (i1, j1), (i2, j2) in edges_gt:
            if (i1, j1) in coord_to_pos and (i2, j2) in coord_to_pos:
                x1, y1 = coord_to_pos[(i1, j1)]
                x2, y2 = coord_to_pos[(i2, j2)]
                axes[1].plot([x1, x2], [y1, y2], 'yellow', linewidth=2, alpha=0.5)
    axes[1].set_title(f'GT (with offsets): {len(junction_positions)} nodes, {len(edges_gt)} edges',
                      fontsize=12, fontweight='bold')
    axes[1].axis('off')

    # Prediction
    axes[2].imshow(img_vis)
    if len(nodes) > 0:
        for src, dst in edges:
            axes[2].plot([nodes[src, 0], nodes[dst, 0]],
                        [nodes[src, 1], nodes[dst, 1]],
                        'cyan', linewidth=2, alpha=0.6)
        axes[2].scatter(nodes[:, 0], nodes[:, 1], c='lime', s=40, alpha=0.8,
                       edgecolors='white', linewidths=1)
    axes[2].set_title(f'Pred: {len(nodes)} nodes, {len(edges)} edges',
                      fontsize=12, fontweight='bold')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    # Print stats
    print(f"Sample {sample_idx} - Junction threshold: {j_thresh}")
    print(f"  GT: {len(junction_positions)} junctions, {len(edges_gt)} edges")
    print(f"  Pred: {len(nodes)} junctions, {len(edges)} edges")
    print(f"  Junction prob range: [{junction_probs.min():.3f}, {junction_probs.max():.3f}]")

In [None]:
# Try multiple samples
for idx in range(min(5, len(test_dataset))):
    infer_and_visualize(idx, j_thresh=junction_thresh)

## Threshold Sensitivity Analysis

In [None]:
# Test different junction thresholds on the same sample
sample_idx = 0
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]

sample = test_dataset[sample_idx]
image_batch = sample['image'].unsqueeze(0).to(device)

fig, axes = plt.subplots(1, len(thresholds), figsize=(4*len(thresholds), 4))

for i, thresh in enumerate(thresholds):
    with torch.no_grad():
        output = model(image_batch, j_thr=thresh)
    
    nodes = output['graphs'][0]['nodes'].cpu().numpy()
    edges = output['graphs'][0]['edges'].cpu().numpy()
    
    img_vis = (sample['image'].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    axes[i].imshow(img_vis)
    
    if len(nodes) > 0:
        for src, dst in edges:
            axes[i].plot([nodes[src, 0], nodes[dst, 0]], 
                        [nodes[src, 1], nodes[dst, 1]], 
                        'cyan', linewidth=1.5, alpha=0.6)
        axes[i].scatter(nodes[:, 0], nodes[:, 1], c='red', s=30, alpha=0.8)
    
    axes[i].set_title(f'Thresh={thresh}\n{len(nodes)} nodes, {len(edges)} edges', 
                      fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Export Predictions

In [None]:
# Save predictions to file
import json

def export_prediction(sample_idx, output_path):
    """Export prediction to JSON format."""
    
    sample = test_dataset[sample_idx]
    
    with torch.no_grad():
        output = model(sample['image'].unsqueeze(0).to(device), j_thr=junction_thresh)
    
    nodes = output['graphs'][0]['nodes'].cpu().numpy()
    edges = output['graphs'][0]['edges'].cpu().numpy()
    edge_probs = output['graphs'][0]['edge_probs'].cpu().numpy()
    
    result = {
        'sample_idx': sample_idx,
        'region_id': sample['meta']['region_id'],
        'nodes': nodes.tolist(),
        'edges': edges.tolist(),
        'edge_probabilities': edge_probs.tolist(),
        'num_nodes': len(nodes),
        'num_edges': len(edges),
        'junction_threshold': junction_thresh
    }
    
    with open(output_path, 'w') as f:
        json.dump(result, f, indent=2)
    
    print(f"Prediction saved to {output_path}")

# Example usage
# export_prediction(0, 'prediction_sample0.json')