From 0dda3a8c0258f0d40b1fa285258e67d5b27c1198 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 1 Dec 2022 06:11:48 -0800 Subject: [PATCH] fix: include instance implicits that depend on `outParams` at `outParamsPos` This fixes the fix for #1852 --- src/Lean/Class.lean | 13 +++++++++---- src/Lean/Meta/SynthInstance.lean | 29 +++++++++++++++++------------ tests/lean/run/1852.lean | 21 +++++++++++++++++++++ 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/Lean/Class.lean b/src/Lean/Class.lean index 38edc2a5e0a..009bdc3a273 100644 --- a/src/Lean/Class.lean +++ b/src/Lean/Class.lean @@ -89,14 +89,19 @@ def hasOutParams (env : Environment) (declName : Name) : Bool := private partial def checkOutParam (i : Nat) (outParamFVarIds : Array FVarId) (outParams : Array Nat) (type : Expr) : Except String (Array Nat) := match type with | .forallE _ d b bi => - if d.isOutParam then + let addOutParam (_ : Unit) := let fvarId := { name := Name.mkNum `_fvar outParamFVarIds.size } let fvar := mkFVar fvarId let b := b.instantiate1 fvar checkOutParam (i+1) (outParamFVarIds.push fvarId) (outParams.push i) b - /- See issue #1852 for a motivation for `!bi.isInstImplicit` -/ - else if !bi.isInstImplicit && d.hasAnyFVar fun fvarId => outParamFVarIds.contains fvarId then - Except.error s!"invalid class, parameter #{i+1} depends on `outParam`, but it is not an `outParam`" + if d.isOutParam then + addOutParam () + else if d.hasAnyFVar fun fvarId => outParamFVarIds.contains fvarId then + if bi.isInstImplicit then + /- See issue #1852 for a motivation for `bi.isInstImplicit` -/ + addOutParam () + else + Except.error s!"invalid class, parameter #{i+1} depends on `outParam`, but it is not an `outParam`" else checkOutParam (i+1) outParamFVarIds outParams b | _ => return outParams diff --git a/src/Lean/Meta/SynthInstance.lean b/src/Lean/Meta/SynthInstance.lean index 7c04e2f4db6..261c39c6a21 100644 --- a/src/Lean/Meta/SynthInstance.lean +++ b/src/Lean/Meta/SynthInstance.lean @@ -624,15 +624,19 @@ private def preprocessLevels (us : List Level) : MetaM (List Level × Bool) := d r := r.push u return (r.toList, modified) -private partial def preprocessArgs (type : Expr) (i : Nat) (args : Array Expr) : MetaM (Array Expr) := do +private partial def preprocessArgs (type : Expr) (i : Nat) (args : Array Expr) (outParamsPos : Array Nat) : MetaM (Array Expr) := do if h : i < args.size then let type ← whnf type match type with - | Expr.forallE _ d b _ => do + | .forallE _ d b _ => do let arg := args.get ⟨i, h⟩ - let arg ← if d.isOutParam then mkFreshExprMVar d else pure arg + /- + We should not simply check `d.isOutParam`. See `checkOutParam` and issue #1852. + If an instance implicit argument depends on an `outParam`, it is treated as an `outParam` too. + -/ + let arg ← if outParamsPos.contains i then mkFreshExprMVar d else pure arg let args := args.set ⟨i, h⟩ arg - preprocessArgs (b.instantiate1 arg) (i+1) args + preprocessArgs (b.instantiate1 arg) (i+1) args outParamsPos | _ => throwError "type class resolution failed, insufficient number of arguments" -- TODO improve error message else @@ -641,15 +645,16 @@ private partial def preprocessArgs (type : Expr) (i : Nat) (args : Array Expr) : private def preprocessOutParam (type : Expr) : MetaM Expr := forallTelescope type fun xs typeBody => do match typeBody.getAppFn with - | c@(Expr.const constName _) => + | c@(Expr.const declName _) => let env ← getEnv - if !hasOutParams env constName then - return type - else - let args := typeBody.getAppArgs - let cType ← inferType c - let args ← preprocessArgs cType 0 args - mkForallFVars xs (mkAppN c args) + if let some outParamsPos := getOutParamPositions? env declName then + unless outParamsPos.isEmpty do + let args := typeBody.getAppArgs + let cType ← inferType c + trace[Meta.debug] "{declName} : {outParamsPos}" + let args ← preprocessArgs cType 0 args outParamsPos + return (← mkForallFVars xs (mkAppN c args)) + return type | _ => return type diff --git a/tests/lean/run/1852.lean b/tests/lean/run/1852.lean index 868f079fdb4..970ca3c2555 100644 --- a/tests/lean/run/1852.lean +++ b/tests/lean/run/1852.lean @@ -3,3 +3,24 @@ class foo (F : Type) where class foobar (F : outParam Type) [foo F] where bar : F + +class C (α : Type) where + val : α + +class D (α : Type) (β : outParam Type) [C β] where + val1 : α + val2 : β := C.val + +instance : C String where + val := "hello" + +instance : C Nat where + val := 42 + +instance : D Nat String where + val1 := 37 + +def f (α : Type) {β : Type} {_ : C β} [D α β] : α × β := + (D.val1, D.val2 α) + +example : f Nat = (37, "hello") := rfl