Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions src/Lean/Compiler/LCNF/SpecInfo.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ inductive SpecParamInfo where
/--
A parameter that is an type class instance (or an arrow that produces a type class instance),
and is fixed in recursive declarations. By default, Lean always specializes this kind of argument.
If the `weak` parameter is set we only specialize for this parameter iff another parameter causes
specialization as well.
-/
| fixedInst
| fixedInst (weak : Bool)
/--
A parameter that is a function and is fixed in recursive declarations. If the user tags a declaration
with `@[specialize]` without specifying which arguments should be specialized, Lean will specialize
Expand Down Expand Up @@ -49,14 +51,15 @@ namespace SpecParamInfo

@[inline]
def causesSpecialization : SpecParamInfo → Bool
| .fixedInst | .fixedHO | .user => true
| .fixedNeutral | .other => false
| .fixedInst false | .fixedHO | .user => true
| .fixedInst true | .fixedNeutral | .other => false

end SpecParamInfo

instance : ToMessageData SpecParamInfo where
toMessageData
| .fixedInst => "I"
| .fixedInst false => "I"
| .fixedInst true => "W"
| .fixedHO => "H"
| .fixedNeutral => "N"
| .user => "U"
Expand Down Expand Up @@ -130,6 +133,18 @@ private def isNoSpecType (env : Environment) (type : Expr) : Bool :=
else
false

/--
Return `true` if `type` is a type tagged with `@[weak_specialize]` or an arrow that produces this kind of type.
-/
private def isWeakSpecType (env : Environment) (type : Expr) : Bool :=
match type with
| .forallE _ _ b _ => isWeakSpecType env b
| _ =>
if let .const declName _ := type.getAppFn then
hasWeakSpecializeAttribute env declName
else
false

/-!
*Note*: `fixedNeutral` must have forward dependencies.

Expand Down Expand Up @@ -160,7 +175,7 @@ See comment at `.fixedNeutral`.
private def hasFwdDeps (decl : Decl .pure) (paramsInfo : Array SpecParamInfo) (j : Nat) : Bool := Id.run do
let param := decl.params[j]!
for h : k in (j+1)...decl.params.size do
if paramsInfo[k]!.causesSpecialization then
if paramsInfo[k]!.causesSpecialization || paramsInfo[k]! matches .fixedInst .. then
let param' := decl.params[k]
if param'.type.containsFVar param.fvarId then
return true
Expand Down Expand Up @@ -199,7 +214,7 @@ def computeSpecEntries (decls : Array (Decl .pure)) (autoSpecialize : Name → O
else if isTypeFormerType param.type then
pure .fixedNeutral
else if (← isArrowClass? param.type).isSome then
pure .fixedInst
pure (.fixedInst (weak := isWeakSpecType (← getEnv) param.type))
/-
Recall that if `specArgs? == some #[]`, then user annotated function with `@[specialize]`, but did not
specify which arguments must be specialized besides instances. In this case, we try to specialize
Expand Down
7 changes: 4 additions & 3 deletions src/Lean/Compiler/LCNF/Specialize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def collect (paramsInfo : Array SpecParamInfo) (args : Array (Arg .pure)) :
match paramInfo with
| .other =>
argMask := argMask.push none
| .fixedNeutral | .user | .fixedInst | .fixedHO =>
| .fixedNeutral | .user | .fixedInst .. | .fixedHO =>
argMask := argMask.push (some arg)
Closure.collectArg arg
return argMask
Expand Down Expand Up @@ -257,7 +257,8 @@ def shouldSpecialize (specEntry : SpecEntry) (args : Array (Arg .pure)) : Specia
match paramInfo with
| .other => pure ()
| .fixedNeutral => pure () -- If we want to monomorphize types such as `Array`, we need to change here
| .fixedInst | .user => if ← isGround arg then return true
| .fixedInst true => pure () -- weak: don't trigger specialization on its own
| .fixedInst false | .user => if ← isGround arg then return true
| .fixedHO => if ← hoCheck arg then return true

return false
Expand Down Expand Up @@ -509,7 +510,7 @@ def updateLocalSpecParamInfo : SpecializeM Unit := do
for entry in infos do
if let some mask := (← get).parentMasks[entry.declName]? then
let maskInfo info :=
mask.zipWith info (f := fun b i => if !b && i.causesSpecialization then .other else i)
mask.zipWith info (f := fun b i => if !b && (i.causesSpecialization || i matches .fixedInst ..) then .other else i)
let entry := { entry with paramsInfo := maskInfo entry.paramsInfo }
modify fun s => {
s with
Expand Down
14 changes: 14 additions & 0 deletions src/Lean/Compiler/Specialize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ Marks a definition to never be specialized during code generation.
builtin_initialize nospecializeAttr : TagAttribute ←
registerTagAttribute `nospecialize "mark definition to never be specialized"

/--
Marks a type for weak specialization: Parameters of this type are only specialized when
another argument already triggers specialization. Unlike `@[nospecialize]`, if specialization
happens for other reasons, parameters of this type will participate in the specialization
rather than being ignored.
-/
@[builtin_doc]
builtin_initialize weakSpecializeAttr : TagAttribute ←
registerTagAttribute `weak_specialize
"mark type for weak specialization: instances are only specialized when another argument already triggers specialization"

private def elabSpecArgs (declName : Name) (args : Array Syntax) : MetaM (Array Nat) := do
if args.isEmpty then return #[]
let info ← getConstInfo declName
Expand Down Expand Up @@ -82,4 +93,7 @@ def hasSpecializeAttribute (env : Environment) (declName : Name) : Bool :=
def hasNospecializeAttribute (env : Environment) (declName : Name) : Bool :=
nospecializeAttr.hasTag env declName

def hasWeakSpecializeAttribute (env : Environment) (declName : Name) : Bool :=
weakSpecializeAttr.hasTag env declName

end Lean.Compiler
2 changes: 1 addition & 1 deletion stage0/src/stdlib_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ options get_default_options() {
opts = opts.update({"debug", "terminalTacticsAsSorry"}, false);
// switch to `true` for ABI-breaking changes affecting meta code;
// see also next option!
opts = opts.update({"interpreter", "prefer_native"}, false);
opts = opts.update({"interpreter", "prefer_native"}, true);
// switch to `false` when enabling `prefer_native` should also affect use
// of built-in parsers in quotations; this is usually the case, but setting
// both to `true` may be necessary for handling non-builtin parsers with
Expand Down
34 changes: 34 additions & 0 deletions tests/elab/weak_specialize.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module

public section

/-!
This test checks that weak_specialize:
- doesn't trigger specialization on its own
- does get included in specialization if other parameters provoke specialization
- does not affect the semantics of nospecialize parameters
-/

@[weak_specialize] class MyWeak (α : Type) where
val : α

/-- -/
#guard_msgs in
set_option trace.Compiler.specialize.info true in
def weakOnly [MyWeak α] (x : α) : α := x

class MyStrong (α : Type) where
op : α → α

/-- trace: [Compiler.specialize.info] weakAndStrong [N, W, I, O] -/
#guard_msgs in
set_option trace.Compiler.specialize.info true in
def weakAndStrong [MyWeak α] [MyStrong α] (x : α) : α := MyStrong.op x

@[nospecialize] class MyNoSpec (α : Type) where
val : α

/-- trace: [Compiler.specialize.info] nospecAndStrong [N, O, I, O] -/
#guard_msgs in
set_option trace.Compiler.specialize.info true in
def nospecAndStrong [MyNoSpec α] [MyStrong α] (x : α) : α := MyStrong.op x
Loading