Skip to content

Commit

Permalink
refactor(pool): use a single lock for Pool.inner to avoid too many in…
Browse files Browse the repository at this point in the history
…stances being created
  • Loading branch information
zshipko committed Apr 15, 2024
1 parent 0883fb5 commit 45cd29f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 34 deletions.
1 change: 0 additions & 1 deletion runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ extism-manifest = { workspace = true }
extism-convert = { workspace = true, features = ["extism-path"] }
uuid = { version = "1", features = ["v4"] }
libc = "0.2"
dashmap = "5.5.3"

[features]
default = ["http", "register-http", "register-filesystem"]
Expand Down
39 changes: 23 additions & 16 deletions runtime/src/pool.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;

use crate::{Error, FromBytesOwned, Plugin, PluginBuilder, ToBytes};
use dashmap::DashMap;

/// `PoolPlugin` is used by the pool to track the number of live instances of a particular plugin
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -33,15 +34,15 @@ impl PoolPlugin {
type PluginSource = dyn Fn() -> Result<Plugin, Error>;

struct PoolInner<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq = String> {
plugins: DashMap<Key, Box<PluginSource>>,
instances: DashMap<Key, Vec<PoolPlugin>>,
plugins: HashMap<Key, Box<PluginSource>>,
instances: HashMap<Key, Vec<PoolPlugin>>,
}

/// `Pool` manages threadsafe access to a limited number of instances of multiple plugins
#[derive(Clone)]
pub struct Pool<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq = String> {
max_instances: usize,
inner: std::sync::Arc<PoolInner<Key>>,
inner: std::sync::Arc<std::sync::Mutex<PoolInner<Key>>>,
}

unsafe impl<T: std::fmt::Debug + Clone + std::hash::Hash + Eq> Send for Pool<T> {}
Expand All @@ -52,10 +53,10 @@ impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {
pub fn new(max_instances: usize) -> Self {
Pool {
max_instances,
inner: std::sync::Arc::new(PoolInner {
inner: std::sync::Arc::new(std::sync::Mutex::new(PoolInner {
plugins: Default::default(),
instances: Default::default(),
}),
})),
}
}

Expand All @@ -64,7 +65,7 @@ impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {
where
F: 'static,
{
let pool = &self.inner;
let mut pool = self.inner.lock().unwrap();
if !pool.instances.contains_key(&key) {
pool.instances.insert(key.clone(), vec![]);
}
Expand All @@ -74,7 +75,7 @@ impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {

/// Add a plugin using a `PluginBuilder`
pub fn add_builder(&self, key: Key, source: PluginBuilder<'static>) {
let pool = &self.inner;
let mut pool = self.inner.lock().unwrap();
if !pool.instances.contains_key(&key) {
pool.instances.insert(key.clone(), vec![]);
}
Expand All @@ -84,7 +85,8 @@ impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {
}

fn find_available(&self, key: &Key) -> Result<Option<PoolPlugin>, Error> {
if let Some(entry) = self.inner.instances.get_mut(key) {
let mut pool = self.inner.lock().unwrap();
if let Some(entry) = pool.instances.get_mut(key) {
for instance in entry.iter() {
if std::rc::Rc::strong_count(&instance.0) == 1 {
return Ok(Some(instance.clone()));
Expand All @@ -97,6 +99,8 @@ impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {
/// Get the number of live instances for a plugin
pub fn count(&self, key: &Key) -> usize {
self.inner
.lock()
.unwrap()
.instances
.get(key)
.map(|x| x.len())
Expand All @@ -117,13 +121,16 @@ impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {
return Ok(Some(avail));
}

if self.inner.instances.get(key).map(|x| x.len()).unwrap_or(0) < max {
if let Some(source) = self.inner.plugins.get(key) {
let plugin = source()?;
let instance = PoolPlugin::new(plugin);
let mut v = self.inner.instances.get_mut(key).unwrap();
v.push(instance);
return Ok(Some(v.last().unwrap().clone()));
{
let mut pool = self.inner.lock().unwrap();
if pool.instances.get(key).map(|x| x.len()).unwrap_or_default() < max {
if let Some(source) = pool.plugins.get(key) {
let plugin = source()?;
let instance = PoolPlugin::new(plugin);
let v = pool.instances.get_mut(key).unwrap();
v.push(instance);
return Ok(Some(v.last().unwrap().clone()));
}
}
}

Expand Down
44 changes: 27 additions & 17 deletions runtime/src/tests/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fn run_thread(p: Pool<String>, i: u64) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(i));
let s: String = p
.get(&"test".to_string(), std::time::Duration::from_secs(5))
.get(&"test".to_string(), std::time::Duration::from_secs(1))
.unwrap()
.unwrap()
.call("count_vowels", "abc")
Expand All @@ -15,24 +15,34 @@ fn run_thread(p: Pool<String>, i: u64) -> std::thread::JoinHandle<()> {

#[test]
fn test_threads() {
let data = include_bytes!("../../../wasm/code.wasm");
let pool: Pool<String> = Pool::new(2);
for i in 1..=3 {
let data = include_bytes!("../../../wasm/code.wasm");
let pool: Pool<String> = Pool::new(i);

let test = "test".to_string();
pool.add_builder(
test.clone(),
extism::PluginBuilder::new(extism::Manifest::new([extism::Wasm::data(data)]))
.with_wasi(true),
);
let test = "test".to_string();
pool.add_builder(
test.clone(),
extism::PluginBuilder::new(extism::Manifest::new([extism::Wasm::data(data)]))
.with_wasi(true),
);

let mut threads = vec![];
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 0));
let mut threads = vec![];
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 0));

for t in threads {
t.join().unwrap();
for t in threads {
t.join().unwrap();
}
assert!(pool.count(&test) <= i);
}
assert_eq!(pool.count(&test), 2);
}

0 comments on commit 45cd29f

Please sign in to comment.