diff --git a/samply/src/linux/profiler.rs b/samply/src/linux/profiler.rs index 12acc3eca..6f141ccd1 100644 --- a/samply/src/linux/profiler.rs +++ b/samply/src/linux/profiler.rs @@ -7,7 +7,7 @@ use std::io::BufWriter; use std::path::Path; use std::process::Command; use std::process::ExitStatus; -use std::sync::atomic::AtomicBool; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::thread; use std::time::Duration; @@ -54,7 +54,14 @@ pub fn start_recording( let observer_thread = thread::spawn(move || { let product = command_name_copy; // start profiling pid - run_profiler(&output_file_copy, &product, time_limit, interval, pid); + run_profiler( + &output_file_copy, + &product, + time_limit, + interval, + pid, + Arc::new(AtomicBool::new(false)), + ); }); let exit_status = root_child.wait().expect("couldn't wait for child"); @@ -73,12 +80,49 @@ pub fn start_recording( Ok(exit_status) } +pub fn start_profiling_pid( + output_file: &Path, + pid: u32, + time_limit: Option, + interval: Duration, + server_props: Option, +) { + // When the first Ctrl+C is received, stop recording. + // The server launches after the recording finishes. On the second Ctrl+C, terminate the server. + let stop = Arc::new(AtomicBool::new(false)); + #[cfg(unix)] + signal_hook::flag::register_conditional_default(signal_hook::consts::SIGINT, stop.clone()) + .expect("cannot register signal handler"); + #[cfg(unix)] + signal_hook::flag::register(signal_hook::consts::SIGINT, stop.clone()) + .expect("cannot register signal handler"); + + let output_file_copy = output_file.to_owned(); + let product = format!("PID {pid}"); + let observer_thread = thread::spawn({ + let stop = stop.clone(); + move || run_profiler(&output_file_copy, &product, time_limit, interval, pid, stop) + }); + + observer_thread + .join() + .expect("couldn't join observer thread"); + // If the recording was stopped due to application terminating, set the flag so that Ctrl+C + // terminates the server. + stop.store(true, Ordering::SeqCst); + + if let Some(server_props) = server_props { + start_server_main(output_file, server_props); + } +} + fn run_profiler( output_filename: &Path, product_name: &str, _time_limit: Option, interval: Duration, pid: u32, + stop: Arc, ) { let interval_nanos = if interval.as_nanos() > 0 { interval.as_nanos() as u64 @@ -171,7 +215,7 @@ fn run_profiler( let mut pending_lost_events = 0; let mut total_lost_events = 0; loop { - if perf.is_empty() { + if stop.load(Ordering::SeqCst) || perf.is_empty() { break; } diff --git a/samply/src/main.rs b/samply/src/main.rs index a9f1e9fd5..82d032306 100644 --- a/samply/src/main.rs +++ b/samply/src/main.rs @@ -91,8 +91,18 @@ struct RecordArgs { server_args: ServerArgs, /// Profile the execution of this command. - #[arg(required = true, allow_hyphen_values = true, trailing_var_arg = true)] + #[arg( + required_unless_present = "pid", + conflicts_with = "pid", + allow_hyphen_values = true, + trailing_var_arg = true + )] command: Vec, + + /// Process ID of existing process to attach to. + #[cfg(target_os = "linux")] + #[arg(short, long)] + pid: Option, } #[derive(Debug, Args)] @@ -102,7 +112,7 @@ struct ServerArgs { no_open: bool, /// The port to use for the local web server - #[arg(short, long, default_value = "3000+")] + #[arg(short = 'P', long, default_value = "3000+")] port: String, /// Print debugging output. @@ -148,21 +158,37 @@ fn main() { std::process::exit(1); } let interval = Duration::from_secs_f64(1.0 / record_args.rate); - let exit_status = match profiler::start_recording( - &record_args.output, - record_args.command[0].clone(), - &record_args.command[1..], - time_limit, - interval, - server_props, - ) { - Ok(exit_status) => exit_status, - Err(err) => { - eprintln!("Encountered a mach error during profiling: {:?}", err); - std::process::exit(1); - } - }; - std::process::exit(exit_status.code().unwrap_or(0)); + + #[cfg(target_os = "linux")] + let pid = record_args.pid; + #[cfg(not(target_os = "linux"))] + let pid = None; + if let Some(pid) = pid { + #[cfg(target_os = "linux")] + profiler::start_profiling_pid( + &record_args.output, + pid, + time_limit, + interval, + server_props, + ); + } else { + let exit_status = match profiler::start_recording( + &record_args.output, + record_args.command[0].clone(), + &record_args.command[1..], + time_limit, + interval, + server_props, + ) { + Ok(exit_status) => exit_status, + Err(err) => { + eprintln!("Encountered a mach error during profiling: {:?}", err); + std::process::exit(1); + } + }; + std::process::exit(exit_status.code().unwrap_or(0)); + } } } }