From cb60232c8def5a5e36a09e8d123022a0e27273c8 Mon Sep 17 00:00:00 2001 From: Torben Schweren Date: Wed, 3 Jan 2024 22:14:03 +0100 Subject: [PATCH 1/5] Service framework improvements - Way better handling of mutable/immutable attributes - Less Mutexes - Better handling of passing references through the service framework's API - Reimplement get_service accepting a TypeId as a generic parameter for easier usage - Reimplement status_map and status_tree as a result of the above adaptations, resulting in way simpler versions --- src/service.rs | 138 ++++++++++++++++++----------------------- src/service/discord.rs | 45 +++++++++----- 2 files changed, 87 insertions(+), 96 deletions(-) diff --git a/src/service.rs b/src/service.rs index d0e09c3..dffa34e 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,5 +1,6 @@ use log::{error, info, warn}; use std::{ + any::{Any, TypeId}, cmp::Ordering, collections::HashMap, error::Error, @@ -7,7 +8,6 @@ use std::{ future::Future, hash::{Hash, Hasher}, pin::Pin, - sync::Arc, }; use tokio::sync::Mutex; @@ -96,7 +96,7 @@ pub struct ServiceInfo { pub name: String, pub priority: Priority, - pub status: Arc>, + pub status: Mutex, } impl ServiceInfo { @@ -105,17 +105,10 @@ impl ServiceInfo { id: id.to_string(), name: name.to_string(), priority, - status: Arc::new(Mutex::new(Status::Stopped)), + status: Mutex::new(Status::Stopped), } } - pub fn is_available(&self) -> Pin + '_>> { - Box::pin(async move { - let lock = self.status.lock().await; - matches!(&*lock, Status::Started) - }) - } - pub async fn set_status(&self, status: Status) { let mut lock = self.status.lock().await; *lock = status; @@ -158,6 +151,9 @@ pub trait ServiceInternals { pub trait Service: ServiceInternals { fn info(&self) -> &ServiceInfo; + // Used for downcasting in get_service method of ServiceManager + fn as_any(&self) -> &dyn Any; + fn wrapped_start(&mut self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { let mut status = self.info().status.lock().await; @@ -215,6 +211,13 @@ pub trait Service: ServiceInternals { } }) } + + fn is_available(&self) -> Pin + '_>> { + Box::pin(async move { + let lock = self.info().status.lock().await; + matches!(&*lock, Status::Started) + }) + } } impl Eq for dyn Service {} @@ -250,7 +253,9 @@ pub struct ServiceManagerBuilder { impl ServiceManagerBuilder { pub fn new() -> Self { - Self { services: vec![] } + Self { + services: Vec::new(), + } } pub fn with_service(mut self, service: Box) -> Self { @@ -300,22 +305,22 @@ impl ServiceManager { }) } - pub fn get_service(&self, id: &str) -> Option<&dyn Service> { + pub fn get_service(&self) -> Option<&T> + where + T: Service + 'static, + { self.services .iter() - .find(|s| s.info().id == id) - .map(|s| &**s) + .find(|s| TypeId::of::() == s.as_any().type_id()) + .and_then(|s| s.as_any().downcast_ref::()) } - pub fn status_map(&self) -> PinnedBoxedFuture<'_, HashMap>>> { + pub fn status_map(&self) -> PinnedBoxedFuture<'_, HashMap<&dyn Service, &Mutex>> { Box::pin(async move { let mut status_map = HashMap::new(); for service in self.services.iter() { - status_map.insert( - service.info().id.clone(), - Arc::clone(&service.info().status), - ); + status_map.insert(&**service, &service.info().status); } status_map @@ -344,92 +349,67 @@ impl ServiceManager { let mut text_buffer = String::new(); - let mut failed_essentials = HashMap::new(); - let mut failed_optionals = HashMap::new(); - let mut non_failed_essentials = HashMap::new(); - let mut non_failed_optionals = HashMap::new(); - let mut others = HashMap::new(); + let mut failed_essentials = String::new(); + let mut failed_optionals = String::new(); + let mut non_failed_essentials = String::new(); + let mut non_failed_optionals = String::new(); + let mut others = String::new(); for (service, status) in status_map.into_iter() { - let priority = match self.get_service(service.as_str()) { - Some(service) => service.info().priority, - None => unreachable!( - "Service with ID {} not found in ServiceManager. This should never happen!", - service, - ), - }; - + let info = service.info(); + let priority = &info.priority; let status = status.lock().await; - match &*status { - Status::Started | Status::Stopped => { - if priority == Priority::Essential { - non_failed_essentials.insert(service, status.to_string()); - } else { - non_failed_optionals.insert(service, status.to_string()); + match *status { + Status::Started | Status::Stopped => match priority { + Priority::Essential => { + non_failed_essentials + .push_str(&format!(" - {}: {}\n", info.name, status)); } - } + Priority::Optional => { + non_failed_optionals + .push_str(&format!(" - {}: {}\n", info.name, status)); + } + }, Status::FailedStarting(_) | Status::FailedStopping(_) - | Status::RuntimeError(_) => { - if priority == Priority::Essential { - failed_essentials.insert(service, status.to_string()); - } else { - failed_optionals.insert(service, status.to_string()); + | Status::RuntimeError(_) => match priority { + Priority::Essential => { + failed_essentials.push_str(&format!(" - {}: {}\n", info.name, status)); } - } + Priority::Optional => { + failed_optionals.push_str(&format!(" - {}: {}\n", info.name, status)); + } + }, _ => { - others.insert(service, status.to_string()); + others.push_str(&format!(" - {}: {}\n", info.name, status)); } } } - let section_generator = |services: &HashMap, title: &str| -> String { - let mut text_buffer = String::new(); - - text_buffer.push_str(&format!("- {}:\n", title)); - - for (service, status) in services.iter() { - let service = match self.get_service(service) { - Some(service) => service, - None => unreachable!( - "Service with ID {} not found in ServiceManager. This should never happen!", - service - ), - }; - - text_buffer.push_str(&format!(" - {}: {}\n", service.info().name, status)); - } - - text_buffer - }; - if !failed_essentials.is_empty() { - text_buffer.push_str( - section_generator(&failed_essentials, "Failed essential services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Failed essential services")); + text_buffer.push_str(&failed_essentials); } if !failed_optionals.is_empty() { - text_buffer.push_str( - section_generator(&failed_optionals, "Failed optional services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Failed optional services")); + text_buffer.push_str(&failed_optionals); } if !non_failed_essentials.is_empty() { - text_buffer.push_str( - section_generator(&non_failed_essentials, "Essential services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Essential services")); + text_buffer.push_str(&non_failed_essentials); } if !non_failed_optionals.is_empty() { - text_buffer.push_str( - section_generator(&non_failed_optionals, "Optional services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Optional services")); + text_buffer.push_str(&non_failed_optionals); } if !others.is_empty() { - text_buffer.push_str(section_generator(&others, "Other services").as_str()); + text_buffer.push_str(&format!("- {}:\n", "Other services")); + text_buffer.push_str(&others); } text_buffer diff --git a/src/service/discord.rs b/src/service/discord.rs index c025531..3104452 100644 --- a/src/service/discord.rs +++ b/src/service/discord.rs @@ -10,11 +10,11 @@ use serenity::{ prelude::TypeMap, Client, Error, }; -use std::{sync::Arc, time::Duration}; +use std::{any::Any, sync::Arc, time::Duration}; use tokio::{ sync::{Mutex, Notify, RwLock}, task::JoinHandle, - time::timeout, + time::{sleep, timeout}, }; pub struct DiscordService { @@ -80,11 +80,21 @@ impl ServiceInternals for DiscordService { info!("Connecting to Discord"); let client_handle = tokio::spawn(async move { client.start().await }); + // This prevents waiting for the timeout if the client fails immediately + // TODO: Optimize this, as it will currently add 1000mqs to the startup time + sleep(Duration::from_secs(1)).await; + if client_handle.is_finished() { + client_handle.await??; + return Err("Discord client stopped unexpectedly and with no error".into()); + } + if timeout(self.connection_timeout, client_ready_notify.notified()) .await .is_err() { client_handle.abort(); + let result = convert_thread_result(client_handle).await; + result?; return Err(format!( "Discord client failed to connect within {} seconds", @@ -93,11 +103,6 @@ impl ServiceInternals for DiscordService { .into()); } - if client_handle.is_finished() { - client_handle.await??; - return Err("Discord client stopped unexpectedly and with no error".into()); - } - self.client_handle = Some(client_handle); Ok(()) }) @@ -105,21 +110,15 @@ impl ServiceInternals for DiscordService { fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()> { Box::pin(async move { - if let Some(handle) = self.client_handle.take() { + if let Some(client_handle) = self.client_handle.take() { info!("Waiting for Discord client to stop..."); - handle.abort(); - - let result = match handle.await { - Ok(result) => result, - Err(_) => { - info!("Discord client stopped"); - return Ok(()); - } - }; + client_handle.abort(); + let result = convert_thread_result(client_handle).await; result?; } + info!("Discord client stopped"); Ok(()) }) } @@ -129,6 +128,18 @@ impl Service for DiscordService { fn info(&self) -> &ServiceInfo { &self.info } + + fn as_any(&self) -> &dyn Any { + self + } +} + +// If the thread ended WITHOUT a JoinError from aborting, the client already stopped unexpectedly +async fn convert_thread_result(client_handle: JoinHandle>) -> Result<(), Error> { + match client_handle.await { + Ok(result) => result, + Err(_) => Ok(()), + } } struct EventHandler { From 6856c66b105e55e884f2234e2dfda8e1bcd5f6f0 Mon Sep 17 00:00:00 2001 From: Torben Schweren Date: Thu, 4 Jan 2024 20:25:52 +0100 Subject: [PATCH 2/5] More service framework improvements - Replace all Mutexes with RwLock - Remove status_map method of ServiceManager - Services vector in ServiceManager now wraps the Service trait objects in a RwLock to potentially make them available mutably through the public API of ServiceManager --- src/bot.rs | 12 ++-- src/main.rs | 8 ++- src/service.rs | 123 +++++++++++++++++++---------------------- src/service/discord.rs | 24 +++----- 4 files changed, 81 insertions(+), 86 deletions(-) diff --git a/src/bot.rs b/src/bot.rs index e9c8946..f77515a 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use tokio::sync::RwLock; + use crate::service::{PinnedBoxedFuture, Service, ServiceManager, ServiceManagerBuilder}; pub struct BotBuilder { @@ -13,15 +17,15 @@ impl BotBuilder { } } - pub fn with_service(mut self, service: Box) -> Self { - self.service_manager = self.service_manager.with_service(service); // The ServiceManagerBuilder itself will warn when adding a service multiple times + pub async fn with_service(mut self, service: Arc>) -> Self { + self.service_manager = self.service_manager.with_service(service).await; // The ServiceManagerBuilder itself will warn when adding a service multiple times self } - pub fn with_services(mut self, services: Vec>) -> Self { + pub async fn with_services(mut self, services: Vec>>) -> Self { for service in services { - self.service_manager = self.service_manager.with_service(service); + self.service_manager = self.service_manager.with_service(service).await; } self diff --git a/src/main.rs b/src/main.rs index 7624dae..e2a2a6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use ::log::{error, warn}; use lum::{ bot::Bot, @@ -5,6 +7,7 @@ use lum::{ log, service::{discord::DiscordService, Service}, }; +use tokio::sync::RwLock; const BOT_NAME: &str = "Lum"; @@ -31,6 +34,7 @@ async fn main() { let bot = Bot::builder(BOT_NAME) .with_services(initialize_services(&config)) + .await .build(); lum::run(bot).await; @@ -45,12 +49,12 @@ fn setup_logger() { } } -fn initialize_services(config: &Config) -> Vec> { +fn initialize_services(config: &Config) -> Vec>> { //TODO: Add services //... let discord_service = DiscordService::new(config.discord_token.as_str(), config.discord_timeout); - vec![Box::new(discord_service)] + vec![Arc::new(RwLock::new(discord_service))] } diff --git a/src/service.rs b/src/service.rs index dffa34e..98cb3b9 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,15 +1,15 @@ use log::{error, info, warn}; use std::{ - any::{Any, TypeId}, + any::Any, cmp::Ordering, - collections::HashMap, error::Error, fmt::Display, future::Future, hash::{Hash, Hasher}, pin::Pin, + sync::Arc, }; -use tokio::sync::Mutex; +use tokio::sync::RwLock; pub mod discord; @@ -92,11 +92,11 @@ impl Display for Priority { #[derive(Debug)] pub struct ServiceInfo { - pub id: String, + id: String, pub name: String, pub priority: Priority, - pub status: Mutex, + pub status: Arc>, } impl ServiceInfo { @@ -105,13 +105,12 @@ impl ServiceInfo { id: id.to_string(), name: name.to_string(), priority, - status: Mutex::new(Status::Stopped), + status: Arc::new(RwLock::new(Status::Stopped)), } } pub async fn set_status(&self, status: Status) { - let mut lock = self.status.lock().await; - *lock = status; + *(self.status.write().await) = status } } @@ -140,23 +139,21 @@ impl Hash for ServiceInfo { self.id.hash(state); } } - //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a PinnedBoxedFutureResult -pub trait ServiceInternals { - fn start(&mut self) -> PinnedBoxedFutureResult<'_, ()>; - fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()>; -} - -//TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a PinnedBoxedFutureResult -pub trait Service: ServiceInternals { +pub trait Service: Any + Send + Sync { fn info(&self) -> &ServiceInfo; + fn start(&mut self, service_manager: &ServiceManager) -> PinnedBoxedFutureResult<'_, ()>; + fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()>; // Used for downcasting in get_service method of ServiceManager - fn as_any(&self) -> &dyn Any; + //fn as_any_arc(&self) -> Arc; - fn wrapped_start(&mut self) -> PinnedBoxedFuture<'_, ()> { + fn wrapped_start<'a>( + &'a mut self, + service_manager: &'a ServiceManager, + ) -> PinnedBoxedFuture<'a, ()> { Box::pin(async move { - let mut status = self.info().status.lock().await; + let mut status = self.info().status.write().await; if !matches!(&*status, Status::Stopped) { warn!( @@ -170,7 +167,7 @@ pub trait Service: ServiceInternals { *status = Status::Starting; drop(status); - match self.start().await { + match self.start(service_manager).await { Ok(()) => { info!("Started service: {}", self.info().name); self.info().set_status(Status::Started).await; @@ -185,7 +182,7 @@ pub trait Service: ServiceInternals { fn wrapped_stop(&mut self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - let mut status = self.info().status.lock().await; + let mut status = self.info().status.write().await; if !matches!(&*status, Status::Started) { warn!( @@ -199,7 +196,7 @@ pub trait Service: ServiceInternals { *status = Status::Stopping; drop(status); - match ServiceInternals::stop(self).await { + match self.stop().await { Ok(()) => { info!("Stopped service: {}", self.info().name); self.info().set_status(Status::Stopped).await; @@ -212,11 +209,8 @@ pub trait Service: ServiceInternals { }) } - fn is_available(&self) -> Pin + '_>> { - Box::pin(async move { - let lock = self.info().status.lock().await; - matches!(&*lock, Status::Started) - }) + fn is_available(&self) -> PinnedBoxedFuture<'_, bool> { + Box::pin(async move { matches!(&*(self.info().status.read().await), Status::Started) }) } } @@ -248,7 +242,7 @@ impl Hash for dyn Service { #[derive(Default)] pub struct ServiceManagerBuilder { - services: Vec>, + services: Vec>>, } impl ServiceManagerBuilder { @@ -258,20 +252,30 @@ impl ServiceManagerBuilder { } } - pub fn with_service(mut self, service: Box) -> Self { - let service_exists = self.services.iter().any(|s| s.info() == service.info()); + //TODO: When Rust allows async closures, refactor this to use iterator methods instead of for loop + pub async fn with_service(mut self, service: Arc>) -> Self { + let lock = service.read().await; - if service_exists { - warn!( - "Tried to add service {} ({}), but a service with that ID already exists. Ignoring.", - service.info().name, service.info().id - ); + let mut found = false; + for registered_service in self.services.iter() { + let registered_service = registered_service.read().await; + if registered_service.info().id == lock.info().id { + found = true; + } + } + + if found { + warn!( + "Tried to add service {} ({}), but a service with that ID already exists. Ignoring.", + lock.info().name, lock.info().id + ); return self; } - self.services.push(service); + drop(lock); + self.services.push(service); self } @@ -281,7 +285,7 @@ impl ServiceManagerBuilder { } pub struct ServiceManager { - pub services: Vec>, + pub services: Vec>>, } impl ServiceManager { @@ -289,49 +293,38 @@ impl ServiceManager { ServiceManagerBuilder::new() } - pub fn start_services(&mut self) -> PinnedBoxedFuture<'_, ()> { + pub fn start_services(&self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - for service in &mut self.services { - service.wrapped_start().await; + for service in &self.services { + let mut service = service.write().await; + service.wrapped_start(self).await; } }) } - pub fn stop_services(&mut self) -> PinnedBoxedFuture<'_, ()> { + pub fn stop_services(&self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - for service in &mut self.services { + for service in &self.services { + let mut service = service.write().await; service.wrapped_stop().await; } }) } - pub fn get_service(&self) -> Option<&T> + pub fn get_service(&self) -> Option> where - T: Service + 'static, + T: Service, { - self.services - .iter() - .find(|s| TypeId::of::() == s.as_any().type_id()) - .and_then(|s| s.as_any().downcast_ref::()) - } - - pub fn status_map(&self) -> PinnedBoxedFuture<'_, HashMap<&dyn Service, &Mutex>> { - Box::pin(async move { - let mut status_map = HashMap::new(); - - for service in self.services.iter() { - status_map.insert(&**service, &service.info().status); - } - - status_map - }) + //TODO + todo!("Implement") } //TODO: When Rust allows async closures, refactor this to use iterator methods instead of for loop pub fn overall_status(&self) -> PinnedBoxedFuture<'_, OverallStatus> { Box::pin(async move { for service in self.services.iter() { - let status = service.info().status.lock().await; + let service = service.read().await; + let status = service.info().status.read().await; if !matches!(&*status, Status::Started) { return OverallStatus::Unhealthy; @@ -345,8 +338,6 @@ impl ServiceManager { //TODO: When Rust allows async closures, refactor this to use iterator methods instead of for loop pub fn status_tree(&self) -> PinnedBoxedFuture<'_, String> { Box::pin(async move { - let status_map = self.status_map().await; - let mut text_buffer = String::new(); let mut failed_essentials = String::new(); @@ -355,10 +346,11 @@ impl ServiceManager { let mut non_failed_optionals = String::new(); let mut others = String::new(); - for (service, status) in status_map.into_iter() { + for service in self.services.iter() { + let service = service.read().await; let info = service.info(); let priority = &info.priority; - let status = status.lock().await; + let status = info.status.read().await; match *status { Status::Started | Status::Stopped => match priority { @@ -428,6 +420,7 @@ impl Display for ServiceManager { let mut services = self.services.iter().peekable(); while let Some(service) = services.next() { + let service = service.blocking_read(); write!(f, "{} ({})", service.info().name, service.info().id)?; if services.peek().is_some() { write!(f, ", ")?; diff --git a/src/service/discord.rs b/src/service/discord.rs index 3104452..5925090 100644 --- a/src/service/discord.rs +++ b/src/service/discord.rs @@ -1,4 +1,4 @@ -use super::{PinnedBoxedFutureResult, Priority, Service, ServiceInfo, ServiceInternals}; +use super::{PinnedBoxedFutureResult, Priority, Service, ServiceInfo, ServiceManager}; use log::info; use serenity::{ all::{GatewayIntents, Ready}, @@ -10,7 +10,7 @@ use serenity::{ prelude::TypeMap, Client, Error, }; -use std::{any::Any, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use tokio::{ sync::{Mutex, Notify, RwLock}, task::JoinHandle, @@ -35,7 +35,7 @@ impl DiscordService { pub fn new(discord_token: &str, connection_timeout: Duration) -> Self { Self { info: ServiceInfo::new("lum_builtin_discord", "Discord", Priority::Essential), - discord_token: discord_token.to_string(), + discord_token: "discord_token".to_string(), connection_timeout, client: Arc::new(Mutex::new(None)), client_handle: None, @@ -49,8 +49,12 @@ impl DiscordService { } } -impl ServiceInternals for DiscordService { - fn start(&mut self) -> PinnedBoxedFutureResult<'_, ()> { +impl Service for DiscordService { + fn info(&self) -> &ServiceInfo { + &self.info + } + + fn start(&mut self, service_manager: &ServiceManager) -> PinnedBoxedFutureResult<'_, ()> { Box::pin(async move { let framework = StandardFramework::new(); framework.configure(Configuration::new().prefix("!")); @@ -124,16 +128,6 @@ impl ServiceInternals for DiscordService { } } -impl Service for DiscordService { - fn info(&self) -> &ServiceInfo { - &self.info - } - - fn as_any(&self) -> &dyn Any { - self - } -} - // If the thread ended WITHOUT a JoinError from aborting, the client already stopped unexpectedly async fn convert_thread_result(client_handle: JoinHandle>) -> Result<(), Error> { match client_handle.await { From 146d0edb87b8759c8d95a06282d3498fc975497a Mon Sep 17 00:00:00 2001 From: Torben Schweren Date: Fri, 5 Jan 2024 16:27:34 +0100 Subject: [PATCH 3/5] Implement get_service method - Add downcast-rs crate - Implement get_service method of ServiceManager Had to use unsafe Rust for this. Tried it with safe Rust for 3 days and couldn't do it. With unsafe Rust, it's very easy. It's also still kinda safe, as the crash case is checked and prevented before going into the unsafe block. --- Cargo.toml | 1 + src/service.rs | 34 +++++++++++++++++++++++++++++----- src/service/discord.rs | 2 +- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 991b168..b91d029 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ repository = "https://github.com/Kitt3120/lum" [dependencies] dirs = "5.0.1" +downcast-rs = "1.2.0" fern = { version = "0.6.2", features = ["chrono", "colored", "date-based"] } humantime = "2.1.0" log = { version = "0.4.20", features = ["serde"] } diff --git a/src/service.rs b/src/service.rs index 98cb3b9..7aa81ea 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,3 +1,4 @@ +use downcast_rs::{impl_downcast, DowncastSync}; use log::{error, info, warn}; use std::{ any::Any, @@ -6,6 +7,7 @@ use std::{ fmt::Display, future::Future, hash::{Hash, Hasher}, + mem, pin::Pin, sync::Arc, }; @@ -140,7 +142,7 @@ impl Hash for ServiceInfo { } } //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a PinnedBoxedFutureResult -pub trait Service: Any + Send + Sync { +pub trait Service: DowncastSync { fn info(&self) -> &ServiceInfo; fn start(&mut self, service_manager: &ServiceManager) -> PinnedBoxedFutureResult<'_, ()>; fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()>; @@ -214,6 +216,8 @@ pub trait Service: Any + Send + Sync { } } +impl_downcast!(sync Service); + impl Eq for dyn Service {} impl PartialEq for dyn Service { @@ -311,12 +315,32 @@ impl ServiceManager { }) } - pub fn get_service(&self) -> Option> + pub async fn get_service(&self) -> Option>> where - T: Service, + T: Service + Any + Send + Sync + 'static, { - //TODO - todo!("Implement") + for service in self.services.iter() { + let lock = service.read().await; + let is_t = lock.as_any().is::(); + drop(lock); + + if is_t { + let arc_clone = Arc::clone(service); + let service_ptr: *const Arc> = &arc_clone; + + /* + I tried to do this in safe rust for 3 days, but I couldn't figure it out + Should you come up with a way to do this in safe rust, please make a PR! :) + Anyways, this should never cause any issues, since we checked if the service is of type T + */ + unsafe { + let t_ptr: *const Arc> = mem::transmute(service_ptr); + return Some(Arc::clone(&*t_ptr)); + } + } + } + + None } //TODO: When Rust allows async closures, refactor this to use iterator methods instead of for loop diff --git a/src/service/discord.rs b/src/service/discord.rs index 5925090..8841e01 100644 --- a/src/service/discord.rs +++ b/src/service/discord.rs @@ -35,7 +35,7 @@ impl DiscordService { pub fn new(discord_token: &str, connection_timeout: Duration) -> Self { Self { info: ServiceInfo::new("lum_builtin_discord", "Discord", Priority::Essential), - discord_token: "discord_token".to_string(), + discord_token: discord_token.to_string(), connection_timeout, client: Arc::new(Mutex::new(None)), client_handle: None, From 484be2e1d0efb590d8324faa64dba69fceb46fd7 Mon Sep 17 00:00:00 2001 From: Torben Schweren Date: Sun, 7 Jan 2024 02:54:09 +0100 Subject: [PATCH 4/5] Finish refactor of service framework - ServiceManager now holds an Arc to itself - Self-Arc is now passed to services when initializing them, so they can access other services and copy Arcs to those for themselves - Implement SetLock struct which is a wrapper around Option for lazy-initialization - ServiceManagerBuilder handles the creation and injection of the Self-Arc of ServiceManager. That's why the build() method is now async and the From trait had to be removed. The From trait cannot be implemented async. - To keep everything consistent, the From trait has also been removed from the BotBuilder and the build() method becase async. --- src/bot.rs | 22 ++++++--------- src/lib.rs | 7 +++-- src/main.rs | 3 +- src/service.rs | 55 ++++++++++++++++++++++++++----------- src/service/discord.rs | 5 +++- src/setlock.rs | 62 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 120 insertions(+), 34 deletions(-) create mode 100644 src/setlock.rs diff --git a/src/bot.rs b/src/bot.rs index f77515a..3984cab 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -31,14 +31,17 @@ impl BotBuilder { self } - pub fn build(self) -> Bot { - Bot::from(self) + pub async fn build(self) -> Bot { + Bot { + name: self.name, + service_manager: self.service_manager.build().await, + } } } pub struct Bot { pub name: String, - pub service_manager: ServiceManager, + pub service_manager: Arc>, } impl Bot { @@ -49,7 +52,7 @@ impl Bot { //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a future pub fn start(&mut self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - self.service_manager.start_services().await; + self.service_manager.write().await.start_services().await; //TODO: Potential for further initialization here, like modules }) } @@ -57,17 +60,8 @@ impl Bot { //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a future pub fn stop(&mut self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - self.service_manager.stop_services().await; + self.service_manager.write().await.stop_services().await; //TODO: Potential for further deinitialization here, like modules }) } } - -impl From for Bot { - fn from(builder: BotBuilder) -> Self { - Self { - name: builder.name, - service_manager: builder.service_manager.build(), - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 1296293..4396be6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ pub mod bot; pub mod config; pub mod log; pub mod service; +pub mod setlock; pub fn is_debug() -> bool { cfg!(debug_assertions) @@ -35,8 +36,9 @@ pub async fn run(mut bot: Bot) { } }; - if bot.service_manager.overall_status().await != OverallStatus::Healthy { - let status_tree = bot.service_manager.status_tree().await; + let service_manager = bot.service_manager.read().await; + if service_manager.overall_status().await != OverallStatus::Healthy { + let status_tree = service_manager.status_tree().await; error!("{} is not healthy! Some essential services did not start up successfully. Please check the logs.\nService status tree:\n{}\n{} will exit.", bot.name, @@ -44,6 +46,7 @@ pub async fn run(mut bot: Bot) { bot.name); return; } + drop(service_manager); info!("{} is alive", bot.name,); diff --git a/src/main.rs b/src/main.rs index e2a2a6b..def8bff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,7 +35,8 @@ async fn main() { let bot = Bot::builder(BOT_NAME) .with_services(initialize_services(&config)) .await - .build(); + .build() + .await; lum::run(bot).await; } diff --git a/src/service.rs b/src/service.rs index 7aa81ea..f0737af 100644 --- a/src/service.rs +++ b/src/service.rs @@ -13,6 +13,8 @@ use std::{ }; use tokio::sync::RwLock; +use crate::setlock::SetLock; + pub mod discord; pub type PinnedBoxedFuture<'a, T> = Pin + 'a>>; @@ -144,16 +146,19 @@ impl Hash for ServiceInfo { //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a PinnedBoxedFutureResult pub trait Service: DowncastSync { fn info(&self) -> &ServiceInfo; - fn start(&mut self, service_manager: &ServiceManager) -> PinnedBoxedFutureResult<'_, ()>; + fn start( + &mut self, + service_manager: Arc>, + ) -> PinnedBoxedFutureResult<'_, ()>; fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()>; // Used for downcasting in get_service method of ServiceManager //fn as_any_arc(&self) -> Arc; - fn wrapped_start<'a>( - &'a mut self, - service_manager: &'a ServiceManager, - ) -> PinnedBoxedFuture<'a, ()> { + fn wrapped_start( + &mut self, + service_manager: Arc>, + ) -> PinnedBoxedFuture<()> { Box::pin(async move { let mut status = self.info().status.write().await; @@ -283,13 +288,38 @@ impl ServiceManagerBuilder { self } - pub fn build(self) -> ServiceManager { - ServiceManager::from(self) + pub async fn build(self) -> Arc> { + let service_manager = ServiceManager { + services: self.services, + arc: RwLock::new(SetLock::new()), + }; + + let self_arc = Arc::new(RwLock::new(service_manager)); + + match self_arc + .write() + .await + .arc + .write() + .await + .set(Arc::clone(&self_arc)) + { + Ok(()) => {} + Err(err) => { + panic!( + "Failed to set ServiceManager in SetLock for self_arc: {}", + err + ); + } + } + + self_arc } } pub struct ServiceManager { pub services: Vec>>, + arc: RwLock>>>, } impl ServiceManager { @@ -301,7 +331,8 @@ impl ServiceManager { Box::pin(async move { for service in &self.services { let mut service = service.write().await; - service.wrapped_start(self).await; + let service_manager = Arc::clone(self.arc.read().await.unwrap()); + service.wrapped_start(service_manager).await; } }) } @@ -453,11 +484,3 @@ impl Display for ServiceManager { Ok(()) } } - -impl From for ServiceManager { - fn from(builder: ServiceManagerBuilder) -> Self { - Self { - services: builder.services, - } - } -} diff --git a/src/service/discord.rs b/src/service/discord.rs index 8841e01..8438b8e 100644 --- a/src/service/discord.rs +++ b/src/service/discord.rs @@ -54,7 +54,10 @@ impl Service for DiscordService { &self.info } - fn start(&mut self, service_manager: &ServiceManager) -> PinnedBoxedFutureResult<'_, ()> { + fn start( + &mut self, + _service_manager: Arc>, + ) -> PinnedBoxedFutureResult<'_, ()> { Box::pin(async move { let framework = StandardFramework::new(); framework.configure(Configuration::new().prefix("!")); diff --git a/src/setlock.rs b/src/setlock.rs new file mode 100644 index 0000000..f960a52 --- /dev/null +++ b/src/setlock.rs @@ -0,0 +1,62 @@ +use std::{error::Error, fmt::Display}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum SetLockError { + AlreadySet, +} + +impl Display for SetLockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetLockError::AlreadySet => write!(f, "AlreadySet"), + } + } +} + +impl Error for SetLockError {} + +pub struct SetLock { + data: Option, +} + +impl SetLock { + pub fn new() -> Self { + Self { data: None } + } + + pub fn set(&mut self, data: T) -> Result<(), SetLockError> { + if self.data.is_some() { + return Err(SetLockError::AlreadySet); + } + + self.data = Some(data); + + Ok(()) + } + + pub fn is_set(&self) -> bool { + self.data.is_some() + } + + pub fn unwrap(&self) -> &T { + self.data.as_ref().unwrap() + } + + pub fn unwrap_mut(&mut self) -> &mut T { + self.data.as_mut().unwrap() + } + + pub fn get(&self) -> Option<&T> { + self.data.as_ref() + } + + pub fn get_mut(&mut self) -> Option<&mut T> { + self.data.as_mut() + } +} + +impl Default for SetLock { + fn default() -> Self { + Self::new() + } +} From 0277ab5d1a820b2ee180ec40c90e123b15dee699 Mon Sep 17 00:00:00 2001 From: Torben Schweren Date: Sun, 7 Jan 2024 14:11:25 +0100 Subject: [PATCH 5/5] Adapt Discord service - Adapt Discord service to new service framework and SetLock type --- src/service/discord.rs | 77 ++++++++++++++++++++++++++---------------- src/setlock.rs | 8 +++++ 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/src/service/discord.rs b/src/service/discord.rs index 8438b8e..4b48b6d 100644 --- a/src/service/discord.rs +++ b/src/service/discord.rs @@ -1,5 +1,7 @@ +use crate::setlock::SetLock; + use super::{PinnedBoxedFutureResult, Priority, Service, ServiceInfo, ServiceManager}; -use log::info; +use log::{error, info}; use serenity::{ all::{GatewayIntents, Ready}, async_trait, @@ -21,14 +23,14 @@ pub struct DiscordService { info: ServiceInfo, discord_token: String, connection_timeout: Duration, - pub client: Arc>>, + pub ready: Arc>>, client_handle: Option>>, - pub cache: Option>, - pub data: Option>>, - pub http: Option>, - pub shard_manager: Option>, - pub voice_manager: Option>, - pub ws_url: Option>>, + pub cache: SetLock>, + pub data: SetLock>>, + pub http: SetLock>, + pub shard_manager: SetLock>, + pub voice_manager: SetLock>, + pub ws_url: SetLock>>, } impl DiscordService { @@ -37,14 +39,14 @@ impl DiscordService { info: ServiceInfo::new("lum_builtin_discord", "Discord", Priority::Essential), discord_token: discord_token.to_string(), connection_timeout, - client: Arc::new(Mutex::new(None)), + ready: Arc::new(RwLock::new(SetLock::new())), client_handle: None, - cache: None, - data: None, - http: None, - shard_manager: None, - voice_manager: None, - ws_url: None, + cache: SetLock::new(), + data: SetLock::new(), + http: SetLock::new(), + shard_manager: SetLock::new(), + voice_manager: SetLock::new(), + ws_url: SetLock::new(), } } } @@ -67,22 +69,36 @@ impl Service for DiscordService { let mut client = Client::builder(self.discord_token.as_str(), GatewayIntents::all()) .framework(framework) .event_handler(EventHandler::new( - Arc::clone(&self.client), + Arc::clone(&self.ready), Arc::clone(&client_ready_notify), )) .await?; - self.cache = Some(Arc::clone(&client.cache)); - self.data = Some(Arc::clone(&client.data)); - self.http = Some(Arc::clone(&client.http)); - self.shard_manager = Some(Arc::clone(&client.shard_manager)); - if let Some(shard_manager) = &self.shard_manager { - self.shard_manager = Some(Arc::clone(shard_manager)); + if let Err(error) = self.cache.set(Arc::clone(&client.cache)) { + return Err(format!("Failed to set cache SetLock: {}", error).into()); + } + + if let Err(error) = self.data.set(Arc::clone(&client.data)) { + return Err(format!("Failed to set data SetLock: {}", error).into()); + } + + if let Err(error) = self.http.set(Arc::clone(&client.http)) { + return Err(format!("Failed to set http SetLock: {}", error).into()); + } + + if let Err(error) = self.shard_manager.set(Arc::clone(&client.shard_manager)) { + return Err(format!("Failed to set shard_manager SetLock: {}", error).into()); + } + + if let Some(voice_manager) = &client.voice_manager { + if let Err(error) = self.voice_manager.set(Arc::clone(voice_manager)) { + return Err(format!("Failed to set voice_manager SetLock: {}", error).into()); + } } - if let Some(voice_manager) = &self.voice_manager { - self.voice_manager = Some(Arc::clone(voice_manager)); + + if let Err(error) = self.ws_url.set(Arc::clone(&client.ws_url)) { + return Err(format!("Failed to set ws_url SetLock: {}", error).into()); } - self.ws_url = Some(Arc::clone(&client.ws_url)); info!("Connecting to Discord"); let client_handle = tokio::spawn(async move { client.start().await }); @@ -108,7 +124,7 @@ impl Service for DiscordService { self.connection_timeout.as_secs() ) .into()); - } + }; self.client_handle = Some(client_handle); Ok(()) @@ -140,12 +156,12 @@ async fn convert_thread_result(client_handle: JoinHandle>) -> } struct EventHandler { - client: Arc>>, + client: Arc>>, ready_notify: Arc, } impl EventHandler { - pub fn new(client: Arc>>, ready_notify: Arc) -> Self { + pub fn new(client: Arc>>, ready_notify: Arc) -> Self { Self { client, ready_notify, @@ -157,7 +173,10 @@ impl EventHandler { impl client::EventHandler for EventHandler { async fn ready(&self, _ctx: Context, data_about_bot: Ready) { info!("Connected to Discord as {}", data_about_bot.user.tag()); - *self.client.lock().await = Some(data_about_bot); + if let Err(error) = self.client.write().await.set(data_about_bot) { + error!("Failed to set client SetLock: {}", error); + panic!("Failed to set client SetLock: {}", error); + } self.ready_notify.notify_one(); } } diff --git a/src/setlock.rs b/src/setlock.rs index f960a52..7bee17c 100644 --- a/src/setlock.rs +++ b/src/setlock.rs @@ -39,10 +39,18 @@ impl SetLock { } pub fn unwrap(&self) -> &T { + if self.data.is_none() { + panic!("unwrap called on an unset SetLock"); + } + self.data.as_ref().unwrap() } pub fn unwrap_mut(&mut self) -> &mut T { + if self.data.is_none() { + panic!("unwrap_mut called on an unset SetLock"); + } + self.data.as_mut().unwrap() }