# UAV Quadcopter Finetuning Tutorial

This notebook demonstrates how to adapt the GR00T humanoid robot VLA model for UAV quadcopter control.

## State Space (13D):
- **position**: x, y, z (3 dims)
- **orientation**: roll, pitch, yaw (3 dims)
- **velocity**: vx, vy, vz (3 dims)
- **battery**: battery level (1 dim)
- **gps**: lat, lon, alt (3 dims)

## Action Space (9D):
- **flight_control**: throttle, roll, pitch, yaw (4 dims)
- **velocity_command**: vx, vy, vz (3 dims)
- **gimbal**: gimbal_pitch, gimbal_yaw (2 dims)

The key insight is to leverage the pretrained VLM part and only retrain the diffusion model for UAV action generation.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.data.dataset import ModalityConfig
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.experiment.data_config import UAVQuadcopterDataConfig

## Step 1: Define UAV Modality Configurations

In [None]:
# UAV-specific modality configurations
video_modality = ModalityConfig(
    delta_indices=[0],
    modality_keys=["video.front_camera", "video.gimbal_camera"],
)

state_modality = ModalityConfig(
    delta_indices=[0],
    modality_keys=[
        "state.position",      # x, y, z
        "state.orientation",   # roll, pitch, yaw
        "state.velocity",      # vx, vy, vz
        "state.battery",       # battery level
        "state.gps",          # lat, lon, alt
    ],
)

action_modality = ModalityConfig(
    delta_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
    modality_keys=[
        "action.flight_control",  # throttle, roll, pitch, yaw
        "action.velocity_command", # vx, vy, vz
        "action.gimbal",          # gimbal_pitch, gimbal_yaw
    ],
)

language_modality = ModalityConfig(
    delta_indices=[0],
    modality_keys=["annotation.human.task_description"],
)

modality_configs = {
    "video": video_modality,
    "state": state_modality,
    "action": action_modality,
    "language": language_modality,
}

print("UAV Modality configurations created successfully!")

## Step 2: Configure UAV-specific Transforms

In [None]:
from gr00t.data.transform.base import ComposedModalityTransform
from gr00t.data.transform import VideoToTensor, VideoResize, VideoColorJitter, VideoToNumpy
from gr00t.data.transform.state_action import StateActionToTensor, StateActionTransform
from gr00t.data.transform.concat import ConcatTransform
from gr00t.model.transforms import GR00TTransform

transforms = [
    # Video transforms optimized for aerial footage
    VideoToTensor(apply_to=video_modality.modality_keys),
    VideoResize(
        apply_to=video_modality.modality_keys,
        height=224,
        width=224,
        antialias=True,
    ),
    VideoColorJitter(
        apply_to=video_modality.modality_keys, 
        brightness=0.3, 
        contrast=0.4, 
        saturation=0.5, 
        hue=0.08, 
        backend="torchvision"
    ),
    VideoToNumpy(apply_to=video_modality.modality_keys),

    # State transforms for UAV telemetry
    StateActionToTensor(apply_to=state_modality.modality_keys),
    StateActionTransform(
        apply_to=state_modality.modality_keys, 
        normalization_modes={
            "state.position": "min_max",      # Normalize position coordinates
            "state.orientation": "min_max",   # Normalize Euler angles
            "state.velocity": "min_max",      # Normalize velocity vectors
            "state.battery": "min_max",       # Normalize battery percentage
            "state.gps": "min_max",          # Normalize GPS coordinates
        },
        target_rotations={
            "state.orientation": "euler_angles",  # Use Euler angles for UAV orientation
        },
    ),

    # Action transforms for UAV control
    StateActionToTensor(apply_to=action_modality.modality_keys),
    StateActionTransform(
        apply_to=action_modality.modality_keys, 
        normalization_modes={
            "action.flight_control": "min_max",   # Normalize flight controls
            "action.velocity_command": "min_max", # Normalize velocity commands  
            "action.gimbal": "min_max",           # Normalize gimbal controls
        },
    ),

    # Concatenation transform
    ConcatTransform(
        video_concat_order=video_modality.modality_keys,
        state_concat_order=state_modality.modality_keys,
        action_concat_order=action_modality.modality_keys,
    ),
    
    # GR00T-specific transform for UAV
    GR00TTransform(
        state_horizon=len(state_modality.delta_indices),
        action_horizon=len(action_modality.delta_indices),
        max_state_dim=64,  # Accommodate 13D state space
        max_action_dim=32, # Accommodate 9D action space
    ),
]

composed_transforms = ComposedModalityTransform(transforms=transforms)
print("UAV-specific transforms configured successfully!")

## Step 3: Load UAV Dataset (Example with demo data structure)

In [None]:
# Set up paths for UAV dataset
dataset_path = "./demo_data/uav.Landing"  # Replace with your UAV dataset path
embodiment_tag = EmbodimentTag.UAV_QUADCOPTER

print(f"Dataset path: {dataset_path}")
print(f"Embodiment tag: {embodiment_tag}")

# Note: You would need to prepare your UAV dataset in LeRobot format
# with the modality.json file copied to meta/modality.json

## Step 4: Initialize UAV Dataset

**Note**: This cell will work once you have actual UAV data prepared in LeRobot format.

In [None]:
# Uncomment when you have actual UAV data
# uav_dataset = LeRobotSingleDataset(
#     dataset_path=dataset_path,
#     modality_configs=modality_configs,
#     embodiment_tag=embodiment_tag,
#     transforms=composed_transforms,
#     video_backend="torchvision_av",
# )

# print(f"Initialized UAV dataset with {len(uav_dataset)} episodes")
# print(f"Sample keys: {list(uav_dataset[0].keys())}")

print("Dataset initialization ready - add your UAV data to proceed!")

## Step 5: Visualize UAV State and Action Spaces

In [None]:
# Example UAV state visualization
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Mock data for visualization
time_steps = np.arange(0, 100)
position = np.random.randn(100, 3) * 10  # x, y, z
orientation = np.random.randn(100, 3) * 0.5  # roll, pitch, yaw
velocity = np.random.randn(100, 3) * 5  # vx, vy, vz
battery = 100 - time_steps * 0.5  # decreasing battery
flight_control = np.random.randn(100, 4) * 0.5  # throttle, roll, pitch, yaw
gimbal = np.random.randn(100, 2) * 0.3  # gimbal_pitch, gimbal_yaw

# Plot state space
axes[0, 0].plot(time_steps, position)
axes[0, 0].set_title('Position (x, y, z)')
axes[0, 0].legend(['x', 'y', 'z'])

axes[0, 1].plot(time_steps, orientation)
axes[0, 1].set_title('Orientation (roll, pitch, yaw)')
axes[0, 1].legend(['roll', 'pitch', 'yaw'])

axes[0, 2].plot(time_steps, velocity)
axes[0, 2].set_title('Velocity (vx, vy, vz)')
axes[0, 2].legend(['vx', 'vy', 'vz'])

# Plot action space
axes[1, 0].plot(time_steps, flight_control)
axes[1, 0].set_title('Flight Control (throttle, roll, pitch, yaw)')
axes[1, 0].legend(['throttle', 'roll', 'pitch', 'yaw'])

axes[1, 1].plot(time_steps, battery)
axes[1, 1].set_title('Battery Level')

axes[1, 2].plot(time_steps, gimbal)
axes[1, 2].set_title('Gimbal Control (pitch, yaw)')
axes[1, 2].legend(['pitch', 'yaw'])

plt.tight_layout()
plt.show()

print("\nUAV State/Action Space Dimensions:")
print(f"State Space: 13D (position: 3, orientation: 3, velocity: 3, battery: 1, gps: 3)")
print(f"Action Space: 9D (flight_control: 4, velocity_command: 3, gimbal: 2)")

## Step 6: Fine-tuning Configuration

Key differences for UAV adaptation:
1. **Leverage VLM**: Keep the visual language model frozen to retain visual understanding
2. **Retrain Diffusion**: Only fine-tune the diffusion action head for UAV-specific control
3. **New Embodiment**: Use `EmbodimentTag.UAV_QUADCOPTER` for separate action head

In [None]:
# Example fine-tuning command (run in terminal)
finetune_command = """
python scripts/gr00t_finetune.py \
    --model_path="nvidia/GR00T-N1.5-3B" \
    --data_path="./demo_data/uav.Landing" \
    --data_config="uav_quadcopter" \
    --embodiment_tag="uav_quadcopter" \
    --output_dir="./checkpoints/uav_quadcopter_finetune" \
    --batch_size=4 \
    --learning_rate=1e-4 \
    --num_epochs=50 \
    --freeze_backbone=true \
    --freeze_language_model=true \
    --only_train_action_head=true
"""

print("Fine-tuning command for UAV:")
print(finetune_command)

print("\nKey training parameters:")
print("- freeze_backbone=true: Keep VLM visual encoder frozen")
print("- freeze_language_model=true: Keep language model frozen")
print("- only_train_action_head=true: Only train UAV-specific diffusion head")
print("- embodiment_tag=uav_quadcopter: Use new UAV embodiment")

## Step 7: Policy Inference for UAV

Example of how to use the fine-tuned UAV model for inference.

In [None]:
# Example inference code (uncomment when model is trained)
inference_example = """
from gr00t.model.policy import Gr00tPolicy
from gr00t.data.embodiment_tags import EmbodimentTag

# Load the fine-tuned UAV model
uav_policy = Gr00tPolicy(
    model_path="./checkpoints/uav_quadcopter_finetune",
    modality_config=modality_configs,
    modality_transform=composed_transforms,
    embodiment_tag=EmbodimentTag.UAV_QUADCOPTER,
    device="cuda"
)

# Get action for current observation
# observation should contain:
# - video.front_camera: front camera image
# - video.gimbal_camera: gimbal camera image  
# - state.position: [x, y, z]
# - state.orientation: [roll, pitch, yaw]
# - state.velocity: [vx, vy, vz]
# - state.battery: battery_level
# - state.gps: [lat, lon, alt]
# - annotation.human.task_description: "Land on the designated platform"

action_chunk = uav_policy.get_action(observation)

# Extract UAV controls from action
flight_control = action_chunk["action.flight_control"]  # [throttle, roll, pitch, yaw]
velocity_command = action_chunk["action.velocity_command"]  # [vx, vy, vz]
gimbal_control = action_chunk["action.gimbal"]  # [gimbal_pitch, gimbal_yaw]
"""

print("UAV Inference Example:")
print(inference_example)

print("\nExpected UAV action outputs:")
print("- flight_control: [throttle, roll, pitch, yaw] - Main flight controls")
print("- velocity_command: [vx, vy, vz] - Velocity setpoints")
print("- gimbal: [gimbal_pitch, gimbal_yaw] - Camera gimbal controls")

## Summary

This notebook demonstrated how to adapt the GR00T humanoid VLA model for UAV quadcopter control:

### Key Adaptations:
1. **New Embodiment Tag**: `EmbodimentTag.UAV_QUADCOPTER` for separate action head
2. **UAV State Space**: 13D including position, orientation, velocity, battery, GPS
3. **UAV Action Space**: 9D including flight controls, velocity commands, gimbal
4. **Modality Configuration**: Custom state/action mapping for UAV telemetry
5. **Transforms**: UAV-specific normalization and data processing

### Training Strategy:
- **Leverage VLM**: Keep visual language model frozen to retain visual understanding
- **Retrain Diffusion**: Only fine-tune the action head for UAV-specific control patterns
- **Embodiment-Specific**: Train separate action head while sharing visual/language representations

### Next Steps:
1. Prepare your UAV dataset in LeRobot format
2. Copy the modality.json file to your dataset's meta/ directory
3. Run the fine-tuning script with UAV-specific parameters
4. Test the trained model on UAV control tasks

The key insight is that aerial robotics can benefit from the same visual language understanding as ground robots, with only the action generation needing UAV-specific adaptation.