diff --git a/Cargo.toml b/Cargo.toml index 278f991..cdd0abb 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ authors = ["Rakshith Ravi "] edition = "2018" name = "juno" -version = "0.1.3-1" +version = "0.1.3-2" license = "MIT" description = "A helper rust library for the juno microservices framework" homepage = "https://github.com/bytesonus/juno-rust" diff --git a/src/juno_module.rs b/src/juno_module.rs index 7690885..2fd98f9 100644 --- a/src/juno_module.rs +++ b/src/juno_module.rs @@ -25,6 +25,7 @@ use futures::channel::{ oneshot::{channel, Sender}, }; use futures_util::sink::SinkExt; +use std::sync::RwLock; use std::{ collections::HashMap, net::{AddrParseError, SocketAddr}, @@ -41,7 +42,7 @@ pub struct JunoModule { functions: ArcFunctionList, hook_listeners: ArcHookListenerList, message_buffer: Buffer, - registered: bool, + registered: Arc>, } impl JunoModule { @@ -62,27 +63,17 @@ impl JunoModule { #[cfg(target_family = "unix")] pub fn from_unix_socket(socket_path: &str) -> Self { - JunoModule { - protocol: BaseProtocol::default(), - connection: Box::new(UnixSocketConnection::new(socket_path.to_string())), - requests: Arc::new(Mutex::new(HashMap::new())), - functions: Arc::new(Mutex::new(HashMap::new())), - hook_listeners: Arc::new(Mutex::new(HashMap::new())), - message_buffer: vec![], - registered: false, - } + JunoModule::new( + BaseProtocol::default(), + Box::new(UnixSocketConnection::new(socket_path.to_string())), + ) } pub fn from_inet_socket(host: &str, port: u16) -> Self { - JunoModule { - protocol: BaseProtocol::default(), - connection: Box::new(InetSocketConnection::new(format!("{}:{}", host, port))), - requests: Arc::new(Mutex::new(HashMap::new())), - functions: Arc::new(Mutex::new(HashMap::new())), - hook_listeners: Arc::new(Mutex::new(HashMap::new())), - message_buffer: vec![], - registered: false, - } + JunoModule::new( + BaseProtocol::default(), + Box::new(InetSocketConnection::new(format!("{}:{}", host, port))), + ) } pub fn new(protocol: BaseProtocol, connection: Box) -> Self { @@ -93,7 +84,7 @@ impl JunoModule { functions: Arc::new(Mutex::new(HashMap::new())), hook_listeners: Arc::new(Mutex::new(HashMap::new())), message_buffer: vec![], - registered: false, + registered: Arc::new(RwLock::new(false)), } } @@ -109,8 +100,6 @@ impl JunoModule { self.protocol .initialize(String::from(module_id), String::from(version), dependencies); self.send_request(request).await?; - - self.registered = true; Ok(()) } @@ -169,7 +158,7 @@ impl JunoModule { } fn ensure_registered(&self) -> Result<()> { - if !self.registered { + if !*self.registered.read().unwrap() { return Err(Error::Internal(String::from( "Module not registered. Did you .await the call to initialize?", ))); @@ -187,6 +176,7 @@ impl JunoModule { let requests = self.requests.clone(); let functions = self.functions.clone(); let hook_listeners = self.hook_listeners.clone(); + let registered_store = self.registered.clone(); // Run the read-write loop task::spawn(async { @@ -196,6 +186,7 @@ impl JunoModule { requests, functions, hook_listeners, + registered_store, write_sender, ) .await; @@ -206,23 +197,15 @@ impl JunoModule { async fn send_request(&mut self, request: BaseMessage) -> Result { if let BaseMessage::RegisterModuleRequest { .. } = request { - if self.registered { - let (sender, receiver) = channel::>(); - sender.send(Ok(Value::Null)).unwrap(); - - return match receiver.await { - Ok(value) => value, - Err(_) => Err(Error::Internal(String::from( - "Request sender was dropped before data could be retrieved", - ))), - }; + if *self.registered.read().unwrap() { + return Err(Error::Internal(String::from("Module already registered"))); } } let request_type = request.get_type(); let request_id = request.get_request_id().clone(); let mut encoded = self.protocol.encode(request); - if self.registered || request_type == 1 { + if *self.registered.read().unwrap() || request_type == 1 { self.connection.send(encoded).await; } else { self.message_buffer.append(&mut encoded); @@ -247,6 +230,7 @@ async fn on_data_listener( requests: ArcRequestList, functions: ArcFunctionList, hook_listeners: ArcHookListenerList, + registered_store: Arc>, mut write_sender: UnboundedSender, ) { while let Some(data) = receiver.next().await { @@ -277,7 +261,7 @@ async fn on_data_listener( Ok(Value::Null) } BaseMessage::TriggerHookRequest { .. } => { - execute_hook_triggered(message, &hook_listeners).await + execute_hook_triggered(message, ®istered_store, &hook_listeners).await } BaseMessage::Error { error, .. } => Err(Error::FromJuno(error)), _ => Ok(Value::Null), @@ -313,15 +297,22 @@ async fn execute_function_call(message: BaseMessage, functions: &ArcFunctionList async fn execute_hook_triggered( message: BaseMessage, + registered_store: &Arc>, hook_listeners: &ArcHookListenerList, ) -> Result { if let BaseMessage::TriggerHookRequest { hook, .. } = message { - let hook_listeners = hook_listeners.lock().await; - if !hook_listeners.contains_key(&hook) { - todo!("Wtf do I do now? Need to propogate errors. How do I do that?"); - } - for listener in &hook_listeners[&hook] { - listener(Value::Null); + if &hook == "juno.activated" { + *registered_store.write().unwrap() = true; + } else if &hook == "juno.deactivated" { + *registered_store.write().unwrap() = false; + } else { + let hook_listeners = hook_listeners.lock().await; + if !hook_listeners.contains_key(&hook) { + todo!("Wtf do I do now? Need to propogate errors. How do I do that?"); + } + for listener in &hook_listeners[&hook] { + listener(Value::Null); + } } } else { panic!("Cannot execute hook from a request that wasn't a TriggerHookRequest!");