Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["Rakshith Ravi <rakshith.ravi@gmx.com>"]
edition = "2018"
name = "juno"
version = "0.1.3-1"
version = "0.1.3-2"
license = "MIT"
description = "A helper rust library for the juno microservices framework"
homepage = "https://github.com/bytesonus/juno-rust"
Expand Down
73 changes: 32 additions & 41 deletions src/juno_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use futures::channel::{
oneshot::{channel, Sender},
};
use futures_util::sink::SinkExt;
use std::sync::RwLock;
use std::{
collections::HashMap,
net::{AddrParseError, SocketAddr},
Expand All @@ -41,7 +42,7 @@ pub struct JunoModule {
functions: ArcFunctionList,
hook_listeners: ArcHookListenerList,
message_buffer: Buffer,
registered: bool,
registered: Arc<RwLock<bool>>,
}

impl JunoModule {
Expand All @@ -62,27 +63,17 @@ impl JunoModule {

#[cfg(target_family = "unix")]
pub fn from_unix_socket(socket_path: &str) -> Self {
JunoModule {
protocol: BaseProtocol::default(),
connection: Box::new(UnixSocketConnection::new(socket_path.to_string())),
requests: Arc::new(Mutex::new(HashMap::new())),
functions: Arc::new(Mutex::new(HashMap::new())),
hook_listeners: Arc::new(Mutex::new(HashMap::new())),
message_buffer: vec![],
registered: false,
}
JunoModule::new(
BaseProtocol::default(),
Box::new(UnixSocketConnection::new(socket_path.to_string())),
)
}

pub fn from_inet_socket(host: &str, port: u16) -> Self {
JunoModule {
protocol: BaseProtocol::default(),
connection: Box::new(InetSocketConnection::new(format!("{}:{}", host, port))),
requests: Arc::new(Mutex::new(HashMap::new())),
functions: Arc::new(Mutex::new(HashMap::new())),
hook_listeners: Arc::new(Mutex::new(HashMap::new())),
message_buffer: vec![],
registered: false,
}
JunoModule::new(
BaseProtocol::default(),
Box::new(InetSocketConnection::new(format!("{}:{}", host, port))),
)
}

pub fn new(protocol: BaseProtocol, connection: Box<dyn BaseConnection + Send + Sync>) -> Self {
Expand All @@ -93,7 +84,7 @@ impl JunoModule {
functions: Arc::new(Mutex::new(HashMap::new())),
hook_listeners: Arc::new(Mutex::new(HashMap::new())),
message_buffer: vec![],
registered: false,
registered: Arc::new(RwLock::new(false)),
}
}

Expand All @@ -109,8 +100,6 @@ impl JunoModule {
self.protocol
.initialize(String::from(module_id), String::from(version), dependencies);
self.send_request(request).await?;

self.registered = true;
Ok(())
}

Expand Down Expand Up @@ -169,7 +158,7 @@ impl JunoModule {
}

fn ensure_registered(&self) -> Result<()> {
if !self.registered {
if !*self.registered.read().unwrap() {
return Err(Error::Internal(String::from(
"Module not registered. Did you .await the call to initialize?",
)));
Expand All @@ -187,6 +176,7 @@ impl JunoModule {
let requests = self.requests.clone();
let functions = self.functions.clone();
let hook_listeners = self.hook_listeners.clone();
let registered_store = self.registered.clone();

// Run the read-write loop
task::spawn(async {
Expand All @@ -196,6 +186,7 @@ impl JunoModule {
requests,
functions,
hook_listeners,
registered_store,
write_sender,
)
.await;
Expand All @@ -206,23 +197,15 @@ impl JunoModule {

async fn send_request(&mut self, request: BaseMessage) -> Result<Value> {
if let BaseMessage::RegisterModuleRequest { .. } = request {
if self.registered {
let (sender, receiver) = channel::<Result<Value>>();
sender.send(Ok(Value::Null)).unwrap();

return match receiver.await {
Ok(value) => value,
Err(_) => Err(Error::Internal(String::from(
"Request sender was dropped before data could be retrieved",
))),
};
if *self.registered.read().unwrap() {
return Err(Error::Internal(String::from("Module already registered")));
}
}

let request_type = request.get_type();
let request_id = request.get_request_id().clone();
let mut encoded = self.protocol.encode(request);
if self.registered || request_type == 1 {
if *self.registered.read().unwrap() || request_type == 1 {
self.connection.send(encoded).await;
} else {
self.message_buffer.append(&mut encoded);
Expand All @@ -247,6 +230,7 @@ async fn on_data_listener(
requests: ArcRequestList,
functions: ArcFunctionList,
hook_listeners: ArcHookListenerList,
registered_store: Arc<RwLock<bool>>,
mut write_sender: UnboundedSender<Buffer>,
) {
while let Some(data) = receiver.next().await {
Expand Down Expand Up @@ -277,7 +261,7 @@ async fn on_data_listener(
Ok(Value::Null)
}
BaseMessage::TriggerHookRequest { .. } => {
execute_hook_triggered(message, &hook_listeners).await
execute_hook_triggered(message, &registered_store, &hook_listeners).await
}
BaseMessage::Error { error, .. } => Err(Error::FromJuno(error)),
_ => Ok(Value::Null),
Expand Down Expand Up @@ -313,15 +297,22 @@ async fn execute_function_call(message: BaseMessage, functions: &ArcFunctionList

async fn execute_hook_triggered(
message: BaseMessage,
registered_store: &Arc<RwLock<bool>>,
hook_listeners: &ArcHookListenerList,
) -> Result<Value> {
if let BaseMessage::TriggerHookRequest { hook, .. } = message {
let hook_listeners = hook_listeners.lock().await;
if !hook_listeners.contains_key(&hook) {
todo!("Wtf do I do now? Need to propogate errors. How do I do that?");
}
for listener in &hook_listeners[&hook] {
listener(Value::Null);
if &hook == "juno.activated" {
*registered_store.write().unwrap() = true;
} else if &hook == "juno.deactivated" {
*registered_store.write().unwrap() = false;
} else {
let hook_listeners = hook_listeners.lock().await;
if !hook_listeners.contains_key(&hook) {
todo!("Wtf do I do now? Need to propogate errors. How do I do that?");
}
for listener in &hook_listeners[&hook] {
listener(Value::Null);
}
}
} else {
panic!("Cannot execute hook from a request that wasn't a TriggerHookRequest!");
Expand Down