Skip to content

Commit 289d461

Browse files
committed
feat: multiple scoped in notation3 (#6793)
See the [discussion on Zulip](https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/mathport.20fails/near/387335863). `notation3` did not support more than one occurrence of `scoped` due to an unnecessary limitation in the generated delaborator. It now handles this case correctly.
1 parent 7320add commit 289d461

File tree

2 files changed

+90
-69
lines changed

2 files changed

+90
-69
lines changed

Mathlib/Mathport/Notation.lean

Lines changed: 48 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ structure MatchState where
9797
We store the contexts since we need to delaborate expressions after we leave
9898
scoping constructs. -/
9999
vars : HashMap Name (SubExpr × LocalContext × LocalInstances)
100-
/-- The binders accumulated when matching a `scoped` expression. -/
101-
scopeState : Array (TSyntax ``extBinderParenthesized)
100+
/-- The binders accumulated while matching a `scoped` expression. -/
101+
scopeState : Option (Array (TSyntax ``extBinderParenthesized))
102102
/-- The arrays of delaborated `Term`s accumulated while matching
103103
`foldl` and `foldr` expressions. For `foldl`, the arrays are stored in reverse order. -/
104104
foldState : HashMap Name (Array Term)
@@ -110,7 +110,7 @@ def Matcher := MatchState → DelabM MatchState
110110
/-- The initial state. -/
111111
def MatchState.empty : MatchState where
112112
vars := {}
113-
scopeState := #[]
113+
scopeState := none
114114
foldState := {}
115115

116116
/-- Evaluate `f` with the given variable's value as the `SubExpr` and within that subexpression's
@@ -134,19 +134,16 @@ def MatchState.delabVar (s : MatchState) (name : Name) (checkNot? : Option Expr
134134
def MatchState.captureSubexpr (s : MatchState) (name : Name) : DelabM MatchState := do
135135
return {s with vars := s.vars.insert name (← readThe SubExpr, ← getLCtx, ← getLocalInstances)}
136136

137-
/-- Push a binder onto the binder array. For `scoped`. -/
138-
def MatchState.pushBinder (s : MatchState) (b : TSyntax ``extBinderParenthesized) :
139-
DelabM MatchState := do
140-
let binders := s.scopeState
141-
-- TODO merge binders as an inverse to `satisfies_binder_pred%`
142-
let binders := binders.push b
143-
return {s with scopeState := binders}
144-
145137
/-- Get the accumulated array of delaborated terms for a given foldr/foldl.
146138
Returns `#[]` if nothing has been pushed yet. -/
147139
def MatchState.getFoldArray (s : MatchState) (name : Name) : Array Term :=
148140
(s.foldState.find? name).getD #[]
149141

142+
/-- Get the accumulated array of delaborated terms for a given foldr/foldl.
143+
Returns `#[]` if nothing has been pushed yet. -/
144+
def MatchState.getBinders (s : MatchState) : Array (TSyntax ``extBinderParenthesized) :=
145+
s.scopeState.getD #[]
146+
150147
/-- Push a delaborated term onto a foldr/foldl array. -/
151148
def MatchState.pushFold (s : MatchState) (name : Name) (t : Term) : MatchState :=
152149
let ts := (s.getFoldArray name).push t
@@ -268,46 +265,44 @@ where
268265
against is in the `lit` variable.
269266
270267
Runs `smatcher`, extracts the resulting `scopeId` variable, processes this value
271-
(which must be a lambda) to produce a binder, and loops.
272-
273-
Succeeds even if it matches nothing, so it is up to the caller to decide if the
274-
empty scope state is ok. -/
275-
partial def matchScoped (lit scopeId : Name) (smatcher : Matcher) : Matcher := fun s => do
276-
-- `lit` is bound to the SubExpr that the `scoped` syntax produced
277-
s.withVar lit do
278-
try
279-
-- Run `smatcher` at `lit`, clearing the `scopeId` variable so that it can get a fresh value
280-
let s ← smatcher {s with vars := s.vars.erase scopeId}
281-
s.withVar scopeId do
282-
guard (← getExpr).isLambda
283-
let prop ← try Meta.isProp (← getExpr).bindingDomain! catch _ => pure false
284-
let isDep := (← getExpr).bindingBody!.hasLooseBVar 0
285-
let ppTypes ← getPPOption getPPPiBinderTypes -- the same option controlling ∀
286-
let dom ← withBindingDomain delab
287-
withBindingBodyUnusedName <| fun x => do
288-
let x : Ident := ⟨x⟩
289-
let binder ←
290-
if prop && !isDep then
291-
-- this underscore is used to support binder predicates, since it indicates
292-
-- the variable is unused and this binder is safe to merge into another
293-
`(extBinderParenthesized|(_ : $dom))
294-
else if prop || ppTypes then
295-
`(extBinderParenthesized|($x:ident : $dom))
296-
else
297-
`(extBinderParenthesized|($x:ident))
298-
-- Now use the body of the lambda for `lit` for the next iteration
299-
let s ← s.captureSubexpr lit
300-
let s ← s.pushBinder binder
301-
matchScoped lit scopeId smatcher s
302-
catch _ =>
303-
return s
304-
305-
/-- Like `matchScoped` but ensures that it matches at least one binder. -/
306-
partial def matchScoped' (lit scopeId : Name) (smatcher : Matcher) : Matcher := fun s => do
307-
guard <| s.scopeState.isEmpty
308-
let s ← matchScoped lit scopeId smatcher s
309-
guard <| !s.scopeState.isEmpty
310-
return s
268+
(which must be a lambda) to produce a binder, and loops. -/
269+
partial def matchScoped (lit scopeId : Name) (smatcher : Matcher) : Matcher := go #[] where
270+
/-- Variant of `matchScoped` after some number of `binders` have already been captured. -/
271+
go (binders : Array (TSyntax ``extBinderParenthesized)) : Matcher := fun s => do
272+
-- `lit` is bound to the SubExpr that the `scoped` syntax produced
273+
s.withVar lit do
274+
try
275+
-- Run `smatcher` at `lit`, clearing the `scopeId` variable so that it can get a fresh value
276+
let s ← smatcher {s with vars := s.vars.erase scopeId}
277+
s.withVar scopeId do
278+
guard (← getExpr).isLambda
279+
let prop ← try Meta.isProp (← getExpr).bindingDomain! catch _ => pure false
280+
let isDep := (← getExpr).bindingBody!.hasLooseBVar 0
281+
let ppTypes ← getPPOption getPPPiBinderTypes -- the same option controlling ∀
282+
let dom ← withBindingDomain delab
283+
withBindingBodyUnusedName <| fun x => do
284+
let x : Ident := ⟨x⟩
285+
let binder ←
286+
if prop && !isDep then
287+
-- this underscore is used to support binder predicates, since it indicates
288+
-- the variable is unused and this binder is safe to merge into another
289+
`(extBinderParenthesized|(_ : $dom))
290+
else if prop || ppTypes then
291+
`(extBinderParenthesized|($x:ident : $dom))
292+
else
293+
`(extBinderParenthesized|($x:ident))
294+
-- Now use the body of the lambda for `lit` for the next iteration
295+
let s ← s.captureSubexpr lit
296+
-- TODO merge binders as an inverse to `satisfies_binder_pred%`
297+
let binders := binders.push binder
298+
go binders s
299+
catch _ =>
300+
guard <| !binders.isEmpty
301+
if let some binders₂ := s.scopeState then
302+
guard <| binders == binders₂ -- TODO: this might be a bit too strict, but it seems to work
303+
return s
304+
else
305+
return {s with scopeState := binders}
311306

312307
/- Create a `Term` that represents a matcher for `scoped` notation.
313308
Fails in the `OptionT` sense if a matcher couldn't be constructed.
@@ -317,7 +312,7 @@ partial def mkScopedMatcher (lit scopeId : Name) (scopedTerm : Term) (boundNames
317312
OptionT TermElabM (List Name × Term) := do
318313
-- Build the matcher for `scopedTerm` with `scopeId` as an additional variable
319314
let (keys, smatcher) ← mkExprMatcher scopedTerm (boundNames.insert scopeId)
320-
return (keys, ← ``(matchScoped' $(quote lit) $(quote scopeId) $smatcher))
315+
return (keys, ← ``(matchScoped $(quote lit) $(quote scopeId) $smatcher))
321316

322317
/-- Matcher for expressions produced by `foldl`. -/
323318
partial def matchFoldl (lit x y : Name) (smatcher : Matcher) (sinit : Matcher) :
@@ -482,8 +477,6 @@ elab doc:(docComment)? attrs?:(Parser.Term.attributes)? attrKind:Term.attrKind
482477
mkFoldrMatcher id.getId x.getId y.getId scopedTerm init (getBoundNames boundValues)
483478
| _ => throwUnsupportedSyntax
484479
| `(notation3Item| $lit:ident $(prec?)? : (scoped $scopedId:ident => $scopedTerm)) =>
485-
if hasScoped then
486-
throwErrorAt item "Cannot have more than one `scoped` item."
487480
hasScoped := true
488481
(syntaxArgs, pattArgs) ← pushMacro syntaxArgs pattArgs <|←
489482
`(macroArg| $lit:ident:term $(prec?)?)
@@ -547,7 +540,7 @@ elab doc:(docComment)? attrs?:(Parser.Term.attributes)? attrKind:Term.attrKind
547540
| .foldr => result ←
548541
`(let $id := MatchState.getFoldArray s $(quote name); $result)
549542
if hasBindersItem then
550-
result ← `(`(extBinders| $$(MatchState.scopeState s)*) >>= fun binders => $result)
543+
result ← `(`(extBinders| $$(MatchState.getBinders s)*) >>= fun binders => $result)
551544
elabCommand <| ← `(command|
552545
def $(Lean.mkIdent delabName) : Delab := whenPPOption getPPNotation <|
553546
getExpr >>= fun e => $matcher MatchState.empty >>= fun s => $result)

test/notation3.lean

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import Std.Tactic.GuardMsgs
12
import Mathlib.Mathport.Notation
23
import Mathlib.Init.Data.Nat.Lemmas
34

5+
set_option pp.unicode.fun true
46
set_option autoImplicit true
57

68
namespace Test
@@ -13,45 +15,71 @@ def Filter.eventually (p : α → Prop) (f : Filter α) := f p
1315

1416
notation3 "∀ᶠ " (...) " in " f ", " r:(scoped p => Filter.eventually p f) => r
1517

16-
#check ∀ᶠ (x : Nat) (y) in Filter.atTop, x < y
17-
#check ∀ᶠ x in Filter.atTop, x < 3
18+
/-- info: ∀ᶠ (x : ℕ) (y : ℕ) in Filter.atTop, x < y : Prop -/
19+
#guard_msgs in #check ∀ᶠ (x : Nat) (y) in Filter.atTop, x < y
20+
/-- info: ∀ᶠ (x : ℕ) in Filter.atTop, x < 3 : Prop -/
21+
#guard_msgs in #check ∀ᶠ x in Filter.atTop, x < 3
1822

23+
def foobar (p : α → Prop) (f : Prop) := ∀ x, p x = f
24+
25+
notation3 "∀ᶠᶠ " (...) " in " f ": "
26+
r1:(scoped p => Filter.eventually p f) ", " r2:(scoped p => foobar p r1) => r2
27+
28+
/-- info: ∀ᶠᶠ (x : ℕ) (y : ℕ) in Filter.atTop: x < y, x = y : Prop -/
29+
#guard_msgs in #check ∀ᶠᶠ (x : Nat) (y) in Filter.atTop: x < y, x = y
30+
/-- info: ∀ᶠᶠ (x : ℕ) in Filter.atTop: x < 3, x = 1 : Prop -/
31+
#guard_msgs in #check ∀ᶠᶠ x in Filter.atTop: x < 3, x = 1
32+
/-- info: ∀ᶠᶠ (x : ℕ) in Filter.atTop: x < 3, x = 1 : Prop -/
33+
#guard_msgs in #check foobar (fun x ↦ Eq x 1) (Filter.atTop.eventually fun x ↦ LT.lt x 3)
34+
/-- info: foobar (fun y ↦ y = 1) (∀ᶠ (x : ℕ) in Filter.atTop, x < 3) : Prop -/
35+
#guard_msgs in #check foobar (fun y ↦ Eq y 1) (Filter.atTop.eventually fun x ↦ LT.lt x 3)
1936

2037
notation3 "∃' " (...) ", " r:(scoped p => Exists p) => r
21-
#check ∃' x < 3, x < 3
38+
/-- info: ∃' (x : ℕ) (_ : x < 3), x < 3 : Prop -/
39+
#guard_msgs in #check ∃' x < 3, x < 3
2240

2341
def func (x : α) : α := x
2442
notation3 "func! " (...) ", " r:(scoped p => func p) => r
2543
-- Make sure it handles additional arguments. Should not consume `(· * 2)`.
2644
-- Note: right now this causes the notation to not pretty print at all.
27-
#check (func! (x : Nat → Nat), x) (· * 2)
45+
/-- info: func (fun x ↦ x) fun x ↦ x * 2 : ℕ → ℕ -/
46+
#guard_msgs in #check (func! (x : Nat → Nat), x) (· * 2)
2847

29-
structure MyUnit
48+
structure MyUnit where
3049
notation3 "~{" (x"; "* => foldl (a b => Prod.mk a b) MyUnit) "}~" => x
31-
#check ~{1; true; ~{2}~}~
32-
#check ~{}~
50+
/-- info: ~{1; true; ~{2}~}~ : ((Type × ℕ) × Bool) × Type × ℕ -/
51+
#guard_msgs in #check ~{1; true; ~{2}~}~
52+
/-- info: ~{}~ : Type -/
53+
#guard_msgs in #check ~{}~
3354

3455
notation3 "%[" (x", "* => foldr (a b => List.cons a b) List.nil) "]" => x
35-
#check %[1, 2, 3]
56+
/-- info: %[1, 2, 3] : List ℕ -/
57+
#guard_msgs in #check %[1, 2, 3]
3658

3759
def foo (a : Nat) (f : Nat → Nat) := a + f a
3860
def bar (a b : Nat) := a * b
3961
notation3 "*[" x "] " (...) ", " v:(scoped c => bar x (foo x c)) => v
40-
#check *[1] (x) (y), x + y
41-
#check bar 1
62+
/-- info: *[1] (x : ℕ) (y : ℕ), x + y : ℕ -/
63+
#guard_msgs in #check *[1] (x) (y), x + y
64+
/-- info: bar 1 : ℕ → ℕ -/
65+
#guard_msgs in #check bar 1
4266

4367
-- Checking that the `<|` macro is expanded when making matcher
4468
def foo' (a : Nat) (f : Nat → Nat) := a + f a
4569
def bar' (a b : Nat) := a * b
4670
notation3 "*'[" x "] " (...) ", " v:(scoped c => bar' x <| foo' x c) => v
47-
#check *'[1] (x) (y), x + y
48-
#check bar' 1
71+
/-- info: *'[1] (x : ℕ) (y : ℕ), x + y : ℕ -/
72+
#guard_msgs in #check *'[1] (x) (y), x + y
73+
/-- info: bar' 1 : ℕ → ℕ -/
74+
#guard_msgs in #check bar' 1
4975

5076
-- Currently does not pretty print due to pi type
5177
notation3 (prettyPrint := false) "MyPi " (...) ", " r:(scoped p => (x : _) → p x) => r
52-
#check MyPi (x : Nat) (y : Nat), x < y
78+
/-- info: ∀ (x : ℕ), (fun x ↦ ∀ (x_1 : ℕ), (fun y ↦ x < y) x_1) x : Prop -/
79+
#guard_msgs in #check MyPi (x : Nat) (y : Nat), x < y
5380

5481
-- The notation parses fine, but the delaborator never succeeds, which is expected
5582
def myId (x : α) := x
5683
notation3 "BAD " c "; " (x", "* => foldl (a b => b) c) " DAB" => myId x
57-
#check BAD 1; 2, 3 DAB
84+
/-- info: myId 3 : ℕ -/
85+
#guard_msgs in #check BAD 1; 2, 3 DAB

0 commit comments

Comments
 (0)