From 722b01f894f7e74366ee475d1a90e0dabf304d0c Mon Sep 17 00:00:00 2001 From: Floris van Doorn Date: Sat, 4 Feb 2023 00:58:57 +0000 Subject: [PATCH] fix: interaction between to_additive and rewriting definitions (#1948) --- Mathlib/Lean/Expr/Basic.lean | 15 ++++++-- Mathlib/Tactic/ToAdditive.lean | 67 +++++++++++++++++++++++++--------- test/toAdditive.lean | 10 +++++ 3 files changed, 71 insertions(+), 21 deletions(-) diff --git a/Mathlib/Lean/Expr/Basic.lean b/Mathlib/Lean/Expr/Basic.lean index 81bd1f366d779..b368a1980984f 100644 --- a/Mathlib/Lean/Expr/Basic.lean +++ b/Mathlib/Lean/Expr/Basic.lean @@ -72,6 +72,17 @@ def splitAt (nm : Name) (n : Nat) : Name × Name := let (nm2, nm1) := (nm.componentsRev.splitAt n) (.fromComponents <| nm1.reverse, .fromComponents <| nm2.reverse) +/-- `isPrefixOf? pre nm` returns `some post` if `nm = pre ++ post`. + Note that this includes the case where `nm` has multiple more namespaces. + If `pre` is not a prefix of `nm`, it returns `none`. -/ +def isPrefixOf? (pre nm : Name) : Option Name := + if pre == nm then + some anonymous + else match nm with + | anonymous => none + | num p' a => (isPrefixOf? pre p').map (·.num a) + | str p' s => (isPrefixOf? pre p').map (·.str s) + end Name @@ -221,10 +232,6 @@ def zero? (e : Expr) : Bool := | some 0 => true | _ => false -/-- Returns a `NameSet` of all constants in an expression starting with a prefix in `pre`. -/ -def listNamesWithPrefixes (pre : NameSet) (e : Expr) : NameSet := - e.foldConsts ∅ fun n l ↦ if pre.contains n.getPrefix then l.insert n else l - def modifyAppArgM [Functor M] [Pure M] (modifier : Expr → M Expr) : Expr → M Expr | app f a => mkApp f <$> modifier a | e => pure e diff --git a/Mathlib/Tactic/ToAdditive.lean b/Mathlib/Tactic/ToAdditive.lean index 57c01879fd118..d8cc77ff46703 100644 --- a/Mathlib/Tactic/ToAdditive.lean +++ b/Mathlib/Tactic/ToAdditive.lean @@ -268,7 +268,7 @@ def additiveTestAux (findTranslation? : Name → Option Name) This is used in `@[to_additive]` for deciding which subexpressions to transform: we only transform constants if `additiveTest` applied to their first argument returns `true`. This means we will replace expression applied to e.g. `α` or `α × β`, but not when applied to -e.g. `Nat` or `ℝ × α`. +e.g. `ℕ` or `ℝ × α`. We ignore all arguments specified by the `ignore` `NameMap`. -/ def additiveTest (findTranslation? : Name → Option Name) @@ -453,6 +453,42 @@ def isInternal' (declName : Name) : Bool := | .str _ s => "match_".isPrefixOf s || "proof_".isPrefixOf s || "eq_".isPrefixOf s | _ => true +/-- Find the target name of `pre` and all created auxiliary declarations. -/ +def findTargetName (env : Environment) (src pre tgt_pre : Name) : CoreM Name := + /- This covers auxiliary declarations like `match_i` and `proof_i`. -/ + if let some post := pre.isPrefixOf? src then + return tgt_pre ++ post + /- This covers equation lemmas (for other declarations). -/ + else if let some post := privateToUserName? src then + match findTranslation? env post.getPrefix with + -- this is an equation lemma for a declaration without `to_additive`. We will skip this. + | none => return src + -- this is an equation lemma for a declaration with `to_additive`. We will additivize this. + -- Note: if this errors we could do this instead by calling `getEqnsFor?` + | some addName => return src.updatePrefix <| mkPrivateName env addName + -- Note: this additivizes lemmas generated by `simp`. + -- Todo: we do not currently check whether such lemmas actually should be additivized. + else if let some post := env.mainModule ++ `_auxLemma |>.isPrefixOf? src then + return env.mainModule ++ `_auxAddLemma ++ post + else + throwError "internal @[to_additive] error." + +/-- Returns a `NameSet` of all auxiliary constants in `e` that might have been generated +when adding `pre` to the environment. +Examples include `pre.match_5`, `Mathlib.MyFile._auxLemma.3` and +`_private.Mathlib.MyFile.someOtherNamespace.someOtherDeclaration._eq_2`. +The last two examples may or may not have been generated by this declaration. +The last example may or may not be the equation lemma of a declaration with the `@[to_additive]` +attribute. We will only translate it has the `@[to_additive]` attribute. +-/ +def findAuxDecls (e : Expr) (pre mainModule : Name) : NameSet := +let auxLemma := mainModule ++ `_auxLemma +e.foldConsts ∅ fun n l ↦ + if n.getPrefix == pre || n.getPrefix == auxLemma || isPrivateName n then + l.insert n + else + l + /-- transform the declaration `src` and all declarations `pre._proof_i` occurring in `src` using the transforms dictionary. `replace_all`, `trace`, `ignore` and `reorder` are configuration options. @@ -472,24 +508,21 @@ partial def transformDeclAux pre}, but does not have the `@[to_additive]` attribute. This is not supported.\n{"" }Workaround: move {src} to a different namespace." -- we find the additive name of `src` - let tgt := src.mapPrefix (fun n ↦ if n == pre then some tgt_pre else - if n == mkPrivateName env pre then some <| mkPrivateName env tgt_pre else - -- note: this is only a partial solution to dealing with lemmas generated by the - -- `simp` attribute, and should be removed/revised when we have a full solution - if n == env.mainModule ++ `_auxLemma then env.mainModule ++ `_auxAddLemma else none) - if tgt == src then - throwError "@[to_additive] doesn't know how to translate {src}, since the additive version has { - ""}the same name. This is a bug in @[to_additive]." - -- we skip if we already transformed this declaration before + let tgt ← findTargetName env src pre tgt_pre + -- we skip if we already transformed this declaration before. if env.contains tgt then + if tgt == src then + -- Note: this can happen for equation lemmas of declarations without `@[to_additive]`. + trace[to_additive_detail] "Auxiliary declaration {src} will be translated to itself." + else + trace[to_additive_detail] "Already visited {tgt} as translation of {src}." return let srcDecl ← getConstInfo src - let prefixes : NameSet := .ofList [pre, env.mainModule ++ `_auxLemma] -- we first transform all auxiliary declarations generated when elaborating `pre` - for n in srcDecl.type.listNamesWithPrefixes prefixes do + for n in findAuxDecls srcDecl.type pre env.mainModule do transformDeclAux cfg pre tgt_pre n if let some value := srcDecl.value? then - for n in value.listNamesWithPrefixes prefixes do + for n in findAuxDecls value pre env.mainModule do transformDeclAux cfg pre tgt_pre n -- if the auxilliary declaration doesn't have prefix `pre`, then we have to add this declaration -- to the translation dictionary, since otherwise we cannot find the additive name. @@ -969,7 +1002,7 @@ mapped to its additive version. The basic heuristic is Examples: * `@Mul.mul Nat n m` (i.e. `(n * m : Nat)`) will not change to `+`, since its - first argument is `Nat`, an identifier not applied to any arguments. + first argument is `ℕ`, an identifier not applied to any arguments. * `@Mul.mul (α × β) x y` will change to `+`. It's first argument contains only the identifier `prod`, but this is applied to arguments, `α` and `β`. * `@Mul.mul (α × Int) x y` will not change to `+`, since its first argument contains `Int`. @@ -988,7 +1021,7 @@ There are some exceptions to this heuristic: * If an identifier has attribute `@[to_additive_ignore_args n1 n2 ...]` then all the arguments in positions `n1`, `n2`, ... will not be checked for unapplied identifiers (start counting from 1). For example, `cont_mdiff_map` has attribute `@[to_additive_ignore_args 21]`, which means - that its 21st argument `(n : WithTop Nat)` can contain `Nat` + that its 21st argument `(n : WithTop Nat)` can contain `ℕ` (usually in the form `Top.top Nat ...`) and still be additivized. So `@Mul.mul (C^∞⟮I, N; I', G⟯) _ f g` will be additivized. @@ -1004,7 +1037,7 @@ mismatch error. reorder the (implicit) arguments of `d` so that the first argument becomes a type with a multiplicative structure (and not some indexing type)? The reason is that `@[to_additive]` doesn't additivize declarations if their first argument - contains fixed types like `Nat` or `ℝ`. See section Heuristics. + contains fixed types like `ℕ` or `ℝ`. See section Heuristics. If the first argument is not the argument with a multiplicative type-class, `@[to_additive]` should have automatically added the attribute `@[to_additive_relevant_arg]` to the declaration. You can test this by running the following (where `d` is the full name of the declaration): @@ -1017,7 +1050,7 @@ mismatch error. multiplicative structure. * Option 2: It didn't additivize a declaration that should be additivized. This happened because the heuristic applied, and the first argument contains a fixed type, - like `Nat` or `ℝ`. Solutions: + like `ℕ` or `ℝ`. Solutions: * If the fixed type has an additive counterpart (like `↥Semigroup`), give it the `@[to_additive]` attribute. * If the fixed type occurs inside the `k`-th argument of a declaration `d`, and the diff --git a/test/toAdditive.lean b/test/toAdditive.lean index 5ca8b8a86e7b1..30baf68afb951 100644 --- a/test/toAdditive.lean +++ b/test/toAdditive.lean @@ -248,6 +248,16 @@ lemma zero_fooClass [Zero α] : FooClass α := by infer_instance end instances +/- Test that we can rewrite with definitions with the `@[to_additive]` attribute. -/ +@[to_additive] +lemma npowRec_zero [One M] [Mul M] (x : M) : npowRec 0 x = 1 := + by rw [npowRec] + +/- Test that we can rewrite with definitions without the `@[to_additive]` attribute. -/ +@[to_additive addoptiontest] +lemma optiontest (x : Option α) : x.elim .none Option.some = x := + by cases x <;> rw [Option.elim] + /- Check that `to_additive` works if a `_match` aux declaration is created. -/ @[to_additive] def IsUnit [Mul M] (a : M) : Prop := a ≠ a