diff --git a/src/bin/lychee/main.rs b/src/bin/lychee/main.rs index 9d54eefc23..5aa7c8f478 100644 --- a/src/bin/lychee/main.rs +++ b/src/bin/lychee/main.rs @@ -96,7 +96,7 @@ async fn run(cfg: &Config, inputs: Vec) -> Result { let accepted = cfg.accept.clone().and_then(|a| parse_statuscodes(&a).ok()); let timeout = parse_timeout(cfg.timeout); let max_concurrency = cfg.max_concurrency; - let method: reqwest::Method = reqwest::Method::from_str(&cfg.method.to_uppercase())?; + let method: http::Method = http::Method::from_str(&cfg.method.to_uppercase())?; let include = RegexSet::new(&cfg.include)?; let exclude = RegexSet::new(&cfg.exclude)?; @@ -118,21 +118,22 @@ async fn run(cfg: &Config, inputs: Vec) -> Result { .accepted(accepted) .build()?; - let links = collector::collect_links( - &inputs, + let mut links = collector::collect_links( + inputs, cfg.base_url.clone(), cfg.skip_missing, - max_concurrency, + cfg.max_concurrency, ) .await?; let pb = match cfg.no_progress { true => None, false => { - let bar = ProgressBar::new(links.len() as u64) - .with_style(ProgressStyle::default_bar().template( + let bar = ProgressBar::new_spinner().with_style(ProgressStyle::default_bar().template( "{spinner:.red.bright} {pos}/{len:.dim} [{elapsed_precise}] {bar:25} {wide_msg}", )); + bar.set_length(0); + bar.set_message("Extracting links"); bar.enable_steady_tick(100); Some(bar) } @@ -145,8 +146,9 @@ async fn run(cfg: &Config, inputs: Vec) -> Result { let bar = pb.clone(); tokio::spawn(async move { - for link in links { + while let Some(link) = links.recv().await { if let Some(pb) = &bar { + pb.inc_length(1); pb.set_message(&link.to_string()); }; send_req.send(link).await.unwrap(); @@ -219,7 +221,7 @@ fn parse_headers>(headers: &[T]) -> Result { fn parse_statuscodes>(accept: T) -> Result> { let mut statuscodes = HashSet::new(); for code in accept.as_ref().split(',').into_iter() { - let code: reqwest::StatusCode = reqwest::StatusCode::from_bytes(code.as_bytes())?; + let code: http::StatusCode = http::StatusCode::from_bytes(code.as_bytes())?; statuscodes.insert(code); } Ok(statuscodes) diff --git a/src/client_pool.rs b/src/client_pool.rs index 56e58b316c..94b1f1292e 100644 --- a/src/client_pool.rs +++ b/src/client_pool.rs @@ -7,7 +7,7 @@ use crate::{client, types}; pub struct ClientPool { tx: mpsc::Sender, rx: mpsc::Receiver, - pool: deadpool::unmanaged::Pool, + pool: deadpool::unmanaged::Pool, } impl ClientPool { diff --git a/src/collector.rs b/src/collector.rs index 0889179c7b..eeddd4b699 100644 --- a/src/collector.rs +++ b/src/collector.rs @@ -197,11 +197,11 @@ impl Input { /// Fetch all unique links from a slice of inputs /// All relative URLs get prefixed with `base_url` if given. pub async fn collect_links( - inputs: &[Input], + inputs: Vec, base_url: Option, skip_missing_inputs: bool, max_concurrency: usize, -) -> Result> { +) -> Result> { let base_url = match base_url { Some(url) => Some(Url::parse(&url)?), _ => None, @@ -210,7 +210,7 @@ pub async fn collect_links( let (contents_tx, mut contents_rx) = tokio::sync::mpsc::channel(max_concurrency); // extract input contents - for input in inputs.iter().cloned() { + for input in inputs { let sender = contents_tx.clone(); tokio::spawn(async move { @@ -234,18 +234,32 @@ pub async fn collect_links( } } - // Note: we could dispatch links to be checked as soon as we get them, - // instead of building a HashSet with all links. - // This optimization would speed up cases where there's - // a lot of inputs and/or the inputs are large (e.g. big files). - let mut collected_links: HashSet = HashSet::new(); - - for handle in extract_links_handles { - let links = handle.await?; - collected_links.extend(links); - } + let (links_tx, links_rx) = tokio::sync::mpsc::channel(max_concurrency); + tokio::spawn(async move { + let mut collected_links = HashSet::new(); + + for handle in extract_links_handles { + // Unwrap should be fine because joining fails: + // * if the Task was dropped (which we don't do) + // * if the Task panicked. Propagating panics is correct here. + let requests = handle + .await + .expect("Awaiting termination of link handle failed"); + for request in requests { + if !collected_links.contains(&request) { + collected_links.insert(request.clone()); + // Unwrap should be fine because sending fails + // if the receiver was closed - in which case we can't continue anyway + links_tx + .send(request) + .await + .expect("Extractor could not send link to channel"); + } + } + } + }); - Ok(collected_links) + Ok(links_rx) } #[cfg(test)] @@ -292,11 +306,11 @@ mod test { }, ]; - let responses = collect_links(&inputs, None, false, 8).await?; - let links = responses - .into_iter() - .map(|r| r.uri) - .collect::>(); + let mut responses = collect_links(inputs, None, false, 8).await?; + let mut links = HashSet::new(); + while let Some(request) = responses.recv().await { + links.insert(request.uri); + } let mut expected_links: HashSet = HashSet::new(); expected_links.insert(website(TEST_STRING));