# 🚁 PPO Training Notebook - CrazyFlie RL Environment

Welcome to the interactive PPO training system for autonomous drone navigation! This notebook provides:

- **Live Training Visualization**: See the drone learning in MuJoCo viewer
- **Real-time Camera View**: Drone's FPV camera in a separate window  
- **Training Metrics**: Live plots of rewards, episode length, and success rate
- **Interactive Controls**: Pause, resume, and adjust training parameters
- **Camera Window Controls**: Move and resize the drone camera window

## 📋 Features Overview

1. **Real-time 3D Environment Visualization** - Watch the drone navigate in MuJoCo
2. **First-Person View Camera** - See what the drone sees during training
3. **Live Performance Metrics** - Track learning progress with interactive plots
4. **Dynamic Parameter Tuning** - Adjust hyperparameters during training
5. **Model Management** - Save/load checkpoints and export trained models

---

## 1. Setup and Import Libraries

Import all necessary components for PPO training with live visualization.

In [3]:
# Core imports
import os
import sys
import time
import threading
import numpy as np
from datetime import datetime
from typing import Dict, Any, Optional, Tuple
from collections import deque
import json

# Environment and RL imports
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback
from stable_baselines3.common.monitor import Monitor

# Visualization imports
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Circle
import seaborn as sns

# Interactive widgets
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import tkinter as tk
from tkinter import ttk

# Custom environment and training modules
from crazy_flie_env import CrazyFlieEnv, TrainingConfig, TestingConfig
from train import TrainingManager, AlgorithmFactory, LiveVisualizationCallback
from train.callbacks import PerformanceMonitorCallback, SafetyMonitorCallback
from train.config import create_default_configs
from train.utils import print_system_info, estimate_training_time

# Configure matplotlib for inline and interactive plotting
%matplotlib widget
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Setup notebook display
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

print("✅ All libraries imported successfully!")
print("🔧 Setting up interactive training environment...")

# Enable GPU if available
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️  Using device: {device}")

✅ Registered algorithm: A2C
✅ All required dependencies found!


  from IPython.core.display import display, HTML


✅ All libraries imported successfully!
🔧 Setting up interactive training environment...
🖥️  Using device: cpu


## 2. Load Environment and Training Configuration

Configure the CrazyFlie environment and training parameters.

In [5]:
# Create training configuration
training_config = TrainingConfig(
    total_timesteps=100000,
    learning_rate=3e-4,
    batch_size=64,
    n_steps=2048,
    n_epochs=10,
    device=str(device),
    seed=42
)

# Additional training settings
training_config.num_envs = 1  # Single environment for better visualization
training_config.enable_live_training = True
training_config.render_freq = 10

# Environment configuration with vision enabled
env_config = {
    'dt': 0.02,  # 50Hz control frequency
    'max_episode_steps': 1000,
    'initial_height': 0.5,
    'target_height': 1.5,
    'enable_vision': True,
    'image_size': (64, 64),
    'camera_fps': 30,
    'enable_rendering': True,
    'room_size': [8, 8, 3],
    'spawn_radius': 1.0
}

print("🎯 Training Configuration:")
print(f"   Algorithm: {training_config.algorithm}")
print(f"   Total Timesteps: {training_config.total_timesteps:,}")
print(f"   Learning Rate: {training_config.learning_rate}")
print(f"   Batch Size: {training_config.batch_size}")
print(f"   Device: {training_config.device}")

print("\n🌍 Environment Configuration:")
print(f"   Control Frequency: {1/env_config['dt']:.0f}Hz")
print(f"   Max Episode Steps: {env_config['max_episode_steps']}")
print(f"   Vision Enabled: {env_config['enable_vision']}")
print(f"   Image Size: {env_config['image_size']}")

# Display system information
print("\n💻 System Information:")
print_system_info()

TypeError: TrainingConfig.__init__() got an unexpected keyword argument 'total_timesteps'

## 3. Initialize PPO Algorithm and Networks

In [None]:
# Create environment with monitoring
def create_monitored_env():
    """Create a monitored environment for training."""
    env = CrazyFlieEnv(env_config)
    env = Monitor(env, filename=None)  # Monitor without file logging for cleaner output
    return env

# Initialize training environment
print("🏗️  Creating training environment...")
train_env = DummyVecEnv([create_monitored_env])

# Create evaluation environment
print("🧪 Creating evaluation environment...")
eval_env = DummyVecEnv([create_monitored_env])

# Initialize PPO model with custom policy
print("🤖 Initializing PPO algorithm...")

# PPO configuration optimized for drone control
ppo_config = {
    "policy": "MultiInputPolicy",  # For dict observation space (state + image)
    "env": train_env,
    "learning_rate": training_config.learning_rate,
    "n_steps": training_config.n_steps,
    "batch_size": training_config.batch_size,
    "n_epochs": training_config.n_epochs,
    "gamma": 0.99,
    "gae_lambda": 0.95,
    "clip_range": 0.2,
    "clip_range_vf": None,
    "normalize_advantage": True,
    "ent_coef": 0.01,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5,
    "target_kl": 0.01,
    "tensorboard_log": "./logs/",
    "device": device,
    "verbose": 1
}

# Create PPO model
model = PPO(**ppo_config)

print("✅ PPO algorithm initialized!")
print(f"📊 Policy Architecture: {model.policy}")
print(f"🧠 Network Device: {model.device}")
print(f"📈 Total Parameters: {sum(p.numel() for p in model.policy.parameters()):,}")

# Display model architecture summary
print("\n🏗️  Model Architecture:")
print("   - State Input: 12D vector (position, velocity, orientation, angular velocity)")
print("   - Image Input: 64x64x3 RGB (drone FPV camera)")
print("   - Action Output: 4D continuous (roll, pitch, yaw_rate, thrust)")
print("   - Policy Network: CNN + MLP feature extractor")
print("   - Value Network: Shared feature extractor with value head")

## 4. Create Interactive Training Controls

Build interactive widgets for controlling training parameters in real-time.

In [None]:
# Training control variables
training_active = {"value": False}
training_paused = {"value": False}
training_thread = {"value": None}

# Create interactive widgets
style = {'description_width': '150px'}

# Training Control Buttons
start_button = widgets.Button(description="▶️ Start Training", button_style='success', layout=widgets.Layout(width='150px'))
pause_button = widgets.Button(description="⏸️ Pause", button_style='warning', layout=widgets.Layout(width='150px'))
stop_button = widgets.Button(description="⏹️ Stop", button_style='danger', layout=widgets.Layout(width='150px'))
reset_button = widgets.Button(description="🔄 Reset", button_style='info', layout=widgets.Layout(width='150px'))

# Parameter Adjustment Sliders
learning_rate_slider = widgets.FloatLogSlider(
    value=training_config.learning_rate,
    base=10,
    min=-5, max=-2,
    step=0.1,
    description='Learning Rate:',
    style=style,
    layout=widgets.Layout(width='400px')
)

batch_size_slider = widgets.IntSlider(
    value=training_config.batch_size,
    min=16, max=512, step=16,
    description='Batch Size:',
    style=style,
    layout=widgets.Layout(width='400px')
)

entropy_coef_slider = widgets.FloatSlider(
    value=0.01,
    min=0.0, max=0.1, step=0.005,
    description='Entropy Coef:',
    style=style,
    layout=widgets.Layout(width='400px')
)

clip_range_slider = widgets.FloatSlider(
    value=0.2,
    min=0.1, max=0.5, step=0.05,
    description='Clip Range:',
    style=style,
    layout=widgets.Layout(width='400px')
)

# Training Progress Display
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=training_config.total_timesteps,
    description='Progress:',
    bar_style='info',
    style=style,
    layout=widgets.Layout(width='400px')
)

status_text = widgets.HTML(value="<b>Status:</b> Ready to start training")

# Current metrics display
metrics_text = widgets.HTML(
    value="<b>Current Metrics:</b><br/>Episode: 0<br/>Reward: 0.0<br/>Length: 0"
)

# Button event handlers
def on_start_clicked(b):
    training_active["value"] = True
    training_paused["value"] = False
    start_button.disabled = True
    pause_button.disabled = False
    stop_button.disabled = False
    status_text.value = "<b>Status:</b> <span style='color:green'>Training Active</span>"

def on_pause_clicked(b):
    if training_paused["value"]:
        training_paused["value"] = False
        pause_button.description = "⏸️ Pause"
        status_text.value = "<b>Status:</b> <span style='color:green'>Training Resumed</span>"
    else:
        training_paused["value"] = True
        pause_button.description = "▶️ Resume"
        status_text.value = "<b>Status:</b> <span style='color:orange'>Training Paused</span>"

def on_stop_clicked(b):
    training_active["value"] = False
    training_paused["value"] = False
    start_button.disabled = False
    pause_button.disabled = True
    stop_button.disabled = True
    pause_button.description = "⏸️ Pause"
    status_text.value = "<b>Status:</b> <span style='color:red'>Training Stopped</span>"

def on_reset_clicked(b):
    # Reset progress and metrics
    progress_bar.value = 0
    metrics_text.value = "<b>Current Metrics:</b><br/>Episode: 0<br/>Reward: 0.0<br/>Length: 0"
    status_text.value = "<b>Status:</b> Environment Reset"

# Connect button events
start_button.on_click(on_start_clicked)
pause_button.on_click(on_pause_clicked)
stop_button.on_click(on_stop_clicked)
reset_button.on_click(on_reset_clicked)

# Parameter update handlers
def update_learning_rate(change):
    if hasattr(model, 'learning_rate'):
        model.learning_rate = change['new']
        print(f"Updated learning rate to: {change['new']:.2e}")

def update_parameters():
    """Update model parameters from sliders"""
    if hasattr(model, 'ent_coef'):
        model.ent_coef = entropy_coef_slider.value
    if hasattr(model, 'clip_range'):
        model.clip_range = clip_range_slider.value

learning_rate_slider.observe(update_learning_rate, names='value')
entropy_coef_slider.observe(lambda x: update_parameters(), names='value')
clip_range_slider.observe(lambda x: update_parameters(), names='value')

# Layout widgets
control_buttons = widgets.HBox([start_button, pause_button, stop_button, reset_button])
parameter_controls = widgets.VBox([
    learning_rate_slider,
    batch_size_slider, 
    entropy_coef_slider,
    clip_range_slider
])

status_panel = widgets.VBox([
    progress_bar,
    status_text,
    metrics_text
])

print("🎮 Interactive training controls created!")
print("   • Start/Pause/Stop/Reset buttons")
print("   • Real-time parameter adjustment sliders")
print("   • Training progress and metrics display")

## 5. Setup Live Visualization Windows

Initialize MuJoCo viewer and drone FPV camera windows.

In [None]:
class CameraWindow:
    """Separate window for drone FPV camera with controls."""
    
    def __init__(self):
        self.root = None
        self.canvas = None
        self.image_label = None
        self.is_running = False
        self.current_image = None
        
        # Window properties
        self.window_x = 100
        self.window_y = 100
        self.window_width = 400
        self.window_height = 400
        
    def create_window(self):
        """Create the camera window with controls."""
        if self.root is not None:
            return
            
        self.root = tk.Toplevel()
        self.root.title("🚁 Drone FPV Camera")
        self.root.geometry(f"{self.window_width}x{self.window_height}+{self.window_x}+{self.window_y}")
        self.root.resizable(True, True)
        
        # Create main frame
        main_frame = ttk.Frame(self.root)
        main_frame.pack(fill=tk.BOTH, expand=True)
        
        # Camera display
        self.image_label = tk.Label(main_frame, text="Camera Feed", bg="black", fg="white")
        self.image_label.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Control frame
        control_frame = ttk.Frame(main_frame)
        control_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # Position controls
        pos_frame = ttk.LabelFrame(control_frame, text="Window Position")
        pos_frame.pack(fill=tk.X, pady=2)
        
        ttk.Button(pos_frame, text="↖", command=lambda: self.move_window(-50, -50)).grid(row=0, column=0)
        ttk.Button(pos_frame, text="↑", command=lambda: self.move_window(0, -50)).grid(row=0, column=1)
        ttk.Button(pos_frame, text="↗", command=lambda: self.move_window(50, -50)).grid(row=0, column=2)
        ttk.Button(pos_frame, text="←", command=lambda: self.move_window(-50, 0)).grid(row=1, column=0)
        ttk.Button(pos_frame, text="⌂", command=self.center_window).grid(row=1, column=1)
        ttk.Button(pos_frame, text="→", command=lambda: self.move_window(50, 0)).grid(row=1, column=2)
        ttk.Button(pos_frame, text="↙", command=lambda: self.move_window(-50, 50)).grid(row=2, column=0)
        ttk.Button(pos_frame, text="↓", command=lambda: self.move_window(0, 50)).grid(row=2, column=1)
        ttk.Button(pos_frame, text="↘", command=lambda: self.move_window(50, 50)).grid(row=2, column=2)
        
        # Size controls
        size_frame = ttk.LabelFrame(control_frame, text="Window Size")
        size_frame.pack(fill=tk.X, pady=2)
        
        ttk.Button(size_frame, text="🔍+", command=lambda: self.resize_window(50, 50)).pack(side=tk.LEFT, padx=2)
        ttk.Button(size_frame, text="🔍-", command=lambda: self.resize_window(-50, -50)).pack(side=tk.LEFT, padx=2)
        ttk.Button(size_frame, text="📐", command=self.reset_size).pack(side=tk.LEFT, padx=2)
        
        self.is_running = True
        print("📷 Camera window created!")
        
    def move_window(self, dx, dy):
        """Move window by relative offset."""
        if self.root:
            self.window_x += dx
            self.window_y += dy
            self.root.geometry(f"{self.window_width}x{self.window_height}+{self.window_x}+{self.window_y}")
            
    def resize_window(self, dw, dh):
        """Resize window by relative amount."""
        if self.root:
            self.window_width = max(200, self.window_width + dw)
            self.window_height = max(200, self.window_height + dh)
            self.root.geometry(f"{self.window_width}x{self.window_height}+{self.window_x}+{self.window_y}")
            
    def center_window(self):
        """Center window on screen."""
        if self.root:
            self.window_x = 100
            self.window_y = 100
            self.root.geometry(f"{self.window_width}x{self.window_height}+{self.window_x}+{self.window_y}")
            
    def reset_size(self):
        """Reset window to default size."""
        if self.root:
            self.window_width = 400
            self.window_height = 400
            self.root.geometry(f"{self.window_width}x{self.window_height}+{self.window_x}+{self.window_y}")
            
    def update_image(self, image_array):
        """Update the camera feed with new image."""
        if self.root and self.image_label and image_array is not None:
            try:
                # Convert numpy array to tkinter-compatible format
                from PIL import Image, ImageTk
                
                # Ensure image is uint8
                if image_array.dtype != np.uint8:
                    image_array = (image_array * 255).astype(np.uint8)
                    
                # Create PIL image
                pil_image = Image.fromarray(image_array)
                
                # Resize to fit window (maintaining aspect ratio)
                display_size = (self.window_width - 20, self.window_height - 100)
                pil_image.thumbnail(display_size, Image.Resampling.LANCZOS)
                
                # Convert to tkinter format
                tk_image = ImageTk.PhotoImage(pil_image)
                
                # Update label
                self.image_label.configure(image=tk_image, text="")
                self.image_label.image = tk_image  # Keep reference
                
            except Exception as e:
                self.image_label.configure(text=f"Camera Error: {str(e)}")
                
    def close(self):
        """Close the camera window."""
        if self.root:
            self.root.destroy()
            self.root = None
            self.is_running = False

# Initialize visualization components
camera_window = CameraWindow()

# MuJoCo viewer will be handled by the environment
# The environment automatically creates and manages the MuJoCo viewer

print("🖼️  Visualization windows initialized!")
print("   • Camera window with position/size controls")
print("   • MuJoCo viewer managed by environment")
print("   • Ready for live training visualization")

## 6. Implement Real-time Metrics Dashboard

Create live plots for tracking training progress and performance metrics.

In [None]:
class MetricsDashboard:
    """Real-time training metrics dashboard with live plots."""
    
    def __init__(self, max_history=1000):
        self.max_history = max_history
        
        # Data storage
        self.episode_rewards = deque(maxlen=max_history)
        self.episode_lengths = deque(maxlen=max_history)
        self.success_rates = deque(maxlen=max_history)
        self.timesteps = deque(maxlen=max_history)
        self.policy_losses = deque(maxlen=max_history)
        self.value_losses = deque(maxlen=max_history)
        self.learning_rates = deque(maxlen=max_history)
        
        # Current episode tracking
        self.current_episode = 0
        self.current_timestep = 0
        self.current_reward = 0.0
        self.recent_success_count = 0
        self.recent_episodes = deque(maxlen=100)  # For success rate calculation
        
        # Setup plots
        self.setup_plots()
        
    def setup_plots(self):
        """Initialize the matplotlib plots."""
        # Create figure with subplots
        self.fig, self.axes = plt.subplots(2, 3, figsize=(15, 8))
        self.fig.suptitle('🚁 PPO Training Dashboard - Live Metrics', fontsize=16)
        
        # Initialize empty plots
        self.reward_line, = self.axes[0,0].plot([], [], 'b-', linewidth=2)
        self.axes[0,0].set_title('Episode Rewards')
        self.axes[0,0].set_xlabel('Episode')
        self.axes[0,0].set_ylabel('Reward')
        self.axes[0,0].grid(True, alpha=0.3)
        
        self.length_line, = self.axes[0,1].plot([], [], 'g-', linewidth=2)  
        self.axes[0,1].set_title('Episode Lengths')
        self.axes[0,1].set_xlabel('Episode')
        self.axes[0,1].set_ylabel('Steps')
        self.axes[0,1].grid(True, alpha=0.3)
        
        self.success_line, = self.axes[0,2].plot([], [], 'r-', linewidth=2)
        self.axes[0,2].set_title('Success Rate (Last 100 Episodes)')
        self.axes[0,2].set_xlabel('Episode')
        self.axes[0,2].set_ylabel('Success Rate (%)')
        self.axes[0,2].grid(True, alpha=0.3)
        self.axes[0,2].set_ylim(0, 100)
        
        self.policy_loss_line, = self.axes[1,0].plot([], [], 'purple', linewidth=2)
        self.axes[1,0].set_title('Policy Loss')
        self.axes[1,0].set_xlabel('Update')
        self.axes[1,0].set_ylabel('Loss')
        self.axes[1,0].grid(True, alpha=0.3)
        
        self.value_loss_line, = self.axes[1,1].plot([], [], 'orange', linewidth=2)
        self.axes[1,1].set_title('Value Loss')
        self.axes[1,1].set_xlabel('Update')
        self.axes[1,1].set_ylabel('Loss')
        self.axes[1,1].grid(True, alpha=0.3)
        
        self.lr_line, = self.axes[1,2].plot([], [], 'brown', linewidth=2)
        self.axes[1,2].set_title('Learning Rate')
        self.axes[1,2].set_xlabel('Update')  
        self.axes[1,2].set_ylabel('Learning Rate')
        self.axes[1,2].grid(True, alpha=0.3)
        self.axes[1,2].set_yscale('log')
        
        plt.tight_layout()
        
    def update_episode_data(self, reward, length, success=False):
        """Update episode-level metrics."""
        self.current_episode += 1
        self.episode_rewards.append(reward)
        self.episode_lengths.append(length)
        self.timesteps.append(self.current_timestep)
        
        # Update success tracking
        self.recent_episodes.append(success)
        if len(self.recent_episodes) > 0:
            success_rate = (sum(self.recent_episodes) / len(self.recent_episodes)) * 100
            self.success_rates.append(success_rate)
        else:
            self.success_rates.append(0)
            
    def update_training_data(self, policy_loss=None, value_loss=None, learning_rate=None):
        """Update training-level metrics."""
        if policy_loss is not None:
            self.policy_losses.append(policy_loss)
        if value_loss is not None:
            self.value_losses.append(value_loss)
        if learning_rate is not None:
            self.learning_rates.append(learning_rate)
            
    def update_plots(self):
        """Update all plots with current data."""
        if len(self.episode_rewards) == 0:
            return
            
        # Update episode rewards
        episodes = range(len(self.episode_rewards))
        self.reward_line.set_data(episodes, self.episode_rewards)
        self.axes[0,0].relim()
        self.axes[0,0].autoscale_view()
        
        # Update episode lengths
        if len(self.episode_lengths) > 0:
            self.length_line.set_data(episodes, self.episode_lengths)
            self.axes[0,1].relim()
            self.axes[0,1].autoscale_view()
            
        # Update success rate
        if len(self.success_rates) > 0:
            self.success_line.set_data(episodes, self.success_rates)
            self.axes[0,2].relim()
            self.axes[0,2].autoscale_view()
            
        # Update policy loss
        if len(self.policy_losses) > 0:
            updates = range(len(self.policy_losses))
            self.policy_loss_line.set_data(updates, self.policy_losses)
            self.axes[1,0].relim()
            self.axes[1,0].autoscale_view()
            
        # Update value loss
        if len(self.value_losses) > 0:
            updates = range(len(self.value_losses))
            self.value_loss_line.set_data(updates, self.value_losses)
            self.axes[1,1].relim()
            self.axes[1,1].autoscale_view()
            
        # Update learning rate
        if len(self.learning_rates) > 0:
            updates = range(len(self.learning_rates))
            self.lr_line.set_data(updates, self.learning_rates)
            self.axes[1,2].relim()
            self.axes[1,2].autoscale_view()
            
        # Refresh the plot
        self.fig.canvas.draw()
        
    def get_current_stats(self):
        """Get current training statistics."""
        if len(self.episode_rewards) == 0:
            return {"episode": 0, "avg_reward": 0.0, "avg_length": 0, "success_rate": 0.0}
            
        recent_rewards = list(self.episode_rewards)[-10:]  # Last 10 episodes
        recent_lengths = list(self.episode_lengths)[-10:]
        recent_success = list(self.success_rates)[-1:] if self.success_rates else [0]
        
        return {
            "episode": self.current_episode,
            "avg_reward": np.mean(recent_rewards),
            "avg_length": np.mean(recent_lengths),
            "success_rate": recent_success[0] if recent_success else 0.0
        }

# Create metrics dashboard
dashboard = MetricsDashboard()

# Display the plots
plt.show()

print("📊 Live metrics dashboard created!")
print("   • Episode rewards and lengths tracking")
print("   • Success rate monitoring (last 100 episodes)")
print("   • Policy and value loss visualization")
print("   • Learning rate tracking")
print("   • Real-time plot updates during training")

## 7. Create Camera Window Management

Display the camera window controls and initialize the FPV camera system.

In [None]:
# Camera control widgets
camera_enabled = widgets.Checkbox(
    value=True,
    description='Enable FPV Camera',
    style={'description_width': '150px'}
)

open_camera_button = widgets.Button(
    description="📷 Open Camera Window",
    button_style='info',
    layout=widgets.Layout(width='200px')
)

close_camera_button = widgets.Button(
    description="❌ Close Camera Window", 
    button_style='warning',
    layout=widgets.Layout(width='200px')
)

# Camera position controls
camera_x_slider = widgets.IntSlider(
    value=100,
    min=0, max=1000, step=50,
    description='X Position:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='300px')
)

camera_y_slider = widgets.IntSlider(
    value=100,
    min=0, max=800, step=50,
    description='Y Position:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='300px')
)

camera_width_slider = widgets.IntSlider(
    value=400,
    min=200, max=800, step=50,
    description='Width:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='300px')
)

camera_height_slider = widgets.IntSlider(
    value=400,
    min=200, max=600, step=50,
    description='Height:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='300px')
)

# Camera window event handlers
def on_open_camera_clicked(b):
    """Open the camera window."""
    try:
        camera_window.create_window()
        open_camera_button.disabled = True
        close_camera_button.disabled = False
        print("📷 Camera window opened!")
    except Exception as e:
        print(f"❌ Error opening camera window: {e}")

def on_close_camera_clicked(b):
    """Close the camera window."""
    try:
        camera_window.close()
        open_camera_button.disabled = False
        close_camera_button.disabled = True
        print("❌ Camera window closed!")
    except Exception as e:
        print(f"❌ Error closing camera window: {e}")

def update_camera_position(change):
    """Update camera window position."""
    if camera_window.root:
        camera_window.window_x = camera_x_slider.value
        camera_window.window_y = camera_y_slider.value
        camera_window.window_width = camera_width_slider.value
        camera_window.window_height = camera_height_slider.value
        
        geometry = f"{camera_window.window_width}x{camera_window.window_height}+{camera_window.window_x}+{camera_window.window_y}"
        camera_window.root.geometry(geometry)

# Connect camera events
open_camera_button.on_click(on_open_camera_clicked)
close_camera_button.on_click(on_close_camera_clicked)

# Connect position sliders
camera_x_slider.observe(update_camera_position, names='value')
camera_y_slider.observe(update_camera_position, names='value')
camera_width_slider.observe(update_camera_position, names='value')
camera_height_slider.observe(update_camera_position, names='value')

# Camera control layout
camera_buttons = widgets.HBox([open_camera_button, close_camera_button])
camera_position_controls = widgets.VBox([
    widgets.HTML("<b>Camera Window Position & Size:</b>"),
    camera_x_slider,
    camera_y_slider,
    camera_width_slider,
    camera_height_slider
])

camera_controls = widgets.VBox([
    camera_enabled,
    camera_buttons,
    camera_position_controls
])

print("📷 Camera window management controls created!")
print("   • Enable/disable FPV camera feed")
print("   • Open/close camera window")
print("   • Real-time position and size adjustment")
print("   • Integrated with training loop")

## 8. Main Training Loop with Visualization

Execute the PPO training with integrated live visualization and metrics collection.

In [None]:
class InteractiveTrainingCallback(BaseCallback):
    """Custom callback for interactive training with live visualization."""
    
    def __init__(self, dashboard, camera_window, update_freq=10):
        super().__init__()
        self.dashboard = dashboard
        self.camera_window = camera_window
        self.update_freq = update_freq
        self.episode_count = 0
        self.episode_reward = 0
        self.episode_length = 0
        
    def _on_step(self) -> bool:
        """Called at each environment step."""
        # Check for pause/stop commands
        if not training_active["value"]:
            return False  # Stop training
            
        while training_paused["value"]:
            time.sleep(0.1)  # Wait while paused
            if not training_active["value"]:
                return False
                
        # Update progress
        progress_bar.value = self.num_timesteps
        
        # Get current episode info
        if len(self.locals.get('infos', [])) > 0:
            info = self.locals['infos'][0]
            
            # Check if episode ended
            if self.locals.get('dones', [False])[0]:
                self.episode_count += 1
                
                # Extract episode metrics
                episode_reward = info.get('episode', {}).get('r', 0)
                episode_length = info.get('episode', {}).get('l', 0)
                success = info.get('success', False)
                
                # Update dashboard
                self.dashboard.update_episode_data(episode_reward, episode_length, success)
                self.dashboard.current_timestep = self.num_timesteps
                
                # Update UI metrics
                stats = self.dashboard.get_current_stats()
                metrics_text.value = f"""
                <b>Current Metrics:</b><br/>
                Episode: {stats['episode']}<br/>
                Avg Reward (10 eps): {stats['avg_reward']:.2f}<br/>
                Avg Length (10 eps): {stats['avg_length']:.1f}<br/>
                Success Rate: {stats['success_rate']:.1f}%
                """
                
                # Update plots every few episodes
                if self.episode_count % self.update_freq == 0:
                    self.dashboard.update_plots()
                    
            # Update camera if enabled and available
            if camera_enabled.value and self.camera_window.is_running:
                if hasattr(self.training_env.envs[0], 'get_camera_image'):
                    try:
                        camera_image = self.training_env.envs[0].get_camera_image()
                        if camera_image is not None:
                            self.camera_window.update_image(camera_image)
                    except Exception as e:
                        pass  # Silently handle camera errors
                        
        return True  # Continue training
        
    def _on_training_end(self) -> None:
        """Called at the end of training."""
        print("🏁 Training completed!")
        status_text.value = "<b>Status:</b> <span style='color:blue'>Training Completed</span>"
        
        # Final plot update
        self.dashboard.update_plots()

def run_interactive_training():
    """Run the interactive training loop."""
    print("🚀 Starting interactive PPO training...")
    
    # Create callback
    callback = InteractiveTrainingCallback(dashboard, camera_window)
    
    # Setup additional callbacks
    callbacks = [callback]
    
    # Add checkpoint callback
    checkpoint_callback = CheckpointCallback(
        save_freq=10000,
        save_path='./checkpoints/',
        name_prefix='ppo_drone'
    )
    callbacks.append(checkpoint_callback)
    
    try:
        # Start training
        model.learn(
            total_timesteps=training_config.total_timesteps,
            callback=callbacks,
            progress_bar=False  # We have our own progress bar
        )
        
    except KeyboardInterrupt:
        print("⏹️  Training interrupted by user")
        status_text.value = "<b>Status:</b> <span style='color:orange'>Training Interrupted</span>"
        
    except Exception as e:
        print(f"❌ Training error: {e}")
        status_text.value = f"<b>Status:</b> <span style='color:red'>Error: {str(e)}</span>"
        
    finally:
        training_active["value"] = False
        start_button.disabled = False
        pause_button.disabled = True
        stop_button.disabled = True

# Modified start button handler to run training
def on_start_training(b):
    """Start the training process."""
    on_start_clicked(b)  # Update UI state
    
    # Run training in separate thread to keep UI responsive
    training_thread["value"] = threading.Thread(target=run_interactive_training)
    training_thread["value"].daemon = True
    training_thread["value"].start()

# Replace the start button click handler
start_button.on_click(on_start_training)

print("🎮 Interactive training loop configured!")
print("   • Real-time visualization integration")
print("   • Pause/resume/stop functionality")
print("   • Live metrics updates")
print("   • Camera feed integration")
print("   • Automatic checkpointing")

## 9. Interactive Parameter Adjustment

Runtime modification of training parameters and learning rate scheduling.

In [None]:
# Advanced parameter controls
exploration_noise_slider = widgets.FloatSlider(
    value=0.1,
    min=0.0, max=1.0, step=0.05,
    description='Exploration Noise:',
    style={'description_width': '150px'},
    layout=widgets.Layout(width='400px')
)

gamma_slider = widgets.FloatSlider(
    value=0.99,
    min=0.8, max=0.999, step=0.001,
    description='Discount Factor (γ):',
    style={'description_width': '150px'},
    layout=widgets.Layout(width='400px')
)

gae_lambda_slider = widgets.FloatSlider(
    value=0.95,
    min=0.8, max=1.0, step=0.01,
    description='GAE Lambda (λ):',
    style={'description_width': '150px'},
    layout=widgets.Layout(width='400px')
)

# Learning rate scheduling
lr_schedule_dropdown = widgets.Dropdown(
    options=[('Constant', 'constant'), ('Linear Decay', 'linear'), ('Exponential Decay', 'exponential')],
    value='constant',
    description='LR Schedule:',
    style={'description_width': '150px'}
)

# Environment parameter controls
target_height_slider = widgets.FloatSlider(
    value=env_config['target_height'],
    min=0.5, max=3.0, step=0.1,
    description='Target Height (m):',
    style={'description_width': '150px'},
    layout=widgets.Layout(width='400px')
)

wind_strength_slider = widgets.FloatSlider(
    value=0.0,
    min=0.0, max=0.5, step=0.01,
    description='Wind Strength:',
    style={'description_width': '150px'},
    layout=widgets.Layout(width='400px')
)

# Parameter update functions
def update_advanced_parameters():
    """Update advanced training parameters during runtime."""
    try:
        # Update model parameters if accessible
        if hasattr(model, 'gamma'):
            model.gamma = gamma_slider.value
        if hasattr(model, 'gae_lambda'):
            model.gae_lambda = gae_lambda_slider.value
            
        # Update environment parameters
        if hasattr(train_env.envs[0], 'set_target_height'):
            train_env.envs[0].set_target_height(target_height_slider.value)
        if hasattr(train_env.envs[0], 'set_wind_strength'):
            train_env.envs[0].set_wind_strength(wind_strength_slider.value)
            
        print(f"Updated parameters: γ={gamma_slider.value:.3f}, λ={gae_lambda_slider.value:.3f}")
        
    except Exception as e:
        print(f"Parameter update error: {e}")

def apply_lr_schedule():
    """Apply learning rate scheduling."""
    current_lr = learning_rate_slider.value
    schedule_type = lr_schedule_dropdown.value
    
    if schedule_type == 'linear':
        # Linear decay over training
        progress = progress_bar.value / training_config.total_timesteps
        new_lr = current_lr * (1 - progress)
    elif schedule_type == 'exponential':
        # Exponential decay
        progress = progress_bar.value / training_config.total_timesteps
        new_lr = current_lr * (0.95 ** (progress * 100))
    else:
        # Constant
        new_lr = current_lr
        
    # Update model learning rate
    if hasattr(model, 'learning_rate'):
        model.learning_rate = new_lr
        
    return new_lr

# Connect parameter update events
gamma_slider.observe(lambda x: update_advanced_parameters(), names='value')
gae_lambda_slider.observe(lambda x: update_advanced_parameters(), names='value')
target_height_slider.observe(lambda x: update_advanced_parameters(), names='value')
wind_strength_slider.observe(lambda x: update_advanced_parameters(), names='value')

# Manual parameter update button
update_params_button = widgets.Button(
    description="🔄 Update Parameters",
    button_style='info',
    layout=widgets.Layout(width='200px')
)
update_params_button.on_click(lambda x: update_advanced_parameters())

# Learning rate scheduler button
schedule_lr_button = widgets.Button(
    description="📈 Apply LR Schedule",
    button_style='info', 
    layout=widgets.Layout(width='200px')
)
schedule_lr_button.on_click(lambda x: apply_lr_schedule())

# Advanced controls layout
advanced_controls = widgets.VBox([
    widgets.HTML("<b>Advanced Training Parameters:</b>"),
    exploration_noise_slider,
    gamma_slider,
    gae_lambda_slider,
    widgets.HTML("<b>Learning Rate Scheduling:</b>"),
    lr_schedule_dropdown,
    widgets.HBox([update_params_button, schedule_lr_button]),
    widgets.HTML("<b>Environment Parameters:</b>"),
    target_height_slider,
    wind_strength_slider
])

print("⚙️  Advanced parameter controls created!")
print("   • Real-time hyperparameter adjustment")
print("   • Learning rate scheduling")
print("   • Environment parameter modification")
print("   • Exploration noise control")
print("   • Discount factor and GAE lambda tuning")

## 10. Save and Load Model Functionality

Controls for model management, checkpointing, and deployment.

In [None]:
# Model management controls
model_name_text = widgets.Text(
    value=f"ppo_drone_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    description='Model Name:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='400px')
)

save_model_button = widgets.Button(
    description="💾 Save Model",
    button_style='success',
    layout=widgets.Layout(width='150px')
)

load_model_button = widgets.Button(
    description="📂 Load Model",
    button_style='info',
    layout=widgets.Layout(width='150px')
)

export_model_button = widgets.Button(
    description="📤 Export for Deployment",
    button_style='warning',
    layout=widgets.Layout(width='180px')
)

# Model file browser
model_files = widgets.Dropdown(
    options=[],
    description='Available Models:',
    style={'description_width': '130px'},
    layout=widgets.Layout(width='400px')
)

# Model information display
model_info_text = widgets.HTML(
    value="<b>Current Model:</b> Newly initialized PPO model"
)

def refresh_model_list():
    """Refresh the list of available models."""
    try:
        model_dir = "./models/"
        if os.path.exists(model_dir):
            models = []
            for folder in os.listdir(model_dir):
                folder_path = os.path.join(model_dir, folder)
                if os.path.isdir(folder_path):
                    # Look for .zip files in the folder
                    for file in os.listdir(folder_path):
                        if file.endswith('.zip'):
                            models.append(os.path.join(folder, file))
            model_files.options = models
        else:
            model_files.options = []
    except Exception as e:
        print(f"Error refreshing model list: {e}")

def save_current_model(b):
    """Save the current model."""
    try:
        model_path = f"./models/{model_name_text.value}"
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        
        # Save the model
        model.save(model_path)
        
        # Save training configuration
        config_path = f"{model_path}_config.json"
        config_data = {
            'model_name': model_name_text.value,
            'algorithm': 'PPO',
            'total_timesteps': progress_bar.value,
            'learning_rate': learning_rate_slider.value,
            'batch_size': batch_size_slider.value,
            'gamma': gamma_slider.value,
            'gae_lambda': gae_lambda_slider.value,
            'saved_at': datetime.now().isoformat(),
            'training_stats': dashboard.get_current_stats()
        }
        
        with open(config_path, 'w') as f:
            json.dump(config_data, f, indent=2)
            
        print(f"✅ Model saved to: {model_path}")
        model_info_text.value = f"<b>Current Model:</b> {model_name_text.value} (saved)"
        refresh_model_list()
        
    except Exception as e:
        print(f"❌ Error saving model: {e}")

def load_selected_model(b):
    """Load the selected model."""
    try:
        if not model_files.value:
            print("❌ No model selected")
            return
            
        model_path = f"./models/{model_files.value}"
        
        # Remove .zip extension for loading
        if model_path.endswith('.zip'):
            model_path = model_path[:-4]
            
        # Load the model
        global model
        model = PPO.load(model_path, env=train_env)
        
        # Try to load configuration
        config_path = f"{model_path}_config.json"
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                config_data = json.load(f)
                
            # Update UI with loaded parameters
            if 'learning_rate' in config_data:
                learning_rate_slider.value = config_data['learning_rate']
            if 'batch_size' in config_data:
                batch_size_slider.value = config_data['batch_size']
            if 'gamma' in config_data:
                gamma_slider.value = config_data['gamma']
            if 'gae_lambda' in config_data:
                gae_lambda_slider.value = config_data['gae_lambda']
                
        print(f"✅ Model loaded from: {model_path}")
        model_info_text.value = f"<b>Current Model:</b> {model_files.value} (loaded)"
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")

def export_for_deployment(b):
    """Export model for deployment."""
    try:
        export_path = f"./exports/{model_name_text.value}"
        os.makedirs(os.path.dirname(export_path), exist_ok=True)
        
        # Save model
        model.save(export_path)
        
        # Create deployment package
        deployment_info = {
            'model_file': f"{model_name_text.value}.zip",
            'algorithm': 'PPO',
            'observation_space': str(train_env.observation_space),
            'action_space': str(train_env.action_space),
            'env_config': env_config,
            'deployment_ready': True,
            'exported_at': datetime.now().isoformat()
        }
        
        with open(f"{export_path}_deployment.json", 'w') as f:
            json.dump(deployment_info, f, indent=2)
            
        print(f"✅ Model exported for deployment: {export_path}")
        print("📦 Deployment package includes model and configuration")
        
    except Exception as e:
        print(f"❌ Error exporting model: {e}")

# Connect button events
save_model_button.on_click(save_current_model)
load_model_button.on_click(load_selected_model)
export_model_button.on_click(export_for_deployment)

# Refresh model list on startup
refresh_model_list()

# Model management layout
model_controls = widgets.VBox([
    widgets.HTML("<b>Model Management:</b>"),
    model_name_text,
    widgets.HBox([save_model_button, load_model_button, export_model_button]),
    model_files,
    widgets.Button(description="🔄 Refresh List", button_style='', layout=widgets.Layout(width='120px')),
    model_info_text
])

# Connect refresh button
refresh_button = model_controls.children[-2]
refresh_button.on_click(lambda x: refresh_model_list())

print("💾 Model management system created!")
print("   • Save/load trained models")
print("   • Model configuration persistence")
print("   • Deployment package export")
print("   • Training state preservation")
print("   • Model browser with metadata")

## 🎮 Interactive Training Dashboard

Launch the complete interactive training interface with all controls and visualizations.

In [None]:
# Create the complete dashboard layout
dashboard_tabs = widgets.Tab()

# Training Control Tab
training_tab = widgets.VBox([
    widgets.HTML("<h3>🚁 PPO Training Controls</h3>"),
    control_buttons,
    status_panel,
    widgets.HTML("<hr>"),
    widgets.HTML("<h4>Training Parameters</h4>"),
    parameter_controls
])

# Camera Control Tab
camera_tab = widgets.VBox([
    widgets.HTML("<h3>📷 Camera System</h3>"),
    camera_controls
])

# Advanced Settings Tab
advanced_tab = widgets.VBox([
    widgets.HTML("<h3>⚙️ Advanced Settings</h3>"),
    advanced_controls
])

# Model Management Tab
model_tab = widgets.VBox([
    widgets.HTML("<h3>💾 Model Management</h3>"),
    model_controls
])

# Add tabs to the dashboard
dashboard_tabs.children = [training_tab, camera_tab, advanced_tab, model_tab]
dashboard_tabs.set_title(0, "🎮 Training")
dashboard_tabs.set_title(1, "📷 Camera")
dashboard_tabs.set_title(2, "⚙️ Advanced")
dashboard_tabs.set_title(3, "💾 Models")

# Quick action buttons (always visible)
quick_actions = widgets.HBox([
    widgets.HTML("<b>Quick Actions:</b> "),
    start_button,
    pause_button,
    stop_button,
    open_camera_button
], layout=widgets.Layout(margin='10px'))

# Instructions
instructions = widgets.HTML("""
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 5px; margin: 10px 0;">
<h4>🚀 Getting Started:</h4>
<ol>
<li><b>Open Camera Window:</b> Click "📷 Open Camera Window" to see the drone's FPV view</li>
<li><b>Start Training:</b> Click "▶️ Start Training" to begin PPO training</li>
<li><b>Monitor Progress:</b> Watch the live metrics dashboard and MuJoCo viewer</li>
<li><b>Adjust Parameters:</b> Use sliders to modify training parameters in real-time</li>
<li><b>Control Training:</b> Use pause/resume/stop buttons as needed</li>
<li><b>Save Progress:</b> Save your trained model at any time</li>
</ol>

<h4>📊 Visualization Features:</h4>
<ul>
<li><b>MuJoCo Viewer:</b> 3D environment visualization (opens automatically during training)</li>
<li><b>FPV Camera:</b> Drone's first-person view with positioning controls</li>
<li><b>Live Metrics:</b> Real-time plots of rewards, episode length, success rate</li>
<li><b>Training Stats:</b> Episode count, average performance, and learning progress</li>
</ul>

<h4>🎮 Control Tips:</h4>
<ul>
<li>Use <b>Pause</b> to temporarily stop training while keeping the session active</li>
<li>Adjust <b>Learning Rate</b> and other parameters during training for better results</li>
<li>Move the <b>Camera Window</b> to your preferred screen position</li>
<li>Save models frequently to preserve training progress</li>
</ul>
</div>
""")

# Display the complete dashboard
display(instructions)
display(quick_actions)
display(dashboard_tabs)

print("=" * 60)
print("🎉 INTERACTIVE PPO TRAINING DASHBOARD READY!")
print("=" * 60)
print("📊 Live Metrics Dashboard: Initialized")
print("📷 Camera Window System: Ready")
print("🎮 Interactive Controls: Active")
print("💾 Model Management: Available")
print("🤖 PPO Algorithm: Configured")
print("🌍 CrazyFlie Environment: Loaded")
print("=" * 60)
print("\n🚀 Click 'Start Training' to begin your drone RL training journey!")
print("📈 Watch your drone learn to fly in real-time with full visualization!")
print("\n💡 Pro Tip: Open the camera window first to see the drone's perspective!")

# PPO Drone Training Notebook

This notebook demonstrates Proximal Policy Optimization (PPO) training for a drone agent in the Drone-UAV environment with:
- **Live MuJoCo visualization**
- **Real-time FPV camera view**
- **Live training metrics**
- **Interactive controls** (pause, resume, adjust parameters)
- **Camera window controls** (move, resize)

---

## 1. Import Required Libraries and Modules

In [None]:
# Import standard libraries
import numpy as np
import matplotlib.pyplot as plt
import threading
import time
import cv2
from IPython.display import display, clear_output
import ipywidgets as widgets

# Import MuJoCo viewer (assuming mujoco-py or mujoco is installed)
try:
    import mujoco
    from mujoco.viewer import launch
except ImportError:
    print("MuJoCo viewer not found. Please install mujoco or mujoco-py.")

# Import Drone-UAV environment and PPO agent
from crazy_flie_env.core.environment import DroneEnvironment
from train.algorithms import PPOAgent
from train.config import PPOConfig
from crazy_flie_env.vision.cameras import DroneCamera
from crazy_flie_env.vision.rendering import render_fpv

# Import any additional utilities
from train.utils import plot_metrics


## 2. Initialize Environment and PPO Agent

In [None]:
# Instantiate environment and PPO agent

env_config = {
    'render_mode': 'mujoco',
    'camera_enabled': True,
    # Add other environment config params as needed
}

# Initialize environment
env = DroneEnvironment(**env_config)

# Load PPO config and agent
ppo_config = PPOConfig()
agent = PPOAgent(env.observation_space, env.action_space, ppo_config)

# Initialize FPV camera
camera = DroneCamera(env)


## 3. Set Up MuJoCo Viewer for Live Visualization

In [None]:
# Launch MuJoCo viewer in a separate thread for live visualization
def launch_mujoco_viewer(env):
    try:
        # If using mujoco-py
        viewer_thread = threading.Thread(target=env.render, kwargs={'mode': 'human'})
        viewer_thread.daemon = True
        viewer_thread.start()
    except Exception as e:
        print(f"MuJoCo viewer launch failed: {e}")

# Start the viewer
launch_mujoco_viewer(env)


## 4. Set Up Real-time FPV Camera Window

In [None]:
# Function to display FPV camera in a separate OpenCV window
def show_fpv_camera(camera, window_name="Drone FPV"):
    def camera_loop():
        while True:
            frame = camera.get_frame()
            if frame is not None:
                cv2.imshow(window_name, frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        cv2.destroyWindow(window_name)
    thread = threading.Thread(target=camera_loop)
    thread.daemon = True
    thread.start()

# Start FPV camera window
show_fpv_camera(camera)


## 5. Define Training Metrics and Live Plotting

In [None]:
# Set up live plotting for training metrics
%matplotlib notebook

rewards = []
episode_lengths = []
success_rates = []

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 10))
fig.suptitle('Training Metrics')

reward_line, = ax1.plot([], [], label='Reward')
ax1.set_ylabel('Reward')
ax1.legend()

length_line, = ax2.plot([], [], label='Episode Length')
ax2.set_ylabel('Length')
ax2.legend()

success_line, = ax3.plot([], [], label='Success Rate')
ax3.set_ylabel('Success Rate')
ax3.set_xlabel('Episode')
ax3.legend()

plt.tight_layout()

# Function to update plots
def update_plots():
    reward_line.set_data(range(len(rewards)), rewards)
    length_line.set_data(range(len(episode_lengths)), episode_lengths)
    success_line.set_data(range(len(success_rates)), success_rates)
    for ax, data in zip([ax1, ax2, ax3], [rewards, episode_lengths, success_rates]):
        ax.relim()
        ax.autoscale_view()
    fig.canvas.draw()
    fig.canvas.flush_events()


## 6. Implement Interactive Controls (Pause, Resume, Adjust Parameters)

In [None]:
# Interactive widgets for training control
pause_button = widgets.ToggleButton(value=False, description='Pause', button_style='warning')
resume_button = widgets.Button(description='Resume', button_style='success')
learning_rate_slider = widgets.FloatSlider(value=ppo_config.learning_rate, min=1e-5, max=1e-2, step=1e-5, description='LR')
batch_size_slider = widgets.IntSlider(value=ppo_config.batch_size, min=32, max=1024, step=32, description='Batch Size')

controls_box = widgets.HBox([pause_button, resume_button, learning_rate_slider, batch_size_slider])
display(controls_box)

# State variable for pausing
is_paused = threading.Event()
is_paused.clear()

def on_pause_change(change):
    if change['new']:
        is_paused.set()
    else:
        is_paused.clear()

def on_resume_clicked(b):
    pause_button.value = False
    is_paused.clear()

def on_lr_change(change):
    agent.set_learning_rate(change['new'])

def on_bs_change(change):
    agent.set_batch_size(change['new'])

pause_button.observe(on_pause_change, names='value')
resume_button.on_click(on_resume_clicked)
learning_rate_slider.observe(on_lr_change, names='value')
batch_size_slider.observe(on_bs_change, names='value')


## 7. Implement Camera Window Controls (Move, Resize)

In [None]:
# Camera window controls using OpenCV
move_x_slider = widgets.IntSlider(value=100, min=0, max=1920, step=10, description='Move X')
move_y_slider = widgets.IntSlider(value=100, min=0, max=1080, step=10, description='Move Y')
resize_w_slider = widgets.IntSlider(value=320, min=100, max=1280, step=10, description='Width')
resize_h_slider = widgets.IntSlider(value=240, min=100, max=720, step=10, description='Height')

camera_controls = widgets.HBox([move_x_slider, move_y_slider, resize_w_slider, resize_h_slider])
display(camera_controls)

# Function to update camera window position and size
def update_camera_window(*args):
    cv2.moveWindow('Drone FPV', move_x_slider.value, move_y_slider.value)
    cv2.resizeWindow('Drone FPV', resize_w_slider.value, resize_h_slider.value)

move_x_slider.observe(update_camera_window, names='value')
move_y_slider.observe(update_camera_window, names='value')
resize_w_slider.observe(update_camera_window, names='value')
resize_h_slider.observe(update_camera_window, names='value')


## 8. Run Training Loop with Live Updates

In [None]:
# Main training loop with live updates
def training_loop(num_episodes=1000):
    for episode in range(num_episodes):
        obs = env.reset()
        done = False
        total_reward = 0
        steps = 0
        success = 0
        while not done:
            if is_paused.is_set():
                time.sleep(0.1)
                continue
            action = agent.select_action(obs)
            next_obs, reward, done, info = env.step(action)
            agent.store_transition(obs, action, reward, next_obs, done)
            obs = next_obs
            total_reward += reward
            steps += 1
            # Update MuJoCo viewer and FPV camera handled by threads
        agent.update()
        rewards.append(total_reward)
        episode_lengths.append(steps)
        success_rates.append(info.get('success', 0))
        update_plots()
        clear_output(wait=True)
        display(fig)
        print(f"Episode {episode+1}: Reward={total_reward:.2f}, Steps={steps}, Success={info.get('success', 0)}")

# Start training (can be run in a thread for UI responsiveness)
# training_thread = threading.Thread(target=training_loop, args=(1000,))
# training_thread.start()
