diff --git a/.gitignore b/.gitignore index d6a8c164..048c2749 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ *~ *.html *.log.* +.vscode \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 8a467cd9..0b419e79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2266,6 +2266,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-stream", + "tokio-util", "tower 0.5.2", "tower-service", "traceparent", @@ -2314,6 +2315,7 @@ dependencies = [ "dhat", "futures", "http", + "libc", "num_cpus", "opentelemetry", "orion-configuration", @@ -2325,6 +2327,7 @@ dependencies = [ "orion-tracing", "orion-xds", "parking_lot", + "pingora-timeout", "prometheus", "regex", "serde", @@ -2332,6 +2335,7 @@ dependencies = [ "thiserror 1.0.69", "tikv-jemallocator", "tokio", + "tokio-util", "tower 0.5.2", "tracing", "tracing-appender", diff --git a/orion-lib/Cargo.toml b/orion-lib/Cargo.toml index 7dcbea61..9cc5dac8 100644 --- a/orion-lib/Cargo.toml +++ b/orion-lib/Cargo.toml @@ -71,6 +71,7 @@ typed-builder = "0.18.2" url.workspace = true uuid = { version = "1.17.0", features = ["v4"] } x509-parser = { version = "0.17", features = ["default"] } +tokio-util = "0.7.16" [dev-dependencies] diff --git a/orion-lib/src/lib.rs b/orion-lib/src/lib.rs index a51ba294..48fa4458 100644 --- a/orion-lib/src/lib.rs +++ b/orion-lib/src/lib.rs @@ -131,13 +131,16 @@ pub fn new_configuration_channel(capacity: usize) -> (ConfigurationSenders, Conf /// Start the listeners manager directly without spawning a background task. /// Caller must be inside a Tokio runtime and await this async function. -pub async fn start_listener_manager(configuration_receivers: ConfigurationReceivers) -> Result<()> { +pub async fn start_listener_manager( + configuration_receivers: ConfigurationReceivers, + ct: tokio_util::sync::CancellationToken, +) -> Result<()> { let ConfigurationReceivers { listener_configuration_receiver, route_configuration_receiver } = configuration_receivers; tracing::debug!("listeners manager starting"); let mgr = ListenersManager::new(listener_configuration_receiver, route_configuration_receiver); - mgr.start().await.map_err(|err| { + mgr.start(ct).await.map_err(|err| { tracing::warn!(error = %err, "listeners manager exited with error"); err })?; diff --git a/orion-lib/src/listeners/listeners_manager.rs b/orion-lib/src/listeners/listeners_manager.rs index c08775c0..6f24a45e 100644 --- a/orion-lib/src/listeners/listeners_manager.rs +++ b/orion-lib/src/listeners/listeners_manager.rs @@ -55,26 +55,30 @@ impl ListenerInfo { } pub struct ListenersManager { - configuration_channel: mpsc::Receiver, + listener_configuration_channel: mpsc::Receiver, route_configuration_channel: mpsc::Receiver, listener_handles: BTreeMap<&'static str, ListenerInfo>, } impl ListenersManager { pub fn new( - configuration_channel: mpsc::Receiver, + listener_configuration_channel: mpsc::Receiver, route_configuration_channel: mpsc::Receiver, ) -> Self { - ListenersManager { configuration_channel, route_configuration_channel, listener_handles: BTreeMap::new() } + ListenersManager { + listener_configuration_channel, + route_configuration_channel, + listener_handles: BTreeMap::new(), + } } - pub async fn start(mut self) -> Result<()> { + pub async fn start(mut self, ct: tokio_util::sync::CancellationToken) -> Result<()> { let (tx_secret_updates, _) = broadcast::channel(16); let (tx_route_updates, _) = broadcast::channel(16); - + // TODO: create child token for each listener? loop { tokio::select! { - Some(listener_configuration_change) = self.configuration_channel.recv() => { + Some(listener_configuration_change) = self.listener_configuration_channel.recv() => { match listener_configuration_change { ListenerConfigurationChange::Added(boxed) => { let (factory, listener_conf) = *boxed; @@ -110,9 +114,9 @@ impl ListenersManager { warn!("Internal problem when updating a route: {e}"); } }, - else => { - warn!("All listener manager channels are closed...exiting"); - return Err("All listener manager channels are closed...exiting".into()); + _ = ct.cancelled() => { + warn!("Listener manager exiting"); + return Ok(()); } } } diff --git a/orion-proxy/Cargo.toml b/orion-proxy/Cargo.toml index 4d40b589..8ed0df0f 100644 --- a/orion-proxy/Cargo.toml +++ b/orion-proxy/Cargo.toml @@ -34,6 +34,8 @@ parking_lot = "0.12.3" tokio.workspace = true tower.workspace = true tracing.workspace = true +pingora-timeout = "0.3.0" + axum = "0.8.1" compact_str.workspace = true @@ -52,6 +54,7 @@ tracing-subscriber = { workspace = true, features = [ "registry", "std", ] } +tokio-util = "0.7.16" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = { version = "0.6", optional = true } @@ -65,5 +68,8 @@ axum-test = "17.2.0" orion-data-plane-api.workspace = true tracing-test.workspace = true +[target.'cfg(unix)'.dev-dependencies] +libc = "0.2" + [lints] workspace = true diff --git a/orion-proxy/src/lib.rs b/orion-proxy/src/lib.rs index fbef52a8..5dfadcf1 100644 --- a/orion-proxy/src/lib.rs +++ b/orion-proxy/src/lib.rs @@ -23,6 +23,7 @@ mod admin; mod core_affinity; mod proxy; mod runtime; +mod signal; mod xds_configurator; pub fn run() -> Result<()> { diff --git a/orion-proxy/src/main.rs b/orion-proxy/src/main.rs index 4b975036..31468d3b 100644 --- a/orion-proxy/src/main.rs +++ b/orion-proxy/src/main.rs @@ -26,7 +26,8 @@ static GLOBAL: Jemalloc = Jemalloc; #[global_allocator] static ALLOC: dhat::Alloc = dhat::Alloc; -fn main() -> orion_error::Result<()> { +#[tokio::main] +async fn main() -> orion_error::Result<()> { #[cfg(all(feature = "dhat-heap", not(feature = "jemalloc")))] let _profiler = dhat::Profiler::new_heap(); orion_proxy::run() diff --git a/orion-proxy/src/proxy.rs b/orion-proxy/src/proxy.rs index 3dbdad57..39674b06 100644 --- a/orion-proxy/src/proxy.rs +++ b/orion-proxy/src/proxy.rs @@ -19,6 +19,7 @@ use crate::{ admin::start_admin_server, core_affinity, runtime::{self, RuntimeId}, + signal::wait_signal, xds_configurator::XdsConfigurationHandler, }; use compact_str::ToCompactString; @@ -50,8 +51,20 @@ use tracing::{debug, info, warn}; pub fn run_orion(bootstrap: Bootstrap, access_log_config: Option) { debug!("Starting on thread {:?}", std::thread::current().name()); + let ct = tokio_util::sync::CancellationToken::new(); + let ct_clone = ct.clone(); + tokio::spawn(async move { + // Set up signal handling and shutdown notification channel + wait_signal().await; + // Trigger cancellation + ct_clone.cancel(); + }); + // launch the runtimes... - _ = launch_runtimes(bootstrap, access_log_config).with_context_msg("failed to launch runtimes"); + let res = launch_runtimes(bootstrap, access_log_config, ct).with_context_msg("failed to launch runtimes"); + if let Err(err) = res { + warn!("Error running orion: {err}"); + } } fn calculate_num_threads_per_runtime(num_cpus: usize, num_runtimes: usize) -> Result { @@ -80,7 +93,7 @@ fn calculate_num_threads_per_runtime(num_cpus: usize, num_runtimes: usize) -> Re } #[derive(Debug, Clone)] -struct ServiceInfo { +struct ProxyConfiguration { bootstrap: Bootstrap, node: Node, configuration_senders: Vec, @@ -93,9 +106,11 @@ struct ServiceInfo { metrics: Vec, } -type SenderGuards = Vec; - -fn launch_runtimes(bootstrap: Bootstrap, access_log_config: Option) -> Result { +fn launch_runtimes( + bootstrap: Bootstrap, + access_log_config: Option, + ct: tokio_util::sync::CancellationToken, +) -> Result<()> { let rt_config = runtime_config(); let num_runtimes = rt_config.num_runtimes(); let num_cpus = rt_config.num_cpus(); @@ -106,11 +121,6 @@ fn launch_runtimes(bootstrap: Bootstrap, access_log_config: Option, Vec) = (0..num_runtimes).map(|_| new_configuration_channel(100)).collect::>().into_iter().unzip(); - // keep a copy of the senders to avoid them being dropped if no services are configured... - // - - let sender_guards = config_senders.clone(); - // launch services runtime... // @@ -145,24 +155,25 @@ fn launch_runtimes(bootstrap: Bootstrap, access_log_config: Option>>()? }; - let handles = proxy_handles.into_iter().chain(std::iter::once(services_handle)).collect::>(); + handlers.push(services_handle); - for h in handles { + for h in handlers { if let Err(err) = h.join() { warn!("Closing handler with error {err:?}"); } } - Ok(sender_guards) + Ok(()) } -type RuntimeHandle = JoinHandle>; - fn spawn_proxy_runtime_from_thread( thread_name: &'static str, num_threads: usize, metrics: Vec, affinity_info: Option<(RuntimeId, Affinity)>, configuration_receivers: ConfigurationReceivers, -) -> Result { + ct: tokio_util::sync::CancellationToken, +) -> Result> { let thread_name = build_thread_name(thread_name, affinity_info.as_ref()); - let handle: JoinHandle> = thread::Builder::new().name(thread_name.clone()).spawn(move || { + let handle = thread::Builder::new().name(thread_name.clone()).spawn(move || { let rt = runtime::build_tokio_runtime(&thread_name, num_threads, affinity_info, Some(metrics)); rt.block_on(async { tokio::select! { - _ = start_proxy(configuration_receivers) => { + _ = start_proxy(configuration_receivers, ct.clone()) => { info!("Proxy Runtime terminated!"); - Ok(()) } - _ = tokio::signal::ctrl_c() => { - info!("CTRL+C (Proxy runtime)!"); - Ok(()) + _ = ct.cancelled() => { + info!("Shutdown channel closed, shutting down Proxy runtime!"); } } - }) + }); })?; Ok(handle) } fn spawn_services_runtime_from_thread( thread_name: &'static str, - num_threads: usize, + threads_num: usize, affinity_info: Option<(RuntimeId, Affinity)>, - service_info: ServiceInfo, -) -> Result { + config: ProxyConfiguration, + ct: tokio_util::sync::CancellationToken, +) -> Result> { let thread_name = build_thread_name(thread_name, affinity_info.as_ref()); - let rt_handle = thread::Builder::new().name(thread_name.clone()).spawn(move || { - let rt = runtime::build_tokio_runtime(&thread_name, num_threads, affinity_info, None); + let rt = runtime::build_tokio_runtime(&thread_name, threads_num, affinity_info, None); rt.block_on(async { tokio::select! { - result = spawn_services(service_info) => { + result = run_services(config) => { if let Err(err) = result { warn!("Error in services runtime: {err:?}"); } - info!("Service Runtime terminated!"); - Ok(()) + info!("Services Runtime terminated!"); } - _ = tokio::signal::ctrl_c() => { - info!("CTRL+C (service runtime)!"); - Ok(()) + _ = ct.cancelled() => { + info!("Shutdown channel closed, shutting down Services runtime!"); } } - }) + }); })?; - Ok(rt_handle) } @@ -274,8 +280,8 @@ fn build_thread_name(thread_name: &'static str, affinity_info: Option<&(RuntimeI } } -async fn spawn_services(info: ServiceInfo) -> Result<()> { - let ServiceInfo { +async fn run_services(config: ProxyConfiguration) -> Result<()> { + let ProxyConfiguration { bootstrap, node, configuration_senders, @@ -287,63 +293,29 @@ async fn spawn_services(info: ServiceInfo) -> Result<()> { metrics, #[allow(unused_variables)] tracing, - } = info; + } = config; let mut set: JoinSet> = JoinSet::new(); - // spawn XSD configuration service... - let configuration_senders_clone = configuration_senders.clone(); - let bootstrap_clone = bootstrap.clone(); - let secret_manager_clone = secret_manager.clone(); - set.spawn(async move { - let initial_clusters = configure_initial_resources( - bootstrap_clone, - listener_factories, - clusters, - configuration_senders_clone.clone(), - ) - .await?; - if !ads_cluster_names.is_empty() { - let mut xds_handler = XdsConfigurationHandler::new(secret_manager_clone, configuration_senders_clone); - _ = xds_handler.run_loop(node, initial_clusters, ads_cluster_names).await; - } - Ok(()) - }); + // spawn XDS configuration service... + spawn_xds_client( + &mut set, + bootstrap.clone(), + node, + configuration_senders.clone(), + secret_manager.clone(), + listener_factories, + clusters, + ads_cluster_names, + ); // spawn access loggers service... if let Some(conf) = access_log_config { - let listeners = bootstrap.static_resources.listeners.clone(); - set.spawn(async move { - let handles = start_access_loggers( - conf.num_instances.get(), - conf.queue_length.get(), - conf.log_rotation.0.clone(), - conf.max_log_files.get(), - ); - - info!("Access loggers started with {} instances", conf.num_instances); - - let listener_configurations = - listeners.iter().map(|l| (l.name.clone(), l.get_access_log_configurations())).collect::>(); - - for (listener_name, access_log_configurations) in listener_configurations { - _ = update_configuration( - Target::Listener(listener_name.to_compact_string()), - access_log_configurations, - ) - .await; - } - - handles.join_all().await; - Ok(()) - }); + spawn_access_loggers(&mut set, bootstrap.clone(), conf); } // spawn admin interface task if bootstrap.admin.is_some() { - set.spawn(async move { - _ = start_admin_server(bootstrap, configuration_senders, secret_manager).await; - Ok(()) - }); + spawn_admin_service(&mut set, bootstrap, configuration_senders, secret_manager); } // spawn metrics exporter... @@ -366,6 +338,64 @@ async fn spawn_services(info: ServiceInfo) -> Result<()> { Ok(()) } +fn spawn_xds_client( + set: &mut JoinSet>, + bootstrap: Bootstrap, + node: Node, + configuration_senders: Vec, + secret_manager: Arc>, + listener_factories: Vec, + clusters: Vec, + ads_cluster_names: Vec, +) { + set.spawn(async move { + let initial_clusters = + configure_initial_resources(bootstrap, listener_factories, clusters, configuration_senders.clone()).await?; + if !ads_cluster_names.is_empty() { + let mut xds_handler = XdsConfigurationHandler::new(secret_manager, configuration_senders); + _ = xds_handler.run_loop(node, initial_clusters, ads_cluster_names).await; + } + Ok(()) + }); +} + +fn spawn_access_loggers(set: &mut JoinSet>, bootstrap: Bootstrap, conf: AccessLogConfig) { + let listeners = bootstrap.static_resources.listeners; + set.spawn(async move { + let handles = start_access_loggers( + conf.num_instances.get(), + conf.queue_length.get(), + conf.log_rotation.0.clone(), + conf.max_log_files.get(), + ); + + info!("Access loggers started with {} instances", conf.num_instances); + + let listener_configurations = + listeners.iter().map(|l| (l.name.clone(), l.get_access_log_configurations())).collect::>(); + + for (listener_name, access_log_configurations) in listener_configurations { + _ = update_configuration(Target::Listener(listener_name.to_compact_string()), access_log_configurations) + .await; + } + + handles.join_all().await; + Ok(()) + }); +} + +fn spawn_admin_service( + set: &mut JoinSet>, + bootstrap: Bootstrap, + configuration_senders: Vec, + secret_manager: Arc>, +) { + set.spawn(async move { + _ = start_admin_server(bootstrap, configuration_senders, secret_manager).await; + Ok(()) + }); +} + async fn configure_initial_resources( bootstrap: Bootstrap, listeners: Vec, @@ -392,7 +422,10 @@ async fn configure_initial_resources( clusters.into_iter().map(orion_lib::clusters::add_cluster).collect::>() } -async fn start_proxy(configuration_receivers: ConfigurationReceivers) -> Result<()> { - orion_lib::start_listener_manager(configuration_receivers).await?; +async fn start_proxy( + configuration_receivers: ConfigurationReceivers, + ct: tokio_util::sync::CancellationToken, +) -> Result<()> { + orion_lib::start_listener_manager(configuration_receivers, ct).await?; Ok(()) } diff --git a/orion-proxy/src/signal.rs b/orion-proxy/src/signal.rs new file mode 100644 index 00000000..91a4fcc1 --- /dev/null +++ b/orion-proxy/src/signal.rs @@ -0,0 +1,127 @@ +// Copyright 2025 The kmesh Authors +// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use tracing::{info, warn}; + +/// Signal types that can trigger shutdown +#[derive(Debug, Clone, Copy)] +pub enum ShutdownSignal { + /// CTRL+C (SIGINT) signal + Interrupt, + /// SIGTERM signal (Unix only) + #[cfg(unix)] + Terminate, +} + +impl std::fmt::Display for ShutdownSignal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ShutdownSignal::Interrupt => write!(f, "SIGINT (CTRL+C)"), + #[cfg(unix)] + ShutdownSignal::Terminate => write!(f, "SIGTERM"), + } + } +} + +/// `wait_signal` listens for shutdown signals +/// +/// On Unix platforms, this listens for both SIGINT (CTRL+C) and SIGTERM. +/// On Windows, this only listens for CTRL+C. +pub async fn wait_signal() { + if let Err(e) = listen_for_signals().await { + warn!("Signal handler error: {}", e); + } +} + +/// Unix-specific signal handling (SIGINT and SIGTERM) +#[cfg(unix)] +async fn listen_for_signals() -> Result<(), Box> { + use tokio::signal::unix::{signal, SignalKind}; + + let mut sigint = signal(SignalKind::interrupt())?; + let mut sigterm = signal(SignalKind::terminate())?; + + tokio::select! { + _ = sigint.recv() => { + info!("Received {} signal, initiating shutdown...", ShutdownSignal::Interrupt); + } + _ = sigterm.recv() => { + info!("Received {} signal, initiating shutdown...", ShutdownSignal::Terminate); + } + } + + Ok(()) +} + +/// Windows-specific signal handling (CTRL+C only) +#[cfg(not(unix))] +async fn listen_for_signals() -> Result<(), Box> { + if let Err(e) = tokio::signal::ctrl_c().await { + return Err(format!("Failed to listen for CTRL+C: {}", e).into()); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use pingora_timeout::fast_timeout::fast_timeout; + use std::time::Duration; + + // These tests send real POSIX signals; mark ignored so they are only run explicitly: + // cargo test --package orion-proxy --lib -- --ignored test_wait_signal_sigint + #[cfg(unix)] + #[ignore] + #[tokio::test] + async fn test_wait_signal_sigint() { + let handle = tokio::spawn(async { + wait_signal().await; + }); + // Give the signal handler a brief moment to install + tokio::time::sleep(Duration::from_millis(20)).await; + unsafe { + libc::kill(libc::getpid(), libc::SIGINT); + } + fast_timeout(Duration::from_secs(1), handle) + .await + .expect("timeout waiting for wait_signal to return") + .expect("wait_signal task panicked"); + } + + #[cfg(unix)] + #[ignore] + #[tokio::test] + async fn test_wait_signal_sigterm() { + let handle = tokio::spawn(async { + wait_signal().await; + }); + tokio::time::sleep(Duration::from_millis(20)).await; + unsafe { + libc::kill(libc::getpid(), libc::SIGTERM); + } + fast_timeout(Duration::from_secs(1), handle) + .await + .expect("timeout waiting for wait_signal to return") + .expect("wait_signal task panicked"); + } + + #[test] + fn test_display_variants() { + assert_eq!(format!("{}", ShutdownSignal::Interrupt), "SIGINT (CTRL+C)"); + #[cfg(unix)] + assert_eq!(format!("{}", ShutdownSignal::Terminate), "SIGTERM"); + } +} diff --git a/orion-proxy/src/xds_configurator.rs b/orion-proxy/src/xds_configurator.rs index a9726eaf..6272bb8f 100644 --- a/orion-proxy/src/xds_configurator.rs +++ b/orion-proxy/src/xds_configurator.rs @@ -15,7 +15,6 @@ // // -use abort_on_drop::ChildTask; #[cfg(feature = "tracing")] use compact_str::ToCompactString; use futures::future::join_all; @@ -138,11 +137,10 @@ impl XdsConfigurationHandler { tokio::time::sleep(RETRY_INTERVAL).await; }; - let _xds_worker: ChildTask<_> = tokio::spawn(async move { + tokio::spawn(async move { let subscribe = worker.run().await; info!("Worker exited {subscribe:?}"); - }) - .into(); + }); loop { select! { diff --git a/orion-xds/src/xds/client.rs b/orion-xds/src/xds/client.rs index d0697a56..f0baf95b 100644 --- a/orion-xds/src/xds/client.rs +++ b/orion-xds/src/xds/client.rs @@ -218,9 +218,8 @@ impl DeltaClientBackgroundWorker { async fn continuously_discover_resources(&mut self, state: &mut DiscoveryClientState) -> Result<(), XdsError> { let (discovery_requests_tx, mut discovery_requests_rx) = mpsc::channel::(100); - let initial_requests = self.build_initial_discovery_requests(state); - let outbound_request_stream = async_stream::stream! { + let request_stream = async_stream::stream! { for request in initial_requests { info!("sending initial discovery request {request:?}"); yield request; @@ -231,12 +230,8 @@ impl DeltaClientBackgroundWorker { } warn!("outbound discovery request stream has ended!"); }; - let mut inbound_response_stream = self - .client_binding - .delta_request(outbound_request_stream) - .await - .map_err(XdsError::GrpcStatus)? - .into_inner(); + let mut response_stream = + self.client_binding.delta_request(request_stream).await.map_err(XdsError::GrpcStatus)?.into_inner(); info!("xDS stream established"); state.reset_backoff(); @@ -245,14 +240,14 @@ impl DeltaClientBackgroundWorker { Some(event) = self.subscriptions_rx.recv() => { self.process_subscription_event(event, state, &discovery_requests_tx).await; } - discovered = inbound_response_stream.message() => { + discovered = response_stream.message() => { let payload = discovered?; let discovery_response = payload.ok_or(XdsError::UnknownResourceType("empty payload received".to_owned()))?; self.process_discovery_response(discovery_response, &discovery_requests_tx, state).await?; - }, + } else => { - warn!("xDS channels are closed...exiting"); - return Ok(()) + warn!("xDS stream has ended"); + return Ok(()); } } } @@ -313,7 +308,7 @@ impl DeltaClientBackgroundWorker { info!(type_url = type_url.to_string(), size = response.resources.len(), "received config resources from xDS"); let for_removal = Self::process_resource_ids_for_removal(state, &response, type_url); - match Self::decode_pending_updates(&response, type_url) { + match Self::decode_response(&response, type_url) { Ok(mut decoded_updates) => { let (internal_ack_tx, internal_ack_rx) = oneshot::channel::>(); @@ -441,7 +436,7 @@ impl DeltaClientBackgroundWorker { .collect() } - fn decode_pending_updates( + fn decode_response( response: &DeltaDiscoveryResponse, type_url: TypeUrl, ) -> Result, Vec> {