In [None]:
# ============================================================================
# TEST 3: CBF (Control Barrier Function) - Learning Safety
# ============================================================================

# Test CBF on sample states
test_states = torch.FloatTensor([
    [1.0, 1.0, 0.0, 0.0],   # Safe position
    [2.0, 3.0, 0.0, 0.0],   # Near obstacle center
    FINAL_GOAL_STATE
])


cbf = CBFNetwork(state_dim=4, hidden_dims=(64, 64), device='cpu')

print(f"\nCBF Network Architecture:")
print(f"  - Input dimension: 4 (x, y, vx, vy)")
print(f"  - Hidden layers: (64, 64)")
print(f"  - Output: h(s) ‚àà ‚Ñù (scalar safety value)")
print(f"  - Parameters: {sum(p.numel() for p in cbf.parameters()):,}")


print(f"\nCBF Values (before training):")
with torch.no_grad():
    h_values = cbf(test_states).squeeze()
    for i, h in enumerate(h_values):
        state = test_states[i]
        print(f"  State {i+1} {state[:2].numpy()}: h(s) = {h.item():.4f}")


# Train CBF on dataset
print(f"\nüèãÔ∏è  Training CBF on dataset...")

safe_states = torch.FloatTensor([t.state for t in transitions if t.is_safe])
unsafe_states = torch.FloatTensor([t.state for t in transitions if not t.is_safe])

print(f"  - Safe training states: {len(safe_states)}")
print(f"  - Unsafe training states: {len(unsafe_states)}")

optimizer = torch.optim.Adam(cbf.parameters(), lr=1e-3)

losses = []
for epoch in range(100):
    optimizer.zero_grad()
    
    # Safe states should have h(s) ‚â• 0
    h_safe = cbf(safe_states).squeeze()
    loss_safe = torch.mean(torch.clamp(-h_safe, min=0.0) ** 2)
    
    # Unsafe states should have h(s) < 0
    h_unsafe = cbf(unsafe_states).squeeze()
    loss_unsafe = torch.mean(torch.clamp(h_unsafe, min=0.0) ** 2)
    
    loss = loss_safe + loss_unsafe
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if epoch % 20 == 0:
        print(f"  Epoch {epoch}: Loss = {loss.item():.4f}")



# TEST

# Test again after training
print(f"\nüìä CBF Values (after training):")
with torch.no_grad():
    h_values = cbf(test_states).squeeze()
    for i, h in enumerate(h_values):
        state = test_states[i]
        safe_label = "‚úì SAFE" if h.item() >= 0 else "‚úó UNSAFE"
        print(f"  State {i+1} {state[:2].numpy()}: h(s) = {h.item():+.4f} {safe_label}")

# Visualize using clean visualization module
print(f"\nüìä Creating visualizations...")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

# Plot 1: Training loss
func_viz = FunctionVisualizer(env)
func_viz.plot_training_curves(losses, title='CBF Training Loss', ax=ax1, color='red')

# Plot 2: CBF heatmap
func_viz.plot_cbf_heatmap(cbf, ax=ax2)

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("KEY INSIGHTS:")
print("="*70)
print("‚Ä¢ CBF learns to classify: h(s) ‚â• 0 for safe, h(s) < 0 for unsafe")
print("‚Ä¢ Black boundary line shows h(s) = 0 (safety boundary)")
print("‚Ä¢ Green regions: robot can operate safely")
print("‚Ä¢ Red regions: collision risk (near obstacles)")
print("‚Ä¢ This provides a differentiable safety certificate!")
print("="*70)
print("\n‚úÖ CBF test complete!")

The following is a method of trying to add more samples explicitly in the safe path between obstacles as some sort of expert trajectory demonstration.

In [None]:
# ============================================================================
# TEST 3b: Fix CBF - Add Safe Path Sampling Using Global Config
# ============================================================================

from src.training.training_utils import generate_safe_path_samples, train_cbf_with_path_augmentation, verify_cbf_on_path
from src.utils.visualization import plot_cbf_comparison

print("\n" + "=" * 70)
print("TEST 3b: Fixing CBF with Path-Aware Sampling")
print("=" * 70)

print(f"\nüîç Problem: CBF trained on random data marks safe paths as unsafe")
print(f"   Solution: Generate safe samples along intended path")
print(f"   Path: {START_POS} ‚Üí {G3_POS} ‚Üí {G1_POS} ‚Üí {FINAL_GOAL_POS}")

# Generate safe path samples using utility function
waypoints = [G3_POS, G1_POS, FINAL_GOAL_POS]
path_samples = generate_safe_path_samples(env, START_POS, waypoints,
                                          samples_per_segment=30, offset_samples=5)

print(f"\n‚úÖ Generated {len(path_samples)} safe path samples")

# Retrain CBF with path augmentation
print(f"\nüèãÔ∏è  Re-training CBF with path-aware data...")
cbf_fixed = CBFNetwork(state_dim=4, hidden_dims=(64, 64), device='cpu')
optimizer_cbf_fixed = torch.optim.Adam(cbf_fixed.parameters(), lr=1e-3)

losses = train_cbf_with_path_augmentation(
    cbf_fixed, optimizer_cbf_fixed,
    safe_states, unsafe_states,
    path_samples=path_samples,
    num_epochs=150,
    verbose=True
)

# Verify CBF on path waypoints
print(f"\nüìä Verifying CBF on Path Waypoints:")
test_waypoints = [START_STATE, G3_STATE, G1_STATE, FINAL_GOAL_STATE]
waypoint_names = ['Start', 'G3', 'G1', 'Goal']
all_safe, h_values = verify_cbf_on_path(cbf_fixed, test_waypoints, waypoint_names)

# Visualize comparison
print(f"\nüìä Creating before/after visualization...")
path_for_viz = [START_POS, G3_POS, G1_POS, FINAL_GOAL_POS]
fig = plot_cbf_comparison(cbf, cbf_fixed, env, path_for_viz)
plt.show()

# Summary
print("\n" + "="*70)
if all_safe:
    print("‚úÖ CBF FIX SUCCESSFUL!")
    print("   - All waypoints marked safe")
    print("   - Path is accessible")
    print("   - Ready for joint training!")
else:
    print(f"‚ö†Ô∏è  Some waypoints still unsafe - may need more samples")
print("="*70)

# Update CBF reference
cbf = cbf_fixed
print("\n‚úÖ CBF updated for use in training.")