Skip to content

Commit

Permalink
feat: support irreducible_def in to_additive (#3399)
Browse files Browse the repository at this point in the history
  • Loading branch information
gebner committed Apr 14, 2023
1 parent dfbb8aa commit 27e0a79
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 72 deletions.
12 changes: 5 additions & 7 deletions Mathlib/Algebra/BigOperators/Finprod.lean
Expand Up @@ -93,23 +93,21 @@ section
with `Classical.dec` in their statement. -/
open Classical


-- Porting note: replaced irreducible_def with def and an irreducible tag here.
/-- Sum of `f x` as `x` ranges over the elements of the support of `f`, if it's finite. Zero
otherwise. -/
@[irreducible]
noncomputable def finsum {M α} [AddCommMonoid M] (f : α → M) : M :=
noncomputable irreducible_def finsum (lemma := finsum_def') [AddCommMonoid M] (f : α → M) : M :=
if h : (support (f ∘ PLift.down)).Finite then ∑ i in h.toFinset, f i.down else 0
#align finsum finsum

-- Porting note: replaced irreducible_def with def and an irreducible tag here.
/-- Product of `f x` as `x` ranges over the elements of the multiplicative support of `f`, if it's
finite. One otherwise. -/
@[to_additive existing (attr:= irreducible)]
noncomputable def finprod (f : α → M) : M :=
@[to_additive existing]
noncomputable irreducible_def finprod (lemma := finprod_def') (f : α → M) : M :=
if h : (mulSupport (f ∘ PLift.down)).Finite then ∏ i in h.toFinset, f i.down else 1
#align finprod finprod

attribute [to_additive existing] finprod_def'

end

open Std.ExtendedBinder
Expand Down
3 changes: 1 addition & 2 deletions Mathlib/Data/Real/Basic.lean
Expand Up @@ -88,8 +88,7 @@ private irreducible_def neg : ℝ → ℝ
private irreducible_def mul : ℝ → ℝ → ℝ
| ⟨a⟩, ⟨b⟩ => ⟨a * b⟩

-- TODO irreducible_def
private noncomputable def inv' : ℝ → ℝ
private noncomputable irreducible_def inv' : ℝ → ℝ
| ⟨a⟩ => ⟨a⁻¹⟩

instance : Zero ℝ :=
Expand Down
31 changes: 13 additions & 18 deletions Mathlib/GroupTheory/MonoidLocalization.lean
Expand Up @@ -214,26 +214,23 @@ instance inhabited : Inhabited (Localization S) := Con.Quotient.inhabited
#align add_localization.inhabited addLocalization.inhabited

/-- Multiplication in a `Localization` is defined as `⟨a, b⟩ * ⟨c, d⟩ = ⟨a * c, b * d⟩`. -/
-- Porting note: replaced irreducible_def by @[irreducible] to prevent an error with protected
@[to_additive (attr := irreducible)
"Addition in an `addLocalization` is defined as `⟨a, b⟩ + ⟨c, d⟩ = ⟨a + c, b + d⟩`.
@[to_additive "Addition in an `addLocalization` is defined as `⟨a, b⟩ + ⟨c, d⟩ = ⟨a + c, b + d⟩`.
Should not be confused with the ring localization counterpart `Localization.add`, which maps
`⟨a, b⟩ + ⟨c, d⟩` to `⟨d * a + b * c, b * d⟩`."]
protected def mul : Localization S → Localization S → Localization S := (r S).commMonoid.mul
protected irreducible_def mul : Localization S → Localization S → Localization S :=
(r S).commMonoid.mul
#align localization.mul Localization.mul
#align add_localization.add addLocalization.add

@[to_additive]
instance : Mul (Localization S) := ⟨Localization.mul S⟩

/-- The identity element of a `Localization` is defined as `⟨1, 1⟩`. -/
@[to_additive (attr := irreducible)
"The identity element of an `addLocalization` is defined as `⟨0, 0⟩`.
@[to_additive "The identity element of an `addLocalization` is defined as `⟨0, 0⟩`.
Should not be confused with the ring localization counterpart `Localization.zero`,
which is defined as `⟨0, 1⟩`."]
-- Porting note: replaced irreducible_def by @[irreducible] to prevent an error with protected
protected def one : Localization S := (r S).commMonoid.one
protected irreducible_def one : Localization S := (r S).commMonoid.one
#align localization.one Localization.one
#align add_localization.zero addLocalization.zero

Expand All @@ -245,14 +242,12 @@ instance : One (Localization S) := ⟨Localization.one S⟩
This is a separate `irreducible` def to ensure the elaborator doesn't waste its time
trying to unify some huge recursive definition with itself, but unfolded one step less.
-/
@[to_additive (attr := irreducible)
"Multiplication with a natural in an `AddLocalization` is defined as
@[to_additive "Multiplication with a natural in an `AddLocalization` is defined as
`n • ⟨a, b⟩ = ⟨n • a, n • b⟩`.
This is a separate `irreducible` def to ensure the elaborator doesn't waste its time
trying to unify some huge recursive definition with itself, but unfolded one step less."]
-- Porting note: replaced irreducible_def by @[irreducible] to prevent an error with protected
protected def npow : ℕ → Localization S → Localization S := (r S).commMonoid.npow
protected irreducible_def npow : ℕ → Localization S → Localization S := (r S).commMonoid.npow
#align localization.npow Localization.npow
#align add_localization.nsmul addLocalization.nsmul

Expand All @@ -261,18 +256,18 @@ instance : CommMonoid (Localization S) where
mul := (· * ·)
one := 1
mul_assoc x y z := show (x.mul S y).mul S z = x.mul S (y.mul S z) by
delta Localization.mul; apply (r S).commMonoid.mul_assoc
rw [Localization.mul]; apply (r S).commMonoid.mul_assoc
mul_comm x y := show x.mul S y = y.mul S x by
delta Localization.mul; apply (r S).commMonoid.mul_comm
rw [Localization.mul]; apply (r S).commMonoid.mul_comm
mul_one x := show x.mul S (.one S) = x by
delta Localization.mul Localization.one; apply (r S).commMonoid.mul_one
rw [Localization.mul, Localization.one]; apply (r S).commMonoid.mul_one
one_mul x := show (Localization.one S).mul S x = x by
delta Localization.mul Localization.one; apply (r S).commMonoid.one_mul
rw [Localization.mul, Localization.one]; apply (r S).commMonoid.one_mul
npow := Localization.npow S
npow_zero x := show Localization.npow S 0 x = .one S by
delta Localization.npow Localization.one; apply (r S).commMonoid.npow_zero
rw [Localization.npow, Localization.one]; apply (r S).commMonoid.npow_zero
npow_succ n x := show .npow S n.succ x = x.mul S (.npow S n x) by
delta Localization.npow Localization.mul; apply (r S).commMonoid.npow_succ
rw [Localization.npow, Localization.mul]; apply (r S).commMonoid.npow_succ

variable {S}

Expand Down
5 changes: 0 additions & 5 deletions Mathlib/LinearAlgebra/Finsupp.lean
Expand Up @@ -1137,11 +1137,6 @@ section

variable (R)

-- Porting note: `irreducible_def` produces a structure.
-- When a structure is defined, an injectivity theorem of the constructor is
-- generated, which has `simp` attr, but this get a `simpNF` linter.
-- So, this option is required.
set_option genInjectivity false in
/-- Pick some representation of `x : span R w` as a linear combination in `w`,
using the axiom of choice.
-/
Expand Down
4 changes: 2 additions & 2 deletions Mathlib/Tactic/Eqns.lean
Expand Up @@ -27,7 +27,7 @@ theorem transpose_const {m n} (c : ℕ) :
rw [transpose]
```
-/
open Lean
open Lean Elab

syntax (name := eqns) "eqns" ident* : attr

Expand All @@ -37,7 +37,7 @@ initialize eqnsAttribute : NameMapExtension (Array Name) ←
descr := "Overrides the equation lemmas for a declation to the provided list"
add := fun
| _, `(attr| eqns $[$names]*) =>
names.mapM resolveGlobalConstNoOverload
names.mapM resolveGlobalConstNoOverloadWithInfo
| _, _ => Lean.Elab.throwUnsupportedSyntax }

initialize Lean.Meta.registerGetEqnsFn (fun name => do
Expand Down
50 changes: 28 additions & 22 deletions Mathlib/Tactic/IrreducibleDef.lean
Expand Up @@ -5,6 +5,7 @@ Authors: Gabriel Ebner
-/
import Lean
import Mathlib.Tactic.Eqns
import Mathlib.Data.Subtype

/-!
# Irreducible definitions
Expand Down Expand Up @@ -45,64 +46,69 @@ local elab "eta_helper " t:term : term => do
let some (_, lhs, rhs) := t.eq? | throwError "not an equation: {t}"
synthesizeSyntheticMVars
let rhs ← instantiateMVars rhs
lambdaLetTelescope rhs fun xs rhs ↦ do
lambdaTelescope rhs fun xs rhs ↦ do
let lhs := (mkAppN lhs xs).headBeta
mkForallFVars xs <|← mkEq lhs rhs

/-- `value_proj x` elabs to `@x.value` -/
local elab "value_proj " e:term : term => do
let e ← elabTerm e none
mkProjection e `value
/-- `val_proj x` elabs to the *primitive projection* `@x.val`. -/
local elab "val_proj " e:term : term => do
let e ← elabTerm (← `(($e : Subtype _))) none
return mkProj ``Subtype 0 e

/--
Executes the commands,
and stops after the first error.
In short, S-A-F-E.
-/
local syntax "stop_at_first_error" command* : command
local syntax "stop_at_first_error" (ppLine command)* : command
open Command in elab_rules : command
| `(stop_at_first_error $[$cmds]*) => do
for cmd in cmds do
elabCommand cmd.raw
if (← get).messages.hasErrors then break

syntax irredDefLemma := atomic("(" "lemma" " := ") ident ")"

/--
Introduces an irreducible definition.
`irreducible_def foo := 42` generates
a constant `foo : Nat` as well as
a theorem `foo_def : foo = 42`.
-/
elab mods:declModifiers "irreducible_def" n_id:declId declSig:optDeclSig val:declVal :
elab mods:declModifiers "irreducible_def" n_id:declId n_def:(irredDefLemma)?
declSig:optDeclSig val:declVal :
command => do
let (n, us) ← match n_id with
| `(Parser.Command.declId| $n:ident $[.{$us,*}]?) => pure (n, us)
| _ => throwUnsupportedSyntax
let us' := us.getD { elemsAndSeps := #[] }
let n_def := mkIdent <| (·.review) <|
let scopes := extractMacroScopes n.getId
{ scopes with name := scopes.name.appendAfter "_def" }
let n_def ← match n_def.getD ⟨mkNullNode⟩ with
| `(irredDefLemma| (lemma := $id)) => pure id
| _ => pure <| mkIdent <| (·.review) <|
let scopes := extractMacroScopes n.getId
{ scopes with name := scopes.name.appendAfter "_def" }
let `(Parser.Command.declModifiersF|
$[$doc:docComment]? $[$attrs:attributes]?
$[$doc:docComment]? $[@[$attrs,*]]?
$[$vis]? $[$nc:noncomputable]? $[$uns:unsafe]?) := mods
| throwError "unsupported modifiers {format mods}"
let attrs := attrs.getD {}
let prot := vis.filter (· matches `(Parser.Command.visibility| protected))
let priv := vis.filter (· matches `(Parser.Command.visibility| private))
elabCommand <|<- `(stop_at_first_error
$[$nc:noncomputable]? $[$uns]? def definition$[.{$us,*}]? $declSig:optDeclSig $val
set_option genInjectivity false in -- generates awful simp lemmas
$[$uns:unsafe]? structure Wrapper$[.{$us,*}]? where
value : type_of% @definition.{$us',*}
prop : Eq @value @(delta% @definition)
$[$nc:noncomputable]? $[$uns]? opaque wrapped$[.{$us,*}]? : Wrapper.{$us',*} := ⟨_, rfl⟩
$[$doc:docComment]? $[$attrs:attributes]? $[private%$priv]? $[$nc:noncomputable]? $[$uns]?
$[$nc:noncomputable]? $[$uns]? opaque wrapped$[.{$us,*}]? : Subtype (Eq @definition.{$us',*}) :=
⟨_, rfl⟩
$[$doc:docComment]? $[private%$priv]? $[$nc:noncomputable]? $[$uns]?
def $n:ident$[.{$us,*}]? :=
value_proj @wrapped.{$us',*}
val_proj @wrapped.{$us',*}
$[private%$priv]? $[$uns:unsafe]? theorem $n_def:ident $[.{$us,*}]? :
eta_helper Eq @$n.{$us',*} @(delta% @definition) := by
intros
simp only [$n:ident]
rw [wrapped.prop]
attribute [irreducible] $n
attribute [eqns $n_def] $n)
delta $n:ident
rw [show wrapped = ⟨@definition.{$us',*}, rfl⟩ from Subtype.ext wrapped.2.symm]
rfl
attribute [irreducible] $n definition
attribute [eqns $n_def] $n
attribute [$attrs:attrInstance,*] $n)
if prot.isSome then
modifyEnv (addProtected · ((← getCurrNamespace) ++ n.getId))
35 changes: 25 additions & 10 deletions Mathlib/Tactic/ToAdditive.lean
Expand Up @@ -18,6 +18,7 @@ import Std.Tactic.Lint -- useful to lint this file and for for DiscrTree.element
import Mathlib.Tactic.Relation.Rfl -- just to copy the attribute
import Mathlib.Tactic.Relation.Symm -- just to copy the attribute
import Mathlib.Tactic.Relation.Trans -- just to copy the attribute
import Mathlib.Tactic.Eqns -- just to copy the attribute
import Mathlib.Tactic.Simps.Basic

/-!
Expand Down Expand Up @@ -488,6 +489,9 @@ def updateDecl
decl := decl.updateType <| ← applyReplacementFun <| ← reorderForall (← expand decl.type) reorder
if let some v := decl.value? then
decl := decl.updateValue <| ← applyReplacementFun <| ← reorderLambda (← expand v) reorder
else if let .opaqueInfo info := decl then -- not covered by `value?`
decl := .opaqueInfo { info with
value := ← applyReplacementFun <| ← reorderLambda (← expand info.value) reorder }
return decl

/-- Find the target name of `pre` and all created auxiliary declarations. -/
Expand All @@ -507,6 +511,8 @@ def findTargetName (env : Environment) (src pre tgt_pre : Name) : CoreM Name :=
-- Todo: we do not currently check whether such lemmas actually should be additivized.
else if let some post := env.mainModule ++ `_auxLemma |>.isPrefixOf? src then
return env.mainModule ++ `_auxAddLemma ++ post
else if src.hasMacroScopes then
mkFreshUserName src.eraseMacroScopes
else
throwError "internal @[to_additive] error."

Expand All @@ -521,7 +527,7 @@ attribute. We will only translate it has the `@[to_additive]` attribute.
def findAuxDecls (e : Expr) (pre mainModule : Name) : NameSet :=
let auxLemma := mainModule ++ `_auxLemma
e.foldConsts ∅ fun n l ↦
if n.getPrefix == pre || n.getPrefix == auxLemma || isPrivateName n then
if n.getPrefix == pre || n.getPrefix == auxLemma || isPrivateName n || n.hasMacroScopes then
l.insert n
else
l
Expand Down Expand Up @@ -561,20 +567,24 @@ partial def transformDeclAux
if let some value := srcDecl.value? then
for n in findAuxDecls value pre env.mainModule do
transformDeclAux cfg pre tgt_pre n
if let .opaqueInfo {value, ..} := srcDecl then
for n in findAuxDecls value pre env.mainModule do
transformDeclAux cfg pre tgt_pre n
-- if the auxilliary declaration doesn't have prefix `pre`, then we have to add this declaration
-- to the translation dictionary, since otherwise we cannot find the additive name.
if !pre.isPrefixOf src then
insertTranslation src tgt
-- now transform the source declaration
let trgDecl : ConstantInfo ←
MetaM.run' <| updateDecl tgt srcDecl <| if src == pre then cfg.reorder else []
if !trgDecl.hasValue then
throwError "Expected {tgt} to have a value."
trace[to_additive] "generating\n{tgt} : {trgDecl.type} :=\n {trgDecl.value!}"
let value ← match trgDecl with
| .thmInfo { value, .. } | .defnInfo { value, .. } | .opaqueInfo { value, .. } => pure value
| _ => throwError "Expected {tgt} to have a value."
trace[to_additive] "generating\n{tgt} : {trgDecl.type} :=\n {value}"
try
-- make sure that the type is correct,
-- and emit a more helpful error message if it fails
discard <| MetaM.run' <| inferType trgDecl.value!
discard <| MetaM.run' <| inferType value
catch
| Exception.error _ msg => throwError "@[to_additive] failed.
Type mismatch in additive declaration. For help, see the docstring
Expand Down Expand Up @@ -961,11 +971,16 @@ partial def applyAttributes (stx : Syntax) (rawAttrs : Array Syntax) (thisAttr s
Copies equation lemmas and attributes from `src` to `tgt`
-/
partial def copyMetaData (cfg : Config) (src tgt : Name) : CoreM (Array Name) := do
/- We need to generate all equation lemmas for `src` and `tgt`, even for non-recursive
definitions. If we don't do that, the equation lemma for `src` might be generated later
when doing a `rw`, but it won't be generated for `tgt`. -/
additivizeLemmas #[src, tgt] "equation lemmas" fun nm ↦
(·.getD #[]) <$> MetaM.run' (getEqnsFor? nm true)
if let some eqns := eqnsAttribute.find? (← getEnv) src then
unless (eqnsAttribute.find? (← getEnv) tgt).isSome do
for eqn in eqns do _ ← addToAdditiveAttr eqn cfg
eqnsAttribute.add tgt (eqns.map (findTranslation? (← getEnv) · |>.get!))
else
/- We need to generate all equation lemmas for `src` and `tgt`, even for non-recursive
definitions. If we don't do that, the equation lemma for `src` might be generated later
when doing a `rw`, but it won't be generated for `tgt`. -/
additivizeLemmas #[src, tgt] "equation lemmas" fun nm ↦
(·.getD #[]) <$> MetaM.run' (getEqnsFor? nm true)
MetaM.run' <| Elab.Term.TermElabM.run' <|
applyAttributes cfg.ref cfg.attrs `to_additive src tgt

Expand Down
5 changes: 0 additions & 5 deletions Mathlib/Topology/Algebra/InfiniteSum/Basic.lean
Expand Up @@ -64,11 +64,6 @@ def Summable (f : β → α) : Prop :=
∃ a, HasSum f a
#align summable Summable

-- Porting note: `irreducible_def` produces a structure.
-- When a structure is defined, an injectivity theorem of the constructor is
-- generated, which has `simp` attr, but this get a `simpNF` linter.
-- So, this option is required.
set_option genInjectivity false in
/-- `∑' i, f i` is the sum of `f` it exists, or 0 otherwise -/
irreducible_def tsum {β} (f : β → α) :=
if h : Summable f then Classical.choose h else 0
Expand Down
8 changes: 7 additions & 1 deletion test/irreducibleDef.lean
@@ -1,4 +1,5 @@
import Mathlib.Tactic.IrreducibleDef
import Mathlib.Util.WhatsNew

/-- Add two natural numbers, but not during unification. -/
irreducible_def frobnicate (a b : Nat) :=
Expand All @@ -10,9 +11,11 @@ example : frobnicate a 0 = a := by
example : frobnicate a 0 = a :=
frobnicate_def a 0

irreducible_def justAsArbitrary [Inhabited α] : α :=
irreducible_def justAsArbitrary (lemma := myLemma) [Inhabited α] : α :=
default

example : justAsArbitrary = 0 := myLemma

irreducible_def withoutType := 42

irreducible_def withEquations : Nat → Nat
Expand All @@ -35,3 +38,6 @@ protected noncomputable irreducible_def Nat.evenMoreArbitrary : Nat :=

private irreducible_def Real.zero := 42
example : Real.zero = 42 := Real.zero_def

irreducible_def y : Nat := let x := 42; x
example : y = 42 := @y_def
8 changes: 8 additions & 0 deletions test/toAdditiveIrredDef.lean
@@ -0,0 +1,8 @@
import Mathlib.Tactic.IrreducibleDef
import Mathlib.Algebra.Group.Defs

@[to_additive]
irreducible_def mul_conj [Group G] (a b : G) := a⁻¹ * b * a

example [AddGroup A] (a b : A) : add_conj a b = (-a) + b + a :=
add_conj_def a b

0 comments on commit 27e0a79

Please sign in to comment.