# UAV Quadcopter GR00T Adaptation Test

This notebook tests the UAV quadcopter adaptation of the GR00T model.

**Key Components:**
- New UAV Embodiment Tag: `EmbodimentTag.UAV_QUADCOPTER`
- UAV State Space: 13D (position, orientation, velocity, battery, GPS)
- UAV Action Space: 9D (flight_control, velocity_command, gimbal)
- Leverage pretrained VLM, retrain only diffusion action head

In [1]:
# Import UAV-specific components
from gr00t.utils.misc import any_describe
from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.data.dataset import ModalityConfig
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.experiment.data_config import UAVQuadcopterDataConfig

# Test the new UAV embodiment tag
print("Available embodiment tags:")
for tag in EmbodimentTag:
    print(f"  {tag.name}: {tag.value}")

print(f"\nUAV Quadcopter tag: {EmbodimentTag.UAV_QUADCOPTER}")

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()
2025-07-22 07:46:15.503540: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-22 07:46:15.503601: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-22 07:46:15.505228: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
`use_fast` is set to `True` but the image processor class does not have a fast version.  Falling back to the slow version.


Available embodiment tags:
  GR1: gr1
  OXE_DROID: oxe_droid
  AGIBOT_GENIE1: agibot_genie1
  NEW_EMBODIMENT: new_embodiment
  UAV_QUADCOPTER: uav_quadcopter

UAV Quadcopter tag: EmbodimentTag.UAV_QUADCOPTER


In [2]:
import os
import numpy as np
import gr00t

# Test UAV data configuration
print("Testing UAV Quadcopter Data Configuration...")

# Initialize UAV config
uav_config = UAVQuadcopterDataConfig()
print("✓ UAV config initialized")

# Get modality configs
modality_configs = uav_config.modality_config()
print("✓ Modality configs generated")

print("\nUAV Modality Keys:")
for modality_type, config in modality_configs.items():
    print(f"  {modality_type}: {config.modality_keys}")

# Test transforms
transforms = uav_config.transform()
print("✓ UAV transforms configured")

print(f"\nTotal transforms: {len(transforms.transforms)}")
for i, transform in enumerate(transforms.transforms):
    print(f"  {i}: {transform.__class__.__name__}")

Testing UAV Quadcopter Data Configuration...
✓ UAV config initialized
✓ Modality configs generated

UAV Modality Keys:
  video: ['video.front_camera', 'video.gimbal_camera']
  state: ['state.position', 'state.orientation', 'state.velocity', 'state.battery', 'state.gps']
  action: ['action.flight_control', 'action.velocity_command', 'action.gimbal']
  language: ['annotation.human.task_description']
✓ UAV transforms configured

Total transforms: 10
  0: VideoToTensor
  1: VideoResize
  2: VideoColorJitter
  3: VideoToNumpy
  4: StateActionToTensor
  5: StateActionTransform
  6: StateActionToTensor
  7: StateActionTransform
  8: ConcatTransform
  9: GR00TTransform


In [3]:
# Test UAV-specific modality configurations
print("UAV State and Action Space Definition:")
print("=" * 50)

print("\nState Space (13D):")
state_keys = uav_config.state_keys
for i, key in enumerate(state_keys):
    if key == "state.position":
        print(f"  {key}: x, y, z (3 dims)")
    elif key == "state.orientation": 
        print(f"  {key}: roll, pitch, yaw (3 dims)")
    elif key == "state.velocity":
        print(f"  {key}: vx, vy, vz (3 dims)")
    elif key == "state.battery":
        print(f"  {key}: battery level (1 dim)")
    elif key == "state.gps":
        print(f"  {key}: lat, lon, alt (3 dims)")

print(f"\nTotal state dimensions: 13")

print("\nAction Space (9D):")
action_keys = uav_config.action_keys
for i, key in enumerate(action_keys):
    if key == "action.flight_control":
        print(f"  {key}: throttle, roll, pitch, yaw (4 dims)")
    elif key == "action.velocity_command":
        print(f"  {key}: vx, vy, vz (3 dims)")
    elif key == "action.gimbal":
        print(f"  {key}: gimbal_pitch, gimbal_yaw (2 dims)")

print(f"\nTotal action dimensions: 9")

print("\nVideo Inputs:")
for key in uav_config.video_keys:
    print(f"  {key}")

UAV State and Action Space Definition:

State Space (13D):
  state.position: x, y, z (3 dims)
  state.orientation: roll, pitch, yaw (3 dims)
  state.velocity: vx, vy, vz (3 dims)
  state.battery: battery level (1 dim)
  state.gps: lat, lon, alt (3 dims)

Total state dimensions: 13

Action Space (9D):
  action.flight_control: throttle, roll, pitch, yaw (4 dims)
  action.velocity_command: vx, vy, vz (3 dims)
  action.gimbal: gimbal_pitch, gimbal_yaw (2 dims)

Total action dimensions: 9

Video Inputs:
  video.front_camera
  video.gimbal_camera


In [4]:
# 3. gr00t embodiment tag
embodiment_tag = EmbodimentTag.UAV_QUADCOPTER

REPO_PATH = os.path.dirname(os.path.dirname(gr00t.__file__))
DATA_PATH = os.path.join(REPO_PATH, "demo_data", "uav.Landing")

# load the dataset
dataset = LeRobotSingleDataset(DATA_PATH, modality_configs,  embodiment_tag=embodiment_tag)

print('\n'*2)
print("="*100)
print(f"{' Humanoid Dataset ':=^100}")
print("="*100)

# print the 7th data point
resp = dataset[7]
any_describe(resp)
print(resp.keys())

Initialized dataset uav.Landing with EmbodimentTag.UAV_QUADCOPTER





KeyError: 'task_description'

show image frame within the data

In [None]:
# show img
import matplotlib.pyplot as plt

images_list = []

for i in range(100):
    if i % 10 == 0:
        resp = dataset[i]
        img = resp["video.ego_view"][0]
        images_list.append(img)


fig, axs = plt.subplots(2, 5, figsize=(20, 10))
for i, ax in enumerate(axs.flat):
    ax.imshow(images_list[i])
    ax.axis("off")
    ax.set_title(f"Image {i}")
plt.tight_layout() # adjust the subplots to fit into the figure area.
plt.show()


## Transform the data to LeRobot

In [None]:
from gr00t.data.transform.base import ComposedModalityTransform
from gr00t.data.transform import VideoToTensor, VideoCrop, VideoResize, VideoColorJitter, VideoToNumpy
from gr00t.data.transform.state_action import StateActionToTensor, StateActionTransform
from gr00t.data.transform.concat import ConcatTransform


video_modality = modality_configs["video"]
state_modality = modality_configs["state"]
action_modality = modality_configs["action"]

# select the transforms you want to apply to the data
to_apply_transforms = ComposedModalityTransform(
    transforms=[
        # video transforms
        VideoToTensor(apply_to=video_modality.modality_keys),
        VideoCrop(apply_to=video_modality.modality_keys, scale=0.95),
        VideoResize(apply_to=video_modality.modality_keys, height=224, width=224, interpolation="linear"),
        VideoColorJitter(apply_to=video_modality.modality_keys, brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08),
        VideoToNumpy(apply_to=video_modality.modality_keys),

        # state transforms
        StateActionToTensor(apply_to=state_modality.modality_keys),
        StateActionTransform(apply_to=state_modality.modality_keys, normalization_modes={
            key: "min_max" for key in state_modality.modality_keys
        }),

        # action transforms
        StateActionToTensor(apply_to=action_modality.modality_keys),
        StateActionTransform(apply_to=action_modality.modality_keys, normalization_modes={
            key: "min_max" for key in action_modality.modality_keys
        }),

        # ConcatTransform
        ConcatTransform(
            video_concat_order=video_modality.modality_keys,
            state_concat_order=state_modality.modality_keys,
            action_concat_order=action_modality.modality_keys,
        ),
    ]
)

In [None]:
dataset = LeRobotSingleDataset(
    DATA_PATH,
    modality_configs,
    transforms=to_apply_transforms,
    embodiment_tag=embodiment_tag
)

# print the 7th data point
resp = dataset[7]
any_describe(resp)
print(resp.keys())

In [None]:
# Create mock UAV observation for testing
print("\nCreating Mock UAV Observation:")
print("=" * 40)

# Mock UAV observation data
mock_uav_observation = {
    # Video feeds (would be actual camera data)
    "video.front_camera": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
    "video.gimbal_camera": np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8),
    
    # UAV telemetry state (13D total)
    "state.position": np.array([10.5, 5.2, 15.0], dtype=np.float32),  # x, y, z
    "state.orientation": np.array([0.1, -0.05, 1.57], dtype=np.float32),  # roll, pitch, yaw
    "state.velocity": np.array([2.0, 0.5, -0.1], dtype=np.float32),  # vx, vy, vz
    "state.battery": np.array([85.5], dtype=np.float32),  # battery %
    "state.gps": np.array([37.7749, -122.4194, 100.0], dtype=np.float32),  # lat, lon, alt
    
    # Task description
    "annotation.human.task_description": "Navigate to landing zone and perform precision landing"
}

print("Mock observation keys:")
for key, value in mock_uav_observation.items():
    if hasattr(value, 'shape'):
        print(f"  {key}: shape {value.shape}, dtype {value.dtype}")
    else:
        print(f"  {key}: {type(value)}")

# Verify state dimensionality
state_dims = 0
for key in uav_config.state_keys:
    if key in mock_uav_observation:
        state_dims += len(mock_uav_observation[key])

print(f"\nTotal state dimensions: {state_dims}")
print("✓ UAV observation structure validated")

In [None]:
# Demonstrate UAV action space
print("\nUAV Action Space Demonstration:")
print("=" * 40)

# Mock UAV action (9D total)
mock_uav_action = {
    "action.flight_control": np.array([0.6, 0.1, -0.05, 0.2], dtype=np.float32),  # throttle, roll, pitch, yaw
    "action.velocity_command": np.array([1.5, 0.3, -0.2], dtype=np.float32),      # vx, vy, vz
    "action.gimbal": np.array([-0.1, 0.4], dtype=np.float32)                     # gimbal_pitch, gimbal_yaw
}

print("Mock UAV action:")
for key, value in mock_uav_action.items():
    print(f"  {key}: {value}")

# Verify action dimensionality  
action_dims = 0
for key in uav_config.action_keys:
    if key in mock_uav_action:
        action_dims += len(mock_uav_action[key])

print(f"\nTotal action dimensions: {action_dims}")

# Interpret action components
print("\nAction Interpretation:")
flight_control = mock_uav_action["action.flight_control"]
print(f"  Flight Control:")
print(f"    Throttle: {flight_control[0]:.3f} (0=min, 1=max thrust)")
print(f"    Roll:     {flight_control[1]:.3f} (rad, + = right roll)")  
print(f"    Pitch:    {flight_control[2]:.3f} (rad, + = nose up)")
print(f"    Yaw:      {flight_control[3]:.3f} (rad, + = clockwise)")

velocity_cmd = mock_uav_action["action.velocity_command"]
print(f"  Velocity Command:")
print(f"    Vx: {velocity_cmd[0]:.3f} m/s (forward)")
print(f"    Vy: {velocity_cmd[1]:.3f} m/s (right)")
print(f"    Vz: {velocity_cmd[2]:.3f} m/s (down)")

gimbal = mock_uav_action["action.gimbal"]
print(f"  Gimbal Control:")
print(f"    Pitch: {gimbal[0]:.3f} rad (+ = tilt down)")
print(f"    Yaw:   {gimbal[1]:.3f} rad (+ = pan right)")

print("✓ UAV action space validated")

In [None]:
# Generate synthetic UAV dataset for testing
print("🚁 Generating Synthetic UAV Landing Dataset")
print("=" * 50)

import sys
import subprocess
import os

# Generate a small test dataset
dataset_path = "./demo_data/uav.Landing"
num_episodes = 10  # Small dataset for testing

print(f"Generating {num_episodes} episodes of UAV landing data...")
print(f"Output path: {dataset_path}")

# Create the data generator inline for notebook use
import cv2
import pandas as pd
from datetime import datetime
from pathlib import Path

class SimpleUAVDataGenerator:
    """Simplified UAV data generator for notebook use."""
    
    def __init__(self, output_dir: str):
        self.output_dir = Path(output_dir)
        self.image_size = (480, 640, 3)
        self.fps = 20
        self.episode_length = 100  # 5 seconds
        self.setup_directories()
    
    def setup_directories(self):
        self.output_dir.mkdir(parents=True, exist_ok=True)
        (self.output_dir / "data").mkdir(exist_ok=True)
        (self.output_dir / "meta").mkdir(exist_ok=True)
        (self.output_dir / "videos").mkdir(exist_ok=True)
    
    def generate_simple_episode(self, episode_idx: int):
        """Generate a simple landing episode."""
        # Simple linear descent trajectory
        positions = []
        velocities = []
        orientations = []
        
        # Start at altitude, land at origin
        start_alt = 20.0
        for i in range(self.episode_length):
            progress = i / self.episode_length
            
            # Linear descent
            pos = np.array([
                10 * (1 - progress),  # x: move towards origin
                5 * (1 - progress),   # y: move towards origin  
                start_alt * (1 - progress)  # z: descend to ground
            ])
            positions.append(pos)
            
            # Simple velocity
            vel = np.array([-0.1, -0.05, -0.2])
            velocities.append(vel)
            
            # Level orientation with small variations
            orientation = np.array([
                np.random.normal(0, 0.05),  # roll
                np.random.normal(0, 0.05),  # pitch
                0.0  # yaw
            ])
            orientations.append(orientation)
        
        # Generate simple actions
        actions = []
        for i in range(self.episode_length):
            # Simple proportional control
            throttle = 0.4 + positions[i][2] * 0.01  # More throttle at altitude
            action = np.concatenate([
                [throttle, 0.0, 0.0, 0.0],  # flight_control
                velocities[i],               # velocity_command
                [0.0, 0.0]                  # gimbal
            ])
            actions.append(action)
        
        # Create simple state array
        states = []
        for i in range(self.episode_length):
            battery = 100 - i * 0.5  # Battery decreases
            gps = np.array([37.7749, -122.4194, positions[i][2]])  # SF coordinates
            
            state = np.concatenate([
                positions[i],     # position (3)
                orientations[i],  # orientation (3)
                velocities[i],    # velocity (3)
                [battery],        # battery (1)
                gps              # gps (3)
            ])
            states.append(state)
        
        return {
            'states': np.array(states),
            'actions': np.array(actions),
            'positions': np.array(positions)
        }
    
    def save_simple_episode(self, episode_data, episode_idx):
        """Save episode in minimal LeRobot format."""
        chunk_dir = self.output_dir / "data" / f"chunk-{episode_idx:03d}"
        chunk_dir.mkdir(exist_ok=True)
        
        # Save states
        obs_df = pd.DataFrame({
            'observation.state': [state.tolist() for state in episode_data['states']]
        })
        obs_df.to_parquet(chunk_dir / "observation.state.parquet")
        
        # Save actions
        action_df = pd.DataFrame({
            'action': [action.tolist() for action in episode_data['actions']]
        })
        action_df.to_parquet(chunk_dir / "action.parquet")
        
        print(f"  ✓ Episode {episode_idx} data saved")
    
    def create_metadata(self, num_episodes):
        """Create minimal metadata files."""
        # modality.json
        modality = {
            "state": {
                "position": {"start": 0, "end": 3},
                "orientation": {"start": 3, "end": 6, "rotation_type": "euler_angles"},
                "velocity": {"start": 6, "end": 9},
                "battery": {"start": 9, "end": 10},
                "gps": {"start": 10, "end": 13}
            },
            "action": {
                "flight_control": {"start": 0, "end": 4},
                "velocity_command": {"start": 4, "end": 7},
                "gimbal": {"start": 7, "end": 9}
            },
            "video": {
                "front_camera": {"original_key": "observation.images.front_camera"},
                "gimbal_camera": {"original_key": "observation.images.gimbal_camera"}
            },
            "annotation": {
                "human.task_description": {"original_key": "task_description"}
            }
        }
        
        with open(self.output_dir / "meta" / "modality.json", 'w') as f:
            json.dump(modality, f, indent=2)
        
        # info.json
        info = {
            "data_name": "uav_landing_test", 
            "fps": self.fps,
            "video": False  # No video for simple test
        }
        
        with open(self.output_dir / "meta" / "info.json", 'w') as f:
            json.dump(info, f, indent=2)
        
        # tasks.jsonl
        with open(self.output_dir / "meta" / "tasks.jsonl", 'w') as f:
            for i in range(num_episodes):
                task = {"task_index": i, "task_description": "Land on the designated platform"}
                f.write(json.dumps(task) + '\n')
        
        print("  ✓ Metadata files created")

# Generate the test dataset
generator = SimpleUAVDataGenerator(dataset_path)

print("Generating episodes...")
for i in range(num_episodes):
    episode_data = generator.generate_simple_episode(i)
    generator.save_simple_episode(episode_data, i)

generator.create_metadata(num_episodes)

print(f"\n✅ UAV test dataset generated successfully!")
print(f"Location: {dataset_path}")
print(f"Episodes: {num_episodes}")
print(f"Format: LeRobot compatible")
print(f"\nDataset structure:")
print(f"  📁 {dataset_path}/")
print(f"    📁 data/ - Episode data (states, actions)")
print(f"    📁 meta/ - Metadata (modality.json, info.json, tasks.jsonl)")
print(f"    📄 Contains 13D state space and 9D action space")

In [None]:
# Test loading the generated UAV dataset
print("🔍 Testing Generated UAV Dataset")
print("=" * 50)

try:
    # Load the generated UAV dataset
    uav_dataset = LeRobotSingleDataset(
        dataset_path=dataset_path,
        modality_configs=modality_configs,
        embodiment_tag=EmbodimentTag.UAV_QUADCOPTER,
        transforms=transforms,
        video_backend="torchvision_av",
    )
    
    print(f"✅ Dataset loaded successfully!")
    print(f"   Episodes: {len(uav_dataset)}")
    print(f"   Embodiment: {EmbodimentTag.UAV_QUADCOPTER}")
    
    # Test accessing data
    sample = uav_dataset[0]
    print(f"\n📊 Sample Data Structure:")
    for key in sorted(sample.keys()):
        if hasattr(sample[key], 'shape'):
            print(f"   {key}: shape {sample[key].shape}, dtype {sample[key].dtype}")
        else:
            print(f"   {key}: {type(sample[key])}")
    
    # Verify state and action dimensions
    print(f"\n🔢 Dimension Verification:")
    
    # Check state dimensions (should be 13D)
    state_data = sample.get('state', None)
    if state_data is not None:
        state_shape = state_data.shape[-1] if len(state_data.shape) > 1 else len(state_data)
        print(f"   State dimensions: {state_shape} (expected: 13)")
        if state_shape == 13:
            print("   ✅ State space correct!")
        else:
            print("   ❌ State space mismatch!")
    
    # Check action dimensions (should be 9D)
    action_data = sample.get('action', None) 
    if action_data is not None:
        action_shape = action_data.shape[-1] if len(action_data.shape) > 1 else len(action_data)
        print(f"   Action dimensions: {action_shape} (expected: 9)")
        if action_shape == 9:
            print("   ✅ Action space correct!")
        else:
            print("   ❌ Action space mismatch!")
    
    # Test multiple episodes
    print(f"\n📈 Testing Multiple Episodes:")
    for i in range(min(3, len(uav_dataset))):
        episode = uav_dataset[i]
        print(f"   Episode {i}: {len(episode)} keys")
    
    print(f"\n🎯 Dataset Ready for UAV Training!")
    print(f"   Use this dataset with:")
    print(f"   - Fine-tuning: python scripts/uav_finetune.py --data_path {dataset_path}")
    print(f"   - Evaluation: python getting_started/examples/eval_uav_quadcopter.py")
    
except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    print("This might be due to missing dependencies or data format issues")
    print("Dataset structure has been created, but full loading may require additional setup")

In [None]:
# Test UAV model configuration
print("🤖 Testing UAV Model Configuration")
print("=" * 50)

try:
    from gr00t.model.gr00t_n1 import GR00TN1Policy
    from gr00t.experiment.data_config import get_data_config
    
    # Get UAV data configuration
    uav_config = get_data_config("uav_quadcopter")
    print(f"✅ UAV Data Config loaded:")
    print(f"   State keys: {uav_config.get_state_keys()}")
    print(f"   Action keys: {uav_config.get_action_keys()}")
    print(f"   State dim: {uav_config.get_state_dim()}")
    print(f"   Action dim: {uav_config.get_action_dim()}")
    
    # Test model initialization (this would typically require more setup)
    print(f"\n🔧 Model Architecture Info:")
    print(f"   Action space: {uav_config.get_action_dim()}D")
    print(f"   State space: {uav_config.get_state_dim()}D")
    print(f"   Embodiment: {EmbodimentTag.UAV_QUADCOPTER}")
    
    # Verify action and state mappings
    print(f"\n🎯 State-Action Mappings:")
    print(f"   Position (x,y,z): indices 0-2")
    print(f"   Orientation (roll,pitch,yaw): indices 3-5") 
    print(f"   Velocity (vx,vy,vz): indices 6-8")
    print(f"   Battery level: index 9")
    print(f"   GPS (lat,lon,alt): indices 10-12")
    print(f"   ")
    print(f"   Flight control (throttle,roll,pitch,yaw): indices 0-3")
    print(f"   Velocity command (vx,vy,vz): indices 4-6")
    print(f"   Gimbal control (pan,tilt): indices 7-8")
    
    print(f"\n✅ UAV Configuration Complete!")
    print(f"Ready for training with frozen VLM approach:")
    print(f"  1. VLM backbone: FROZEN (pretrained vision-language understanding)")
    print(f"  2. Action head: TRAINABLE (UAV-specific diffusion)")
    print(f"  3. Dataset: LeRobot compatible format")

except ImportError as e:
    print(f"⚠️  Import error: {e}")
    print("Some GR00T modules may not be available in this environment")
    print("This is expected if running outside the full GR00T environment")
except Exception as e:
    print(f"❌ Error: {e}")
    print("Configuration test encountered an issue")