Skip to content

Commit

Permalink
fix: interaction between to_additive and rewriting definitions (#1948)
Browse files Browse the repository at this point in the history
  • Loading branch information
fpvandoorn committed Feb 4, 2023
1 parent 0a2e56f commit 722b01f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 21 deletions.
15 changes: 11 additions & 4 deletions Mathlib/Lean/Expr/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def splitAt (nm : Name) (n : Nat) : Name × Name :=
let (nm2, nm1) := (nm.componentsRev.splitAt n)
(.fromComponents <| nm1.reverse, .fromComponents <| nm2.reverse)

/-- `isPrefixOf? pre nm` returns `some post` if `nm = pre ++ post`.
Note that this includes the case where `nm` has multiple more namespaces.
If `pre` is not a prefix of `nm`, it returns `none`. -/
def isPrefixOf? (pre nm : Name) : Option Name :=
if pre == nm then
some anonymous
else match nm with
| anonymous => none
| num p' a => (isPrefixOf? pre p').map (·.num a)
| str p' s => (isPrefixOf? pre p').map (·.str s)

end Name


Expand Down Expand Up @@ -221,10 +232,6 @@ def zero? (e : Expr) : Bool :=
| some 0 => true
| _ => false

/-- Returns a `NameSet` of all constants in an expression starting with a prefix in `pre`. -/
def listNamesWithPrefixes (pre : NameSet) (e : Expr) : NameSet :=
e.foldConsts ∅ fun n l ↦ if pre.contains n.getPrefix then l.insert n else l

def modifyAppArgM [Functor M] [Pure M] (modifier : Expr → M Expr) : Expr → M Expr
| app f a => mkApp f <$> modifier a
| e => pure e
Expand Down
67 changes: 50 additions & 17 deletions Mathlib/Tactic/ToAdditive.lean
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def additiveTestAux (findTranslation? : Name → Option Name)
This is used in `@[to_additive]` for deciding which subexpressions to transform: we only transform
constants if `additiveTest` applied to their first argument returns `true`.
This means we will replace expression applied to e.g. `α` or `α × β`, but not when applied to
e.g. `Nat` or `ℝ × α`.
e.g. `` or `ℝ × α`.
We ignore all arguments specified by the `ignore` `NameMap`.
-/
def additiveTest (findTranslation? : Name → Option Name)
Expand Down Expand Up @@ -453,6 +453,42 @@ def isInternal' (declName : Name) : Bool :=
| .str _ s => "match_".isPrefixOf s || "proof_".isPrefixOf s || "eq_".isPrefixOf s
| _ => true

/-- Find the target name of `pre` and all created auxiliary declarations. -/
def findTargetName (env : Environment) (src pre tgt_pre : Name) : CoreM Name :=
/- This covers auxiliary declarations like `match_i` and `proof_i`. -/
if let some post := pre.isPrefixOf? src then
return tgt_pre ++ post
/- This covers equation lemmas (for other declarations). -/
else if let some post := privateToUserName? src then
match findTranslation? env post.getPrefix with
-- this is an equation lemma for a declaration without `to_additive`. We will skip this.
| none => return src
-- this is an equation lemma for a declaration with `to_additive`. We will additivize this.
-- Note: if this errors we could do this instead by calling `getEqnsFor?`
| some addName => return src.updatePrefix <| mkPrivateName env addName
-- Note: this additivizes lemmas generated by `simp`.
-- Todo: we do not currently check whether such lemmas actually should be additivized.
else if let some post := env.mainModule ++ `_auxLemma |>.isPrefixOf? src then
return env.mainModule ++ `_auxAddLemma ++ post
else
throwError "internal @[to_additive] error."

/-- Returns a `NameSet` of all auxiliary constants in `e` that might have been generated
when adding `pre` to the environment.
Examples include `pre.match_5`, `Mathlib.MyFile._auxLemma.3` and
`_private.Mathlib.MyFile.someOtherNamespace.someOtherDeclaration._eq_2`.
The last two examples may or may not have been generated by this declaration.
The last example may or may not be the equation lemma of a declaration with the `@[to_additive]`
attribute. We will only translate it has the `@[to_additive]` attribute.
-/
def findAuxDecls (e : Expr) (pre mainModule : Name) : NameSet :=
let auxLemma := mainModule ++ `_auxLemma
e.foldConsts ∅ fun n l ↦
if n.getPrefix == pre || n.getPrefix == auxLemma || isPrivateName n then
l.insert n
else
l

/-- transform the declaration `src` and all declarations `pre._proof_i` occurring in `src`
using the transforms dictionary.
`replace_all`, `trace`, `ignore` and `reorder` are configuration options.
Expand All @@ -472,24 +508,21 @@ partial def transformDeclAux
pre}, but does not have the `@[to_additive]` attribute. This is not supported.\n{""
}Workaround: move {src} to a different namespace."
-- we find the additive name of `src`
let tgt := src.mapPrefix (fun n ↦ if n == pre then some tgt_pre else
if n == mkPrivateName env pre then some <| mkPrivateName env tgt_pre else
-- note: this is only a partial solution to dealing with lemmas generated by the
-- `simp` attribute, and should be removed/revised when we have a full solution
if n == env.mainModule ++ `_auxLemma then env.mainModule ++ `_auxAddLemma else none)
if tgt == src then
throwError "@[to_additive] doesn't know how to translate {src}, since the additive version has {
""}the same name. This is a bug in @[to_additive]."
-- we skip if we already transformed this declaration before
let tgt ← findTargetName env src pre tgt_pre
-- we skip if we already transformed this declaration before.
if env.contains tgt then
if tgt == src then
-- Note: this can happen for equation lemmas of declarations without `@[to_additive]`.
trace[to_additive_detail] "Auxiliary declaration {src} will be translated to itself."
else
trace[to_additive_detail] "Already visited {tgt} as translation of {src}."
return
let srcDecl ← getConstInfo src
let prefixes : NameSet := .ofList [pre, env.mainModule ++ `_auxLemma]
-- we first transform all auxiliary declarations generated when elaborating `pre`
for n in srcDecl.type.listNamesWithPrefixes prefixes do
for n in findAuxDecls srcDecl.type pre env.mainModule do
transformDeclAux cfg pre tgt_pre n
if let some value := srcDecl.value? then
for n in value.listNamesWithPrefixes prefixes do
for n in findAuxDecls value pre env.mainModule do
transformDeclAux cfg pre tgt_pre n
-- if the auxilliary declaration doesn't have prefix `pre`, then we have to add this declaration
-- to the translation dictionary, since otherwise we cannot find the additive name.
Expand Down Expand Up @@ -969,7 +1002,7 @@ mapped to its additive version. The basic heuristic is
Examples:
* `@Mul.mul Nat n m` (i.e. `(n * m : Nat)`) will not change to `+`, since its
first argument is `Nat`, an identifier not applied to any arguments.
first argument is ``, an identifier not applied to any arguments.
* `@Mul.mul (α × β) x y` will change to `+`. It's first argument contains only the identifier
`prod`, but this is applied to arguments, `α` and `β`.
* `@Mul.mul (α × Int) x y` will not change to `+`, since its first argument contains `Int`.
Expand All @@ -988,7 +1021,7 @@ There are some exceptions to this heuristic:
* If an identifier has attribute `@[to_additive_ignore_args n1 n2 ...]` then all the arguments in
positions `n1`, `n2`, ... will not be checked for unapplied identifiers (start counting from 1).
For example, `cont_mdiff_map` has attribute `@[to_additive_ignore_args 21]`, which means
that its 21st argument `(n : WithTop Nat)` can contain `Nat`
that its 21st argument `(n : WithTop Nat)` can contain ``
(usually in the form `Top.top Nat ...`) and still be additivized.
So `@Mul.mul (C^∞⟮I, N; I', G⟯) _ f g` will be additivized.
Expand All @@ -1004,7 +1037,7 @@ mismatch error.
reorder the (implicit) arguments of `d` so that the first argument becomes a type with a
multiplicative structure (and not some indexing type)?
The reason is that `@[to_additive]` doesn't additivize declarations if their first argument
contains fixed types like `Nat` or `ℝ`. See section Heuristics.
contains fixed types like `` or `ℝ`. See section Heuristics.
If the first argument is not the argument with a multiplicative type-class, `@[to_additive]`
should have automatically added the attribute `@[to_additive_relevant_arg]` to the declaration.
You can test this by running the following (where `d` is the full name of the declaration):
Expand All @@ -1017,7 +1050,7 @@ mismatch error.
multiplicative structure.
* Option 2: It didn't additivize a declaration that should be additivized.
This happened because the heuristic applied, and the first argument contains a fixed type,
like `Nat` or `ℝ`. Solutions:
like `` or `ℝ`. Solutions:
* If the fixed type has an additive counterpart (like `↥Semigroup`), give it the `@[to_additive]`
attribute.
* If the fixed type occurs inside the `k`-th argument of a declaration `d`, and the
Expand Down
10 changes: 10 additions & 0 deletions test/toAdditive.lean
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,16 @@ lemma zero_fooClass [Zero α] : FooClass α := by infer_instance

end instances

/- Test that we can rewrite with definitions with the `@[to_additive]` attribute. -/
@[to_additive]
lemma npowRec_zero [One M] [Mul M] (x : M) : npowRec 0 x = 1 :=
by rw [npowRec]

/- Test that we can rewrite with definitions without the `@[to_additive]` attribute. -/
@[to_additive addoptiontest]
lemma optiontest (x : Option α) : x.elim .none Option.some = x :=
by cases x <;> rw [Option.elim]

/- Check that `to_additive` works if a `_match` aux declaration is created. -/
@[to_additive]
def IsUnit [Mul M] (a : M) : Prop := a ≠ a
Expand Down

0 comments on commit 722b01f

Please sign in to comment.