# Basic Drone Training with PPO

This notebook demonstrates how to train a drone using PPO algorithm.

## Quick Start
1. Import training functions
2. Configure training parameters
3. Train the model
4. Test the trained model

## 1. Setup and Imports

In [None]:
import sys
import os

# Add parent directory to path if needed
parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

# Import training utilities
from train.simple_train import train_ppo, train_sac, load_model
from train.config import TrainingConfig
from train.test_utils import test_model, quick_test, visualize_episode

print("✅ Imports successful!")

## 2. Quick Test (Optional)

Test the environment before training to make sure everything works.

In [None]:
# Quick environment test
from crazy_flie_env import CrazyFlieEnv

env = CrazyFlieEnv()
obs, info = env.reset()

print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")
print(f"State shape: {obs['state'].shape}")
print(f"Image shape: {obs['image'].shape}")

# Test one step
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
print(f"Step completed! Reward: {reward:.3f}")

env.close()
print("\n✅ Environment test passed!")

## 3. Configure Training

Create a training configuration. You can modify any parameters here.

In [None]:
# Quick training (for testing)
config = TrainingConfig(
    algorithm="PPO",
    total_timesteps=100_000,  # Short for testing
    num_envs=4,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    eval_freq=10_000,
    save_freq=20_000
)

print("Training Configuration:")
print(f"  Algorithm: {config.algorithm}")
print(f"  Total timesteps: {config.total_timesteps:,}")
print(f"  Parallel envs: {config.num_envs}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Device: {config.device}")

## 4. Train PPO Agent

This will train the agent. Progress bar will show training progress.

⚠️ **Note**: Training can take a while depending on `total_timesteps`.
- 100K steps: ~10-30 minutes
- 500K steps: ~1-2 hours
- 1M steps: ~2-4 hours

In [None]:
# Train the model
model, results = train_ppo(config, verbose=True)

print("\n✅ Training completed!")
print(f"Model saved to: {results['final_model_path']}")
print(f"Logs available at: {results['log_dir']}")

## 5. Test the Trained Model

Test the trained model on several episodes.

In [None]:
# Test the model (will render first 3 episodes)
avg_reward, metrics = test_model(
    model_path=results['final_model_path'],
    algorithm="PPO",
    num_episodes=10,
    render=True
)

print("\n📊 Test Results:")
print(f"  Average Reward: {metrics['avg_reward']:.2f}")
print(f"  Std Dev: {metrics['std_reward']:.2f}")
print(f"  Success Rate: {metrics['success_rate']:.1%}")
print(f"  Average Episode Length: {metrics['avg_length']:.1f} steps")

## 6. View Training Logs with TensorBoard

Run this cell to view training metrics in TensorBoard.

In [None]:
# Load TensorBoard extension
%load_ext tensorboard

# Launch TensorBoard
%tensorboard --logdir logs/

# Alternatively, run this in terminal:
# tensorboard --logdir logs/

## 7. Visualize Episode Trajectory (Optional)

Visualize the drone's trajectory during an episode.

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Get trajectory data
trajectory = visualize_episode(
    model_path=results['final_model_path'],
    algorithm="PPO",
    max_steps=500
)

# Extract positions
positions = [t['position'] for t in trajectory]
positions = np.array(positions)

# Plot 3D trajectory
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')

ax.plot(positions[:, 0], positions[:, 1], positions[:, 2], 'b-', linewidth=2, alpha=0.7)
ax.scatter(positions[0, 0], positions[0, 1], positions[0, 2], c='g', s=100, label='Start')
ax.scatter(positions[-1, 0], positions[-1, 1], positions[-1, 2], c='r', s=100, label='End')

ax.set_xlabel('X Position (m)')
ax.set_ylabel('Y Position (m)')
ax.set_zlabel('Z Position (m)')
ax.set_title('Drone Flight Trajectory')
ax.legend()

plt.tight_layout()
plt.show()

print(f"Episode completed in {len(trajectory)} steps")

## 8. Train SAC Agent (Alternative)

If you want to try SAC instead of PPO:

In [None]:
# Configure for SAC
sac_config = TrainingConfig(
    algorithm="SAC",
    total_timesteps=100_000,
    num_envs=4,
    learning_rate=3e-4,
    batch_size=256,
    buffer_size=100_000
)

# Train SAC
sac_model, sac_results = train_sac(sac_config, verbose=True)

# Test SAC
test_model(
    model_path=sac_results['final_model_path'],
    algorithm="SAC",
    num_episodes=10
)

## Next Steps

- **Longer training**: Increase `total_timesteps` to 500K or 1M for better performance
- **Hyperparameter tuning**: Experiment with learning rate, batch size, etc.
- **Custom rewards**: Modify the environment reward function
- **Different algorithms**: Try SAC, A2C, or custom algorithms
- **Advanced features**: Add curriculum learning, domain randomization