Skip to content

Commit

Permalink
Add global init and get accessors for all newtyped TaskPools (#2250)
Browse files Browse the repository at this point in the history
Right now, a direct reference to the target TaskPool is required to launch tasks on the pools, despite the three newtyped pools (AsyncComputeTaskPool, ComputeTaskPool, and IoTaskPool) effectively acting as global instances. The need to pass a TaskPool reference adds notable friction to spawning subtasks within existing tasks. Possible use cases for this may include chaining tasks within the same pool like spawning separate send/receive I/O tasks after waiting on a network connection to be established, or allowing cross-pool dependent tasks like starting dependent multi-frame computations following a long I/O load. 

Other task execution runtimes provide static access to spawning tasks (i.e. `tokio::spawn`), which is notably easier to use than the reference passing required by `bevy_tasks` right now.

This PR makes does the following:

 * Adds `*TaskPool::init` which initializes a `OnceCell`'ed with a provided TaskPool. Failing if the pool has already been initialized.
 * Adds `*TaskPool::get` which fetches the initialized global pool of the respective type or panics. This generally should not be an issue in normal Bevy use, as the pools are initialized before they are accessed.
 * Updated default task pool initialization to either pull the global handles and save them as resources, or if they are already initialized, pull the a cloned global handle as the resource.

This should make it notably easier to build more complex task hierarchies for dependent tasks. It should also make writing bevy-adjacent, but not strictly bevy-only plugin crates easier, as the global pools ensure it's all running on the same threads.

One alternative considered is keeping a thread-local reference to the pool for all threads in each pool to enable the same `tokio::spawn` interface. This would spawn tasks on the same pool that a task is currently running in. However this potentially leads to potential footgun situations where long running blocking tasks run on `ComputeTaskPool`.
  • Loading branch information
james7132 committed Jun 9, 2022
1 parent 5ace79f commit 012ae07
Show file tree
Hide file tree
Showing 19 changed files with 226 additions and 219 deletions.
3 changes: 2 additions & 1 deletion benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub struct Benchmark(World, Box<dyn System<In = (), Out = ()>>);

impl Benchmark {
pub fn new() -> Self {
ComputeTaskPool::init(TaskPool::default);

let mut world = World::default();

world.spawn_batch((0..1000).map(|_| {
Expand All @@ -39,7 +41,6 @@ impl Benchmark {
});
}

world.insert_resource(ComputeTaskPool(TaskPool::default()));
let mut system = IntoSystem::into_system(sys);
system.initialize(&mut world);
system.update_archetype_component_access(&world);
Expand Down
12 changes: 1 addition & 11 deletions crates/bevy_app/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use bevy_ecs::{
system::Resource,
world::World,
};
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool};
use bevy_utils::{tracing::debug, HashMap};
use std::fmt::Debug;

Expand Down Expand Up @@ -863,18 +862,9 @@ impl App {
pub fn add_sub_app(
&mut self,
label: impl AppLabel,
mut app: App,
app: App,
sub_app_runner: impl Fn(&mut World, &mut App) + 'static,
) -> &mut Self {
if let Some(pool) = self.world.get_resource::<ComputeTaskPool>() {
app.world.insert_resource(pool.clone());
}
if let Some(pool) = self.world.get_resource::<AsyncComputeTaskPool>() {
app.world.insert_resource(pool.clone());
}
if let Some(pool) = self.world.get_resource::<IoTaskPool>() {
app.world.insert_resource(pool.clone());
}
self.sub_apps.insert(
Box::new(label),
SubApp {
Expand Down
18 changes: 7 additions & 11 deletions crates/bevy_asset/src/asset_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
use anyhow::Result;
use bevy_ecs::system::{Res, ResMut};
use bevy_log::warn;
use bevy_tasks::TaskPool;
use bevy_tasks::IoTaskPool;
use bevy_utils::{Entry, HashMap, Uuid};
use crossbeam_channel::TryRecvError;
use parking_lot::{Mutex, RwLock};
Expand Down Expand Up @@ -56,7 +56,6 @@ pub struct AssetServerInternal {
loaders: RwLock<Vec<Arc<dyn AssetLoader>>>,
extension_to_loader_index: RwLock<HashMap<String, usize>>,
handle_to_path: Arc<RwLock<HashMap<HandleId, AssetPath<'static>>>>,
task_pool: TaskPool,
}

/// Loads assets from the filesystem on background threads
Expand All @@ -66,11 +65,11 @@ pub struct AssetServer {
}

impl AssetServer {
pub fn new<T: AssetIo>(source_io: T, task_pool: TaskPool) -> Self {
Self::with_boxed_io(Box::new(source_io), task_pool)
pub fn new<T: AssetIo>(source_io: T) -> Self {
Self::with_boxed_io(Box::new(source_io))
}

pub fn with_boxed_io(asset_io: Box<dyn AssetIo>, task_pool: TaskPool) -> Self {
pub fn with_boxed_io(asset_io: Box<dyn AssetIo>) -> Self {
AssetServer {
server: Arc::new(AssetServerInternal {
loaders: Default::default(),
Expand All @@ -79,7 +78,6 @@ impl AssetServer {
asset_ref_counter: Default::default(),
handle_to_path: Default::default(),
asset_lifecycles: Default::default(),
task_pool,
asset_io,
}),
}
Expand Down Expand Up @@ -315,7 +313,6 @@ impl AssetServer {
&self.server.asset_ref_counter.channel,
self.asset_io(),
version,
&self.server.task_pool,
);

if let Err(err) = asset_loader
Expand Down Expand Up @@ -377,8 +374,7 @@ impl AssetServer {
pub(crate) fn load_untracked(&self, asset_path: AssetPath<'_>, force: bool) -> HandleId {
let server = self.clone();
let owned_path = asset_path.to_owned();
self.server
.task_pool
IoTaskPool::get()
.spawn(async move {
if let Err(err) = server.load_async(owned_path, force).await {
warn!("{}", err);
Expand Down Expand Up @@ -620,8 +616,8 @@ mod test {

fn setup(asset_path: impl AsRef<Path>) -> AssetServer {
use crate::FileAssetIo;

AssetServer::new(FileAssetIo::new(asset_path, false), Default::default())
IoTaskPool::init(Default::default);
AssetServer::new(FileAssetIo::new(asset_path, false))
}

#[test]
Expand Down
12 changes: 6 additions & 6 deletions crates/bevy_asset/src/debug_asset_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ impl<T: Asset> Default for HandleMap<T> {

impl Plugin for DebugAssetServerPlugin {
fn build(&self, app: &mut bevy_app::App) {
IoTaskPool::init(|| {
TaskPoolBuilder::default()
.num_threads(2)
.thread_name("Debug Asset Server IO Task Pool".to_string())
.build()
});
let mut debug_asset_app = App::new();
debug_asset_app
.insert_resource(IoTaskPool(
TaskPoolBuilder::default()
.num_threads(2)
.thread_name("Debug Asset Server IO Task Pool".to_string())
.build(),
))
.insert_resource(AssetServerSettings {
asset_folder: "crates".to_string(),
watch_for_changes: true,
Expand Down
7 changes: 1 addition & 6 deletions crates/bevy_asset/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub use path::*;

use bevy_app::{prelude::Plugin, App};
use bevy_ecs::schedule::{StageLabel, SystemStage};
use bevy_tasks::IoTaskPool;

/// The names of asset stages in an App Schedule
#[derive(Debug, Hash, PartialEq, Eq, Clone, StageLabel)]
Expand Down Expand Up @@ -82,12 +81,8 @@ pub fn create_platform_default_asset_io(app: &mut App) -> Box<dyn AssetIo> {
impl Plugin for AssetPlugin {
fn build(&self, app: &mut App) {
if !app.world.contains_resource::<AssetServer>() {
let task_pool = app.world.resource::<IoTaskPool>().0.clone();

let source = create_platform_default_asset_io(app);

let asset_server = AssetServer::with_boxed_io(source, task_pool);

let asset_server = AssetServer::with_boxed_io(source);
app.insert_resource(asset_server);
}

Expand Down
8 changes: 0 additions & 8 deletions crates/bevy_asset/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use crate::{
use anyhow::Result;
use bevy_ecs::system::{Res, ResMut};
use bevy_reflect::{TypeUuid, TypeUuidDynamic};
use bevy_tasks::TaskPool;
use bevy_utils::{BoxedFuture, HashMap};
use crossbeam_channel::{Receiver, Sender};
use downcast_rs::{impl_downcast, Downcast};
Expand Down Expand Up @@ -84,7 +83,6 @@ pub struct LoadContext<'a> {
pub(crate) labeled_assets: HashMap<Option<String>, BoxedLoadedAsset>,
pub(crate) path: &'a Path,
pub(crate) version: usize,
pub(crate) task_pool: &'a TaskPool,
}

impl<'a> LoadContext<'a> {
Expand All @@ -93,15 +91,13 @@ impl<'a> LoadContext<'a> {
ref_change_channel: &'a RefChangeChannel,
asset_io: &'a dyn AssetIo,
version: usize,
task_pool: &'a TaskPool,
) -> Self {
Self {
ref_change_channel,
asset_io,
labeled_assets: Default::default(),
version,
path,
task_pool,
}
}

Expand Down Expand Up @@ -144,10 +140,6 @@ impl<'a> LoadContext<'a> {
asset_metas
}

pub fn task_pool(&self) -> &TaskPool {
self.task_pool
}

pub fn asset_io(&self) -> &dyn AssetIo {
self.asset_io
}
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl Plugin for CorePlugin {
.get_resource::<DefaultTaskPoolOptions>()
.cloned()
.unwrap_or_default()
.create_default_pools(&mut app.world);
.create_default_pools();

app.register_type::<Entity>().register_type::<Name>();

Expand Down
28 changes: 14 additions & 14 deletions crates/bevy_core/src/task_pool_options.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use bevy_ecs::world::World;
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
use bevy_utils::tracing::trace;

Expand Down Expand Up @@ -93,14 +92,14 @@ impl DefaultTaskPoolOptions {
}

/// Inserts the default thread pools into the given resource map based on the configured values
pub fn create_default_pools(&self, world: &mut World) {
pub fn create_default_pools(&self) {
let total_threads =
bevy_tasks::logical_core_count().clamp(self.min_total_threads, self.max_total_threads);
trace!("Assigning {} cores to default task pools", total_threads);

let mut remaining_threads = total_threads;

if !world.contains_resource::<IoTaskPool>() {
{
// Determine the number of IO threads we will use
let io_threads = self
.io
Expand All @@ -109,15 +108,15 @@ impl DefaultTaskPoolOptions {
trace!("IO Threads: {}", io_threads);
remaining_threads = remaining_threads.saturating_sub(io_threads);

world.insert_resource(IoTaskPool(
IoTaskPool::init(|| {
TaskPoolBuilder::default()
.num_threads(io_threads)
.thread_name("IO Task Pool".to_string())
.build(),
));
.build()
});
}

if !world.contains_resource::<AsyncComputeTaskPool>() {
{
// Determine the number of async compute threads we will use
let async_compute_threads = self
.async_compute
Expand All @@ -126,28 +125,29 @@ impl DefaultTaskPoolOptions {
trace!("Async Compute Threads: {}", async_compute_threads);
remaining_threads = remaining_threads.saturating_sub(async_compute_threads);

world.insert_resource(AsyncComputeTaskPool(
AsyncComputeTaskPool::init(|| {
TaskPoolBuilder::default()
.num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string())
.build(),
));
.build()
});
}

if !world.contains_resource::<ComputeTaskPool>() {
{
// Determine the number of compute threads we will use
// This is intentionally last so that an end user can specify 1.0 as the percent
let compute_threads = self
.compute
.get_number_of_threads(remaining_threads, total_threads);

trace!("Compute Threads: {}", compute_threads);
world.insert_resource(ComputeTaskPool(

ComputeTaskPool::init(|| {
TaskPoolBuilder::default()
.num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string())
.build(),
));
.build()
});
}
}
}
4 changes: 2 additions & 2 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ mod tests {

#[test]
fn par_for_each_dense() {
ComputeTaskPool::init(TaskPool::default);
let mut world = World::new();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(A(1)).id();
let e2 = world.spawn().insert(A(2)).id();
let e3 = world.spawn().insert(A(3)).id();
Expand All @@ -397,8 +397,8 @@ mod tests {

#[test]
fn par_for_each_sparse() {
ComputeTaskPool::init(TaskPool::default);
let mut world = World::new();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(SparseStored(1)).id();
let e2 = world.spawn().insert(SparseStored(2)).id();
let e3 = world.spawn().insert(SparseStored(3)).id();
Expand Down
Loading

0 comments on commit 012ae07

Please sign in to comment.