Skip to content

Commit c9aa46a

Browse files
committed
feat: depth first and best first search using ListM (#3221)
Implementations of depth first search, best first search, and beam search, for graphs described by a neighbours function `α → ListM m α`. There are also wrappers for using `α → List α`. This is only intended for use in meta code. Co-authored-by: Scott Morrison <scott.morrison@gmail.com>
1 parent 4cf370e commit c9aa46a

File tree

8 files changed

+239
-24
lines changed

8 files changed

+239
-24
lines changed

Mathlib.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,9 @@ import Mathlib.Data.List.Sublists
861861
import Mathlib.Data.List.TFAE
862862
import Mathlib.Data.List.ToFinsupp
863863
import Mathlib.Data.List.Zip
864-
import Mathlib.Data.ListM
864+
import Mathlib.Data.ListM.Basic
865+
import Mathlib.Data.ListM.BestFirst
866+
import Mathlib.Data.ListM.DepthFirst
865867
import Mathlib.Data.ListM.Heartbeats
866868
import Mathlib.Data.Matrix.Basic
867869
import Mathlib.Data.Matrix.Basis

Mathlib/Data/ListM.lean renamed to Mathlib/Data/ListM/Basic.lean

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ unsafe def fixl [Alternative m] (f : α → m (α × List β)) (s : α) : ListM
8585

8686
/-- Deconstruct a `ListM`, returning inside the monad an optional pair `α × ListM m α`
8787
representing the head and tail of the list. -/
88-
unsafe def uncons {α : Type u} : ListM m α → m (Option (α × ListM m α))
88+
unsafe def uncons : ListM m α → m (Option (α × ListM m α))
8989
| nil => pure none
9090
| cons l => do
9191
let (x, xs) ← l
@@ -94,27 +94,27 @@ unsafe def uncons {α : Type u} : ListM m α → m (Option (α × ListM m α))
9494
#align tactic.mllist.uncons ListM.uncons
9595

9696
/-- Compute, inside the monad, whether a `ListM` is empty. -/
97-
unsafe def isEmpty {α : Type u} (xs : ListM m α) : m (ULift Bool) :=
97+
unsafe def isEmpty (xs : ListM m α) : m (ULift Bool) :=
9898
(ULift.up ∘ Option.isSome) <$> uncons xs
9999
#align tactic.mllist.empty ListM.isEmpty
100100

101101
/-- Convert a `List` to a `ListM`. -/
102-
unsafe def ofList {α : Type u} : List α → ListM m α
102+
unsafe def ofList : List α → ListM m α
103103
| [] => nil
104104
| h :: t => cons (pure (h, ofList t))
105105
#align tactic.mllist.of_list ListM.ofList
106106

107107
/-- The empty `ListM`. -/
108-
unsafe def empty {α : Type u} : ListM m α := ofList []
108+
unsafe def empty : ListM m α := ofList []
109109

110110
/-- Convert a `List` of values inside the monad into a `ListM`. -/
111-
unsafe def ofListM {α : Type u} : List (m α) → ListM m α
111+
unsafe def ofListM : List (m α) → ListM m α
112112
| [] => nil
113113
| h :: t => cons ((fun x => (x, ofListM t)) <$> some <$> h)
114114
#align tactic.mllist.m_of_list ListM.ofListM
115115

116116
/-- Extract a list inside the monad from a `ListM`. -/
117-
unsafe def force {α} (L : ListM m α) : m (List α) := do
117+
unsafe def force (L : ListM m α) : m (List α) := do
118118
match ← uncons L with
119119
| none => pure []
120120
| some (x, xs) => (x :: ·) <$> force xs
@@ -142,7 +142,7 @@ unsafe def folds (f : β → α → β) (init : β) (L : ListM m α) : ListM m
142142
L.foldsM (fun b a => pure (f b a)) init
143143

144144
/-- Take the first `n` elements, as a list inside the monad. -/
145-
unsafe def takeAsList {α} : ListM m α → Nat → m (List α)
145+
unsafe def takeAsList : ListM m α → Nat → m (List α)
146146
| _, 0 => pure []
147147
| xs, n+1 => do
148148
match ← uncons xs with
@@ -167,33 +167,33 @@ unsafe def drop : ListM m α → Nat → ListM m α
167167
| none => return (none, empty)
168168

169169
/-- Apply a function which returns values in the monad to every element of a `ListM`. -/
170-
unsafe def mapM {α β : Type u} (f : α → m β) (L : ListM m α) : ListM m β :=
170+
unsafe def mapM (f : α → m β) (L : ListM m α) : ListM m β :=
171171
cons do match ← uncons L with
172172
| some (x, xs) => return (← f x, mapM f xs)
173173
| none => return (none, empty)
174174
#align tactic.mllist.mmap ListM.mapM
175175

176176
/-- Apply a function to every element of a `ListM`. -/
177-
unsafe def map {α β : Type u} (f : α → β) (L : ListM m α) : ListM m β :=
177+
unsafe def map (f : α → β) (L : ListM m α) : ListM m β :=
178178
L.mapM fun a => pure (f a)
179179
#align tactic.mllist.map ListM.map
180180

181181
/-- Filter a `ListM` using a monadic function. -/
182-
unsafe def filterM {α : Type u} (p : α → m (ULift Bool)) (L : ListM m α) : ListM m α :=
182+
unsafe def filterM (p : α → m (ULift Bool)) (L : ListM m α) : ListM m α :=
183183
cons do match ← uncons L with
184184
| some (x, xs) => return (if (← p x).down then some x else none, filterM p xs)
185185
| none => return (none, empty)
186186
#align tactic.mllist.mfilter ListM.filterM
187187

188188
/-- Filter a `ListM`. -/
189-
unsafe def filter {α : Type u} (p : α → Bool) (L : ListM m α) : ListM m α :=
189+
unsafe def filter (p : α → Bool) (L : ListM m α) : ListM m α :=
190190
L.filterM fun a => pure <| .up (p a)
191191
#align tactic.mllist.filter ListM.filter
192192

193193
/-- Filter and transform a `ListM` using a function that returns values inside the monad. -/
194194
-- Note that the type signature has changed since Lean 3, when we allowed `f` to fail.
195195
-- Use `try?` from `Mathlib.Control.Basic` to lift a possibly failing function to `Option`.
196-
unsafe def filterMapM {α β : Type u} (f : α → m (Option β)) (L : ListM m α) : ListM m β :=
196+
unsafe def filterMapM (f : α → m (Option β)) (L : ListM m α) : ListM m β :=
197197
cons do match ← uncons L with
198198
| none => return (none, empty)
199199
| some (x, xs) => match ← f x with
@@ -202,7 +202,7 @@ unsafe def filterMapM {α β : Type u} (f : α → m (Option β)) (L : ListM m
202202
#align tactic.mllist.mfilter_map ListM.filterMapM
203203

204204
/-- Filter and transform a `ListM` using an `Option` valued function. -/
205-
unsafe def filterMap {α β : Type u} (f : α → Option β) : ListM m α → ListM m β :=
205+
unsafe def filterMap (f : α → Option β) : ListM m α → ListM m β :=
206206
filterMapM fun a => do pure (f a)
207207
#align tactic.mllist.filter_map ListM.filterMap
208208

@@ -217,14 +217,14 @@ unsafe def takeWhile (f : α → Bool) : ListM m α → ListM m α :=
217217
takeWhileM fun a => pure (.up (f a))
218218

219219
/-- Concatenate two monadic lazy lists. -/
220-
unsafe def append {α : Type u} (L M : ListM m α) : ListM m α :=
220+
unsafe def append (L M : ListM m α) : ListM m α :=
221221
cons do match ← uncons L with
222222
| none => return (none, M)
223223
| some (x, xs) => return (some x, append xs M)
224224
#align tactic.mllist.append ListM.append
225225

226226
/-- Join a monadic lazy list of monadic lazy lists into a single monadic lazy list. -/
227-
unsafe def join {α : Type u} (L : ListM m (ListM m α)) : ListM m α :=
227+
unsafe def join (L : ListM m (ListM m α)) : ListM m α :=
228228
cons do match ← uncons L with
229229
| none => return (none, empty)
230230
| some (x, xs) => match ← uncons x with
@@ -238,14 +238,14 @@ unsafe def squash (t : m (ListM m α)) : ListM m α :=
238238
#align tactic.mllist.squash ListM.squash
239239

240240
/-- Enumerate the elements of a monadic lazy list, starting at a specified offset. -/
241-
unsafe def enum_from {α : Type u} (n : Nat) (L : ListM m α) : ListM m (Nat × α) :=
241+
unsafe def enum_from (n : Nat) (L : ListM m α) : ListM m (Nat × α) :=
242242
cons do match ← uncons L with
243243
| none => return (none, empty)
244244
| some (x, xs) => return (some (n, x), xs.enum_from (n+1))
245245
#align tactic.mllist.enum_from ListM.enum_from
246246

247247
/-- Enumerate the elements of a monadic lazy list. -/
248-
unsafe def enum {α : Type u} : ListM m α → ListM m (Nat × α) :=
248+
unsafe def enum : ListM m α → ListM m (Nat × α) :=
249249
enum_from 0
250250
#align tactic.mllist.enum ListM.enum
251251

@@ -255,7 +255,7 @@ unsafe def range {m : Type → Type} [Alternative m] : ListM m Nat :=
255255
#align tactic.mllist.range ListM.range
256256

257257
/-- Add one element to the end of a monadic lazy list. -/
258-
unsafe def concat {α : Type u} : ListM m α → α → ListM m α
258+
unsafe def concat : ListM m α → α → ListM m α
259259
| L, a => (ListM.ofList [L, ListM.ofList [a]]).join
260260
#align tactic.mllist.concat ListM.concat
261261

@@ -267,7 +267,7 @@ unsafe def zip (L : ListM m α) (M : ListM m β) : ListM m (α × β) :=
267267

268268
/-- Apply a function returning a monadic lazy list to each element of a monadic lazy list,
269269
joining the results. -/
270-
unsafe def bind {α β : Type u} (L : ListM m α) (f : α → ListM m β) : ListM m β :=
270+
unsafe def bind (L : ListM m α) (f : α → ListM m β) : ListM m β :=
271271
cons do match ← uncons L with
272272
| none => return (none, empty)
273273
| some (x, xs) => match ← uncons (f x) with
@@ -287,13 +287,13 @@ unsafe def liftM [Monad n] [MonadLift m n] (L : ListM m α) : ListM n α :=
287287
| some (a, L') => pure <| cons do pure (a, L'.liftM)
288288

289289
/-- Given a lazy list in a state monad, run it on some initial state, recording the states. -/
290-
unsafe def runState {σ α : Type u} (L : ListM (StateT.{u} σ m) α) (s : σ) : ListM m (α × σ) :=
290+
unsafe def runState (L : ListM (StateT.{u} σ m) α) (s : σ) : ListM m (α × σ) :=
291291
squash do match ← StateT.run (uncons L) s with
292292
| (none, _) => pure empty
293293
| (some (a, L'), s') => pure <| cons do pure (some (a, s'), L'.runState s')
294294

295295
/-- Given a lazy list in a state monad, run it on some initial state. -/
296-
unsafe def runState' {σ α : Type u} (L : ListM (StateT.{u} σ m) α) (s : σ) : ListM m α :=
296+
unsafe def runState' (L : ListM (StateT.{u} σ m) α) (s : σ) : ListM m α :=
297297
L.runState s |>.map (·.1)
298298

299299
/-- Return the head of a monadic lazy list if it exists, as an `Option` in the monad. -/

Mathlib/Data/ListM/BestFirst.lean

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/-
2+
Copyright (c) 2023 Scott Morrison. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Scott Morrison
5+
-/
6+
import Mathlib.Data.ListM.Basic
7+
8+
/-!
9+
# Best first search
10+
11+
We perform best first search of a tree or graph,
12+
where the neighbours of a vertex are provided by a lazy list `α → ListM m α`.
13+
14+
We maintain a priority queue of visited-but-not-exhausted nodes,
15+
and at each step take the next child of the highest priority node in the queue.
16+
17+
This is useful in meta code for searching for solutions in the presence of alternatives.
18+
It can be nice to represent the choices via a lazy list,
19+
so the later choices don't need to be evaluated while we do depth first search on earlier choices.
20+
21+
Options:
22+
* `maxDepth` allows bounding the search depth
23+
* `maxQueued` implements "beam" search,
24+
by discarding elements from the priority queue when it grows too large
25+
* `removeDuplicates` maintains an `RBSet` of previously visited nodes;
26+
otherwise if the graph is not a tree nodes may be visited multiple times.
27+
-/
28+
29+
30+
variable {α : Type u} [Monad m] [Alternative m] [Ord α]
31+
32+
open Std ListM
33+
34+
/--
35+
Auxiliary function for `bestFirstSearch`, that updates the internal state,
36+
consisting of a priority queue of triples `α × Nat × ListM m α`.
37+
We remove the next element from the list contained in the best triple
38+
(discarding the triple if there is no next element),
39+
enqueue it and return it.
40+
-/
41+
-- The return type has `× List α` rather than just `× Option α` so `bestFirstSearch` can use `fixl`.
42+
unsafe def bestFirstSearchAux
43+
(f : Nat → α → ListM m α) (maxQueued : Option Nat := none) :
44+
RBMap α (Nat × ListM m α) compare → m (RBMap α (Nat × ListM m α) compare × List α) :=
45+
fun s => do
46+
match s.min with
47+
| none => failure
48+
| some (a, (n, L)) =>
49+
match ← uncons L with
50+
| none => pure (s.erase a, [])
51+
| some (b, L') => do
52+
let s' := s.insert a (n, L') |>.insert b (n + 1, f (n+1) b)
53+
let s' := match maxQueued with
54+
| some q => if s'.size > q then
55+
match s'.max with | some x => s'.erase x.1 | none => unreachable!
56+
else
57+
s'
58+
| none => s'
59+
pure (s', [b])
60+
61+
/--
62+
A lazy list recording the best first search of a graph generated by a function
63+
`f : α → ListM m α`.
64+
65+
We maintain a priority queue of visited-but-not-exhausted nodes,
66+
and at each step take the next child of the highest priority node in the queue.
67+
68+
The option `maxDepth` limits the search depth.
69+
70+
The option `maxQueued` bounds the size of the priority queue,
71+
discarding the lowest priority nodes as needed.
72+
This implements a "beam" search, which may be incomplete but uses bounded memory.
73+
74+
The option `removeDuplicates` keeps an `RBSet` of previously visited nodes.
75+
Otherwise, if the graph is not a tree then nodes will be visited multiple times.
76+
-/
77+
unsafe def bestFirstSearch (f : α → ListM m α) (a : α)
78+
(maxDepth : Option Nat := none) (maxQueued : Option Nat := none) (removeDuplicates := true) :
79+
ListM m α :=
80+
let f := match maxDepth with
81+
| none => fun _ a => f a
82+
| some d => fun n a => if d < n then empty else f a
83+
if removeDuplicates then
84+
let f' : Nat → α → ListM (StateT.{u} (RBSet α compare) m) α := fun n a =>
85+
(f n a).liftM >>= fun b => do
86+
let s ← get
87+
if s.contains b then failure
88+
set <| s.insert b
89+
pure b
90+
cons (do pure (some a, fixl (bestFirstSearchAux f' maxQueued) (RBMap.single a (0, f' 0 a))))
91+
|>.runState' (RBSet.empty.insert a)
92+
else
93+
cons do pure (some a, fixl (bestFirstSearchAux f maxQueued) (RBMap.single a (0, f 0 a)))

Mathlib/Data/ListM/DepthFirst.lean

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/-
2+
Copyright (c) 2023 Scott Morrison. All rights reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Scott Morrison
5+
-/
6+
import Mathlib.Data.ListM.Basic
7+
import Mathlib.Control.Traversable.Basic
8+
9+
/-!
10+
# Depth first search
11+
12+
We perform depth first search of a tree or graph,
13+
where the neighbours of a vertex are provided either by list `α → List α`
14+
or a lazy list `α → ListM MetaM α`.
15+
16+
This is useful in meta code for searching for solutions in the presence of alternatives.
17+
It can be nice to represent the choices via a lazy list,
18+
so the later choices don't need to be evaluated while we do depth first search on earlier choices.
19+
-/
20+
21+
section
22+
variable [Monad m] [Alternative m]
23+
24+
/-- A generalisation of `depthFirst`, which allows the generation function to know the current
25+
depth, and to count the depth starting from a specified value. -/
26+
partial def depthFirst' (f : Nat → α → m α) (n : Nat) (a : α) : m α :=
27+
pure a <|> joinM ((f n a) <&> (depthFirst' f (n+1)))
28+
29+
/--
30+
Depth first search of a graph generated by a function
31+
`f : α → m α`.
32+
33+
Here `m` must be an `Alternative` `Monad`,
34+
and perhaps the only sensible values are `List` and `ListM MetaM`.
35+
36+
The option `maxDepth` limits the search depth.
37+
38+
Note that if the graph is not a tree then elements will be visited multiple times.
39+
(See `depthFirstRemovingDuplicates`)
40+
-/
41+
def depthFirst (f : α → m α) (a : α) (maxDepth : Option Nat := none) : m α :=
42+
match maxDepth with
43+
| some d => depthFirst' (fun n a => if n ≤ d then f a else failure) 0 a
44+
| none => depthFirst' (fun _ a => f a) 0 a
45+
46+
end
47+
48+
variable [Monad m]
49+
50+
open Lean in
51+
/--
52+
Variant of `depthFirst`,
53+
using an internal `HashSet` to record and avoid already visited nodes.
54+
55+
This version describes the graph using `α → ListM m α`,
56+
and returns the monadic lazy list of nodes visited in order.
57+
58+
This is potentially very expensive.
59+
If you want to do efficient enumerations from a generation function,
60+
avoiding duplication up to equality or isomorphism,
61+
use Brendan McKay's method of "generation by canonical construction path".
62+
-/
63+
-- TODO can you make this work in `List` and `ListM m` simultaneously, by being tricky with monads?
64+
unsafe def depthFirstRemovingDuplicates {α : Type u} [BEq α] [Hashable α]
65+
(f : α → ListM m α) (a : α) (maxDepth : Option Nat := none) : ListM m α :=
66+
let f' : α → ListM (StateT.{u} (HashSet α) m) α := fun a =>
67+
(f a).liftM >>= fun b => do
68+
let s ← get
69+
if s.contains b then failure
70+
set <| s.insert b
71+
pure b
72+
(depthFirst f' a maxDepth).runState' (HashSet.empty.insert a)
73+
74+
/--
75+
Variant of `depthFirst`,
76+
using an internal `HashSet` to record and avoid already visited nodes.
77+
78+
This version describes the graph using `α → List α`, and returns the list of nodes visited in order.
79+
-/
80+
def depthFirstRemovingDuplicates' [BEq α] [Hashable α]
81+
(f : α → List α) (a : α) (maxDepth : Option Nat := none) : List α :=
82+
unsafe depthFirstRemovingDuplicates
83+
(fun a => (.ofList (f a) : ListM Option α)) a maxDepth |>.force |>.get!

Mathlib/Data/ListM/Heartbeats.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Copyright (c) 2023 Scott Morrison. All rights reserved.
33
Released under Apache 2.0 license as described in the file LICENSE.
44
Authors: Scott Morrison
55
-/
6-
import Mathlib.Data.ListM
6+
import Mathlib.Data.ListM.Basic
77
import Mathlib.Lean.CoreM
88

99
/-!

test/ListM.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Copyright (c) 2019 Scott Morrison. All rights reserved.
33
Released under Apache 2.0 license as described in the file LICENSE.
44
Authors: Scott Morrison
55
-/
6-
import Mathlib.Data.ListM
6+
import Mathlib.Data.ListM.Basic
77
import Mathlib.Control.Basic
88

99
@[reducible] def S (α : Type) := StateT (List Nat) Option α

0 commit comments

Comments
 (0)