Skip to content

Commit

Permalink
feat: improve binrel% elaborator
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed May 10, 2022
1 parent 1768067 commit 7ce0471
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 144 deletions.
100 changes: 53 additions & 47 deletions src/Lean/Elab/Extra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,6 @@ Auxiliary elaboration functions: AKA custom elaborators
namespace Lean.Elab.Term
open Meta

def elabBinRelCore (noProp : Bool) (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
match (← resolveId? stx[1]) with
| some f =>
let s ← saveState
let (lhs, rhs) ← withSynthesize (mayPostpone := true) do
let mut lhs ← elabTerm stx[2] none
let mut rhs ← elabTerm stx[3] none
if lhs.isAppOfArity ``OfNat.ofNat 3 then
lhs ← ensureHasType (← inferType rhs) lhs
else if rhs.isAppOfArity ``OfNat.ofNat 3 then
rhs ← ensureHasType (← inferType lhs) rhs
return (lhs, rhs)
let lhs ← toBoolIfNecessary lhs
let rhs ← toBoolIfNecessary rhs
let lhsType ← inferType lhs
let rhsType ← inferType rhs

let (lhs, rhs) ←
try
pure (lhs, ← withRef stx[3] do ensureHasType lhsType rhs)
catch _ =>
try
pure (← withRef stx[2] do ensureHasType rhsType lhs, rhs)
catch _ =>
s.restore
-- Use default approach
let lhs ← elabTerm stx[2] none
let rhs ← elabTerm stx[3] none
let lhsType ← inferType lhs
let rhsType ← inferType rhs
pure (lhs, ← withRef stx[3] do ensureHasType lhsType rhs)
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false)
| none => throwUnknownConstant stx[1].getId
where
/-- If `noProp == true` and `e` has type `Prop`, then coerce it to `Bool`. -/
toBoolIfNecessary (e : Expr) : TermElabM Expr := do
if noProp then
-- We use `withNewMCtxDepth` to make sure metavariables are not assigned
if (← withNewMCtxDepth <| isDefEq (← inferType e) (mkSort levelZero)) then
return (← ensureHasType (Lean.mkConst ``Bool) e)
return e

@[builtinTermElab binrel] def elabBinRel : TermElab := elabBinRelCore false

@[builtinTermElab binrel_no_prop] def elabBinRelNoProp : TermElab := elabBinRelCore true

private def getMonadForIn (expectedType? : Option Expr) : TermElabM Expr := do
match expectedType? with
| none => throwError "invalid 'for_in%' notation, expected type is not available"
Expand Down Expand Up @@ -174,7 +128,9 @@ private inductive Tree where
| op (ref : Syntax) (lazy : Bool) (f : Expr) (lhs rhs : Tree)

private partial def toTree (s : Syntax) : TermElabM Tree := do
let result ← go (← liftMacroM <| expandMacros s)
let s ← liftMacroM <| expandMacros s
trace[Meta.debug] "toTree: {s}"
let result ← go s
synthesizeSyntheticMVars (mayPostpone := true)
return result
where
Expand Down Expand Up @@ -349,6 +305,55 @@ def elabBinOp : TermElab := fun stx expectedType? => do
@[builtinTermElab binop_lazy]
def elabBinOpLazy : TermElab := elabBinOp

/--
Elaboration functionf for `binrel%` and `binrel_no_prop%` notations.
We use the infrastructure for `binop%` to make sure we propagate information between the left and right hand sides
of a binary relation.
Recall that the `binrel_no_prop%` notation is used for relations such as `==` which do not support `Prop`, but
we still want to be able to write `(5 > 2) == (2 > 1)`.
-/
def elabBinRelCore (noProp : Bool) (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
match (← resolveId? stx[1]) with
| some f => withSynthesize (mayPostpone := true) do
let lhs ← withRef stx[2] <| toTree stx[2]
let rhs ← withRef stx[3] <| toTree stx[3]
let tree := Tree.op (lazy := false) stx f lhs rhs
let r ← analyze tree none
trace[Elab.binrel] "hasUncomparable: {r.hasUncomparable}, maxType: {r.max?}"
if r.hasUncomparable || r.max?.isNone then
-- Use default elaboration strategy + `toBoolIfNecessary`
let lhs ← toExpr lhs
let rhs ← toExpr rhs
let lhs ← toBoolIfNecessary lhs
let rhs ← toBoolIfNecessary rhs
let lhsType ← inferType lhs
let rhsType ← inferType rhs
let rhs ← ensureHasType lhsType rhs
elabAppArgs f #[] #[Arg.expr lhs, Arg.expr rhs] expectedType? (explicit := false) (ellipsis := false)
else
let mut maxType := r.max?.get!
/- If `noProp == true` and `maxType` is `Prop`, then set `maxType := Bool`. `See toBoolIfNecessary` -/
if noProp then
if (← withNewMCtxDepth <| isDefEq maxType (mkSort levelZero)) then
maxType := Lean.mkConst ``Bool
let result ← toExpr (← applyCoe tree maxType)
trace[Elab.binrel] "result: {result}"
return result
| none => throwUnknownConstant stx[1].getId
where
/-- If `noProp == true` and `e` has type `Prop`, then coerce it to `Bool`. -/
toBoolIfNecessary (e : Expr) : TermElabM Expr := do
if noProp then
-- We use `withNewMCtxDepth` to make sure metavariables are not assigned
if (← withNewMCtxDepth <| isDefEq (← inferType e) (mkSort levelZero)) then
return (← ensureHasType (Lean.mkConst ``Bool) e)
return e

@[builtinTermElab binrel] def elabBinRel : TermElab := elabBinRelCore false

@[builtinTermElab binrel_no_prop] def elabBinRelNoProp : TermElab := elabBinRelCore true

/--
Decompose `e` into `(r, a, b)`.
Expand Down Expand Up @@ -426,6 +431,7 @@ def elabDefaultOrNonempty : TermElab := fun stx expectedType? => do

builtin_initialize
registerTraceClass `Elab.binop
registerTraceClass `Elab.binrel

end BinOp

Expand Down
11 changes: 11 additions & 0 deletions tests/lean/binrel_binop.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
theorem ex1 (a : Int) (b c : Nat) : a = ↑b - ↑c := sorry

#check ex1

theorem ex2 (a : Int) (b c : Nat) : a = b - c := sorry

#check ex2

theorem ex3 (a : Int) (b c : Nat) : a = ↑(b - c) := sorry

#check ex3
6 changes: 6 additions & 0 deletions tests/lean/binrel_binop.lean.expected.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
binrel_binop.lean:1:51-1:56: warning: declaration uses 'sorry'
ex1 : ∀ (a : Int) (b c : Nat), a = Int.ofNat b - Int.ofNat c
binrel_binop.lean:5:49-5:54: warning: declaration uses 'sorry'
ex2 : ∀ (a : Int) (b c : Nat), a = Int.ofNat b - Int.ofNat c
binrel_binop.lean:9:52-9:57: warning: declaration uses 'sorry'
ex3 : ∀ (a : Int) (b c : Nat), a = Int.ofNat (b - c)
Loading

0 comments on commit 7ce0471

Please sign in to comment.