Skip to content

Commit

Permalink
refactor: simplify runTermElabM and liftTermElabM
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Aug 7, 2022
1 parent e236eeb commit 413db56
Show file tree
Hide file tree
Showing 27 changed files with 81 additions and 84 deletions.
12 changes: 6 additions & 6 deletions src/Lean/Elab/BuiltinCommand.lean
Expand Up @@ -217,7 +217,7 @@ private def replaceBinderAnnotation (binder : TSyntax ``Parser.Term.bracketedBin
@[builtinCommandElab «variable»] def elabVariable : CommandElab
| `(variable $binders*) => do
-- Try to elaborate `binders` for sanity checking
runTermElabM none fun _ => Term.withAutoBoundImplicit <|
runTermElabM fun _ => Term.withAutoBoundImplicit <|
Term.elabBinders binders fun _ => pure ()
for binder in binders do
let binders ← replaceBinderAnnotation binder
Expand All @@ -230,7 +230,7 @@ private def replaceBinderAnnotation (binder : TSyntax ``Parser.Term.bracketedBin
open Meta

def elabCheckCore (ignoreStuckTC : Bool) : CommandElab
| `(#check%$tk $term) => withoutModifyingEnv $ runTermElabM (some `_check) fun _ => do
| `(#check%$tk $term) => withoutModifyingEnv <| runTermElabM fun _ => Term.withDeclName `_check do
let e ← Term.elabTerm term none
Term.synthesizeSyntheticMVarsNoPostponing (ignoreStuckTC := ignoreStuckTC)
let (e, _) ← Term.levelMVarToParam (← instantiateMVars e)
Expand All @@ -242,7 +242,7 @@ def elabCheckCore (ignoreStuckTC : Bool) : CommandElab
@[builtinCommandElab Lean.Parser.Command.check] def elabCheck : CommandElab := elabCheckCore (ignoreStuckTC := true)

@[builtinCommandElab Lean.Parser.Command.reduce] def elabReduce : CommandElab
| `(#reduce%$tk $term) => withoutModifyingEnv <| runTermElabM (some `_check) fun _ => do
| `(#reduce%$tk $term) => withoutModifyingEnv <| runTermElabM fun _ => Term.withDeclName `_reduce do
let e ← Term.elabTerm term none
Term.synthesizeSyntheticMVarsNoPostponing
let (e, _) ← Term.levelMVarToParam (← instantiateMVars e)
Expand Down Expand Up @@ -344,7 +344,7 @@ unsafe def elabEvalUnsafe : CommandElab
-- Evaluate using term using `MetaEval` class.
let elabMetaEval : CommandElabM Unit := do
-- act? is `some act` if elaborated `term` has type `CommandElabM α`
let act? ← runTermElabM (some declName) fun _ => do
let act? ← runTermElabM fun _ => Term.withDeclName declName do
let e ← elabEvalTerm
let eType ← instantiateMVars (← inferType e)
if eType.isAppOfArity ``CommandElabM 1 then
Expand All @@ -366,7 +366,7 @@ unsafe def elabEvalUnsafe : CommandElab
let some act := act? | return ()
act
-- Evaluate using term using `Eval` class.
let elabEval : CommandElabM Unit := runTermElabM (some declName) fun _ => do
let elabEval : CommandElabM Unit := runTermElabM fun _ => Term.withDeclName declName do
-- fall back to non-meta eval if MetaEval hasn't been defined yet
-- modify e to `runEval e`
let e ← mkRunEval (← elabEvalTerm)
Expand All @@ -388,7 +388,7 @@ opaque elabEval : CommandElab

@[builtinCommandElab «synth»] def elabSynth : CommandElab := fun stx => do
let term := stx[1]
withoutModifyingEnv <| runTermElabM `_synth_cmd fun _ => do
withoutModifyingEnv <| runTermElabM fun _ => Term.withDeclName `_synth_cmd do
let inst ← Term.elabTerm term none
Term.synthesizeSyntheticMVarsNoPostponing
let inst ← instantiateMVars inst
Expand Down
18 changes: 7 additions & 11 deletions src/Lean/Elab/Command.lean
Expand Up @@ -343,20 +343,18 @@ def getBracketedBinderIds : Syntax → Array Name
| `(bracketedBinder|[$_]) => #[Name.anonymous]
| _ => #[]

private def mkTermContext (ctx : Context) (s : State) (declName? : Option Name) : Term.Context := Id.run do
private def mkTermContext (ctx : Context) (s : State) : Term.Context := Id.run do
let scope := s.scopes.head!
let mut sectionVars := {}
for id in scope.varDecls.concatMap getBracketedBinderIds, uid in scope.varUIds do
sectionVars := sectionVars.insert id uid
{ macroStack := ctx.macroStack
declName? := declName?
sectionVars := sectionVars
isNoncomputableSection := scope.isNoncomputable
tacticCache? := ctx.tacticCache? }

/--
Lift the `TermElabM` monadic action `x` into a `CommandElabM` monadic action.
You can optionally set the current declaration name for `x` using the parameter `declName?`.
Note that `x` is executed with an empty message log. Thus, `x` cannot modify/view messages produced by
previous commands.
Expand All @@ -375,11 +373,11 @@ def printExpr (e : Expr) : MetaM Unit := do
IO.println s!"{← ppExpr e} : {← ppExpr (← inferType e)}"
#eval
liftTermElabM none do
liftTermElabM do
printExpr (mkConst ``Nat)
```
-/
def liftTermElabM (declName? : Option Name) (x : TermElabM α) : CommandElabM α := do
def liftTermElabM (x : TermElabM α) : CommandElabM α := do
let ctx ← read
let s ← get
let heartbeats ← IO.getNumHeartbeats
Expand All @@ -388,7 +386,7 @@ def liftTermElabM (declName? : Option Name) (x : TermElabM α) : CommandElabM α
-- We execute `x` with an empty message log. Thus, `x` cannot modify/view messages produced by previous commands.
-- This is useful for implementing `runTermElabM` where we use `Term.resetMessageLog`
let x : TermElabM _ := withSaveInfoContext x
let x : MetaM _ := (observing x).run (mkTermContext ctx s declName?) { levelNames := scope.levelNames }
let x : MetaM _ := (observing x).run (mkTermContext ctx s) { levelNames := scope.levelNames }
let x : CoreM _ := x.run mkMetaContext {}
let x : EIO _ _ := x.run (mkCoreContext ctx s heartbeats) { env := s.env, ngen := s.ngen, nextMacroScope := s.nextMacroScope, infoState.enabled := s.infoState.enabled }
let (((ea, _), _), coreS) ← liftEIO x
Expand All @@ -410,8 +408,6 @@ corresponding to all active scoped variables declared using the `variable` comma
This method is similar to `liftTermElabM`, but it elaborates all scoped variables declared using the `variable`
command.
You can optionally set the current declaration name for `elabFn xs` using the parameter `declName?`.
Example:
```
import Lean
Expand All @@ -422,14 +418,14 @@ variable {α : Type u} {f : α → α}
variable (n : Nat)
#eval
runTermElabM none fun xs => do
runTermElabM fun xs => do
for x in xs do
IO.println s!"{← ppExpr x} : {← ppExpr (← inferType x)}"
```
-/
def runTermElabM (declName? : Option Name) (elabFn : Array Expr → TermElabM α) : CommandElabM α := do
def runTermElabM (elabFn : Array Expr → TermElabM α) : CommandElabM α := do
let scope ← getScope
liftTermElabM declName? <|
liftTermElabM <|
Term.withAutoBoundImplicit <|
Term.elabBinders scope.varDecls fun xs => do
-- We need to synthesize postponed terms because this is a checkpoint for the auto-bound implicit feature
Expand Down
57 changes: 29 additions & 28 deletions src/Lean/Elab/Declaration.lean
Expand Up @@ -104,34 +104,35 @@ def elabAxiom (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit := do
let scopeLevelNames ← getLevelNames
let ⟨_, declName, allUserLevelNames⟩ ← expandDeclId declId modifiers
addDeclarationRanges declName stx
runTermElabM declName fun vars => Term.withLevelNames allUserLevelNames <| Term.elabBinders binders.getArgs fun xs => do
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.beforeElaboration
let type ← Term.elabType typeStx
Term.synthesizeSyntheticMVarsNoPostponing
let type ← instantiateMVars type
let type ← mkForallFVars xs type
let type ← mkForallFVars vars type (usedOnly := true)
let (type, _) ← Term.levelMVarToParam type
let usedParams := collectLevelParams {} type |>.params
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedParams with
| Except.error msg => throwErrorAt stx msg
| Except.ok levelParams =>
runTermElabM fun vars =>
Term.withDeclName declName <| Term.withLevelNames allUserLevelNames <| Term.elabBinders binders.getArgs fun xs => do
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.beforeElaboration
let type ← Term.elabType typeStx
Term.synthesizeSyntheticMVarsNoPostponing
let type ← instantiateMVars type
let decl := Declaration.axiomDecl {
name := declName,
levelParams := levelParams,
type := type,
isUnsafe := modifiers.isUnsafe
}
trace[Elab.axiom] "{declName} : {type}"
Term.ensureNoUnassignedMVars decl
addDecl decl
withSaveInfoContext do -- save new env
Term.addTermInfo' declId (← mkConstWithLevelParams declName) (isBinder := true)
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterTypeChecking
if isExtern (← getEnv) declName then
compileDecl decl
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterCompilation
let type ← mkForallFVars xs type
let type ← mkForallFVars vars type (usedOnly := true)
let (type, _) ← Term.levelMVarToParam type
let usedParams := collectLevelParams {} type |>.params
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedParams with
| Except.error msg => throwErrorAt stx msg
| Except.ok levelParams =>
let type ← instantiateMVars type
let decl := Declaration.axiomDecl {
name := declName,
levelParams := levelParams,
type := type,
isUnsafe := modifiers.isUnsafe
}
trace[Elab.axiom] "{declName} : {type}"
Term.ensureNoUnassignedMVars decl
addDecl decl
withSaveInfoContext do -- save new env
Term.addTermInfo' declId (← mkConstWithLevelParams declName) (isBinder := true)
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterTypeChecking
if isExtern (← getEnv) declName then
compileDecl decl
Term.applyAttributesAt declName modifiers.attrs AttributeApplicationTime.afterCompilation

/-
leading_parser "inductive " >> declId >> optDeclSig >> optional ":=" >> many ctor
Expand Down Expand Up @@ -366,7 +367,7 @@ def elabMutual : CommandElab := fun stx => do
attrInsts := attrInsts.push attrKindStx
let attrs ← elabAttrs attrInsts
let idents := stx[4].getArgs
for ident in idents do withRef ident <| liftTermElabM none do
for ident in idents do withRef ident <| liftTermElabM do
let declName ← resolveGlobalConstNoOverloadWithInfo ident
Term.applyAttributes declName attrs
for attrName in toErase do
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/DefView.lean
Expand Up @@ -83,7 +83,7 @@ def mkFreshInstanceName : CommandElabM Name := do
def mkInstanceName (binders : Array Syntax) (type : Syntax) : CommandElabM Name := do
let savedState ← get
try
let result ← runTermElabM `inst fun _ => Term.withAutoBoundImplicit <| Term.elabBinders binders fun _ => Term.withoutErrToSorry do
let result ← runTermElabM fun _ => Term.withAutoBoundImplicit <| Term.elabBinders binders fun _ => Term.withoutErrToSorry do
let type ← instantiateMVars (← Term.elabType type)
let ref ← IO.mkRef ""
Meta.forEachExpr type fun e => do
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Elab/Deriving/BEq.lean
Expand Up @@ -111,11 +111,11 @@ open Command

def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size == 1 && (← isEnumType declNames[0]!) then
let cmds ← liftTermElabM none <| mkBEqEnumCmd declNames[0]!
let cmds ← liftTermElabM <| mkBEqEnumCmd declNames[0]!
cmds.forM elabCommand
return true
else if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM none <| mkBEqInstanceCmds declNames
let cmds ← liftTermElabM <| mkBEqInstanceCmds declNames
cmds.forM elabCommand
return true
else
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Deriving/Basic.lean
Expand Up @@ -43,7 +43,7 @@ def applyDerivingHandlers (className : Name) (typeNames : Array Name) (args? : O
| none => defaultHandler className typeNames

private def tryApplyDefHandler (className : Name) (declName : Name) : CommandElabM Bool :=
liftTermElabM none do
liftTermElabM do
Term.processDefDeriving className declName

@[builtinCommandElab «deriving»] def elabDeriving : CommandElab
Expand Down
6 changes: 3 additions & 3 deletions src/Lean/Elab/Deriving/DecEq.lean
Expand Up @@ -103,7 +103,7 @@ def mkDecEq (declName : Name) : CommandElabM Bool := do
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM none <| mkDecEqCmds indVal
let cmds ← liftTermElabM <| mkDecEqCmds indVal
cmds.forM elabCommand
return true

Expand Down Expand Up @@ -157,8 +157,8 @@ def mkEnumOfNatThm (declName : Name) : MetaM Unit := do
}

def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
liftTermElabM none <| mkEnumOfNat declName
liftTermElabM none <| mkEnumOfNatThm declName
liftTermElabM <| mkEnumOfNat declName
liftTermElabM <| mkEnumOfNatThm declName
let ofNatIdent := mkIdent (Name.mkStr declName "ofNat")
let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_toCtorIdx")
let cmd ← `(
Expand Down
8 changes: 4 additions & 4 deletions src/Lean/Elab/Deriving/FromToJson.lean
Expand Up @@ -22,7 +22,7 @@ def mkJsonField (n : Name) : Bool × Term :=
def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size == 1 then
if isStructure (← getEnv) declNames[0]! then
let cmds ← liftTermElabM none <| do
let cmds ← liftTermElabM do
let ctx ← mkContext "toJson" declNames[0]!
let header ← mkHeader ``ToJson 1 ctx.typeInfos[0]!
let fields := getStructureFieldsFlattened (← getEnv) declNames[0]! (includeSubobjectFields := false)
Expand All @@ -37,7 +37,7 @@ def mkToJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
return true
else
let indVal ← getConstInfoInduct declNames[0]!
let cmds ← liftTermElabM none <| do
let cmds ← liftTermElabM do
let ctx ← mkContext "toJson" declNames[0]!
let toJsonFuncId := mkIdent ctx.auxFunNames[0]!
-- Return syntax to JSONify `id`, either via `ToJson` or recursively
Expand Down Expand Up @@ -104,7 +104,7 @@ where
def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size == 1 then
if isStructure (← getEnv) declNames[0]! then
let cmds ← liftTermElabM none <| do
let cmds ← liftTermElabM do
let ctx ← mkContext "fromJson" declNames[0]!
let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]!
let fields := getStructureFieldsFlattened (← getEnv) declNames[0]! (includeSubobjectFields := false)
Expand All @@ -119,7 +119,7 @@ def mkFromJsonInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
return true
else
let indVal ← getConstInfoInduct declNames[0]!
let cmds ← liftTermElabM none <| do
let cmds ← liftTermElabM do
let ctx ← mkContext "fromJson" declNames[0]!
let header ← mkHeader ``FromJson 0 ctx.typeInfos[0]!
let fromJsonFuncId := mkIdent ctx.auxFunNames[0]!
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Deriving/Hashable.lean
Expand Up @@ -80,7 +80,7 @@ private def mkHashableInstanceCmds (declNames : Array Name) : TermElabM (Array S

def mkHashableHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM none <| mkHashableInstanceCmds declNames
let cmds ← liftTermElabM <| mkHashableInstanceCmds declNames
cmds.forM elabCommand
return true
else
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Deriving/Inhabited.lean
Expand Up @@ -16,7 +16,7 @@ private def implicitBinderF := Parser.Term.implicitBinder
private def instBinderF := Parser.Term.instBinder

private def mkInhabitedInstanceUsing (inductiveTypeName : Name) (ctorName : Name) (addHypotheses : Bool) : CommandElabM Bool := do
match (← liftTermElabM none mkInstanceCmd?) with
match (← liftTermElabM mkInstanceCmd?) with
| some cmd =>
elabCommand cmd
return true
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Deriving/Ord.lean
Expand Up @@ -96,7 +96,7 @@ open Command

def mkOrdInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM none <| mkOrdInstanceCmds declNames
let cmds ← liftTermElabM <| mkOrdInstanceCmds declNames
cmds.forM elabCommand
return true
else
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Deriving/Repr.lean
Expand Up @@ -116,7 +116,7 @@ open Command

def mkReprInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
let cmds ← liftTermElabM none <| mkReprInstanceCmds declNames
let cmds ← liftTermElabM <| mkReprInstanceCmds declNames
cmds.forM elabCommand
return true
else
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/Deriving/SizeOf.lean
Expand Up @@ -17,7 +17,7 @@ open Command

def mkSizeOfHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) && declNames.size > 0 then
liftTermElabM none <| Meta.mkSizeOfInstances declNames[0]!
liftTermElabM <| Meta.mkSizeOfInstances declNames[0]!
return true
else
return false
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/GenInjective.lean
Expand Up @@ -10,7 +10,7 @@ namespace Lean.Elab.Command

@[builtinCommandElab genInjectiveTheorems] def elabGenInjectiveTheorems : CommandElab := fun stx => do
let declName ← resolveGlobalConstNoOverload stx[1]
liftTermElabM none do
liftTermElabM do
Meta.mkInjectiveTheorems declName

end Lean.Elab.Command
6 changes: 3 additions & 3 deletions src/Lean/Elab/Inductive.lean
Expand Up @@ -831,21 +831,21 @@ private def applyComputedFields (indViews : Array InductiveView) : CommandElabM
|>.setBool `elaboratingComputedFields true}) <|
elabCommand <| ← `(mutual $computedFieldDefs* end)

liftTermElabM indViews[0]!.declName do
liftTermElabM do Term.withDeclName indViews[0]!.declName do
ComputedFields.setComputedFields computedFields

def elabInductiveViews (views : Array InductiveView) : CommandElabM Unit := do
let view0 := views[0]!
let ref := view0.ref
runTermElabM view0.declName fun vars => withRef ref do
runTermElabM fun vars => Term.withDeclName view0.declName do withRef ref do
mkInductiveDecl vars views
mkSizeOfInstances view0.declName
Lean.Meta.IndPredBelow.mkBelow view0.declName
for view in views do
mkInjectiveTheorems view.declName
applyComputedFields views -- NOTE: any generated code before this line is invalid
applyDerivingHandlers views
runTermElabM view0.declName fun _ => withRef ref do
runTermElabM fun _ => Term.withDeclName view0.declName do withRef ref do
for view in views do
Term.applyAttributesAt view.declName view.modifiers.attrs .afterCompilation

Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Elab/MutualDef.lean
Expand Up @@ -873,7 +873,7 @@ def elabMutualDef (ds : Array Syntax) (hints : TerminationHints) : CommandElabM
if ds.size > 1 && modifiers.isNonrec then
throwErrorAt d "invalid use of 'nonrec' modifier in 'mutual' block"
mkDefView modifiers d[1]
runTermElabM none fun vars => Term.elabMutualDef vars views hints
runTermElabM fun vars => Term.elabMutualDef vars views hints

end Command
end Lean.Elab
4 changes: 2 additions & 2 deletions src/Lean/Elab/Structure.lean
Expand Up @@ -876,7 +876,7 @@ def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit :=
let derivingClassViews ← getOptDerivingClasses stx[6]
let type ← if optType.isNone then `(Sort _) else pure optType[0][1]
let declName ←
runTermElabM none fun scopeVars => do
runTermElabM fun scopeVars => do
let scopeLevelNames ← Term.getLevelNames
let ⟨name, declName, allUserLevelNames⟩ ← Elab.expandDeclId (← getCurrNamespace) scopeLevelNames declId modifiers
Term.withAutoBoundImplicitForbiddenPred (fun n => name == n) do
Expand Down Expand Up @@ -908,7 +908,7 @@ def elabStructure (modifiers : Modifiers) (stx : Syntax) : CommandElabM Unit :=
mkInjectiveTheorems declName
return declName
derivingClassViews.forM fun view => view.applyHandlers #[declName]
runTermElabM declName fun _ =>
runTermElabM fun _ => Term.withDeclName declName do
Term.applyAttributesAt declName modifiers.attrs .afterCompilation

builtin_initialize registerTraceClass `Elab.structure
Expand Down

0 comments on commit 413db56

Please sign in to comment.