Skip to content

Commit

Permalink
perf: improve to_additive performance (#1060)
Browse files Browse the repository at this point in the history
Using
```
def Ones : ℕ → Q(Nat)
| 0     => q(1)
| (n+1) => q($(Ones n) + $(Ones n))
```
The new `to_additive` takes `45ms` on `Ones 500` (higher gives stack overflows)
The old `to_additive` takes `13794ms` on `Ones 17` (exponential in the argument)
There is still one issue workaround by using `transform` in `etaExpand`.

* Remove `replaceRecM` and `replaceRecMeta` that are exponentially slow
* Remove `replaceRecTraversal` because its interface is less convenient than `replaceRec`



Co-authored-by: Scott Morrison <scott.morrison@gmail.com>
  • Loading branch information
fpvandoorn and semorrison committed Dec 16, 2022
1 parent f976b5e commit 8b308fa
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 145 deletions.
38 changes: 0 additions & 38 deletions Mathlib/Lean/Expr/ReplaceRec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -31,42 +31,4 @@ def replaceRec (f? : (Expr → Expr) → Expr → Option Expr) : Expr → Expr :
| some x => x
| none => traverseChildren (M := Id) r e

/-- replaceRec under a monad. -/
partial def replaceRecM [Monad M] (f? : (Expr → M Expr) → Expr → M (Option Expr)) (e : Expr) :
M Expr := do
match ← f? (replaceRecM f?) e with
| some x => return x
| none => traverseChildren (replaceRecM f?) e

/-- Similar to `replaceRecM` except that bound variables are instantiated with free variables
(like `Lean.Meta.transform`).
This means that MetaM tactics can be used inside the replacement function.
If you don't need recursive calling,
you should prefer using `Lean.Meta.transform` because it also caches visits.
-/
partial def replaceRecMeta [Monad M] [MonadLiftT MetaM M] [MonadControlT MetaM M]
(f? : (Expr → M Expr) → Expr → M (Option Expr)) (e : Expr) : M Expr := do
match ← f? (replaceRecMeta f?) e with
| some x => return x
| none => Lean.Meta.traverseChildren (replaceRecMeta f?) e

/-- A version of `Expr.replace` where we can use recursive calls even if we replace a subexpression.
When reaching a subexpression `e` we call `traversal e` to see if we want to do anything with this
expression. If `traversal e = none` we proceed to the children of `e`. If
`traversal e = some (#[e₁, ..., eₙ], g)`, we first recursively apply this function to
`#[e₁, ..., eₙ]` to get new expressions `#[f₁, ..., fₙ]`.
Then we replace `e` by `g [f₁, ..., fₙ]`.
Important: In order for this function to terminate, the `[e₁, ..., eₙ]` must all be smaller than
`e` according to some measure (and this measure must also be strictly decreasing on the w.r.t.
the structural subterm relation).
-/
def replaceRecTraversal (traversal : Expr → Option (Array Expr × (Array Expr → Expr))) :
Expr → Expr :=
replaceRec fun r e ↦
match traversal e with
| none => none
| some (get, set) => some <| set <| .map r <| get

end Lean.Expr
175 changes: 81 additions & 94 deletions Mathlib/Tactic/ToAdditive.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro, Yury Kudryashov, Floris van Doorn, Jon Eugster
Ported by: E.W.Ayers
-/
import Mathlib.Init.Data.Nat.Basic
import Mathlib.Data.String.Defs
import Mathlib.Data.KVMap
import Mathlib.Lean.Expr.ReplaceRec
Expand Down Expand Up @@ -120,12 +121,6 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ←
| _ => throwUnsupportedSyntax
return ids.toList }

/-- Gets the set of arguments that should be ignored for the given name
(according to `@[to_additive_ignore_args ...]`).
This value is used in `additiveTestAux`. -/
def ignore [Functor M] [MonadEnv M]: Name → M (Option (List Nat))
| n => (ignoreArgsAttr.find? · n) <$> getEnv

/--
An attribute that stores all the declarations that needs their arguments reordered when
applying `@[to_additive]`. Currently, we only support swapping consecutive arguments.
Expand All @@ -151,15 +146,6 @@ initialize reorderAttr : NameMapExtension (List Nat) ←
pure <| Array.toList <| ids.map (·.1.isNatLit?.get!)
| _, _ => throwUnsupportedSyntax }

/-- Get the reorder list (defined using `@[to_additive_reorder ...]`) for the given declaration. -/
def getReorder [Functor M] [MonadEnv M] : Name → M (List Nat)
| n => (reorderAttr.find? · n |>.getD []) <$> getEnv

/-- Given a declaration name and an argument index, determines whether this index
should be swapped with the next one. -/
def shouldReorder [Functor M] [MonadEnv M] : Name → Nat → M Bool
| n, i => (i ∈ ·) <$> getReorder n

/--
An attribute that is automatically added to declarations tagged with `@[to_additive]`, if needed.
Expand Down Expand Up @@ -190,14 +176,6 @@ initialize relevantArgAttr : NameMapExtension Nat ←
| _, `(attr| to_additive_relevant_arg $id) => pure <| id.1.isNatLit?.get!.pred
| _, _ => throwUnsupportedSyntax }

/-- Given a declaration name and an argument index, determines whether it
is relevant. This is used in `applyReplacementFun` where more detail on what it
does can be found. -/
def isRelevant [Monad M] [MonadEnv M] (n : Name) (i : Nat) : M Bool := do
match relevantArgAttr.find? (← getEnv) n with
| some j => return i == j
| none => return i == 0

/--
An attribute that stores all the declarations that deal with numeric literals on fixed types.
* `@[to_additive_fixed_numeral]` should be added to all functions that take a numeral as argument
Expand Down Expand Up @@ -231,12 +209,6 @@ def insertTranslation (src tgt : Name) : CoreM Unit := do
modifyEnv (ToAdditive.translations.addEntry · (src, tgt))
trace[to_additive] "Added translation {src} ↦ {tgt}"

/-- Get whether or not the replace-all flag is set. If this is true, then the
additiveTest heuristic is not used and all instances of multiplication are replaced.
You can enable this with `@[to_additive!]`-/
def replaceAll [Functor M] [MonadOptions M] : M Bool :=
(·.getBool `to_additive.replaceAll) <$> getOptions

/-- `Config` is the type of the arguments that can be provided to `to_additive`. -/
structure Config : Type where
/-- Replace all multiplicative declarations, do not use the heuristic. -/
Expand All @@ -263,21 +235,22 @@ variable [Monad M] [MonadOptions M] [MonadEnv M]

/-- Auxilliary function for `additiveTest`. The bool argument *only* matters when applied
to exactly a constant. -/
private def additiveTestAux : Bool → Expr → M Bool
| b, .const n _ => return b || (findTranslation? (← getEnv) n).isSome
| _, .app e a => do
if ← additiveTestAux true e then
private def additiveTestAux (findTranslation? : Name → Option Name)
(ignore : Name → Option (List ℕ)) : Bool → Expr → Bool := visit where
visit : Bool → Expr → Bool
| b, .const n _ => b || (findTranslation? n).isSome
| _, .app e a => Id.run do
if visit true e then
return true
if let some n := e.getAppFn.constName? then
if let some l ignore n then
if let some l := ignore n then
if e.getAppNumArgs + 1 ∈ l then
return true
additiveTestAux false a
| _, .lam _ _ t _ => additiveTestAux false t
| _, .forallE _ _ t _ => additiveTestAux false t
| _, .letE _ _ e body _ =>
additiveTestAux false e <&&> additiveTestAux false body
| _, _ => return true
visit false a
| _, .lam _ _ t _ => visit false t
| _, .forallE _ _ t _ => visit false t
| _, .letE _ _ e body _ => visit false e && visit false body
| _, _ => true

/--
`additiveTest e` tests whether the expression `e` contains no constant
Expand All @@ -289,18 +262,18 @@ e.g. `Nat` or `ℝ × α`.
We ignore all arguments specified by the `ignore` `NameMap`.
If `replaceAll` is `true` the test always returns `true`.
-/
def additiveTest (e : Expr) : M Bool := do
if ← replaceAll then
return true
else
additiveTestAux false e
def additiveTest (replaceAll : Bool) (findTranslation? : Name → Option Name)
(ignore : Name → Option (List ℕ)) (e : Expr) : Bool :=
replaceAll || additiveTestAux findTranslation? ignore false e

/-- Checks whether a numeral should be translated. -/
def shouldTranslateNumeral [Monad M] [MonadEnv M] (n : Name) (firstArg : Expr) : M Bool := do
match fixedNumeralAttr.find? (← getEnv) n with
| some true => additiveTest firstArg
| some false => return false
| none => return true
def shouldTranslateNumeral (replaceAll : Bool) (findTranslation? : Name → Option Name)
(ignore : Name → Option (List ℕ)) (fixedNumeral : Name → Option Bool)
(nm : Name) (firstArg : Expr) : Bool :=
match fixedNumeral nm with
| some true => additiveTest replaceAll findTranslation? ignore firstArg
| some false => false
| none => true

/-- Swap the first two elements of a list -/
def _root_.List.swapFirstTwo {α : Type _} : List α → List α
Expand All @@ -316,64 +289,74 @@ It translates each identifier (inductive type, defined function etc) in an expre
However, if `f` is in the dictionary `relevant`, then the argument `relevant.find f`
is tested, instead of the first argument.
It will also reorder arguments of certain functions, using `shouldReorder`:
e.g. `g x₁ x₂ x₃ ... xₙ` becomes `g x₂ x₁ x₃ ... xₙ` if `reorderAttr.find? env g = some [1]`.
It will also reorder arguments of certain functions, using `reorderFn`:
e.g. `g x₁ x₂ x₃ ... xₙ` becomes `g x₂ x₁ x₃ ... xₙ` if `reorderFn g = some [1]`.
-/
def applyReplacementFun : Expr → MetaM Expr :=
Lean.Expr.replaceRecMeta fun r e ↦ do
trace[to_additive_detail] "applyReplacementFun: replace at {e}"
def applyReplacementFun (e : Expr) : MetaM Expr := do
let env ← getEnv
let reorderFn : Name → List ℕ := fun nm ↦ (reorderAttr.find? env nm |>.getD [])
let isRelevant : Name → ℕ → Bool := fun nm i ↦ i == (relevantArgAttr.find? env nm).getD 0
return aux ((← getOptions).getBool `to_additive.replaceAll)
(findTranslation? <| ← getEnv) reorderFn (ignoreArgsAttr.find? env)
(fixedNumeralAttr.find? env) isRelevant e
where /-- Implementation of `applyReplacementFun`. -/
aux (replaceAll : Bool) (findTranslation? : Name → Option Name)
(reorderFn : Name → List ℕ) (ignore : Name → Option (List ℕ))
(fixedNumeral : Name → Option Bool) (isRelevant : Name → ℕ → Bool) : Expr → Expr :=
Lean.Expr.replaceRec fun r e ↦ Id.run do
-- trace[to_additive_detail] "applyReplacementFun: replace at {e}"
match e with
| .lit (.natVal 1) => pure <| mkRawNatLit 0
| .const n₀ ls => do
let n₁ := n₀.mapPrefix <| findTranslation? <| ← getEnv
if n₀ != n₁ then
trace[to_additive_detail] "applyReplacementFun: {n₀} → {n₁}"
let ls : List Level := if ← shouldReorder n₀ 1 then ls.swapFirstTwo else ls
let n₁ := n₀.mapPrefix findTranslation?
-- if n₀ != n₁ then
-- trace[to_additive_detail] "applyReplacementFun: {n₀} → {n₁}"
let ls : List Level := if 1 ∈ reorderFn n₀ then ls.swapFirstTwo else ls
return some <| Lean.mkConst n₁ ls
| .app g x => do
let gf := g.getAppFn
if let some nm := gf.constName? then
let gArgs := g.getAppArgs
-- e = `(nm y₁ .. yₙ x)
trace[to_additive_detail] "applyReplacementFun: app {nm} {gArgs} {x}"
-- trace[to_additive_detail] "applyReplacementFun: app {nm} {gArgs} {x}"
/- Test if arguments should be reordered. -/
if h : gArgs.size > 0 then
let c1 ← shouldReorder nm gArgs.size
let c2 additiveTest gArgs[0]
let c1 : Bool := gArgs.size ∈ reorderFn nm
let c2 := additiveTest replaceAll findTranslation? ignore gArgs[0]
if c1 && c2 then
-- interchange `x` and the last argument of `g`
let x r x
let gf r g.appFn!
let ga r g.appArg!
let e₂ := mkApp2 gf x ga
trace[to_additive_detail]
"applyReplacementFun: reordering {nm}: {x} ↔ {ga}\nBefore: {e}\nAfter: {e₂}"
let x := r x
let gf := r g.appFn!
let ga := r g.appArg!
let e₂ := mkApp2 gf x ga
-- trace[to_additive_detail]
-- "applyReplacementFun: reordering {nm}: {x} ↔ {ga}\nBefore: {e}\nAfter: {e₂}"
return some e₂
/- Test if the head should not be replaced. -/
let c1 isRelevant nm gArgs.size
let c1 := isRelevant nm gArgs.size
let c2 := gf.isConst
let c3 additiveTest x
if c1 && c2 && c3 then
trace[to_additive_detail]
"applyReplacementFun: {x} doesn't contain a fixed type, so we will change {nm}"
let c3 := additiveTest replaceAll findTranslation? ignore x
-- if c1 && c2 && c3 then
-- trace[to_additive_detail]
-- "applyReplacementFun: {x} doesn't contain a fixed type, so we will change {nm}"
if c1 && c2 && not c3 then
-- the test failed, so don't update the function body.
trace[to_additive_detail]
"applyReplacementFun: {x} contains a fixed type, so {nm} is not changed"
-- trace[to_additive_detail]
-- "applyReplacementFun: {x} contains a fixed type, so {nm} is not changed"
let x ← r x
let args ← gArgs.mapM r
return some $ mkApp (mkAppN gf args) x
/- Do not replace numerals in specific types. -/
let firstArg := if h : gArgs.size > 0 then gArgs[0] else x
if not (← shouldTranslateNumeral nm firstArg) then
trace[to_additive_detail] "applyReplacementFun: Do not change numeral {g.app x}"
if !shouldTranslateNumeral replaceAll findTranslation? ignore fixedNumeral nm firstArg then
-- trace[to_additive_detail] "applyReplacementFun: Do not change numeral {g.app x}"
return some <| g.app x
return e.updateApp! (← r g) (← r x)
| .proj n₀ idx e => do
let n₁ := n₀.mapPrefix <| findTranslation? <| ← getEnv
if n₀ != n₁ then
trace[to_additive_detail] "applyReplacementFun: in projection {e}.{idx} of type {n₀}, {""
}replace type with {n₁}"
let n₁ := n₀.mapPrefix findTranslation?
-- if n₀ != n₁ then
-- trace[to_additive_detail] "applyReplacementFun: in projection {e}.{idx} of type {n₀}, {""
-- }replace type with {n₁}"
return some <| .proj n₁ idx <| ← r e
| _ => return none

Expand All @@ -385,24 +368,29 @@ def etaExpandN (n : Nat) (e : Expr): MetaM Expr := do
`reorder`. They are expanded until they are applied to one more argument than the maximum in
`reorder.find n`. -/
def expand (e : Expr) : MetaM Expr := do
let e₂ ← e.replaceRecMeta $ fun r e ↦ do
let env ← getEnv
let reorderFn : Name → List ℕ := fun nm ↦ (reorderAttr.find? env nm |>.getD [])
let e₂ ← Lean.Meta.transform (input := e) (post := fun e => return .done e) <| fun e ↦ do
let e0 := e.getAppFn
let es := e.getAppArgs
let some e0n := e0.constName? | return none
let reorder ← getReorder e0n
let some e0n := e0.constName? | return .continue
let reorder := reorderFn e0n
if reorder.isEmpty then
-- no need to expand if nothing needs reordering
return none
let e' := mkAppN e0 $ ← es.mapM r
return .continue
let needed_n := reorder.foldr Nat.max 0 + 1
if needed_n ≤ es.size then
return some e'
-- the second disjunct is a temporary fix to avoid infinite loops.
-- We may need to use `replaceRec` or something similar to not change the head of an application
if needed_n ≤ es.size || es.size == 0 then
return .continue
else
-- in this case, we need to reorder arguments that are not yet
-- applied, so first η-expand the function.
let e' ← etaExpandN (needed_n - es.size) e'
return some $ e'
trace[to_additive_detail] "expand:\nBefore: {e}\nAfter: {e₂}"
let e' ← etaExpandN (needed_n - es.size) e
trace[to_additive_detail] "expanded {e} to {e'}"
return .continue e'
if e != e₂ then
trace[to_additive_detail] "expand:\nBefore: {e}\nAfter: {e₂}"
return e₂

/-- Reorder pi-binders. See doc of `reorderAttr` for the interpretation of the argument -/
Expand Down Expand Up @@ -439,10 +427,9 @@ def updateDecl
let mut decl := srcDecl.updateName tgt
if 1 ∈ reorder then
decl := decl.updateLevelParams decl.levelParams.swapFirstTwo
decl := decl.updateType <| ← applyReplacementFun <| ← (reorderForall · reorder) <|
← expand decl.type
decl := decl.updateType <| ← applyReplacementFun <| ← reorderForall (← expand decl.type) reorder
if let some v := decl.value? then
decl := decl.updateValue <| ← applyReplacementFun <| ← (reorderLambda · reorder) <| ← expand v
decl := decl.updateValue <| ← applyReplacementFun <| ← reorderLambda (← expand v) reorder
return decl

/-- Lean 4 makes declarations which are not internal
Expand Down
10 changes: 6 additions & 4 deletions test/Expr.lean
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import Mathlib.Lean.Expr.ReplaceRec
import Mathlib.Tactic.RunCmd
import Mathlib.Init.Data.Nat.Basic

open Lean Meta Elab
open Lean Meta Elab Command

section replaceRec
/-! Test the implementation of `Expr.replaceRec` -/

/-- Reorder the last two arguments of every function in the expression.
(The resulting term will generally not be a type-correct) -/
def reorderLastArguments : Expr → Expr :=
Expr.replaceRecTraversal λ e =>
Expr.replaceRec λ r e =>
let n := e.getAppNumArgs
if n ≥ 2 then
some (e.getAppArgs, λ es => mkAppN e.getAppFn $ es.swap! (n - 1) (n - 2)) else
mkAppN e.getAppFn <| e.getAppArgs.map r |>.swap! (n - 1) (n - 2)
else
none

def foo (f : ℕ → ℕ → ℕ) (n₁ n₂ n₃ n₄ : ℕ) : ℕ := f (f n₁ n₂) (f n₃ n₄)
def bar (f : ℕ → ℕ → ℕ) (n₁ n₂ n₃ n₄ : ℕ) : ℕ := f (f n₄ n₃) (f n₂ n₁)

#eval show TermElabM _ from do
run_cmd liftTermElabM <| do
let d ← getConstInfo `foo
let e := d.value!
logInfo m!"before: {e}"
Expand Down
Loading

0 comments on commit 8b308fa

Please sign in to comment.