diff --git a/Cargo.toml b/Cargo.toml index b32aaf3..991b168 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ humantime = "2.1.0" log = { version = "0.4.20", features = ["serde"] } serde = { version = "1.0.193", features = ["derive"] } serde_json = "1.0.108" +serenity = { version = "0.12.0", default-features=false, features = ["builder", "cache", "collector", "client", "framework", "gateway", "http", "model", "standard_framework", "utils", "voice", "default_native_tls", "tokio_task_builder", "unstable_discord_api", "simd_json", "temp_cache", "chrono", "interactions_endpoint"] } sqlx = { version = "0.7.3", features = ["runtime-tokio", "any", "postgres", "mysql", "sqlite", "tls-native-tls", "migrate", "macros", "uuid", "chrono", "json"] } thiserror = "1.0.52" tokio = { version = "1.35.1", features = ["full"] } diff --git a/src/config.rs b/src/config.rs index 41e4ed1..2abd660 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,6 +4,7 @@ use std::{ fmt::{Display, Formatter}, fs, io, path::PathBuf, + time::Duration, }; use thiserror::Error; @@ -37,16 +38,23 @@ fn discord_token_default() -> String { String::from("Please provide a token") } +fn discord_timeout_default() -> Duration { + Duration::from_secs(10) +} + #[derive(Debug, PartialEq, PartialOrd, Serialize, Deserialize, Clone)] pub struct Config { #[serde(rename = "discordToken", default = "discord_token_default")] pub discord_token: String, + #[serde(rename = "discordTimeout", default = "discord_timeout_default")] + pub discord_timeout: Duration, } impl Default for Config { fn default() -> Self { Config { discord_token: discord_token_default(), + discord_timeout: discord_timeout_default(), } } } diff --git a/src/log.rs b/src/log.rs index 4898578..43aa4e5 100644 --- a/src/log.rs +++ b/src/log.rs @@ -33,6 +33,11 @@ pub fn setup() -> Result<(), SetLoggerError> { )) }) .level(get_min_log_level()) + .level_for("serenity", LevelFilter::Warn) + .level_for("hyper", LevelFilter::Warn) + .level_for("tracing", LevelFilter::Warn) + .level_for("reqwest", LevelFilter::Warn) + .level_for("tungstenite", LevelFilter::Warn) .chain(io::stdout()) .apply()?; diff --git a/src/main.rs b/src/main.rs index 6e3bad7..7624dae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,10 @@ use ::log::{error, warn}; -use lum::{bot::Bot, config::ConfigHandler, log, service::Service}; +use lum::{ + bot::Bot, + config::{Config, ConfigHandler}, + log, + service::{discord::DiscordService, Service}, +}; const BOT_NAME: &str = "Lum"; @@ -12,7 +17,7 @@ async fn main() { } let config_handler = ConfigHandler::new(BOT_NAME.to_lowercase().as_str()); - let _config = match config_handler.load_config() { + let config = match config_handler.load_config() { Ok(config) => config, Err(err) => { error!( @@ -25,7 +30,7 @@ async fn main() { }; let bot = Bot::builder(BOT_NAME) - .with_services(initialize_services()) + .with_services(initialize_services(&config)) .build(); lum::run(bot).await; @@ -40,9 +45,12 @@ fn setup_logger() { } } -fn initialize_services() -> Vec> { +fn initialize_services(config: &Config) -> Vec> { //TODO: Add services //... - vec![] + let discord_service = + DiscordService::new(config.discord_token.as_str(), config.discord_timeout); + + vec![Box::new(discord_service)] } diff --git a/src/service.rs b/src/service.rs index f5d4e73..d0e09c3 100644 --- a/src/service.rs +++ b/src/service.rs @@ -11,6 +11,8 @@ use std::{ }; use tokio::sync::Mutex; +pub mod discord; + pub type PinnedBoxedFuture<'a, T> = Pin + 'a>>; pub type PinnedBoxedFutureResult<'a, T> = @@ -174,12 +176,12 @@ pub trait Service: ServiceInternals { match self.start().await { Ok(()) => { - self.info().set_status(Status::Started).await; info!("Started service: {}", self.info().name); + self.info().set_status(Status::Started).await; } Err(error) => { + error!("Failed to start service {}: {}", self.info().name, error); self.info().set_status(Status::FailedStarting(error)).await; - error!("Failed to start service: {}", self.info().name); } } }) @@ -189,7 +191,7 @@ pub trait Service: ServiceInternals { Box::pin(async move { let mut status = self.info().status.lock().await; - if matches!(&*status, Status::Started) { + if !matches!(&*status, Status::Started) { warn!( "Tried to stop service {} while it was in state {}. Ignoring stop request.", self.info().name, @@ -203,12 +205,12 @@ pub trait Service: ServiceInternals { match ServiceInternals::stop(self).await { Ok(()) => { - self.info().set_status(Status::Stopped).await; info!("Stopped service: {}", self.info().name); + self.info().set_status(Status::Stopped).await; } Err(error) => { + error!("Failed to stop service {}: {}", self.info().name, error); self.info().set_status(Status::FailedStopping(error)).await; - error!("Failed to stop service: {}", self.info().name); } } }) diff --git a/src/service/discord.rs b/src/service/discord.rs new file mode 100644 index 0000000..c025531 --- /dev/null +++ b/src/service/discord.rs @@ -0,0 +1,155 @@ +use super::{PinnedBoxedFutureResult, Priority, Service, ServiceInfo, ServiceInternals}; +use log::info; +use serenity::{ + all::{GatewayIntents, Ready}, + async_trait, + client::{self, Cache, Context}, + framework::{standard::Configuration, StandardFramework}, + gateway::{ShardManager, VoiceGatewayManager}, + http::Http, + prelude::TypeMap, + Client, Error, +}; +use std::{sync::Arc, time::Duration}; +use tokio::{ + sync::{Mutex, Notify, RwLock}, + task::JoinHandle, + time::timeout, +}; + +pub struct DiscordService { + info: ServiceInfo, + discord_token: String, + connection_timeout: Duration, + pub client: Arc>>, + client_handle: Option>>, + pub cache: Option>, + pub data: Option>>, + pub http: Option>, + pub shard_manager: Option>, + pub voice_manager: Option>, + pub ws_url: Option>>, +} + +impl DiscordService { + pub fn new(discord_token: &str, connection_timeout: Duration) -> Self { + Self { + info: ServiceInfo::new("lum_builtin_discord", "Discord", Priority::Essential), + discord_token: discord_token.to_string(), + connection_timeout, + client: Arc::new(Mutex::new(None)), + client_handle: None, + cache: None, + data: None, + http: None, + shard_manager: None, + voice_manager: None, + ws_url: None, + } + } +} + +impl ServiceInternals for DiscordService { + fn start(&mut self) -> PinnedBoxedFutureResult<'_, ()> { + Box::pin(async move { + let framework = StandardFramework::new(); + framework.configure(Configuration::new().prefix("!")); + + let client_ready_notify = Arc::new(Notify::new()); + + let mut client = Client::builder(self.discord_token.as_str(), GatewayIntents::all()) + .framework(framework) + .event_handler(EventHandler::new( + Arc::clone(&self.client), + Arc::clone(&client_ready_notify), + )) + .await?; + + self.cache = Some(Arc::clone(&client.cache)); + self.data = Some(Arc::clone(&client.data)); + self.http = Some(Arc::clone(&client.http)); + self.shard_manager = Some(Arc::clone(&client.shard_manager)); + if let Some(shard_manager) = &self.shard_manager { + self.shard_manager = Some(Arc::clone(shard_manager)); + } + if let Some(voice_manager) = &self.voice_manager { + self.voice_manager = Some(Arc::clone(voice_manager)); + } + self.ws_url = Some(Arc::clone(&client.ws_url)); + + info!("Connecting to Discord"); + let client_handle = tokio::spawn(async move { client.start().await }); + + if timeout(self.connection_timeout, client_ready_notify.notified()) + .await + .is_err() + { + client_handle.abort(); + + return Err(format!( + "Discord client failed to connect within {} seconds", + self.connection_timeout.as_secs() + ) + .into()); + } + + if client_handle.is_finished() { + client_handle.await??; + return Err("Discord client stopped unexpectedly and with no error".into()); + } + + self.client_handle = Some(client_handle); + Ok(()) + }) + } + + fn stop(&mut self) -> PinnedBoxedFutureResult<'_, ()> { + Box::pin(async move { + if let Some(handle) = self.client_handle.take() { + info!("Waiting for Discord client to stop..."); + handle.abort(); + + let result = match handle.await { + Ok(result) => result, + Err(_) => { + info!("Discord client stopped"); + return Ok(()); + } + }; + + result?; + } + + Ok(()) + }) + } +} + +impl Service for DiscordService { + fn info(&self) -> &ServiceInfo { + &self.info + } +} + +struct EventHandler { + client: Arc>>, + ready_notify: Arc, +} + +impl EventHandler { + pub fn new(client: Arc>>, ready_notify: Arc) -> Self { + Self { + client, + ready_notify, + } + } +} + +#[async_trait] +impl client::EventHandler for EventHandler { + async fn ready(&self, _ctx: Context, data_about_bot: Ready) { + info!("Connected to Discord as {}", data_about_bot.user.tag()); + *self.client.lock().await = Some(data_about_bot); + self.ready_notify.notify_one(); + } +}