Skip to content

Commit

Permalink
Improve or-with disjoint checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mvlabat committed Jan 3, 2023
1 parent b44b606 commit bb8359f
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 7 deletions.
1 change: 1 addition & 0 deletions crates/bevy_ecs/Cargo.toml
Expand Up @@ -27,6 +27,7 @@ fixedbitset = "0.4"
fxhash = "0.2"
downcast-rs = "1.2"
serde = { version = "1", features = ["derive"] }
smallvec = "1.6"

[dev-dependencies]
rand = "0.8"
Expand Down
92 changes: 86 additions & 6 deletions crates/bevy_ecs/src/query/access.rs
Expand Up @@ -2,6 +2,7 @@ use crate::storage::SparseSetIndex;
use bevy_utils::HashSet;
use core::fmt;
use fixedbitset::FixedBitSet;
use smallvec::SmallVec;
use std::marker::PhantomData;

/// A wrapper struct to make Debug representations of [`FixedBitSet`] easier
Expand All @@ -25,6 +26,7 @@ struct FormattedBitSet<'a, T: SparseSetIndex> {
bit_set: &'a FixedBitSet,
_marker: PhantomData<T>,
}

impl<'a, T: SparseSetIndex> FormattedBitSet<'a, T> {
fn new(bit_set: &'a FixedBitSet) -> Self {
Self {
Expand All @@ -33,6 +35,7 @@ impl<'a, T: SparseSetIndex> FormattedBitSet<'a, T> {
}
}
}

impl<'a, T: SparseSetIndex + fmt::Debug> fmt::Debug for FormattedBitSet<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
Expand All @@ -41,6 +44,28 @@ impl<'a, T: SparseSetIndex + fmt::Debug> fmt::Debug for FormattedBitSet<'a, T> {
}
}

struct FormattedExpandedOrWithAccesses<'a, T: SparseSetIndex> {
with: &'a ExpandedOrWithAccesses,
_marker: PhantomData<T>,
}

impl<'a, T: SparseSetIndex> FormattedExpandedOrWithAccesses<'a, T> {
fn new(with: &'a ExpandedOrWithAccesses) -> Self {
Self {
with,
_marker: PhantomData,
}
}
}

impl<'a, T: SparseSetIndex + fmt::Debug> fmt::Debug for FormattedExpandedOrWithAccesses<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entries(self.with.arr.iter().map(FormattedBitSet::<T>::new))
.finish()
}
}

/// Tracks read and write access to specific elements in a collection.
///
/// Used internally to ensure soundness during system initialization and execution.
Expand Down Expand Up @@ -69,6 +94,7 @@ impl<T: SparseSetIndex + fmt::Debug> fmt::Debug for Access<T> {
.finish()
}
}

impl<T: SparseSetIndex> Default for Access<T> {
fn default() -> Self {
Self::new()
Expand Down Expand Up @@ -213,20 +239,24 @@ impl<T: SparseSetIndex> Access<T> {
/// is read/write `T`, read `U`. It must still have a read `U` access otherwise the following
/// queries would be incorrectly considered disjoint:
/// - `Query<&mut T>` read/write `T`
/// - `Query<Option<&T>` accesses nothing
/// - `Query<Option<&T>>` accesses nothing
///
/// See comments the `WorldQuery` impls of `AnyOf`/`Option`/`Or` for more information.
#[derive(Clone, Eq, PartialEq)]
pub struct FilteredAccess<T: SparseSetIndex> {
access: Access<T>,
with: FixedBitSet,
with: ExpandedOrWithAccesses,
without: FixedBitSet,
}

impl<T: SparseSetIndex + fmt::Debug> fmt::Debug for FilteredAccess<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FilteredAccess")
.field("access", &self.access)
.field("with", &FormattedBitSet::<T>::new(&self.with))
.field(
"with",
&FormattedExpandedOrWithAccesses::<T>::new(&self.with),
)
.field("without", &FormattedBitSet::<T>::new(&self.without))
.finish()
}
Expand Down Expand Up @@ -277,8 +307,7 @@ impl<T: SparseSetIndex> FilteredAccess<T> {

/// Retains only combinations where the element given by `index` is also present.
pub fn add_with(&mut self, index: T) {
self.with.grow(index.sparse_set_index() + 1);
self.with.insert(index.sparse_set_index());
self.with.add(index.sparse_set_index());
}

/// Retains only combinations where the element given by `index` is not present.
Expand All @@ -289,7 +318,7 @@ impl<T: SparseSetIndex> FilteredAccess<T> {

pub fn extend_intersect_filter(&mut self, other: &FilteredAccess<T>) {
self.without.intersect_with(&other.without);
self.with.intersect_with(&other.with);
self.with.extend_with_or(&other.with);
}

pub fn extend_access(&mut self, other: &FilteredAccess<T>) {
Expand Down Expand Up @@ -325,6 +354,57 @@ impl<T: SparseSetIndex> FilteredAccess<T> {
}
}

// A struct to express something like `Or<(With<A>, With<B>)>`.
// Filters like `(With<A>, Or<(With<B>, With<C>)>` are expanded into `Or<(With<(A, B)>, With<(B, C)>)>`.
#[derive(Clone, Eq, PartialEq)]
struct ExpandedOrWithAccesses {
arr: SmallVec<[FixedBitSet; 8]>,
}

impl Default for ExpandedOrWithAccesses {
fn default() -> Self {
Self {
arr: smallvec::smallvec![FixedBitSet::default()],
}
}
}

impl ExpandedOrWithAccesses {
fn add(&mut self, index: usize) {
for with in &mut self.arr {
with.grow(index + 1);
with.insert(index);
}
}

fn extend_with_or(&mut self, other: &ExpandedOrWithAccesses) {
self.arr.append(&mut other.arr.clone());
}

fn is_disjoint(&self, without: &FixedBitSet) -> bool {
self.arr.iter().any(|with| with.is_disjoint(without))
}

fn union_with(&mut self, other: &Self) {
if other.arr.len() == 1 {
for with in &mut self.arr {
with.union_with(&other.arr[0]);
}
return;
}

let mut new_with = SmallVec::with_capacity(self.arr.len() * other.arr.len());
for with in &self.arr {
for other_with in &other.arr {
let mut w = with.clone();
w.union_with(other_with);
new_with.push(w);
}
}
self.arr = new_with;
}
}

/// A collection of [`FilteredAccess`] instances.
///
/// Used internally to statically check if systems have conflicting access.
Expand Down
11 changes: 11 additions & 0 deletions crates/bevy_ecs/src/system/mod.rs
Expand Up @@ -390,6 +390,17 @@ mod tests {
run_system(&mut world, sys);
}

#[test]
fn or_has_filter_with() {
fn sys(
_: Query<&mut C, Or<(With<A>, With<B>)>>,
_: Query<&mut C, (Without<A>, Without<B>)>,
) {
}
let mut world = World::default();
run_system(&mut world, sys);
}

#[test]
fn or_doesnt_remove_unrelated_filter_with() {
fn sys(_: Query<&mut B, (Or<(With<A>, With<B>)>, With<A>)>, _: Query<&mut B, Without<A>>) {}
Expand Down
12 changes: 11 additions & 1 deletion examples/ecs/system_param.rs
Expand Up @@ -13,6 +13,12 @@ fn main() {
#[derive(Component)]
pub struct Player;

#[derive(Component)]
pub struct A;

#[derive(Component)]
pub struct B;

#[derive(Resource)]
pub struct PlayerCount(usize);

Expand Down Expand Up @@ -40,7 +46,11 @@ fn spawn(mut commands: Commands) {
}

/// The [`SystemParam`] can be used directly in a system argument.
fn count_players(mut counter: PlayerCounter) {
fn count_players(
mut counter: PlayerCounter,
q1: Query<&mut Transform, AnyOf<(&A, &B)>>,
q2: Query<&Transform, (Without<A>, Without<B>)>,
) {
counter.count();

println!("{} players in the game", counter.count.0);
Expand Down

0 comments on commit bb8359f

Please sign in to comment.