Skip to content

Commit

Permalink
feat: elaborate strict implicit binders
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Aug 4, 2021
1 parent 9988264 commit 4cd7e35
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 36 deletions.
51 changes: 32 additions & 19 deletions src/Lean/Elab/App.lean
Expand Up @@ -388,6 +388,11 @@ private def processExplictArg (k : M Expr) : M Expr := do
else
finalize

/- Return true if there are regular or named arguments to be processed. -/
private def hasArgsToProcess : M Bool := do
let s ← get
return !s.args.isEmpty || !s.namedArgs.isEmpty

/-
Process a `fType` of the form `{x : A} → B x`.
This method assume `fType` is a function type -/
Expand All @@ -397,6 +402,17 @@ private def processImplicitArg (k : M Expr) : M Expr := do
else
addImplicitArg k

/-
Process a `fType` of the form `{{x : A}} → B x`.
This method assume `fType` is a function type -/
private def processStrictImplicitArg (k : M Expr) : M Expr := do
if (← get).explicit then
processExplictArg k
else if (← hasArgsToProcess) then
addImplicitArg k
else
finalize

/- Return true if the next argument at `args` is of the form `_` -/
private def isNextArgHole : M Bool := do
match (← get).args with
Expand All @@ -423,11 +439,6 @@ private def processInstImplicitArg (k : M Expr) : M Expr := do
addNewArg arg
k

/- Return true if there are regular or named arguments to be processed. -/
private def hasArgsToProcess : M Bool := do
let s ← get
pure $ !s.args.isEmpty || !s.namedArgs.isEmpty

/- Elaborate function application arguments. -/
partial def main : M Expr := do
let s ← get
Expand All @@ -444,9 +455,10 @@ partial def main : M Expr := do
main
| none =>
match binfo with
| BinderInfo.implicit => processImplicitArg main
| BinderInfo.instImplicit => processInstImplicitArg main
| _ => processExplictArg main
| BinderInfo.implicit => processImplicitArg main
| BinderInfo.instImplicit => processInstImplicitArg main
| BinderInfo.strictImplicit => processStrictImplicitArg main
| _ => processExplictArg main
else if (← hasArgsToProcess) then
synthesizePendingAndNormalizeFunType
main
Expand Down Expand Up @@ -572,25 +584,25 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L

/- whnfCore + implicit consumption.
Example: given `e` with `eType := {α : Type} → (fun β => List β) α `, it produces `(e ?m, List ?m)` where `?m` is fresh metavariable. -/
private partial def consumeImplicits (stx : Syntax) (e eType : Expr) : TermElabM (Expr × Expr) := do
private partial def consumeImplicits (stx : Syntax) (e eType : Expr) (hasArgs : Bool) : TermElabM (Expr × Expr) := do
let eType ← whnfCore eType
match eType with
| Expr.forallE n d b c =>
if c.binderInfo.isImplicit then
if c.binderInfo.isImplicit || (hasArgs && c.binderInfo.isStrictImplicit) then
let mvar ← mkFreshExprMVar d
registerMVarErrorHoleInfo mvar.mvarId! stx
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar)
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) hasArgs
else if c.binderInfo.isInstImplicit then
let mvar ← mkInstMVar d
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar)
consumeImplicits stx (mkApp e mvar) (b.instantiate1 mvar) hasArgs
else match d.getOptParamDefault? with
| some defVal => consumeImplicits stx (mkApp e defVal) (b.instantiate1 defVal)
| some defVal => consumeImplicits stx (mkApp e defVal) (b.instantiate1 defVal) hasArgs
-- TODO: we do not handle autoParams here.
| _ => pure (e, eType)
| _ => pure (e, eType)

private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExceptions : Array Exception) : TermElabM (Expr × LValResolution) := do
let (e, eType) ← consumeImplicits lval.getRef e eType
private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExceptions : Array Exception) (hasArgs : Bool) : TermElabM (Expr × LValResolution) := do
let (e, eType) ← consumeImplicits lval.getRef e eType hasArgs
tryPostponeIfMVar eType
try
let lvalRes ← resolveLValAux e eType lval
Expand All @@ -599,15 +611,15 @@ private partial def resolveLValLoop (lval : LVal) (e eType : Expr) (previousExce
| ex@(Exception.error _ _) =>
let eType? ← unfoldDefinition? eType
match eType? with
| some eType => resolveLValLoop lval e eType (previousExceptions.push ex)
| some eType => resolveLValLoop lval e eType (previousExceptions.push ex) hasArgs
| none =>
previousExceptions.forM fun ex => logException ex
throw ex
| ex@(Exception.internal _ _) => throw ex

private def resolveLVal (e : Expr) (lval : LVal) : TermElabM (Expr × LValResolution) := do
private def resolveLVal (e : Expr) (lval : LVal) (hasArgs : Bool) : TermElabM (Expr × LValResolution) := do
let eType ← inferType e
resolveLValLoop lval e eType #[]
resolveLValLoop lval e eType #[] hasArgs

private partial def mkBaseProjections (baseStructName : Name) (structName : Name) (e : Expr) : TermElabM Expr := do
let env ← getEnv
Expand Down Expand Up @@ -675,7 +687,8 @@ private def elabAppLValsAux (namedArgs : Array NamedArg) (args : Array Arg) (exp
| f, lval::lvals => do
if let LVal.fieldName (ref := fieldStx) (targetStx := targetStx) .. := lval then
addDotCompletionInfo targetStx f expectedType? fieldStx
let (f, lvalRes) ← resolveLVal f lval
let hasArgs := !namedArgs.isEmpty || !args.isEmpty
let (f, lvalRes) ← resolveLVal f lval hasArgs
match lvalRes with
| LValResolution.projIdx structName idx =>
let f := mkProj structName idx f
Expand Down
26 changes: 16 additions & 10 deletions src/Lean/Elab/Binders.lean
Expand Up @@ -104,24 +104,29 @@ private def getBinderIds (ids : Syntax) : TermElabM (Array Syntax) :=

private def matchBinder (stx : Syntax) : TermElabM (Array BinderView) := do
let k := stx.getKind
if k == `Lean.Parser.Term.simpleBinder then
if k == ``Lean.Parser.Term.simpleBinder then
-- binderIdent+ >> optType
let ids ← getBinderIds stx[0]
let type := expandOptType (mkNullNode ids) stx[1]
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default }
else if k == `Lean.Parser.Term.explicitBinder then
else if k == ``Lean.Parser.Term.explicitBinder then
-- `(` binderIdent+ binderType (binderDefault <|> binderTactic)? `)`
let ids ← getBinderIds stx[1]
let type := expandBinderType (mkNullNode ids) stx[2]
let optModifier := stx[3]
let type ← expandBinderModifier type optModifier
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.default }
else if k == `Lean.Parser.Term.implicitBinder then
else if k == ``Lean.Parser.Term.implicitBinder then
-- `{` binderIdent+ binderType `}`
let ids ← getBinderIds stx[1]
let type := expandBinderType (mkNullNode ids) stx[2]
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.implicit }
else if k == `Lean.Parser.Term.instBinder then
else if k == ``Lean.Parser.Term.strictImplicitBinder then
-- `⦃` binderIdent+ binderType `⦄`
let ids ← getBinderIds stx[1]
let type := expandBinderType (mkNullNode ids) stx[2]
ids.mapM fun id => do pure { id := (← expandBinderIdent id), type := type, bi := BinderInfo.strictImplicit }
else if k == ``Lean.Parser.Term.instBinder then
-- `[` optIdent type `]`
let id ← expandOptIdent stx[1]
let type := stx[2]
Expand Down Expand Up @@ -256,15 +261,16 @@ partial def expandFunBinders (binders : Array Syntax) (body : Syntax) : MacroM (
let newBody ← `(match $major:ident with | $pattern => $newBody)
pure (binders, newBody, true)
match binder with
| Syntax.node `Lean.Parser.Term.implicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node `Lean.Parser.Term.instBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node `Lean.Parser.Term.explicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node `Lean.Parser.Term.simpleBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node `Lean.Parser.Term.hole _ =>
| Syntax.node ``Lean.Parser.Term.implicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node ``Lean.Parser.Term.strictImplicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node ``Lean.Parser.Term.instBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node ``Lean.Parser.Term.explicitBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node ``Lean.Parser.Term.simpleBinder _ => loop body (i+1) (newBinders.push binder)
| Syntax.node ``Lean.Parser.Term.hole _ =>
let ident ← mkFreshIdent binder
let type := binder
loop body (i+1) (newBinders.push <| mkExplicitBinder ident type)
| Syntax.node `Lean.Parser.Term.paren args =>
| Syntax.node ``Lean.Parser.Term.paren args =>
-- `(` (termParser >> parenSpecial)? `)`
-- parenSpecial := (tupleTail <|> typeAscription)?
let binderBody := binder[1]
Expand Down
17 changes: 10 additions & 7 deletions src/Lean/PrettyPrinter/Delaborator/Builtins.lean
Expand Up @@ -506,10 +506,12 @@ def delabLam : Delab :=
else
pure $ curNames.get! 0;
`(funBinder| ($stxCurNames : $stxT))
| BinderInfo.default, false => pure curNames.back -- here `curNames.size == 1`
| BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT})
| BinderInfo.implicit, false => `(funBinder| {$curNames*})
| BinderInfo.instImplicit, _ =>
| BinderInfo.default, false => pure curNames.back -- here `curNames.size == 1`
| BinderInfo.implicit, true => `(funBinder| {$curNames* : $stxT})
| BinderInfo.implicit, false => `(funBinder| {$curNames*})
| BinderInfo.strictImplicit, true => `(funBinder| ⦃$curNames* : $stxT⦄)
| BinderInfo.strictImplicit, false => `(funBinder| ⦃$curNames*⦄)
| BinderInfo.instImplicit, _ =>
if usedDownstream then `(funBinder| [$curNames.back : $stxT]) -- here `curNames.size == 1`
else `(funBinder| [$stxT])
| _ , _ => unreachable!;
Expand All @@ -524,10 +526,11 @@ def delabForall : Delab :=
let prop ← try isProp e catch _ => false
let stxT ← withBindingDomain delab
let group ← match e.binderInfo with
| BinderInfo.implicit => `(bracketedBinderF|{$curNames* : $stxT})
| BinderInfo.implicit => `(bracketedBinderF|{$curNames* : $stxT})
| BinderInfo.strictImplicit => `(bracketedBinderF|⦃$curNames* : $stxT⦄)
-- here `curNames.size == 1`
| BinderInfo.instImplicit => `(bracketedBinderF|[$curNames.back : $stxT])
| _ =>
| BinderInfo.instImplicit => `(bracketedBinderF|[$curNames.back : $stxT])
| _ =>
-- heuristic: use non-dependent arrows only if possible for whole group to avoid
-- noisy mix like `(α : Type) → Type → (γ : Type) → ...`.
let dependent := curNames.any $ fun n => hasIdent n.getId stxBody
Expand Down
7 changes: 7 additions & 0 deletions tests/lean/strictImplicit.lean
@@ -0,0 +1,7 @@
def g {α : Type} (a : α) := a
def f {{α : Type}} (a : α) := a

#check g
#check f
#check g 1
#check f 1
4 changes: 4 additions & 0 deletions tests/lean/strictImplicit.lean.expected.out
@@ -0,0 +1,4 @@
g : ?m → ?m
f : ⦃α : Type⦄ → α → α
g 1 : Nat
f 1 : Nat

0 comments on commit 4cd7e35

Please sign in to comment.