diff --git a/Cargo.toml b/Cargo.toml index 991b168..b91d029 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ repository = "https://github.com/Kitt3120/lum" [dependencies] dirs = "5.0.1" +downcast-rs = "1.2.0" fern = { version = "0.6.2", features = ["chrono", "colored", "date-based"] } humantime = "2.1.0" log = { version = "0.4.20", features = ["serde"] } diff --git a/src/bot.rs b/src/bot.rs index e9c8946..3984cab 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use tokio::sync::RwLock; + use crate::service::{PinnedBoxedFuture, Service, ServiceManager, ServiceManagerBuilder}; pub struct BotBuilder { @@ -13,28 +17,31 @@ impl BotBuilder { } } - pub fn with_service(mut self, service: Box) -> Self { - self.service_manager = self.service_manager.with_service(service); // The ServiceManagerBuilder itself will warn when adding a service multiple times + pub async fn with_service(mut self, service: Arc>) -> Self { + self.service_manager = self.service_manager.with_service(service).await; // The ServiceManagerBuilder itself will warn when adding a service multiple times self } - pub fn with_services(mut self, services: Vec>) -> Self { + pub async fn with_services(mut self, services: Vec>>) -> Self { for service in services { - self.service_manager = self.service_manager.with_service(service); + self.service_manager = self.service_manager.with_service(service).await; } self } - pub fn build(self) -> Bot { - Bot::from(self) + pub async fn build(self) -> Bot { + Bot { + name: self.name, + service_manager: self.service_manager.build().await, + } } } pub struct Bot { pub name: String, - pub service_manager: ServiceManager, + pub service_manager: Arc>, } impl Bot { @@ -45,7 +52,7 @@ impl Bot { //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a future pub fn start(&mut self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - self.service_manager.start_services().await; + self.service_manager.write().await.start_services().await; //TODO: Potential for further initialization here, like modules }) } @@ -53,17 +60,8 @@ impl Bot { //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a future pub fn stop(&mut self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - self.service_manager.stop_services().await; + self.service_manager.write().await.stop_services().await; //TODO: Potential for further deinitialization here, like modules }) } } - -impl From for Bot { - fn from(builder: BotBuilder) -> Self { - Self { - name: builder.name, - service_manager: builder.service_manager.build(), - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 1296293..4396be6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ pub mod bot; pub mod config; pub mod log; pub mod service; +pub mod setlock; pub fn is_debug() -> bool { cfg!(debug_assertions) @@ -35,8 +36,9 @@ pub async fn run(mut bot: Bot) { } }; - if bot.service_manager.overall_status().await != OverallStatus::Healthy { - let status_tree = bot.service_manager.status_tree().await; + let service_manager = bot.service_manager.read().await; + if service_manager.overall_status().await != OverallStatus::Healthy { + let status_tree = service_manager.status_tree().await; error!("{} is not healthy! Some essential services did not start up successfully. Please check the logs.\nService status tree:\n{}\n{} will exit.", bot.name, @@ -44,6 +46,7 @@ pub async fn run(mut bot: Bot) { bot.name); return; } + drop(service_manager); info!("{} is alive", bot.name,); diff --git a/src/main.rs b/src/main.rs index 7624dae..def8bff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use ::log::{error, warn}; use lum::{ bot::Bot, @@ -5,6 +7,7 @@ use lum::{ log, service::{discord::DiscordService, Service}, }; +use tokio::sync::RwLock; const BOT_NAME: &str = "Lum"; @@ -31,7 +34,9 @@ async fn main() { let bot = Bot::builder(BOT_NAME) .with_services(initialize_services(&config)) - .build(); + .await + .build() + .await; lum::run(bot).await; } @@ -45,12 +50,12 @@ fn setup_logger() { } } -fn initialize_services(config: &Config) -> Vec> { +fn initialize_services(config: &Config) -> Vec>> { //TODO: Add services //... let discord_service = DiscordService::new(config.discord_token.as_str(), config.discord_timeout); - vec![Box::new(discord_service)] + vec![Arc::new(RwLock::new(discord_service))] } diff --git a/src/service.rs b/src/service.rs index d0e09c3..f0737af 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,15 +1,19 @@ +use downcast_rs::{impl_downcast, DowncastSync}; use log::{error, info, warn}; use std::{ + any::Any, cmp::Ordering, - collections::HashMap, error::Error, fmt::Display, future::Future, hash::{Hash, Hasher}, + mem, pin::Pin, sync::Arc, }; -use tokio::sync::Mutex; +use tokio::sync::RwLock; + +use crate::setlock::SetLock; pub mod discord; @@ -92,11 +96,11 @@ impl Display for Priority { #[derive(Debug)] pub struct ServiceInfo { - pub id: String, + id: String, pub name: String, pub priority: Priority, - pub status: Arc>, + pub status: Arc>, } impl ServiceInfo { @@ -105,20 +109,12 @@ impl ServiceInfo { id: id.to_string(), name: name.to_string(), priority, - status: Arc::new(Mutex::new(Status::Stopped)), + status: Arc::new(RwLock::new(Status::Stopped)), } } - pub fn is_available(&self) -> Pin + '_>> { - Box::pin(async move { - let lock = self.status.lock().await; - matches!(&*lock, Status::Started) - }) - } - pub async fn set_status(&self, status: Status) { - let mut lock = self.status.lock().await; - *lock = status; + *(self.status.write().await) = status } } @@ -147,20 +143,24 @@ impl Hash for ServiceInfo { self.id.hash(state); } } - //TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a PinnedBoxedFutureResult -pub trait ServiceInternals { - fn start(&mut self) -> PinnedBoxedFutureResult<'_, ()>; +pub trait Service: DowncastSync { + fn info(&self) -> &ServiceInfo; + fn start( + &mut self, + service_manager: Arc>, + ) -> PinnedBoxedFutureResult<'_, ()>; fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()>; -} -//TODO: When Rust allows async trait methods to be object-safe, refactor this to use async instead of returning a PinnedBoxedFutureResult -pub trait Service: ServiceInternals { - fn info(&self) -> &ServiceInfo; + // Used for downcasting in get_service method of ServiceManager + //fn as_any_arc(&self) -> Arc; - fn wrapped_start(&mut self) -> PinnedBoxedFuture<'_, ()> { + fn wrapped_start( + &mut self, + service_manager: Arc>, + ) -> PinnedBoxedFuture<()> { Box::pin(async move { - let mut status = self.info().status.lock().await; + let mut status = self.info().status.write().await; if !matches!(&*status, Status::Stopped) { warn!( @@ -174,7 +174,7 @@ pub trait Service: ServiceInternals { *status = Status::Starting; drop(status); - match self.start().await { + match self.start(service_manager).await { Ok(()) => { info!("Started service: {}", self.info().name); self.info().set_status(Status::Started).await; @@ -189,7 +189,7 @@ pub trait Service: ServiceInternals { fn wrapped_stop(&mut self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - let mut status = self.info().status.lock().await; + let mut status = self.info().status.write().await; if !matches!(&*status, Status::Started) { warn!( @@ -203,7 +203,7 @@ pub trait Service: ServiceInternals { *status = Status::Stopping; drop(status); - match ServiceInternals::stop(self).await { + match self.stop().await { Ok(()) => { info!("Stopped service: {}", self.info().name); self.info().set_status(Status::Stopped).await; @@ -215,8 +215,14 @@ pub trait Service: ServiceInternals { } }) } + + fn is_available(&self) -> PinnedBoxedFuture<'_, bool> { + Box::pin(async move { matches!(&*(self.info().status.read().await), Status::Started) }) + } } +impl_downcast!(sync Service); + impl Eq for dyn Service {} impl PartialEq for dyn Service { @@ -245,38 +251,75 @@ impl Hash for dyn Service { #[derive(Default)] pub struct ServiceManagerBuilder { - services: Vec>, + services: Vec>>, } impl ServiceManagerBuilder { pub fn new() -> Self { - Self { services: vec![] } + Self { + services: Vec::new(), + } } - pub fn with_service(mut self, service: Box) -> Self { - let service_exists = self.services.iter().any(|s| s.info() == service.info()); + //TODO: When Rust allows async closures, refactor this to use iterator methods instead of for loop + pub async fn with_service(mut self, service: Arc>) -> Self { + let lock = service.read().await; - if service_exists { - warn!( - "Tried to add service {} ({}), but a service with that ID already exists. Ignoring.", - service.info().name, service.info().id - ); + let mut found = false; + for registered_service in self.services.iter() { + let registered_service = registered_service.read().await; + + if registered_service.info().id == lock.info().id { + found = true; + } + } + if found { + warn!( + "Tried to add service {} ({}), but a service with that ID already exists. Ignoring.", + lock.info().name, lock.info().id + ); return self; } - self.services.push(service); + drop(lock); + self.services.push(service); self } - pub fn build(self) -> ServiceManager { - ServiceManager::from(self) + pub async fn build(self) -> Arc> { + let service_manager = ServiceManager { + services: self.services, + arc: RwLock::new(SetLock::new()), + }; + + let self_arc = Arc::new(RwLock::new(service_manager)); + + match self_arc + .write() + .await + .arc + .write() + .await + .set(Arc::clone(&self_arc)) + { + Ok(()) => {} + Err(err) => { + panic!( + "Failed to set ServiceManager in SetLock for self_arc: {}", + err + ); + } + } + + self_arc } } pub struct ServiceManager { - pub services: Vec>, + pub services: Vec>>, + arc: RwLock>>>, } impl ServiceManager { @@ -284,49 +327,59 @@ impl ServiceManager { ServiceManagerBuilder::new() } - pub fn start_services(&mut self) -> PinnedBoxedFuture<'_, ()> { + pub fn start_services(&self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - for service in &mut self.services { - service.wrapped_start().await; + for service in &self.services { + let mut service = service.write().await; + let service_manager = Arc::clone(self.arc.read().await.unwrap()); + service.wrapped_start(service_manager).await; } }) } - pub fn stop_services(&mut self) -> PinnedBoxedFuture<'_, ()> { + pub fn stop_services(&self) -> PinnedBoxedFuture<'_, ()> { Box::pin(async move { - for service in &mut self.services { + for service in &self.services { + let mut service = service.write().await; service.wrapped_stop().await; } }) } - pub fn get_service(&self, id: &str) -> Option<&dyn Service> { - self.services - .iter() - .find(|s| s.info().id == id) - .map(|s| &**s) - } - - pub fn status_map(&self) -> PinnedBoxedFuture<'_, HashMap>>> { - Box::pin(async move { - let mut status_map = HashMap::new(); - - for service in self.services.iter() { - status_map.insert( - service.info().id.clone(), - Arc::clone(&service.info().status), - ); + pub async fn get_service(&self) -> Option>> + where + T: Service + Any + Send + Sync + 'static, + { + for service in self.services.iter() { + let lock = service.read().await; + let is_t = lock.as_any().is::(); + drop(lock); + + if is_t { + let arc_clone = Arc::clone(service); + let service_ptr: *const Arc> = &arc_clone; + + /* + I tried to do this in safe rust for 3 days, but I couldn't figure it out + Should you come up with a way to do this in safe rust, please make a PR! :) + Anyways, this should never cause any issues, since we checked if the service is of type T + */ + unsafe { + let t_ptr: *const Arc> = mem::transmute(service_ptr); + return Some(Arc::clone(&*t_ptr)); + } } + } - status_map - }) + None } //TODO: When Rust allows async closures, refactor this to use iterator methods instead of for loop pub fn overall_status(&self) -> PinnedBoxedFuture<'_, OverallStatus> { Box::pin(async move { for service in self.services.iter() { - let status = service.info().status.lock().await; + let service = service.read().await; + let status = service.info().status.read().await; if !matches!(&*status, Status::Started) { return OverallStatus::Unhealthy; @@ -340,96 +393,70 @@ impl ServiceManager { //TODO: When Rust allows async closures, refactor this to use iterator methods instead of for loop pub fn status_tree(&self) -> PinnedBoxedFuture<'_, String> { Box::pin(async move { - let status_map = self.status_map().await; - let mut text_buffer = String::new(); - let mut failed_essentials = HashMap::new(); - let mut failed_optionals = HashMap::new(); - let mut non_failed_essentials = HashMap::new(); - let mut non_failed_optionals = HashMap::new(); - let mut others = HashMap::new(); - - for (service, status) in status_map.into_iter() { - let priority = match self.get_service(service.as_str()) { - Some(service) => service.info().priority, - None => unreachable!( - "Service with ID {} not found in ServiceManager. This should never happen!", - service, - ), - }; - - let status = status.lock().await; - - match &*status { - Status::Started | Status::Stopped => { - if priority == Priority::Essential { - non_failed_essentials.insert(service, status.to_string()); - } else { - non_failed_optionals.insert(service, status.to_string()); + let mut failed_essentials = String::new(); + let mut failed_optionals = String::new(); + let mut non_failed_essentials = String::new(); + let mut non_failed_optionals = String::new(); + let mut others = String::new(); + + for service in self.services.iter() { + let service = service.read().await; + let info = service.info(); + let priority = &info.priority; + let status = info.status.read().await; + + match *status { + Status::Started | Status::Stopped => match priority { + Priority::Essential => { + non_failed_essentials + .push_str(&format!(" - {}: {}\n", info.name, status)); } - } + Priority::Optional => { + non_failed_optionals + .push_str(&format!(" - {}: {}\n", info.name, status)); + } + }, Status::FailedStarting(_) | Status::FailedStopping(_) - | Status::RuntimeError(_) => { - if priority == Priority::Essential { - failed_essentials.insert(service, status.to_string()); - } else { - failed_optionals.insert(service, status.to_string()); + | Status::RuntimeError(_) => match priority { + Priority::Essential => { + failed_essentials.push_str(&format!(" - {}: {}\n", info.name, status)); } - } + Priority::Optional => { + failed_optionals.push_str(&format!(" - {}: {}\n", info.name, status)); + } + }, _ => { - others.insert(service, status.to_string()); + others.push_str(&format!(" - {}: {}\n", info.name, status)); } } } - let section_generator = |services: &HashMap, title: &str| -> String { - let mut text_buffer = String::new(); - - text_buffer.push_str(&format!("- {}:\n", title)); - - for (service, status) in services.iter() { - let service = match self.get_service(service) { - Some(service) => service, - None => unreachable!( - "Service with ID {} not found in ServiceManager. This should never happen!", - service - ), - }; - - text_buffer.push_str(&format!(" - {}: {}\n", service.info().name, status)); - } - - text_buffer - }; - if !failed_essentials.is_empty() { - text_buffer.push_str( - section_generator(&failed_essentials, "Failed essential services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Failed essential services")); + text_buffer.push_str(&failed_essentials); } if !failed_optionals.is_empty() { - text_buffer.push_str( - section_generator(&failed_optionals, "Failed optional services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Failed optional services")); + text_buffer.push_str(&failed_optionals); } if !non_failed_essentials.is_empty() { - text_buffer.push_str( - section_generator(&non_failed_essentials, "Essential services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Essential services")); + text_buffer.push_str(&non_failed_essentials); } if !non_failed_optionals.is_empty() { - text_buffer.push_str( - section_generator(&non_failed_optionals, "Optional services").as_str(), - ); + text_buffer.push_str(&format!("- {}:\n", "Optional services")); + text_buffer.push_str(&non_failed_optionals); } if !others.is_empty() { - text_buffer.push_str(section_generator(&others, "Other services").as_str()); + text_buffer.push_str(&format!("- {}:\n", "Other services")); + text_buffer.push_str(&others); } text_buffer @@ -448,6 +475,7 @@ impl Display for ServiceManager { let mut services = self.services.iter().peekable(); while let Some(service) = services.next() { + let service = service.blocking_read(); write!(f, "{} ({})", service.info().name, service.info().id)?; if services.peek().is_some() { write!(f, ", ")?; @@ -456,11 +484,3 @@ impl Display for ServiceManager { Ok(()) } } - -impl From for ServiceManager { - fn from(builder: ServiceManagerBuilder) -> Self { - Self { - services: builder.services, - } - } -} diff --git a/src/service/discord.rs b/src/service/discord.rs index c025531..4b48b6d 100644 --- a/src/service/discord.rs +++ b/src/service/discord.rs @@ -1,5 +1,7 @@ -use super::{PinnedBoxedFutureResult, Priority, Service, ServiceInfo, ServiceInternals}; -use log::info; +use crate::setlock::SetLock; + +use super::{PinnedBoxedFutureResult, Priority, Service, ServiceInfo, ServiceManager}; +use log::{error, info}; use serenity::{ all::{GatewayIntents, Ready}, async_trait, @@ -14,21 +16,21 @@ use std::{sync::Arc, time::Duration}; use tokio::{ sync::{Mutex, Notify, RwLock}, task::JoinHandle, - time::timeout, + time::{sleep, timeout}, }; pub struct DiscordService { info: ServiceInfo, discord_token: String, connection_timeout: Duration, - pub client: Arc>>, + pub ready: Arc>>, client_handle: Option>>, - pub cache: Option>, - pub data: Option>>, - pub http: Option>, - pub shard_manager: Option>, - pub voice_manager: Option>, - pub ws_url: Option>>, + pub cache: SetLock>, + pub data: SetLock>>, + pub http: SetLock>, + pub shard_manager: SetLock>, + pub voice_manager: SetLock>, + pub ws_url: SetLock>>, } impl DiscordService { @@ -37,20 +39,27 @@ impl DiscordService { info: ServiceInfo::new("lum_builtin_discord", "Discord", Priority::Essential), discord_token: discord_token.to_string(), connection_timeout, - client: Arc::new(Mutex::new(None)), + ready: Arc::new(RwLock::new(SetLock::new())), client_handle: None, - cache: None, - data: None, - http: None, - shard_manager: None, - voice_manager: None, - ws_url: None, + cache: SetLock::new(), + data: SetLock::new(), + http: SetLock::new(), + shard_manager: SetLock::new(), + voice_manager: SetLock::new(), + ws_url: SetLock::new(), } } } -impl ServiceInternals for DiscordService { - fn start(&mut self) -> PinnedBoxedFutureResult<'_, ()> { +impl Service for DiscordService { + fn info(&self) -> &ServiceInfo { + &self.info + } + + fn start( + &mut self, + _service_manager: Arc>, + ) -> PinnedBoxedFutureResult<'_, ()> { Box::pin(async move { let framework = StandardFramework::new(); framework.configure(Configuration::new().prefix("!")); @@ -60,43 +69,62 @@ impl ServiceInternals for DiscordService { let mut client = Client::builder(self.discord_token.as_str(), GatewayIntents::all()) .framework(framework) .event_handler(EventHandler::new( - Arc::clone(&self.client), + Arc::clone(&self.ready), Arc::clone(&client_ready_notify), )) .await?; - self.cache = Some(Arc::clone(&client.cache)); - self.data = Some(Arc::clone(&client.data)); - self.http = Some(Arc::clone(&client.http)); - self.shard_manager = Some(Arc::clone(&client.shard_manager)); - if let Some(shard_manager) = &self.shard_manager { - self.shard_manager = Some(Arc::clone(shard_manager)); + if let Err(error) = self.cache.set(Arc::clone(&client.cache)) { + return Err(format!("Failed to set cache SetLock: {}", error).into()); + } + + if let Err(error) = self.data.set(Arc::clone(&client.data)) { + return Err(format!("Failed to set data SetLock: {}", error).into()); } - if let Some(voice_manager) = &self.voice_manager { - self.voice_manager = Some(Arc::clone(voice_manager)); + + if let Err(error) = self.http.set(Arc::clone(&client.http)) { + return Err(format!("Failed to set http SetLock: {}", error).into()); + } + + if let Err(error) = self.shard_manager.set(Arc::clone(&client.shard_manager)) { + return Err(format!("Failed to set shard_manager SetLock: {}", error).into()); + } + + if let Some(voice_manager) = &client.voice_manager { + if let Err(error) = self.voice_manager.set(Arc::clone(voice_manager)) { + return Err(format!("Failed to set voice_manager SetLock: {}", error).into()); + } + } + + if let Err(error) = self.ws_url.set(Arc::clone(&client.ws_url)) { + return Err(format!("Failed to set ws_url SetLock: {}", error).into()); } - self.ws_url = Some(Arc::clone(&client.ws_url)); info!("Connecting to Discord"); let client_handle = tokio::spawn(async move { client.start().await }); + // This prevents waiting for the timeout if the client fails immediately + // TODO: Optimize this, as it will currently add 1000mqs to the startup time + sleep(Duration::from_secs(1)).await; + if client_handle.is_finished() { + client_handle.await??; + return Err("Discord client stopped unexpectedly and with no error".into()); + } + if timeout(self.connection_timeout, client_ready_notify.notified()) .await .is_err() { client_handle.abort(); + let result = convert_thread_result(client_handle).await; + result?; return Err(format!( "Discord client failed to connect within {} seconds", self.connection_timeout.as_secs() ) .into()); - } - - if client_handle.is_finished() { - client_handle.await??; - return Err("Discord client stopped unexpectedly and with no error".into()); - } + }; self.client_handle = Some(client_handle); Ok(()) @@ -105,39 +133,35 @@ impl ServiceInternals for DiscordService { fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()> { Box::pin(async move { - if let Some(handle) = self.client_handle.take() { + if let Some(client_handle) = self.client_handle.take() { info!("Waiting for Discord client to stop..."); - handle.abort(); - - let result = match handle.await { - Ok(result) => result, - Err(_) => { - info!("Discord client stopped"); - return Ok(()); - } - }; + client_handle.abort(); + let result = convert_thread_result(client_handle).await; result?; } + info!("Discord client stopped"); Ok(()) }) } } -impl Service for DiscordService { - fn info(&self) -> &ServiceInfo { - &self.info +// If the thread ended WITHOUT a JoinError from aborting, the client already stopped unexpectedly +async fn convert_thread_result(client_handle: JoinHandle>) -> Result<(), Error> { + match client_handle.await { + Ok(result) => result, + Err(_) => Ok(()), } } struct EventHandler { - client: Arc>>, + client: Arc>>, ready_notify: Arc, } impl EventHandler { - pub fn new(client: Arc>>, ready_notify: Arc) -> Self { + pub fn new(client: Arc>>, ready_notify: Arc) -> Self { Self { client, ready_notify, @@ -149,7 +173,10 @@ impl EventHandler { impl client::EventHandler for EventHandler { async fn ready(&self, _ctx: Context, data_about_bot: Ready) { info!("Connected to Discord as {}", data_about_bot.user.tag()); - *self.client.lock().await = Some(data_about_bot); + if let Err(error) = self.client.write().await.set(data_about_bot) { + error!("Failed to set client SetLock: {}", error); + panic!("Failed to set client SetLock: {}", error); + } self.ready_notify.notify_one(); } } diff --git a/src/setlock.rs b/src/setlock.rs new file mode 100644 index 0000000..7bee17c --- /dev/null +++ b/src/setlock.rs @@ -0,0 +1,70 @@ +use std::{error::Error, fmt::Display}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum SetLockError { + AlreadySet, +} + +impl Display for SetLockError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetLockError::AlreadySet => write!(f, "AlreadySet"), + } + } +} + +impl Error for SetLockError {} + +pub struct SetLock { + data: Option, +} + +impl SetLock { + pub fn new() -> Self { + Self { data: None } + } + + pub fn set(&mut self, data: T) -> Result<(), SetLockError> { + if self.data.is_some() { + return Err(SetLockError::AlreadySet); + } + + self.data = Some(data); + + Ok(()) + } + + pub fn is_set(&self) -> bool { + self.data.is_some() + } + + pub fn unwrap(&self) -> &T { + if self.data.is_none() { + panic!("unwrap called on an unset SetLock"); + } + + self.data.as_ref().unwrap() + } + + pub fn unwrap_mut(&mut self) -> &mut T { + if self.data.is_none() { + panic!("unwrap_mut called on an unset SetLock"); + } + + self.data.as_mut().unwrap() + } + + pub fn get(&self) -> Option<&T> { + self.data.as_ref() + } + + pub fn get_mut(&mut self) -> Option<&mut T> { + self.data.as_mut() + } +} + +impl Default for SetLock { + fn default() -> Self { + Self::new() + } +}