Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SCM syntax highlighting & preventing 3-way merges
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
35 changes: 20 additions & 15 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
name: Testing # Skips RL tests because stable-baselines3 comes with a lot of heavy-weight dependencies

on: [push]
name: Testing
on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: mamba-org/setup-micromamba@v1

- name: Setup Pixi (installs pixi + caches envs) # https://github.com/marketplace/actions/setup-pixi
uses: prefix-dev/setup-pixi@v0.9.0 # pin the action version
with:
micromamba-version: '2.0.2-1' # any version from https://github.com/mamba-org/micromamba-releases
environment-name: test-env
init-shell: bash
create-args: python=3.11
cache-environment: true
- name: Install dependencies and package
run: pip install .[test]
shell: micromamba-shell {0}
- name: Test with pytest
run: pytest tests --cov=crazyflow
shell: micromamba-shell {0}
pixi-version: v0.49.0 # pin the pixi binary version (optional)
cache: true # enable caching of installed envs
# only write new caches on main pushes (TODO: Enable)
# cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
# ensure the 'test' environment(s) are installed
environments: test
# don't activate env (we'll call pixi run -e test explicitly)
activate-environment: false
# prefer using existing lockfile if present (faster, deterministic)
locked: true

- name: Verify pixi and run tests
run: |
pixi --version
pixi run -e test pytest
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@ build
**/*.pt
tutorials/ppo/wandb
dist
benchmark/data
benchmark/data
# pixi environments
.pixi
*.egg-info
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

--------------------------------------------------------------------------------

Fast, parallelizable simulations of Crazyflies with JAX and MuJoCo.
Fast, parallelizable simulations of Crazyflies with JAX.

[![Python Version]][Python Version URL] [![Ruff Check]][Ruff Check URL] [![Documentation Status]][Documentation Status URL] [![Tests]][Tests URL]

Expand Down Expand Up @@ -32,7 +32,6 @@ The simulation is built as a pipeline of functions that are composed at initiali
Multiple physics models are supported:
- analytical: A first-principles model based on physical equations
- sys_id: A system-identified model trained on real drone data
- mujoco: MuJoCo physics engine for more complex interactions

#### Control Modes
Different control interfaces are available:
Expand All @@ -41,9 +40,10 @@ Different control interfaces are available:
- thrust: Low-level control of individual motor thrusts

#### Integration Methods
For analytical and system-identified physics:
We support multiple integration schemes for additional precision:
- euler: Simple first-order integration
- rk4: Fourth-order Runge-Kutta integration for higher accuracy
- symplectic\_euler: Symplectic integration for conservation of energy

### Parallelization
Crazyflow supports massive parallelization across:
Expand All @@ -58,6 +58,12 @@ The framework supports domain randomization through the crazyflow/randomize modu
### Functional Design
The simulation follows a functional programming paradigm: All state is contained in immutable data structures. Updates create new states rather than modifying existing ones. All functions are pure, enabling JAX's transformations (JIT, grad, vmap) and thus automatic differentiation through the entire simulation, making it suitable for gradient-based optimization and reinforcement learning.

### Contacts and Non-Drone Models
We focus on drones dynamics in free-space flight. Consequently, no models other than drones are available in the simulation and contact dynamics with external objects are not considered. However, we use MuJoCo for contact detection and visualization. Users can load their own objects into the simulation by changing the MuJoCo world spec. Drone collisions with these objects will be detected during collision checks, but they won't have an effect on the dynamics (i.e. drones will pass through objects). Similarly, the objects themselves will be static.

### Visualization
We use `gymnasium`'s MuJoCo renderer and synchronize the simulation data with MuJoCo to either render an interactive UI or RGB arrays.

## Examples
The repository includes several example scripts demonstrating different capabilities:
| Example | Description |
Expand Down
203 changes: 112 additions & 91 deletions benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from datetime import datetime
from pathlib import Path

import fire
import gymnasium
import jax
import jax.numpy as jnp
import numpy as np
from jax.errors import JaxRuntimeError
from ml_collections import config_dict

import crazyflow # noqa: F401, ensure gymnasium envs are registered
Expand Down Expand Up @@ -41,15 +43,15 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float


def profile_gym_env_step(
sim_config: config_dict.ConfigDict, n_steps: int, device: str
sim_config: config_dict.ConfigDict, n_steps: int, device: str, print_summary: bool = True
) -> list[float]:
"""Profile the Crazyflow gym environment step performance."""
times = []
device = jax.devices(device)[0]

envs = gymnasium.make_vec(
"DroneReachPos-v0",
time_horizon_in_seconds=3,
max_episode_time=3,
num_envs=sim_config.n_worlds,
device=sim_config.device,
freq=sim_config.freq,
Expand All @@ -60,7 +62,7 @@ def profile_gym_env_step(
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
action[..., 0] = 0.3
# Step through env once to ensure JIT compilation
envs.reset(seed=42)
envs.reset()
envs.step(action)

jax.block_until_ready(envs.unwrapped.sim.data) # Ensure JIT compiled dynamics
Expand All @@ -73,12 +75,15 @@ def profile_gym_env_step(
times.append(time.perf_counter() - tstart)

envs.close()
print("Gym env step performance:")
analyze_timings(times, n_steps, envs.unwrapped.sim.n_worlds, envs.unwrapped.sim.freq)
if print_summary:
print("Gym env step performance:")
analyze_timings(times, n_steps, envs.unwrapped.sim.n_worlds, sim_config.freq)
return times


def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str) -> list[float]:
def profile_step(
sim_config: config_dict.ConfigDict, n_steps: int, device: str, print_summary: bool = True
) -> list[float]:
"""Profile the Crazyflow simulator step performance."""
sim = Sim(**sim_config)
times = []
Expand All @@ -99,8 +104,9 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
jax.block_until_ready(sim.data)
times.append(time.perf_counter() - tstart)

print("Sim step performance:")
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
if print_summary:
print("Sim step performance:")
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
return times


Expand Down Expand Up @@ -140,9 +146,8 @@ def profile_reset(sim_config: config_dict.ConfigDict, n_steps: int, device: str)
analyze_timings(times_masked, n_steps, sim.n_worlds, sim.freq)


def main():
def main(device: str = "cpu", n_worlds_exp: int = 6):
"""Main entry point for profiling."""
device = "cpu"
sim_config = config_dict.ConfigDict()
sim_config.n_worlds = 1
sim_config.n_drones = 1
Expand Down Expand Up @@ -181,93 +186,109 @@ def main():
# Reopen the file in append mode for each result

n_steps = 1000
skip_sim, skip_gym = False, False
# Test with increasing number of parallel environments (worlds)
for n_worlds in [1, 10, 100, 1000, 10000, 100000, 1000000]:
print(f"\nTesting with {n_worlds} parallel environments:")
for n_worlds in [10**i for i in range(n_worlds_exp + 1)]:
sim_config.n_worlds = n_worlds
print("-" * 80)
if not skip_sim:
# Test with a single step first to see if we should continue
sim_config.freq = 500 # Test sim at 500 hz
single_step_time = profile_step(sim_config, 2, device, print_summary=False)[1]

# If single step takes too long, skip this and remaining tests
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
print(
f" Skipping benchmark for {n_worlds} and higher - projected time "
f"{single_step_time * n_steps:.2f}s (> 1m)"
)
skip_sim = True

if not skip_sim:
# Configure simulator
print(f"Running simulator benchmark ({n_worlds} worlds)...")
# Run simulator benchmark using existing function
times_sim = profile_step(sim_config, n_steps, device)

# Calculate metrics for CSV
total_time = sum(times_sim)
avg_step_time = np.mean(times_sim)
n_frames = n_steps * n_worlds
fps = n_frames / total_time
real_time_factor = (n_steps / sim_config.freq) * n_worlds / total_time

# Save simulator results
# Reopen CSV writer in append mode
with open(csv_file, "a", newline="") as f:
csv_writer = csv.writer(f)
csv_writer.writerow(
[
"simulator",
1, # n_drones
n_worlds,
n_steps,
total_time,
avg_step_time,
fps,
real_time_factor,
sim_config.device,
]
)
f.flush()

if not skip_gym:
print(f"Running gym environment benchmark ({n_worlds} worlds)...")
# Run gym environment benchmark using existing function
sim_config.freq = 50 # Test gym at 50 hz
try:
step_times = profile_gym_env_step(sim_config, 2, device, print_summary=False)
single_step_time = step_times[1]
# If single step takes too long, skip this test only
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
print(
f" Skipping benchmark for {n_worlds} - projected time "
f"{single_step_time * n_steps:.2f}s (> 1m)"
)
skip_gym = True
except JaxRuntimeError:
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
skip_gym = True

# Test with a single step first to see if we should continue
sim_config.freq = 500 # Test sim at 500 hz
test_times = profile_step(sim_config, 1, device)

single_step_time = test_times[0]
# If single step takes too long, skip this and remaining tests
if single_step_time > max_seconds_per_run / n_steps: # threshold for the tests
print(
f" Skipping benchmark for {n_worlds} and higher - single step took "
f"{single_step_time * 1000:.2f}s (> 1m)"
)
break

# Configure simulator
print(f" Running simulator benchmark ({n_worlds} worlds)...")
# Run simulator benchmark using existing function
times_sim = profile_step(sim_config, n_steps, device)

# Calculate metrics for CSV
total_time = sum(times_sim)
avg_step_time = np.mean(times_sim)
n_frames = n_steps * n_worlds
fps = n_frames / total_time
real_time_factor = (n_steps / sim_config.freq) * n_worlds / total_time

# Save simulator results
# Reopen CSV writer in append mode
with open(csv_file, "w", newline="") as f:
csv_writer = csv.writer(f)
csv_writer.writerow(
[
"simulator",
1, # n_drones
n_worlds,
n_steps,
total_time,
avg_step_time,
fps,
real_time_factor,
sim_config.device,
]
)
f.flush()

print(f" Running gym environment benchmark ({n_worlds} worlds)...")
# Run gym environment benchmark using existing function
sim_config.freq = 50 # Test gym at 50 hz
try:
times_gym = profile_gym_env_step(sim_config, n_steps, device)
except ValueError as e:
if "RESOURCE_EXHAUSTED" in str(e):
if not skip_gym:
try:
times_gym = profile_gym_env_step(sim_config, n_steps, device)
except JaxRuntimeError:
print(f" Skipping benchmark for {n_worlds} - resource exhausted")
continue # Only continue, we might still be able to benchmark sim
raise e

# Calculate metrics for CSV
total_time = sum(times_gym)
avg_step_time = np.mean(times_gym)
n_frames = n_steps * n_worlds
fps = n_frames / total_time
real_time_factor = (n_steps / sim_config.freq) * sim_config.n_worlds / total_time

# Save gym environment results
with open(csv_file, "a", newline="") as f:
csv_writer = csv.writer(f)
csv_writer.writerow(
[
"gym_env",
sim_config.n_drones,
sim_config.n_worlds,
n_steps,
total_time,
avg_step_time,
fps,
real_time_factor,
sim_config.device,
]
)
f.flush()
skip_gym = True
continue

# Calculate metrics for CSV
total_time = sum(times_gym)
avg_step_time = np.mean(times_gym)
n_frames = n_steps * n_worlds
fps = n_frames / total_time
real_time_factor = (n_steps / sim_config.freq) * sim_config.n_worlds / total_time

# Save gym environment results
with open(csv_file, "a", newline="") as f:
csv_writer = csv.writer(f)
csv_writer.writerow(
[
"gym_env",
sim_config.n_drones,
sim_config.n_worlds,
n_steps,
total_time,
avg_step_time,
fps,
real_time_factor,
sim_config.device,
]
)
f.flush()

print(f"\nBenchmark results saved to {csv_file}")


if __name__ == "__main__":
main()
fire.Fire(main)
4 changes: 2 additions & 2 deletions benchmark/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from crazyflow.sim import Sim

if TYPE_CHECKING:
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal
from crazyflow.envs import ReachPosEnv


def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
Expand Down Expand Up @@ -44,7 +44,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
device = jax.devices(device)[0]

envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
envs: ReachPosEnv = gymnasium.make_vec(
"DroneReachPos-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config
)

Expand Down
Loading