# KeyBot: Accurate and Efficient Vertebrae Keypoint Estimation
## ECCV 2024 - Demo Notebook

**Paper:** *Bones Can't Be Triangles: Accurate and Efficient Vertebrae Keypoint Estimation through Collaborative Error Revision*

---

This notebook demonstrates the KeyBot system for vertebrae keypoint estimation on spine X-ray images.

### System Overview:
- **Refiner**: Interactive keypoint estimation model (RITM_SE_HRNet32)
- **Detector**: Identifies potential errors in predictions
- **Corrector**: Automatically revises detected errors

### Dataset:
- AASCE (Anterior-Posterior Spinal X-rays)
- 68 keypoints per image (4 corners per vertebra Ã— 17 vertebrae)
- Image size: 512Ã—256 pixels


## 1. Setup and Imports


In [None]:
# Standard imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
import json
import os
import sys
import warnings
from tqdm.auto import tqdm

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Configure matplotlib for better visualizations
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nâœ“ Using device: {device}")


In [None]:
# Setup paths and imports from codes directory
original_dir = os.getcwd()
codes_dir = os.path.join(original_dir, 'codes')

# Add codes directory to path and change to it (required for imports)
sys.path.insert(0, codes_dir)
os.chdir(codes_dir)

# Import KeyBot modules
from AnomalySuggestion_get_model import get_keypoint_model, get_test_data_loader, test
from suggest_codes.get_suggest_model import SuggestionConvModel
from suggest_codes.get_suggest_dataset import SuggestionDataset
from suggest_codes.get_pseudo_generation_model_image_heatmap import PseudoLabelModel, get_func_pseudo_label
from suggest_codes.get_pseudo_generation_dataset_image_negative_sample import RefineDataset
from util import DEVICE

print("âœ“ All modules imported successfully")


## 2. Load Pre-trained Models

We load three models that work collaboratively:
1. **Refiner**: Base keypoint estimation model
2. **Detector**: Identifies error-prone keypoints
3. **Corrector**: Refines detected errors


In [None]:
print("Loading models...\n")

# 1. Load Refiner (Base Interactive Keypoint Model)
print("[1/3] Loading Refiner (RITM_SE_HRNet32)...")
trainer, save_manager = get_keypoint_model(data='spineweb')
refiner = trainer.model
refiner.eval()
print("  âœ“ Refiner loaded")

# 2. Load Detector (Suggestion Model)
print("\n[2/3] Loading Detector (Error Detection)...")
detector = SuggestionConvModel()
detector_path = os.path.join(original_dir, 'save_suggestion', 'AASCE_suggestModel.pth')
if os.path.exists(detector_path):
    detector.load_state_dict(torch.load(detector_path, map_location=DEVICE))
    detector.eval()
    detector.to(DEVICE)
    print("  âœ“ Detector loaded")
else:
    print(f"  âš  Warning: Detector not found at {detector_path}")

# 3. Load Corrector (Pseudo Label Model)
print("\n[3/3] Loading Corrector (Error Refinement)...")
corrector = PseudoLabelModel(n_keypoint=68, num_bones=17)
corrector_path = os.path.join(original_dir, 'save_refine', 'AASCE_refineModel.pth')
if os.path.exists(corrector_path):
    corrector.load_state_dict(torch.load(corrector_path, map_location=DEVICE))
    corrector.eval()
    corrector.to(DEVICE)
    print("  âœ“ Corrector loaded")
else:
    print(f"  âš  Warning: Corrector not found at {corrector_path}")

get_pseudo_label = get_func_pseudo_label()

print("\n" + "="*60)
print("âœ“ All models loaded successfully!")
print("="*60)


## 3. Load Test Data

Load the AASCE test dataset with preprocessed spine X-ray images.


In [None]:
print("Loading test data...")
keypoint_train_loader, keypoint_val_loader, keypoint_test_loader = get_test_data_loader(
    off_train_aug=True, 
    data='spineweb'
)

print(f"âœ“ Test dataset loaded")
print(f"  - Test samples: {len(keypoint_test_loader.dataset)}")
print(f"  - Batch size: {keypoint_test_loader.batch_size}")
print(f"  - Image size: 512Ã—256 pixels")
print(f"  - Keypoints per image: 68 (17 vertebrae Ã— 4 corners)")


## 4. Single Image Demo: Inference Pipeline

Let's run inference on a single test image to demonstrate the KeyBot pipeline.


In [None]:
# Get a single test image
batch = next(iter(keypoint_test_loader))

image = batch['input_image'][0]  # Shape: (3, 512, 256)
gt_coords = batch['label']['coord'][0]  # Shape: (68, 2)
gt_heatmap = batch['label']['heatmap'][0]  # Shape: (68, 512, 256)
image_path = batch['input_image_path'][0]

print(f"Selected image: {image_path}")
print(f"\nData shapes:")
print(f"  - Input image: {tuple(image.shape)}")
print(f"  - GT coordinates: {tuple(gt_coords.shape)}")
print(f"  - GT heatmap: {tuple(gt_heatmap.shape)}")


In [None]:
# Visualize the input image
def denormalize_image(img_tensor):
    """Convert normalized tensor to displayable image"""
    img = img_tensor.permute(1, 2, 0).cpu().numpy()
    img = (img + 1) / 2  # [-1, 1] -> [0, 1]
    img = np.clip(img, 0, 1)
    return img.mean(axis=2)  # Convert to grayscale

display_image = denormalize_image(image)

fig, ax = plt.subplots(1, 1, figsize=(6, 10))
ax.imshow(display_image, cmap='gray')
ax.set_title('Input X-ray Image', fontsize=14, fontweight='bold')
ax.axis('off')
plt.tight_layout()
plt.show()

print(f"Image dimensions: {display_image.shape[0]}Ã—{display_image.shape[1]} pixels")


In [None]:
# Run inference with the Refiner (base model)
print("Running inference with Refiner...")

with torch.no_grad():
    # Prepare batch
    batch['is_training'] = False
    batch['hint'] = {'index': [None]}  # No manual hints
    
    # Forward pass
    out, batch = refiner(batch)
    
    # Extract predictions
    pred_coords = out.pred.sargmax_coord[0].cpu()  # (68, 2)
    pred_heatmap = out.pred.heatmap[0].cpu()  # (68, 512, 256)

# Compute error metrics
def compute_mre(pred, gt):
    """Mean Radial Error in pixels"""
    return torch.sqrt(((pred - gt.cpu())**2).sum(-1)).mean().item()

def compute_sdr(pred, gt, threshold=2.0):
    """Success Detection Rate (% within threshold)"""
    errors = torch.sqrt(((pred - gt.cpu())**2).sum(-1))
    return (errors <= threshold).float().mean().item() * 100

mre = compute_mre(pred_coords, gt_coords)
sdr_2mm = compute_sdr(pred_coords, gt_coords, threshold=2.0)
sdr_4mm = compute_sdr(pred_coords, gt_coords, threshold=4.0)

print(f"\nâœ“ Inference complete")
print(f"\nðŸ“Š Metrics:")
print(f"  - Mean Radial Error (MRE): {mre:.2f} pixels")
print(f"  - SDR@2mm: {sdr_2mm:.1f}%")
print(f"  - SDR@4mm: {sdr_4mm:.1f}%")


## 5. Visualization: Predictions vs Ground Truth


In [None]:
# Create comprehensive visualization
gt_coords_np = gt_coords.cpu().numpy()
pred_coords_np = pred_coords.numpy()
errors = torch.sqrt(((pred_coords - gt_coords.cpu())**2).sum(-1))
errors_np = errors.numpy()

fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. Ground Truth
ax = axes[0, 0]
ax.imshow(display_image, cmap='gray')
scatter = ax.scatter(gt_coords_np[:, 1], gt_coords_np[:, 0], 
                     c='lime', s=40, alpha=0.8, edgecolors='darkgreen', linewidths=1)
ax.set_title('Ground Truth Keypoints', fontsize=13, fontweight='bold')
ax.axis('off')

# 2. Predictions
ax = axes[0, 1]
ax.imshow(display_image, cmap='gray')
ax.scatter(pred_coords_np[:, 1], pred_coords_np[:, 0], 
           c='red', s=40, alpha=0.8, edgecolors='darkred', linewidths=1, marker='x')
ax.set_title('Model Predictions', fontsize=13, fontweight='bold')
ax.axis('off')

# 3. Overlay with Error Lines
ax = axes[0, 2]
ax.imshow(display_image, cmap='gray')
ax.scatter(gt_coords_np[:, 1], gt_coords_np[:, 0], 
           c='lime', s=50, alpha=0.7, label='Ground Truth', marker='o')
ax.scatter(pred_coords_np[:, 1], pred_coords_np[:, 0], 
           c='red', s=40, alpha=0.7, label='Prediction', marker='x')
# Draw error lines
for i in range(len(gt_coords_np)):
    ax.plot([gt_coords_np[i, 1], pred_coords_np[i, 1]], 
            [gt_coords_np[i, 0], pred_coords_np[i, 0]], 
            'yellow', linewidth=0.8, alpha=0.5)
ax.set_title(f'Comparison (MRE: {mre:.2f}px)', fontsize=13, fontweight='bold')
ax.legend(loc='upper right')
ax.axis('off')

# 4. Ground Truth Heatmap
ax = axes[1, 0]
gt_heatmap_sum = gt_heatmap.cpu().sum(dim=0)
ax.imshow(display_image, cmap='gray', alpha=0.6)
im = ax.imshow(gt_heatmap_sum, cmap='hot', alpha=0.5)
ax.set_title('Ground Truth Heatmap', fontsize=13, fontweight='bold')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)

# 5. Predicted Heatmap
ax = axes[1, 1]
pred_heatmap_sum = pred_heatmap.sum(dim=0)
ax.imshow(display_image, cmap='gray', alpha=0.6)
im = ax.imshow(pred_heatmap_sum, cmap='hot', alpha=0.5)
ax.set_title('Predicted Heatmap', fontsize=13, fontweight='bold')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)

# 6. Per-Keypoint Error Visualization
ax = axes[1, 2]
ax.imshow(display_image, cmap='gray', alpha=0.7)
scatter = ax.scatter(pred_coords_np[:, 1], pred_coords_np[:, 0], 
                     c=errors_np, s=100, alpha=0.8, cmap='RdYlGn_r', 
                     edgecolors='black', linewidths=1, vmin=0, vmax=errors_np.max())
ax.set_title(f'Per-Keypoint Error (Max: {errors_np.max():.2f}px)', fontsize=13, fontweight='bold')
ax.axis('off')
cbar = plt.colorbar(scatter, ax=ax, fraction=0.046)
cbar.set_label('Error (pixels)', rotation=270, labelpad=15)

plt.tight_layout()
plt.show()

print("\nVisualization shows:")
print("  Top row: Keypoint locations (GT, Pred, Overlay)")
print("  Bottom row: Heatmaps and per-keypoint errors")
print("  Red/Hot colors in error map indicate higher localization errors")


## 6. Interactive Exploration: Browse Test Samples

Browse through different test samples to see predictions on various spine X-rays.


In [None]:
def visualize_sample(sample_idx, test_loader):
    """
    Visualize a specific test sample with predictions
    
    Args:
        sample_idx: Index of the sample to visualize
        test_loader: DataLoader for test data
    """
    # Get the sample
    dataset = test_loader.dataset
    if sample_idx >= len(dataset):
        print(f"Error: sample_idx {sample_idx} out of range (max: {len(dataset)-1})")
        return
    
    sample = dataset[sample_idx]
    
    # Prepare batch
    batch = {
        'input_image': sample['input_image'].unsqueeze(0),
        'label': {
            'coord': sample['label']['coord'].unsqueeze(0),
            'heatmap': sample['label']['heatmap'].unsqueeze(0),
        },
        'input_image_path': [sample['input_image_path']],
        'is_training': False,
        'hint': {'index': [None]}
    }
    
    # Run inference
    with torch.no_grad():
        out, batch = refiner(batch)
        pred_coords = out.pred.sargmax_coord[0].cpu()
    
    # Extract data
    image = sample['input_image']
    gt_coords = sample['label']['coord'].cpu()
    
    # Compute metrics
    mre = compute_mre(pred_coords, gt_coords)
    sdr_2mm = compute_sdr(pred_coords, gt_coords, 2.0)
    sdr_4mm = compute_sdr(pred_coords, gt_coords, 4.0)
    
    # Visualize
    display_img = denormalize_image(image)
    gt_np = gt_coords.numpy()
    pred_np = pred_coords.numpy()
    errors = torch.sqrt(((pred_coords - gt_coords)**2).sum(-1)).numpy()
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Ground Truth
    axes[0].imshow(display_img, cmap='gray')
    axes[0].scatter(gt_np[:, 1], gt_np[:, 0], c='lime', s=40, alpha=0.8)
    axes[0].set_title('Ground Truth', fontsize=13, fontweight='bold')
    axes[0].axis('off')
    
    # Predictions
    axes[1].imshow(display_img, cmap='gray')
    axes[1].scatter(pred_np[:, 1], pred_np[:, 0], c='red', s=40, alpha=0.8, marker='x')
    axes[1].set_title('Predictions', fontsize=13, fontweight='bold')
    axes[1].axis('off')
    
    # Error Map
    axes[2].imshow(display_img, cmap='gray', alpha=0.7)
    scatter = axes[2].scatter(pred_np[:, 1], pred_np[:, 0], c=errors, s=100, 
                              cmap='RdYlGn_r', alpha=0.8, edgecolors='black', linewidths=1)
    axes[2].set_title(f'Error Map', fontsize=13, fontweight='bold')
    axes[2].axis('off')
    plt.colorbar(scatter, ax=axes[2], fraction=0.046, label='Error (pixels)')
    
    plt.suptitle(f"Sample {sample_idx} | MRE: {mre:.2f}px | SDR@2mm: {sdr_2mm:.1f}% | SDR@4mm: {sdr_4mm:.1f}%",
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()
    
    print(f"Image: {sample['input_image_path']}")
    print(f"MRE: {mre:.2f} pixels | SDR@2mm: {sdr_2mm:.1f}% | SDR@4mm: {sdr_4mm:.1f}%")

# Show the first sample
print(f"Total test samples: {len(keypoint_test_loader.dataset)}\n")
visualize_sample(0, keypoint_test_loader)


In [None]:
# Interactive widget to browse samples
# Change the sample_idx value to view different samples (0 to N-1)

sample_idx = 5  # Change this value!

visualize_sample(sample_idx, keypoint_test_loader)


In [None]:
# Get a sample for vertebrae analysis
batch = next(iter(keypoint_test_loader))

with torch.no_grad():
    batch['is_training'] = False
    batch['hint'] = {'index': [None]}
    out, batch = refiner(batch)
    pred_coords = out.pred.sargmax_coord[0].cpu()

image = batch['input_image'][0]
gt_coords = batch['label']['coord'][0].cpu()
display_img = denormalize_image(image)

# Visualize vertebrae structure
fig, axes = plt.subplots(1, 2, figsize=(14, 10))

# Define colors for each vertebra
colors = plt.cm.rainbow(np.linspace(0, 1, 17))

# Ground Truth Vertebrae
ax = axes[0]
ax.imshow(display_img, cmap='gray')
for vertebra_idx in range(17):
    # Each vertebra has 4 keypoints (corners)
    start_idx = vertebra_idx * 4
    end_idx = start_idx + 4
    vertebra_coords = gt_coords[start_idx:end_idx].numpy()
    
    # Plot corners
    ax.scatter(vertebra_coords[:, 1], vertebra_coords[:, 0], 
               c=[colors[vertebra_idx]], s=50, alpha=0.8, edgecolors='black', linewidths=1)
    
    # Draw bounding box for vertebra
    min_y, max_y = vertebra_coords[:, 0].min(), vertebra_coords[:, 0].max()
    min_x, max_x = vertebra_coords[:, 1].min(), vertebra_coords[:, 1].max()
    
    rect = Rectangle((min_x, min_y), max_x - min_x, max_y - min_y,
                      fill=False, edgecolor=colors[vertebra_idx], linewidth=2, alpha=0.7)
    ax.add_patch(rect)
    
    # Label vertebra
    center_y = (min_y + max_y) / 2
    center_x = (min_x + max_x) / 2
    ax.text(center_x, center_y, f'V{vertebra_idx+1}', 
            fontsize=8, color='white', ha='center', va='center',
            bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[vertebra_idx], alpha=0.7))

ax.set_title('Ground Truth Vertebrae (17 vertebrae, 4 keypoints each)', 
             fontsize=12, fontweight='bold')
ax.axis('off')

# Predicted Vertebrae
ax = axes[1]
ax.imshow(display_img, cmap='gray')
for vertebra_idx in range(17):
    start_idx = vertebra_idx * 4
    end_idx = start_idx + 4
    vertebra_coords = pred_coords[start_idx:end_idx].numpy()
    
    ax.scatter(vertebra_coords[:, 1], vertebra_coords[:, 0], 
               c=[colors[vertebra_idx]], s=50, alpha=0.8, edgecolors='black', linewidths=1)
    
    min_y, max_y = vertebra_coords[:, 0].min(), vertebra_coords[:, 0].max()
    min_x, max_x = vertebra_coords[:, 1].min(), vertebra_coords[:, 1].max()
    
    rect = Rectangle((min_x, min_y), max_x - min_x, max_y - min_y,
                      fill=False, edgecolor=colors[vertebra_idx], linewidth=2, alpha=0.7)
    ax.add_patch(rect)
    
    center_y = (min_y + max_y) / 2
    center_x = (min_x + max_x) / 2
    ax.text(center_x, center_y, f'V{vertebra_idx+1}', 
            fontsize=8, color='white', ha='center', va='center',
            bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[vertebra_idx], alpha=0.7))

ax.set_title('Predicted Vertebrae', fontsize=12, fontweight='bold')
ax.axis('off')

plt.tight_layout()
plt.show()

print("Each vertebra is labeled V1-V17 (from bottom to top of spine)")
print("Each vertebra has 4 corner keypoints forming a rectangular region")


In [None]:
# Change back to original directory
os.chdir(original_dir)

# Create output directory
output_dir = os.path.join(original_dir, 'demo_outputs')
os.makedirs(output_dir, exist_ok=True)

print(f"Output directory created: {output_dir}")
print("\nYou can save visualizations using plt.savefig() in the cells above.")
print("\nExample:")
print("  plt.savefig(os.path.join(output_dir, 'visualization.png'), dpi=150, bbox_inches='tight')")

print("\n" + "="*60)
print("âœ“ Demo notebook complete!")
print("="*60)
print("\nThank you for exploring KeyBot!")
print("For more information, visit: https://ts-kim.github.io/KeyBot/")
