# GRPO Snapshot and Branching Tutorial - WORKING VERSION

This notebook demonstrates the advanced GRPO features: snapshot/restore and branching functionality.

## What You'll Learn:
- **Snapshot System**: Save environment states for multi-turn rollouts
- **Branching Mode**: Run multiple CARLA instances in parallel
- **Trajectory Collection**: Collect multiple trajectories from the same state
- **Branch Selection**: Select best trajectory and continue exploration

### Prerequisites:
- Complete the "02_grpo_carla_interface.ipynb" notebook first
- CARLA server running at `http://localhost:8080`

### Key Concepts:
- **Snapshot**: Save complete environment state at decision points
- **Branching**: Create parallel environments from same snapshot
- **Multi-turn Rollouts**: Explore different actions from same state
- **Best Trajectory Selection**: Choose optimal path based on rewards

In [None]:
# Start CARLA services
import logging, os, signal, subprocess, time, shutil
from typing import Set, List

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def _pids_listening_on_port(port: int, only_mine: bool = True) -> Set[int]:
    """Return PIDs listening on port."""
    pids = set()
    
    if shutil.which("lsof"):
        cmd = ["lsof", "-nP", "-tiTCP:%d" % port, "-sTCP:LISTEN"]
        if only_mine:
            cmd.insert(1, "-u")
            cmd.insert(2, os.getlogin())
        res = subprocess.run(cmd, capture_output=True, text=True)
        if res.returncode == 0 and res.stdout.strip():
            pids.update(int(x) for x in res.stdout.strip().splitlines() if x.strip().isdigit())
        return pids
    
    return set()

def _graceful_kill(pids: Set[int], grace: float = 2.5):
    """Send SIGTERM, wait, then SIGKILL."""
    if not pids:
        return
    me = os.getpid()
    pids = {pid for pid in pids if pid != me}
    if not pids:
        return

    for pid in pids:
        try:
            os.kill(pid, signal.SIGTERM)
            logger.info(f"SIGTERM sent to PID {pid}")
        except:
            pass

    deadline = time.time() + grace
    remaining = set(pids)
    while time.time() < deadline and remaining:
        for pid in list(remaining):
            try:
                os.kill(pid, 0)
            except:
                remaining.discard(pid)
        time.sleep(0.2)

    for pid in list(remaining):
        try:
            os.kill(pid, signal.SIGKILL)
            logger.info(f"SIGKILL sent to PID {pid}")
        except:
            pass

def clean_ports(ports: List[int], only_mine: bool = True):
    """Kill processes listening on the given ports."""
    all_pids = set()
    for port in ports:
        try:
            pids = _pids_listening_on_port(port, only_mine=only_mine)
            if pids:
                logger.info(f"Port {port} in use by PIDs {sorted(pids)}")
                all_pids.update(pids)
        except:
            pass

    if all_pids:
        _graceful_kill(all_pids)
    else:
        logger.info("No listeners found on target ports.")

def start_servers(num_services: int = 2, root_dir: str = "/mnt3/Documents/AD_Framework/bench2drive-gymnasium/bench2drive_microservices"):
    logger.info(f"Starting {num_services} CARLA services...")

    # Clean named processes
    kill_names = ["carla_server.py", "microservice_manager.py", "CarlaUE4", "server_manager.py"]
    for name in kill_names:
        subprocess.run(["pkill", "-f", name], capture_output=True)
        time.sleep(0.2)
        subprocess.run(["pkill", "-9", "-f", name], capture_output=True)

    # Clean ports
    api_ports = list(range(8080, 8084))
    carla_ports = list(range(2000, 2013))
    tm_ports = list(range(3000, 3013))
    clean_ports(api_ports + carla_ports + tm_ports, only_mine=True)

    time.sleep(1.0)

    # Start microservice manager
    cmd = [
        "python",
        os.path.join(root_dir, "server", "microservice_manager.py"),
        "--num-services", str(num_services),
        "--startup-delay", "30",
    ]
    logger.info(f"Launching: {' '.join(cmd)}")
    subprocess.Popen(cmd)

print("🚀 Starting CARLA services...")
start_servers(2)
print("✅ Services starting... Wait 60 seconds for full initialization")

INFO:__main__:Starting 2 CARLA services...


🚀 Starting CARLA services...


INFO:__main__:No listeners found on target ports.
INFO:__main__:Launching: python /mnt3/Documents/AD_Framework/bench2drive-gymnasium/bench2drive_microservices/server/microservice_manager.py --num-services 2 --startup-delay 30


✅ Services starting... Wait 60 seconds for full initialization


2025-09-30 13:45:28,550 [MicroserviceManager] INFO: Starting 2 services in parallel...
2025-09-30 13:45:28,550 [resource_manager] INFO: Allocated resources for service 0:
2025-09-30 13:45:28,550 [resource_manager] INFO:   GPU: 0
2025-09-30 13:45:28,550 [resource_manager] INFO:   API Port: 8080
2025-09-30 13:45:28,550 [resource_manager] INFO:   CARLA Port: 2000
2025-09-30 13:45:28,550 [resource_manager] INFO:   Streaming Port: 3000
2025-09-30 13:45:28,550 [resource_manager] INFO:   TM Port: 3000
2025-09-30 13:45:28,550 [MicroserviceManager] INFO: [Resource Manager] Allocated resources for service 0
2025-09-30 13:45:28,550 [Service-0] INFO: Starting service 0
2025-09-30 13:45:28,550 [Service-0] INFO: Cleaning up ports for service 0
2025-09-30 13:45:29,552 [Service-0] INFO: Starting carla_server.py:
2025-09-30 13:45:29,552 [Service-0] INFO:   API port: 8080
2025-09-30 13:45:29,552 [Service-0] INFO:   CARLA port: 2000
2025-09-30 13:45:29,552 [Service-0] INFO:   GPU: 0
2025-09-30 13:45:29,5

## Service Setup

⚠️ **IMPORTANT**: For GRPO branching functionality, you need multiple CARLA services running!

The notebook will automatically start 2 CARLA microservices for branching.

**Note**: Branching requires at least 2 services to work properly!

In [2]:
# Wait for services to be ready
import requests
import time

print("⏳ Waiting for services to be ready...")

def wait_for_services(base_port=8080, num_services=2, timeout=120):
    """Wait for all services to be ready."""
    start_time = time.time()
    
    while time.time() - start_time < timeout:
        all_ready = True
        status = []
        
        for i in range(num_services):
            port = base_port + i
            url = f"http://localhost:{port}/health"
            
            try:
                response = requests.get(url, timeout=5)
                if response.status_code == 200:
                    data = response.json()
                    status.append(f"✓ Service {i}: {data['status']}")
                else:
                    status.append(f"✗ Service {i}: HTTP {response.status_code}")
                    all_ready = False
            except Exception as e:
                status.append(f"✗ Service {i}: Error")
                all_ready = False
        
        print(f"\r[{int(time.time() - start_time)}s] {' | '.join(status)}", end="")
        
        if all_ready:
            print("\n🎉 All services ready!")
            return True
        
        time.sleep(5)
    
    print(f"\n❌ Timeout: Services not ready after {timeout} seconds")
    return False

# Wait for services
if wait_for_services():
    print("✅ Proceeding with notebook...")
else:
    print("⚠️  Services not ready. Some features may not work.")

⏳ Waiting for services to be ready...
[15s] ✓ Service 0: healthy | ✗ Service 1: Error

2025-09-30 13:46:01,607 [resource_manager] INFO: Allocated resources for service 1:
2025-09-30 13:46:01,607 [resource_manager] INFO:   GPU: 1
2025-09-30 13:46:01,607 [resource_manager] INFO:   API Port: 8081
2025-09-30 13:46:01,607 [resource_manager] INFO:   CARLA Port: 2004
2025-09-30 13:46:01,607 [resource_manager] INFO:   Streaming Port: 3010
2025-09-30 13:46:01,607 [resource_manager] INFO:   TM Port: 3004
2025-09-30 13:46:01,608 [MicroserviceManager] INFO: [Resource Manager] Allocated resources for service 1
2025-09-30 13:46:01,608 [Service-1] INFO: Starting service 1
2025-09-30 13:46:01,608 [Service-1] INFO: Cleaning up ports for service 1
2025-09-30 13:46:02,610 [Service-1] INFO: Starting carla_server.py:
2025-09-30 13:46:02,610 [Service-1] INFO:   API port: 8081
2025-09-30 13:46:02,610 [Service-1] INFO:   CARLA port: 2004
2025-09-30 13:46:02,610 [Service-1] INFO:   GPU: 1
2025-09-30 13:46:02,616 [Service-1] INFO: Server process started with PID 272545


[20s] ✓ Service 0: healthy | ✓ Service 1: healthy
🎉 All services ready!
✅ Proceeding with notebook...


In [3]:
# Import required libraries
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import logging

# Add client path to Python path
client_path = str(Path.cwd().parent / "client")
if client_path not in sys.path:
    sys.path.insert(0, client_path)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Import GRPO Carla environment
try:
    from grpo_carla_env import GRPOCarlaEnv
    print("✅ Successfully imported GRPOCarlaEnv")
except ImportError as e:
    print(f"❌ Import error: {e}")
    raise

✅ Successfully imported GRPOCarlaEnv


In [4]:
# Configuration
BASE_API_PORT = 8080
TIMEOUT = 60.0

# Create GRPO environment
print("🔧 Setting up GRPO Environment...")

env = GRPOCarlaEnv(
    num_services=2,
    base_api_port=BASE_API_PORT,
    render_mode="rgb_array",
    max_steps=100,
    timeout=TIMEOUT
)

# Pre-initialize all services for fast branching
print("🌿 Pre-initializing all services for fast branching...")
init_status = env.initialize_all_services(route_id=0)

if init_status.ready:
    print("✓ All services pre-initialized successfully!")
else:
    print(f"⚠ Service pre-initialization issues: {init_status.message}")
    print("   Continuing anyway...")

# Helper function to create actions
def create_action(throttle=0.0, brake=0.0, steer=0.0):
    """Create a valid action vector."""
    action = np.array([throttle, brake, steer], dtype=np.float32)
    return np.clip(action, env.action_space.low, env.action_space.high)

print(f"\n✅ GRPO Environment created!")
print(f"📊 Configuration:")
print(f"   Max branches: {env.max_branches}")
print(f"   Service URLs: {env.service_urls}")
print(f"   Current mode: {env.current_mode}")
print(f"   Is branching: {env.is_branching}")

if len(env.service_urls) < 2:
    print("\n⚠️  WARNING: Only 1 service available!")
    print("   Make sure both services are running")
else:
    print(f"\n🎉 Perfect! Both services ready for branching!")

print(f"Environments: {[env is not None for env in env.envs]}")

  gym.logger.warn(
  gym.logger.warn(
INFO:carla_env:Connected to CARLA server at http://localhost:8080
INFO:grpo_carla_env:Created environment 0 at http://localhost:8080
INFO:grpo_carla_env:Pre-initializing 2 services for fast branching...
INFO:carla_env:Connected to CARLA server at http://localhost:8081
INFO:grpo_carla_env:Created environment 1 at http://localhost:8081


🔧 Setting up GRPO Environment...
🌿 Pre-initializing all services for fast branching...


2025-09-30 13:47:04,691 [Service-1] ERROR: Server not ready after 60 seconds
2025-09-30 13:47:04,692 [Service-1] ERROR: Server did not become ready
2025-09-30 13:47:04,692 [Service-1] INFO: Stopping service 1
2025-09-30 13:47:14,692 [Service-1] INFO: Cleaning up ports for service 1
2025-09-30 13:47:14,692 [MicroserviceManager] INFO: Port 8081 is in use, cleaning up...
2025-09-30 13:47:16,850 [MicroserviceManager] INFO: Port 2004 is in use, cleaning up...
2025-09-30 13:47:16,993 [MicroserviceManager] INFO: Killing process 272634 using port 2004
2025-09-30 13:47:20,996 [MicroserviceManager] INFO: Port 2004 successfully freed
2025-09-30 13:47:20,996 [Service-1] INFO: Service stopped and ports cleaned
2025-09-30 13:47:20,996 [MicroserviceManager] ERROR: Failed to spawn service 1
2025-09-30 13:47:20,996 [MicroserviceManager] INFO: Manager started with 1 services
2025-09-30 13:47:20,996 [MicroserviceManager] INFO: Services running. Press Ctrl+C to stop.
ERROR:carla_env:Failed to reset enviro

⚠ Service pre-initialization issues: Pre-initialized 1/2 services
   Continuing anyway...

✅ GRPO Environment created!
📊 Configuration:
   Max branches: 2
   Service URLs: ['http://localhost:8080', 'http://localhost:8081']
   Current mode: single
   Is branching: False

🎉 Perfect! Both services ready for branching!
Environments: [True, True]


## 1. Initial Exploration in Single Mode

First, let's explore the environment to reach an interesting decision point.

In [None]:
# Initial exploration
print("🚗 Starting initial exploration...")

obs, info = env.reset(options={"route_id": 0})
print(f"Reset complete. Mode: {env.current_mode}, is_branching: {env.is_branching}")

trajectory = []
total_reward = 0.0

# Explore for 20 steps to reach a decision point
for step in range(20):
    # Vary actions to create an interesting scenario
    if step < 10:
        # First, move forward
        action = create_action(throttle=1.0, brake=0.0, steer=0.0)
    else:
        # Then add some steering variation
        action = create_action(throttle=1.0, brake=0.0, steer=0.3 * np.sin(step * 0.5))
    
    obs, reward, terminated, truncated, info = env.single_step(action)
    total_reward += reward
    
    trajectory.append({
        'step': step,
        'position': obs['vehicle_state']['position'],
        'speed': obs['vehicle_state']['speed'][0],
        'action': action.tolist(),
        'reward': reward,
        'image': obs['center_image'].copy()
    })
    
    if step % 5 == 0:
        print(f"Step {step:2d}: pos=({obs['vehicle_state']['position'][0]:6.1f}, {obs['vehicle_state']['position'][1]:6.1f}), "
              f"speed={obs['vehicle_state']['speed'][0]:4.1f} m/s, reward={reward:.3f}")
    
    if terminated or truncated:
        print(f"🏁 Episode ended at step {step}")
        break

print(f"\n📊 Exploration completed:")
print(f"   Total steps: {len(trajectory)}")
print(f"   Total reward: {total_reward:.3f}")
print(f"   Final speed: {trajectory[-1]['speed']:.2f} m/s")

# Show final state before snapshot
plt.figure(figsize=(12, 8))
plt.imshow(trajectory[-1]['image'])
plt.title(f"State Before Snapshot - Step {len(trajectory)-1}")
plt.axis('off')
plt.show()

🚗 Starting initial exploration...


2025-09-30 13:48:31,014 [MicroserviceManager] INFO: Status: {'total': 1, 'healthy': 0, 'unhealthy': 1, 'dead': 0}
2025-09-30 13:49:41,086 [MicroserviceManager] INFO: Status: {'total': 1, 'healthy': 0, 'unhealthy': 1, 'dead': 0}


## 2. Snapshot System

Now let's explore the snapshot functionality - this is the core of GRPO's multi-turn rollout capability.

In [None]:
# Save current state for branching
print("💾 Saving snapshot for branching...")

try:
    snapshot_id = env.save_snapshot()
    print(f"✅ Snapshot saved successfully!")
    print(f"   Snapshot ID: {snapshot_id}")
    print(f"   Saved at step: {env.episode_steps}")
    print(f"   Current position: {trajectory[-1]['position']}")
    print(f"   Current speed: {trajectory[-1]['speed']:.2f} m/s")
    print(f"   Current mode: {env.current_mode}")
    print(f"   Is branching: {env.is_branching}")
except Exception as e:
    print(f"❌ Failed to save snapshot: {e}")
    print("💡 Make sure you're in single mode before saving snapshot")
    raise

In [None]:
# Demonstrate snapshot information
print("📋 Snapshot System Information")
print("=" * 40)
print()
print("✅ What is a Snapshot?")
print("   A snapshot captures the complete state of the CARLA environment:")
print("   • Vehicle position and orientation")
print("   • Vehicle velocity and speed")
print("   • Traffic light states")
print("   • Weather and time of day")
print("   • Scenario progress and triggers")
print("   • Agent internal state")
print()
print("✅ Why Use Snapshots?")
print("   • Multi-turn rollouts: Explore from same state multiple times")
print("   • Decision points: Save at critical moments for branching")
print("   • Reproducibility: Replay exact scenarios")
print("   • GRPO optimization: Compare different action sequences")
print()
print(f"✅ Current Snapshot Status:")
print(f"   Snapshot ID: {snapshot_id}")
print(f"   Branch start step: {env.branch_start_step}")
print(f"   Can branch: {env.current_snapshot is not None}")
print(f"   Mode: {env.current_mode}")

## 3. Simple Branching Demo

Now let's demonstrate the branching functionality with multiple CARLA instances running in parallel.

In [None]:
# Simple branching demo
print("\n🌿 Starting simple branching demonstration...")

try:
    # Enable branching with 2 branches
    print("🌿 Enabling branching mode...")
    status = env.enable_branching(snapshot_id, num_branches=2, async_setup=False)
    
    if status.status.value == "branching_ready":
        print("✅ Branching enabled successfully!")
        print(f"   Active branches: {env.active_branches}")
        print(f"   Mode: {env.current_mode}")
        
        # Run 15 steps of branching
        print("\n🧪 Running 15 steps with 2 branches...")
        print("   Branch 0: Straight driving")
        print("   Branch 1: Right turns")
        
        total_rewards = [0, 0]
        branch_images = []
        
        for step in range(15):
            # Create different actions for each branch
            actions = [
                np.array([1.0, 0.0, 0.0], dtype=np.float32),   # Branch 0: Straight
                np.array([1.0, 0.0, 0.5], dtype=np.float32)    # Branch 1: Right turn
            ]
            
            # Execute branch step
            observations, rewards, terminateds, truncateds, infos = env.branch_step(actions)
            
            # Accumulate rewards
            for i in range(2):
                total_rewards[i] += rewards[i]
            
            # Store images for final display
            branch_images.append([obs['center_image'].copy() for obs in observations])
            
            # Show progress every 5 steps
            if (step + 1) % 5 == 0:
                print(f"   Step {step + 1}: Branch 0 = {total_rewards[0]:.2f}, Branch 1 = {total_rewards[1]:.2f}")
            
            # Check if any branch terminated
            if any(terminateds) or any(truncateds):
                print(f"   Episode ended at step {step + 1}")
                break
        
        # Show final images from both branches
        print("\n📸 Final Branch States:")
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        
        for i in range(2):
            axes[i].imshow(branch_images[-1][i])
            axes[i].set_title(f"Branch {i} - Final State (Reward: {total_rewards[i]:.3f})")
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Show results but don't select yet
        print("\n🏆 Branching Results:")
        print(f"   Branch 0 total reward: {total_rewards[0]:.3f}")
        print(f"   Branch 1 total reward: {total_rewards[1]:.3f}")
        print("   Ready for branch selection...")
        
    else:
        print(f"❌ Branching failed: {status.message}")
        
except Exception as e:
    print(f"❌ Error: {e}")
    raise

## 4. Branch Selection

Now let's select the best branch and continue from there.

In [None]:
# Branch selection
print("\n🏆 Selecting best branch...")

# Select best branch based on rewards
best_branch = 0 if total_rewards[0] > total_rewards[1] else 1
print(f"   Branch 0 reward: {total_rewards[0]:.3f}")
print(f"   Branch 1 reward: {total_rewards[1]:.3f}")
print(f"   Best branch: {best_branch}")

# Select the branch
env.select_branch(best_branch)
print(f"✅ Branch {best_branch} selected - continuing in single mode")

print(f"   Current mode: {env.current_mode}")
print(f"   Is branching: {env.is_branching}")
print(f"   Active branches: {env.active_branches}")

## 5. Continue from Selected Branch

Let's continue from the selected branch to prove that the selection works correctly.

In [None]:
# Continue from selected branch
print(f"\n🚗 Continuing from Branch {best_branch} for 20 steps to prove selection works...")

# Continue for 20 more steps to prove the selection works
continuation_reward = 0
continuation_trajectory = []

for step in range(20):
    # Simple forward driving
    action = create_action(throttle=1.0, brake=0.0, steer=0.0)
    
    obs, reward, terminated, truncated, info = env.single_step(action)
    continuation_reward += reward
    
    continuation_trajectory.append({
        'step': step,
        'position': obs['vehicle_state']['position'],
        'speed': obs['vehicle_state']['speed'][0],
        'reward': reward,
        'image': obs['center_image'].copy()
    })
    
    if step % 5 == 0:
        pos = obs['vehicle_state']['position']
        speed = obs['vehicle_state']['speed'][0]
        print(f"   Step {step}: pos=({pos[0]:.1f}, {pos[1]:.1f}), speed={speed:.1f} m/s")
    
    if terminated or truncated:
        print(f"   Episode ended at step {step}")
        break

print(f"\n📊 Continuation completed:")
print(f"   Total steps: {len(continuation_trajectory)}")
print(f"   Total reward: {continuation_reward:.3f}")

print(f"\n✅ PROOF: Successfully continued for {len(continuation_trajectory)} steps from selected branch {best_branch}!")
print(f"   This proves the branch selection and continuation works correctly!")

# Show final state
plt.figure(figsize=(10, 6))
plt.imshow(obs['center_image'])
plt.title(f"Final State - Branch {best_branch} Selected (Proof of Continuation)")
plt.axis('off')
plt.show()

## 6. Summary

Let's summarize what we've accomplished:

In [None]:
# Summary
print("🎉 GRPO Tutorial Summary")
print("=" * 40)
print()
print("✅ What We Demonstrated:")
print(f"   1. Service Setup: Started {len(env.service_urls)} CARLA services")
print(f"   2. Initial Exploration: Explored for {len(trajectory)} steps to decision point")
print(f"   3. Snapshot System: Saved snapshot at step {env.branch_start_step}")
print(f"   4. Branching Mode: Ran {env.active_branches} parallel branches for 15 steps")
print(f"   5. Branch Selection: Selected branch {best_branch} with reward {max(total_rewards):.3f}")
print(f"   6. Continuation: Continued for {len(continuation_trajectory)} steps from selected branch")
print()
print("✅ Key GRPO Features:")
print("   • Snapshot/Restore: Save complete environment state")
print("   • Parallel Branching: Multiple CARLA instances")
print("   • Trajectory Collection: Collect multiple rollouts")
print("   • Best Branch Selection: Choose optimal trajectory")
print("   • Seamless Continuation: Continue from selected branch")
print()
print("✅ Technical Implementation:")
print("   • Microservices Architecture: Independent CARLA instances")
print("   • REST API Communication: HTTP-based service interaction")
print("   • Dynamic Port Allocation: Automatic port management")
print("   • GPU Resource Management: Multi-GPU support")
print("   • Error Handling: Robust service recovery")
print()
print("🚀 Ready for GRPO Training!")
print("   This setup provides all the necessary components for:")
print("   • Multi-turn rollout collection")
print("   • Trajectory optimization")
print("   • Policy gradient updates")
print("   • Large-scale RL training")

In [None]:
# Clean up
print("\n🧹 Cleaning up...")
env.close()
print("✅ Tutorial completed successfully!")

# Show final statistics
print("\n📊 Final Statistics:")
print(f"   Total exploration steps: {len(trajectory)}")
print(f"   Total branching steps: {len(branch_images) * env.active_branches}")
print(f"   Total continuation steps: {len(continuation_trajectory)}")
print(f"   Best branch reward: {max(total_rewards):.3f}")
print(f"   Total combined reward: {total_reward + continuation_reward:.3f}")