Skip to content

Commit

Permalink
feat: reorder tc subgoals according to out-params
Browse files Browse the repository at this point in the history
  • Loading branch information
gebner committed Apr 10, 2023
1 parent 25fe723 commit 4544443
Show file tree
Hide file tree
Showing 19 changed files with 233 additions and 80 deletions.
1 change: 0 additions & 1 deletion src/Lean/Elab/Structure.lean
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,6 @@ private partial def mkCoercionToCopiedParent (levelParams : List Name) (params :
let binfo := if view.isClass && isClass env parentStructName then BinderInfo.instImplicit else BinderInfo.default
withLocalDeclD `self structType fun source => do
let mut declType ← instantiateMVars (← mkForallFVars params (← mkForallFVars #[source] parentType))
declType := mkOutParamArgsImplicit declType
if view.isClass && isClass env parentStructName then
declType := setSourceInstImplicit declType
declType := declType.inferImplicit params.size true
Expand Down
103 changes: 102 additions & 1 deletion src/Lean/Meta/Instances.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@ Authors: Leonardo de Moura
import Lean.ScopedEnvExtension
import Lean.Meta.GlobalInstances
import Lean.Meta.DiscrTree
import Lean.Meta.CollectMVars

namespace Lean.Meta

register_builtin_option synthInstance.checkSynthOrder : Bool := {
defValue := true
descr := "check instances do not introduce metavariable in non-out-params"
}

/-
Note: we want to use iota reduction when indexing instaces. Otherwise,
we cannot elaborate examples such as
Expand Down Expand Up @@ -38,6 +44,8 @@ structure InstanceEntry where
val : Expr
priority : Nat
globalName? : Option Name := none
/-- The order in which the instance's arguments are to be synthesized. -/
synthOrder : Array Nat
/-
We store the attribute kind to be able to implement the API `getInstanceAttrKind`.
TODO: add better support for retrieving the `attrKind` of any attribute.
Expand Down Expand Up @@ -88,11 +96,103 @@ private def mkInstanceKey (e : Expr) : MetaM (Array InstanceKey) := do
let (_, _, type) ← forallMetaTelescopeReducing type
DiscrTree.mkPath type

/--
Compute the order the arguments of `inst` should by synthesized.
The synthesization order makes sure that all mvars in non-out-params of the
subgoals are assigned before we try to synthesize it. Otherwise it goes left
to right.
For example:
- `[Add α] [Zero α] : Foo α` returns `[0, 1]`
- `[Mul A] [Mul B] [MulHomClass F A B] : FunLike F A B` returns `[2, 0, 1]`
(because A B are out-params and are only filled in once we synthesize 2)
(The type of `inst` must not contain mvars.)
-/
partial def computeSynthOrder (inst : Expr) : MetaM (Array Nat) :=
withReducible do
let instTy ← inferType inst

-- Gets positions of all out- and semi-out-params of `classTy`
-- (where `classTy` is e.g. something like `Inhabited Nat`)
let rec getSemiOutParamPositionsOf (classTy : Expr) : MetaM (Array Nat) := do
if let .const className .. := classTy.getAppFn then
forallTelescopeReducing (← inferType classTy.getAppFn) fun args _ => do
let mut pos := (getOutParamPositions? (← getEnv) className).getD #[]
for arg in args, i in [:args.size] do
if (← inferType arg).isAppOf ``semiOutParam then
pos := pos.push i
return pos
else
return #[]

-- Create both metavariables and free variables for the instance args
-- We will successively pick subgoals where all non-out-params have been
-- assigned already. After picking such a "ready" subgoal, we assign the
-- mvars in its out-params by the corresponding fvars.
let (argMVars, argBIs, ty) ← forallMetaTelescopeReducing instTy
let ty ← whnf ty
forallTelescopeReducing instTy fun argVars _ => do

-- Assigns all mvars from argMVars in e by the corresponding fvar.
let rec assignMVarsIn (e : Expr) : MetaM Unit := do
for mvarId in ← getMVars e do
if let some i := argMVars.findIdx? (·.mvarId! == mvarId) then
mvarId.assign argVars[i]!
assignMVarsIn (← inferType (.mvar mvarId))

-- We start by assigning all metavariables in non-out-params of the return value.
-- These are assumed to not be mvars during TC search (or at least not assignable)
let tyOutParams ← getSemiOutParamPositionsOf ty
let tyArgs := ty.getAppArgs
for tyArg in tyArgs, i in [:tyArgs.size] do
unless tyOutParams.contains i do assignMVarsIn tyArg

-- Now we successively try to find the next ready subgoal, where all
-- non-out-params are mvar-free.
let mut synthed := #[]
let mut toSynth := List.range argMVars.size |>.filter (argBIs[·]! == .instImplicit) |>.toArray
while !toSynth.isEmpty do
let next? ← toSynth.findM? fun i => do
forallTelescopeReducing (← instantiateMVars (← inferType argMVars[i]!)) fun _ argTy => do
let argTy ← whnf argTy
let argOutParams ← getSemiOutParamPositionsOf argTy
let argTyArgs := argTy.getAppArgs
for i in [:argTyArgs.size], argTyArg in argTyArgs do
if !argOutParams.contains i && argTyArg.hasExprMVar then
return false
return true
let next ←
match next? with
| some next => pure next
| none =>
if synthInstance.checkSynthOrder.get (← getOptions) then
let typeLines := ("" : MessageData).joinSep <| Array.toList <| ← toSynth.mapM fun i => do
let ty ← instantiateMVars (← inferType argMVars[i]!)
return indentExpr (ty.setPPExplicit true)
logError m!"cannot find synthesization order for instance {inst} with type{indentExpr instTy}\nall remaining arguments have metavariables:{typeLines}"
pure toSynth[0]!
synthed := synthed.push next
toSynth := toSynth.filter (· != next)
assignMVarsIn (← inferType argMVars[next]!)
assignMVarsIn argMVars[next]!

if synthInstance.checkSynthOrder.get (← getOptions) then
let ty ← instantiateMVars ty
if ty.hasExprMVar then
logError m!"instance does not provide concrete values for (semi-)out-params{indentExpr (ty.setPPExplicit true)}"

trace[Meta.synthOrder] "synthesizing the arguments of {inst} in the order {synthed}:{("" : MessageData).joinSep (← synthed.mapM fun i => return indentExpr (← inferType argVars[i]!)).toList}"

return synthed

def addInstance (declName : Name) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do
let c ← mkConstWithLevelParams declName
let keys ← mkInstanceKey c
addGlobalInstance declName attrKind
instanceExtension.add { keys := keys, val := c, priority := prio, globalName? := declName, attrKind } attrKind
let synthOrder ← computeSynthOrder c
instanceExtension.add { keys, val := c, priority := prio, globalName? := declName, attrKind, synthOrder } attrKind

builtin_initialize
registerBuiltinAttribute {
Expand Down Expand Up @@ -171,6 +271,7 @@ builtin_initialize
unless kind == AttributeKind.global do throwError "invalid attribute 'default_instance', must be global"
discard <| addDefaultInstance declName prio |>.run {} {}
}
registerTraceClass `Meta.synthOrder

def getDefaultInstancesPriorities [Monad m] [MonadEnv m] : m PrioritySet :=
return defaultInstanceExtension.getState (← getEnv) |>.priorities
Expand Down
88 changes: 47 additions & 41 deletions src/Lean/Meta/SynthInstance.lean
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,16 @@ namespace SynthInstance
def getMaxHeartbeats (opts : Options) : Nat :=
synthInstance.maxHeartbeats.get opts * 1000

builtin_initialize inferTCGoalsRLAttr : TagAttribute ←
registerTagAttribute `infer_tc_goals_rl "instruct type class resolution procedure to solve goals from right to left for this instance"

def hasInferTCGoalsRLAttribute (env : Environment) (constName : Name) : Bool :=
inferTCGoalsRLAttr.hasTag env constName
structure Instance where
val : Expr
synthOrder : Array Nat
deriving Inhabited

structure GeneratorNode where
mvar : Expr
key : Expr
mctx : MetavarContext
instances : Array Expr
instances : Array Instance
currInstanceIdx : Nat
deriving Inhabited

Expand Down Expand Up @@ -191,7 +190,7 @@ instance : Inhabited (SynthM α) where
default := fun _ _ => default

/-- Return globals and locals instances that may unify with `type` -/
def getInstances (type : Expr) : MetaM (Array Expr) := do
def getInstances (type : Expr) : MetaM (Array Instance) := do
-- We must retrieve `localInstances` before we use `forallTelescopeReducing` because it will update the set of local instances
let localInstances ← getLocalInstances
forallTelescopeReducing type fun _ type => do
Expand All @@ -205,16 +204,27 @@ def getInstances (type : Expr) : MetaM (Array Expr) := do
-- Most instances have default priority.
let result := result.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority
let erasedInstances ← getErasedInstances
let result ← result.filterMapM fun e => match e.val with
let mut result ← result.filterMapM fun e => match e.val with
| Expr.const constName us =>
if erasedInstances.contains constName then
return none
else
return some <| e.val.updateConst! (← us.mapM (fun _ => mkFreshLevelMVar))
return some {
val := e.val.updateConst! (← us.mapM (fun _ => mkFreshLevelMVar))
synthOrder := e.synthOrder
}
| _ => panic! "global instance is not a constant"
let result := localInstances.foldl (init := result) fun (result : Array Expr) linst =>
if linst.className == className then result.push linst.fvar else result
trace[Meta.synthInstance.instances] result
for linst in localInstances do
if linst.className == className then
let synthOrder ← forallTelescopeReducing (← inferType linst.fvar) fun xs _ => do
if xs.isEmpty then return #[]
let mut order := #[]
for i in [:xs.size], x in xs do
if (← getFVarLocalDecl x).binderInfo == .instImplicit then
order := order.push i
return order
result := result.push { val := linst.fvar, synthOrder }
trace[Meta.synthInstance.instances] result.map (·.val)
return result

def mkGeneratorNode? (key mvar : Expr) : MetaM (Option GeneratorNode) := do
Expand Down Expand Up @@ -275,25 +285,6 @@ structure SubgoalsResult where
instVal : Expr
instTypeBody : Expr

private partial def getSubgoalsAux (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array Expr)
: Array Expr → Nat → List Expr → Expr → Expr → MetaM SubgoalsResult
| args, j, subgoals, instVal, Expr.forallE _ d b bi => do
let d := d.instantiateRevRange j args.size args
let mvarType ← mkForallFVars xs d
let mvar ← mkFreshExprMVarAt lctx localInsts mvarType
let arg := mkAppN mvar xs
let instVal := mkApp instVal arg
let subgoals := if bi.isInstImplicit then mvar::subgoals else subgoals
let args := args.push (mkAppN mvar xs)
getSubgoalsAux lctx localInsts xs args j subgoals instVal b
| args, j, subgoals, instVal, type => do
let type := type.instantiateRevRange j args.size args
let type ← whnf type
if type.isForall then
getSubgoalsAux lctx localInsts xs args args.size subgoals instVal type
else
return ⟨subgoals, instVal, type⟩

/--
`getSubgoals lctx localInsts xs inst` creates the subgoals for the instance `inst`.
The subgoals are in the context of the free variables `xs`, and
Expand All @@ -309,21 +300,36 @@ private partial def getSubgoalsAux (lctx : LocalContext) (localInsts : LocalInst
metavariables that are instance implicit arguments, and the expressions:
- `inst (?m_1 xs) ... (?m_n xs)` (aka `instVal`)
- `B (?m_1 xs) ... (?m_n xs)` -/
def getSubgoals (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array Expr) (inst : Expr) : MetaM SubgoalsResult := do
let instType ← inferType inst
let result ← getSubgoalsAux lctx localInsts xs #[] 0 [] inst instType
if let .const constName _ := inst.getAppFn then
let env ← getEnv
if hasInferTCGoalsRLAttribute env constName then
return result
return { result with subgoals := result.subgoals.reverse }
def getSubgoals (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array Expr) (inst : Instance) : MetaM SubgoalsResult := do
let mut instVal := inst.val
let mut instType ← inferType instVal
let mut mvars := #[]
let mut subst := #[]
repeat do
if let .forallE _ d b _ := instType then
let d := d.instantiateRev subst
let mvar ← mkFreshExprMVarAt lctx localInsts (← mkForallFVars xs d)
subst := subst.push (mkAppN mvar xs)
instVal := mkApp instVal (mkAppN mvar xs)
instType := b
mvars := mvars.push mvar
else
instType ← whnf (instType.instantiateRev subst)
instVal := instVal.instantiateRev subst
subst := #[]
unless instType.isForall do break
return {
instVal := instVal.instantiateRev subst
instTypeBody := instType.instantiateRev subst
subgoals := inst.synthOrder.map (mvars[·]!) |>.toList
}

/--
Try to synthesize metavariable `mvar` using the instance `inst`.
Remark: `mctx` is set using `withMCtx`.
If it succeeds, the result is a new updated metavariable context and a new list of subgoals.
A subgoal is created for each instance implicit parameter of `inst`. -/
def tryResolve (mvar : Expr) (inst : Expr) : MetaM (Option (MetavarContext × List Expr)) := do
def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext × List Expr)) := do
let mvarType ← inferType mvar
let lctx ← getLCtx
let localInsts ← getLocalInstances
Expand Down Expand Up @@ -518,7 +524,7 @@ def generate : SynthM Unit := do
let mvar := gNode.mvar
discard do withMCtx mctx do
withTraceNode `Meta.synthInstance
(return m!"{exceptOptionEmoji ·} apply {inst} to {← instantiateMVars (← inferType mvar)}") do
(return m!"{exceptOptionEmoji ·} apply {inst.val} to {← instantiateMVars (← inferType mvar)}") do
modifyTop fun gNode => { gNode with currInstanceIdx := idx }
if let some (mctx, subgoals) ← tryResolve mvar inst then
consume { key, mvar, subgoals, mctx, size := 0 }
Expand Down
2 changes: 1 addition & 1 deletion src/lake
26 changes: 10 additions & 16 deletions src/library/constructions/projection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ static bool is_prop(expr type) {
return is_sort(type) && is_zero(sort_level(type));
}

extern "C" object * lean_mk_outparam_args_implicit(object * n);
expr mk_outparam_args_implicit(expr const & type) { return expr(lean_mk_outparam_args_implicit(type.to_obj_arg())); }

environment mk_projections(environment const & env, name const & n, buffer<name> const & proj_names, bool inst_implicit) {
local_ctx lctx;
name_generator ngen = mk_constructions_name_generator();
Expand All @@ -48,38 +45,35 @@ environment mk_projections(environment const & env, name const & n, buffer<name>
throw exception(sstream() << "projection generation, '" << n << "' does not have a single constructor");
constant_info cnstr_info = env.get(head(ind_val.get_cnstrs()));
expr cnstr_type = cnstr_info.get_type();
expr cnstr_type_norm = mk_outparam_args_implicit(cnstr_type);
// The binder inference is quite messy since it is using `mk_outparam_args_implicit` and `infer_implicit_params`.
// TODO: cleanup
bool is_predicate = is_prop(ind_info.get_type());
names lvl_params = ind_info.get_lparams();
levels lvls = lparams_to_levels(lvl_params);
buffer<expr> params; // datatype parameters
expr cnstr_type_orig = cnstr_type; // we use the original type before `mk_outparam_args_implicit` to get the original binder info
expr cnstr_type_orig = cnstr_type;
for (unsigned i = 0; i < nparams; i++) {
if (!is_pi(cnstr_type_norm))
if (!is_pi(cnstr_type))
throw_ill_formed(n);
lean_assert(is_pi(cnstr_type_orig));
auto bi = binding_info(cnstr_type_norm);
auto bi = binding_info(cnstr_type);
auto bi_orig = binding_info(cnstr_type_orig);
auto type = binding_domain(cnstr_type_norm);
auto type = binding_domain(cnstr_type);
auto type_orig = binding_domain(cnstr_type_orig);
if (!is_inst_implicit(bi_orig) && !is_class_out_param(type_orig)) {
// We reset implicit binders in favor of having them inferred by `infer_implicit_params` later IF
// 1. The original binder before `mk_outparam_args_implicit` is not an instance implicit.
// 2. It is not originally an outparam. Outparams must be implicit.
bi = mk_binder_info();
}
expr param = lctx.mk_local_decl(ngen, binding_name(cnstr_type_norm), type, bi);
cnstr_type_norm = instantiate(binding_body(cnstr_type_norm), param);
expr param = lctx.mk_local_decl(ngen, binding_name(cnstr_type), type, bi);
cnstr_type = instantiate(binding_body(cnstr_type), param);
cnstr_type_orig = binding_body(cnstr_type_orig);
params.push_back(param);
}
expr C_A = mk_app(mk_constant(n, lvls), params);
binder_info c_bi = inst_implicit ? mk_inst_implicit_binder_info() : mk_binder_info();
expr c = lctx.mk_local_decl(ngen, name("self"), C_A, c_bi);
buffer<expr> cnstr_type_args; // arguments that are not parameters
expr it = cnstr_type_norm;
expr it = cnstr_type;
while (is_pi(it)) {
expr local = lctx.mk_local_decl(ngen, binding_name(it), binding_domain(it), binding_info(it));
cnstr_type_args.push_back(local);
Expand All @@ -88,10 +82,10 @@ environment mk_projections(environment const & env, name const & n, buffer<name>
unsigned i = 0;
environment new_env = env;
for (name const & proj_name : proj_names) {
if (!is_pi(cnstr_type_norm))
if (!is_pi(cnstr_type))
throw exception(sstream() << "generating projection '" << proj_name << "', '"
<< n << "' does not have sufficient data");
expr result_type = consume_type_annotations(binding_domain(cnstr_type_norm));
expr result_type = consume_type_annotations(binding_domain(cnstr_type));
if (is_predicate && !type_checker(new_env, lctx).is_prop(result_type)) {
throw exception(sstream() << "failed to generate projection '" << proj_name << "' for '" << n << "', "
<< "type is an inductive predicate, but field is not a proposition");
Expand All @@ -110,7 +104,7 @@ environment mk_projections(environment const & env, name const & n, buffer<name>
new_env = set_reducible(new_env, proj_name, reducible_status::Reducible, true);
new_env = save_projection_info(new_env, proj_name, cnstr_info.get_name(), nparams, i, inst_implicit);
expr proj = mk_app(mk_app(mk_constant(proj_name, lvls), params), c);
cnstr_type_norm = instantiate(binding_body(cnstr_type_norm), proj);
cnstr_type = instantiate(binding_body(cnstr_type), proj);
i++;
}
return new_env;
Expand Down
7 changes: 3 additions & 4 deletions tests/lean/1007.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ class Trait (X : Type u) where

attribute [reducible] Trait.R

class SemiInner (X : Type u) (R : Type v) where
class SemiInner (X : Type u) (R : outParam (Type v)) where
semiInner : X → X → R

@[reducible] instance (X) (R : Type u) [SemiInner X R] : Trait X := ⟨R⟩

class SemiHilbert (X) (R : Type u) [Vec R] extends Vec X, SemiInner X R
class SemiHilbert (X) (R : outParam (Type u)) [Vec R] [Vec X] extends SemiInner X R

@[infer_tc_goals_rl]
instance (X R) [Trait X] [Vec R] [SemiHilbert X R] (ι : Type v) : SemiHilbert (ι → X) R := sorry
instance (X R) [Trait X] [Vec R] [Vec X] [SemiHilbert X R] (ι : Type v) : SemiHilbert (ι → X) R := sorry
instance : SemiHilbert ℝ ℝ := sorry

--------------
Expand Down
Loading

0 comments on commit 4544443

Please sign in to comment.