diff --git a/src/directed/dfs.rs b/src/directed/dfs.rs index 942b9bb..9cca6f5 100644 --- a/src/directed/dfs.rs +++ b/src/directed/dfs.rs @@ -5,6 +5,8 @@ use std::collections::HashSet; use std::hash::Hash; use std::iter::FusedIterator; +use crate::FxIndexSet; + /// Compute a path using the [depth-first search /// algorithm](https://en.wikipedia.org/wiki/Depth-first_search). /// The path starts from `start` up to a node for which `success` returns `true` is computed and @@ -45,18 +47,19 @@ use std::iter::FusedIterator; /// ``` pub fn dfs(start: N, mut successors: FN, mut success: FS) -> Option> where - N: Eq, + N: Eq + Hash, FN: FnMut(&N) -> IN, IN: IntoIterator, FS: FnMut(&N) -> bool, { - let mut path = vec![start]; - step(&mut path, &mut successors, &mut success).then_some(path) + let mut path = FxIndexSet::default(); + path.insert(start); + step(&mut path, &mut successors, &mut success).then_some(Vec::from_iter(path)) } -fn step(path: &mut Vec, successors: &mut FN, success: &mut FS) -> bool +fn step(path: &mut FxIndexSet, successors: &mut FN, success: &mut FS) -> bool where - N: Eq, + N: Eq + Hash, FN: FnMut(&N) -> IN, IN: IntoIterator, FS: FnMut(&N) -> bool, @@ -67,7 +70,7 @@ where let successors_it = successors(path.last().unwrap()); for n in successors_it { if !path.contains(&n) { - path.push(n); + path.insert(n); if step(path, successors, success) { return true; }