Skip to content

Commit 822c870

Browse files
committed
fix: fix to_additive handling of Nat and Int (#689)
* Remove `[to_additive]` attributes from `Nat` and `Int`. That attribute disables all useful heuristics for `to_additive` on these types. * deal with literals in specific types, like `Nat`. This uses a new attribute `@[to_additive_fixed_numeral]`. * Add `to_additive_relevant_arg` attribute for heterogenous operations * Fix `to_additive_reorder` arguments on `HPow.hPow` * Fix parsing of `to_additive_relevant_arg` to use 1-indexed numbers instead of 0-indexed numbers * Do not add `relevantArgAttr` on many declarations if they would have the default value anyway. * Revert test `foo2` in `tests/ToAdditive` to the Lean 3 behavior. Previosuly the test would only succeed if the heuristics for `Nat` and `Int` were disabled. * Some cleanup in code, comments and docs
1 parent f017f32 commit 822c870

File tree

5 files changed

+104
-51
lines changed

5 files changed

+104
-51
lines changed

Mathlib/Algebra/Group/Defs.lean

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,16 @@ infixl:65 " +ᵥ " => VAdd.vadd
6060
infixl:65 " -ᵥ " => HasVsub.vsub
6161
infixr:73 " • " => SMul.smul
6262

63-
-- [todo] is this correct? I think it's needed to ensure that additiveTest
64-
-- succeeds if the relevant arg involves Nat.
65-
attribute [to_additive] Nat
66-
attribute [to_additive] Int
67-
6863
attribute [to_additive] Mul
6964
attribute [to_additive] Div
7065
attribute [to_additive] HMul
7166
attribute [to_additive] instHMul
7267
attribute [to_additive] HDiv
7368
attribute [to_additive] instHDiv
7469

70+
attribute [to_additive_relevant_arg 3] HMul HAdd HAdd.hAdd HMul.hMul
7571
attribute [to_additive_reorder 1] HPow
76-
attribute [to_additive_reorder 1 4] HPow.hPow
72+
attribute [to_additive_reorder 1 5] HPow.hPow
7773
attribute [to_additive] HPow
7874

7975
universe u
@@ -482,7 +478,7 @@ need right away.
482478
483479
In the definition, we use `n.succ` instead of `n + 1` in the `nsmul_succ'` and `npow_succ'` fields
484480
to make sure that `to_additive` is not confused (otherwise, it would try to convert `1 : ℕ`
485-
to `0 : ℕ`).
481+
to `0 : ℕ`). Todo: fix this in `to_additive`
486482
-/
487483

488484

Mathlib/Data/Fin/Basic.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ section
7474

7575
variable {n : Nat} [Nonempty (Fin n)]
7676

77+
@[to_additive_fixed_numeral]
7778
instance : OfNat (Fin n) a where
7879
ofNat := Fin.ofNat' a Fin.size_positive'
7980

Mathlib/Init/ZeroOne.lean

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,17 @@ class One (α : Type u) where
2424
one : α
2525
#align has_one One
2626

27-
@[to_additive Zero.toOfNat0]
27+
@[to_additive Zero.toOfNat0, to_additive_fixed_numeral ?]
2828
instance One.toOfNat1 {α} [One α] : OfNat α (nat_lit 1) where
2929
ofNat := ‹One α›.1
30-
@[to_additive Zero.ofOfNat0]
30+
@[to_additive Zero.ofOfNat0, to_additive_fixed_numeral ?]
3131
instance One.ofOfNat1 {α} [OfNat α (nat_lit 1)] : One α where
3232
one := 1
3333

3434
@[deprecated, match_pattern] def bit0 {α : Type u} [Add α] (a : α) : α := a + a
3535

3636
set_option linter.deprecated false in
3737
@[deprecated, match_pattern] def bit1 {α : Type u} [One α] [Add α] (a : α) : α := bit0 a + 1
38+
39+
attribute [to_additive_fixed_numeral ?] OfNat OfNat.ofNat
40+
attribute [to_additive_fixed_numeral] instOfNatNat instOfNatInt

Mathlib/Tactic/ToAdditive.lean

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ syntax (name := to_additive_ignore_args) "to_additive_ignore_args" num* : attr
3333
syntax (name := to_additive_relevant_arg) "to_additive_relevant_arg" num : attr
3434
/-- The `to_additive_reorder` attribute. -/
3535
syntax (name := to_additive_reorder) "to_additive_reorder" num* : attr
36+
/-- The `to_additive_fixed_numeral` attribute. -/
37+
syntax (name := to_additive_fixed_numeral) "to_additive_fixed_numeral" "?"? : attr
3638
/-- Remaining arguments of `to_additive`. -/
3739
syntax to_additiveRest := (ppSpace ident)? (ppSpace str)?
3840
/-- The `to_additive` attribute. -/
@@ -104,7 +106,7 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ←
104106
add := fun _ stx ↦ do
105107
let ids ← match stx with
106108
| `(attr| to_additive_ignore_args $[$ids:num]*) => pure <| ids.map (·.1.isNatLit?.get!)
107-
| _ => throwError "unexpected to_additive_ignore_args syntax {stx}"
109+
| _ => throwUnsupportedSyntax
108110
return ids.toList
109111
}
110112

@@ -131,8 +133,7 @@ initialize reorderAttr : NameMapExtension (List Nat) ←
131133
add := fun
132134
| _, `(attr| to_additive_reorder $[$ids:num]*) =>
133135
pure <| Array.toList <| ids.map (·.1.isNatLit?.get!)
134-
| _, stx => throwError "unexpected to_additive_reorder syntax {stx}"
135-
}
136+
| _, _ => throwUnsupportedSyntax }
136137

137138
/-- Get the reorder list (defined using `@[to_additive_reorder ...]`) for the given declaration. -/
138139
def getReorder [Functor M] [MonadEnv M]: Name → M (List Nat)
@@ -170,9 +171,8 @@ initialize relevantArgAttr : NameMapExtension Nat ←
170171
descr := "Auxiliary attribute for `to_additive` stating" ++
171172
" which arguments are the types with a multiplicative structure."
172173
add := fun
173-
| _, `(attr| to_additive_relevant_arg $id) => pure <| id.1.isNatLit?.get!
174-
| _, stx => throwError "unexpected to_additive_relevant_arg syntax {stx}"
175-
}
174+
| _, `(attr| to_additive_relevant_arg $id) => pure <| id.1.isNatLit?.get!.pred
175+
| _, _ => throwUnsupportedSyntax }
176176

177177
/-- Given a declaration name and an argument index, determines whether it
178178
is relevant. This is used in `applyReplacementFun` where more detail on what it
@@ -182,6 +182,25 @@ def isRelevant [Monad M] [MonadEnv M] (n : Name) (i : Nat) : M Bool := do
182182
| some j => return i == j
183183
| none => return i == 0
184184

185+
/--
186+
An attribute that stores all the declarations that deal with numeric literals on fixed types.
187+
* `@[to_additive_fixed_numeral]` should be added to all functions that take a numeral as argument
188+
that should never be changed by `@[to_additive]` (because it represents a numeral in a fixed
189+
type).
190+
* `@[to_additive_fixed_numeral?]` should be added to all functions that take a numeral as argument
191+
that should only be changed if `additiveTest` succeeds on the first argument, i.e. when the
192+
numeral is only translated if the first argument is a variable (or consists of variables).
193+
-/
194+
initialize fixedNumeralAttr : NameMapExtension Bool ←
195+
registerNameMapAttribute {
196+
name := `to_additive_fixed_numeral
197+
descr :=
198+
"Auxiliary attribute for `to_additive` that stores functions that have numerals as argument."
199+
add := fun
200+
| _, `(attr| to_additive_fixed_numeral $[?%$conditional]?) =>
201+
pure <| conditional.isSome
202+
| _, _ => throwUnsupportedSyntax }
203+
185204
/-- Maps multiplicative names to their additive counterparts. -/
186205
initialize translations : NameMapExtension Name ← registerNameMapExtension _
187206

@@ -238,18 +257,23 @@ def additiveTest (e : Expr) : M Bool := do
238257
else
239258
additiveTestAux false e
240259

260+
/-- Checks whether a numeral should be translated. -/
261+
def shouldTranslateNumeral [Monad M] [MonadEnv M] (n : Name) (firstArg : Expr) : M Bool := do
262+
match fixedNumeralAttr.find? (← getEnv) n with
263+
| some true => additiveTest firstArg
264+
| some false => return false
265+
| none => return true
266+
241267
/--
242-
`e.applyReplacementFun f test` applies `f` to each identifier
243-
(inductive type, defined function etc) in an expression, unless
268+
`applyReplacementFun e` replaces the expression `e` with its additive counterpart.
269+
It translates each identifier (inductive type, defined function etc) in an expression, unless
244270
* The identifier occurs in an application with first argument `arg`; and
245271
* `test arg` is false.
246272
However, if `f` is in the dictionary `relevant`, then the argument `relevant.find f`
247273
is tested, instead of the first argument.
248274
249-
Reorder contains the information about what arguments to reorder:
250-
e.g. `g x₁ x₂ x₃ ... xₙ` becomes `g x₂ x₁ x₃ ... xₙ` if `reorder.find g = some [1]`.
251-
We assume that all functions where we want to reorder arguments are fully applied.
252-
This can be done by applying `etaExpand` first.
275+
It will also reorder arguments of certain functions, using `shouldReorder`:
276+
e.g. `g x₁ x₂ x₃ ... xₙ` becomes `g x₂ x₁ x₃ ... xₙ` if `reorderAttr.find? env g = some [1]`.
253277
-/
254278
def applyReplacementFun : Expr → MetaM Expr :=
255279
Lean.Expr.replaceRecMeta fun r e ↦ do
@@ -258,7 +282,8 @@ def applyReplacementFun : Expr → MetaM Expr :=
258282
| .lit (.natVal 1) => pure <| mkRawNatLit 0
259283
| .const n₀ ls => do
260284
let n₁ := Name.mapPrefix (findTranslation? <|← getEnv) n₀
261-
trace[to_additive_detail] "applyReplacementFun: {n₀} → {n₁}"
285+
if n₀ != n₁ then
286+
trace[to_additive_detail] "applyReplacementFun: {n₀} → {n₁}"
262287
let ls : List Level ← (do -- [todo] just get Lean to figure out the levels?
263288
if ← shouldReorder n₀ 1 then
264289
return ls.get! 1::ls.head!::ls.drop 2
@@ -270,6 +295,7 @@ def applyReplacementFun : Expr → MetaM Expr :=
270295
let gArgs := g.getAppArgs
271296
-- e = `(nm y₁ .. yₙ x)
272297
trace[to_additive_detail] "applyReplacementFun: app {nm} {gArgs} {x}"
298+
/- Test if arguments should be reordered. -/
273299
if h : gArgs.size > 0 then
274300
let c1 ← shouldReorder nm gArgs.size
275301
let c2 ← additiveTest gArgs[0]
@@ -282,16 +308,25 @@ def applyReplacementFun : Expr → MetaM Expr :=
282308
trace[to_additive_detail]
283309
"applyReplacementFun: reordering {nm}: {x} ↔ {ga}\nBefore: {e}\nAfter: {e₂}"
284310
return some e₂
311+
/- Test if the head should not be replaced. -/
285312
let c1 ← isRelevant nm gArgs.size
286313
let c2 := gf.isConst
287314
let c3 ← additiveTest x
315+
if c1 && c2 && c3 then
316+
trace[to_additive_detail]
317+
"applyReplacementFun: {x} doesn't contain a fixed type, so we will change {nm}"
288318
if c1 && c2 && not c3 then
289319
-- the test failed, so don't update the function body.
290320
trace[to_additive_detail]
291-
"applyReplacementFun: isRelevant and test failed: {nm} {gArgs} {x}"
321+
"applyReplacementFun: {x} contains a fixed type, so {nm} is not changed"
292322
let x ← r x
293323
let args ← gArgs.mapM r
294324
return some $ mkApp (mkAppN gf args) x
325+
/- Do not replace numerals in specific types. -/
326+
let firstArg := if h : gArgs.size > 0 then gArgs[0] else x
327+
if not (← shouldTranslateNumeral nm firstArg) then
328+
trace[to_additive_detail] "applyReplacementFun: Do not change numeral {g.app x}"
329+
return some <| g.app x
295330
return e.updateApp! (← r g) (← r x)
296331
| _ => return none
297332

@@ -366,7 +401,7 @@ partial def transformDeclAux
366401
if env.contains tgt then
367402
return
368403
let srcDecl ← getConstInfo src
369-
-- we first transform all the declarations of the form `pre._proof_i`
404+
-- we first transform all auxilliary declarations generated when elaborating `pre`
370405
for n in srcDecl.type.listNamesWithPrefix pre do
371406
transformDeclAux none pre tgt_pre n
372407
if let some value := srcDecl.value? then
@@ -450,7 +485,7 @@ Find the first argument of `nm` that has a multiplicative type-class on it.
450485
Returns 1 if there are no types with a multiplicative class as arguments.
451486
E.g. `prod.group` returns 1, and `pi.has_one` returns 2.
452487
-/
453-
def firstMultiplicativeArg (nm : Name) : MetaM (Option Nat) := do
488+
def firstMultiplicativeArg (nm : Name) : MetaM Nat := do
454489
forallTelescopeReducing (← getConstInfo nm).type fun xs _ ↦ do
455490
-- xs are the arguments to the constant
456491
let xs := xs.toList
@@ -466,8 +501,8 @@ def firstMultiplicativeArg (nm : Name) : MetaM (Option Nat) := do
466501
xs.findIdx? fun x ↦ Expr.containsFVar tgt_arg x.fvarId!
467502
trace[to_additive_detail] "firstMultiplicativeArg: {l}"
468503
match l.join with
469-
| [] => return none
470-
| (head :: tail) => return some <| tail.foldr Nat.min head
504+
| [] => return 0
505+
| (head :: tail) => return tail.foldr Nat.min head
471506

472507
/-- `ValueType` is the type of the arguments that can be provided to `to_additive`. -/
473508
structure ValueType : Type where
@@ -633,12 +668,12 @@ def targetName (src tgt : Name) (allowAutoName : Bool) : CoreM Name := do
633668
throwError "to_additive: can't transport {src} to itself."
634669
return res
635670

636-
private def proceedFieldsAux (src tgt : Name) (f : Name → CoreM (List String)) : CoreM Unit := do
671+
private def proceedFieldsAux (src tgt : Name) (f : Name → CoreM (Array Name)) : CoreM Unit := do
637672
let srcFields ← f src
638673
let tgtFields ← f tgt
639-
if srcFields.length != tgtFields.length then
674+
if srcFields.size != tgtFields.size then
640675
throwError "Failed to map fields of {src}, {tgt} with {srcFields} ↦ {tgtFields}"
641-
for (srcField, tgtField) in List.zip srcFields tgtFields do
676+
for (srcField, tgtField) in srcFields.zip tgtFields do
642677
if srcField != tgtField then
643678
insertTranslation (src ++ srcField) (tgt ++ tgtField)
644679

@@ -647,12 +682,9 @@ so that future uses of `to_additive` will map them to the corresponding `tgt` fi
647682
def proceedFields (src tgt : Name) : CoreM Unit := do
648683
let env : Environment ← getEnv
649684
let aux := proceedFieldsAux src tgt
650-
aux (fun n ↦ do
651-
let fields := if isStructure env n then getStructureFieldsFlattened env n else #[]
652-
return fields |> .map Name.toString |> Array.toList
653-
)
654-
-- [todo] run to_additive on the constructors of n:
655-
-- aux (fun n ↦ (env.constructorsOf n).mmap $ ...
685+
aux fun n ↦ pure <| if isStructure env n then getStructureFields env n else #[]
686+
-- We don't have to run toAdditive on the constructor of a structure, since the use of
687+
-- `Name.mapPrefix` will do that automatically.
656688

657689
private def elabToAdditiveAux (ref : Syntax) (replaceAll trace : Bool) (tgt : Option Syntax)
658690
(doc : Option Syntax) : ValueType :=
@@ -661,8 +693,7 @@ private def elabToAdditiveAux (ref : Syntax) (replaceAll trace : Bool) (tgt : Op
661693
tgt := match tgt with | some tgt => tgt.getId | none => Name.anonymous
662694
doc := doc.bind (·.isStrLit?)
663695
allowAutoName := false
664-
ref
665-
}
696+
ref }
666697

667698
private def elabToAdditive : Syntax → CoreM ValueType
668699
| `(attr| to_additive%$tk $[!%$replaceAll]? $[?%$trace]? $[$tgt]? $[$doc]?) =>
@@ -676,7 +707,8 @@ def addToAdditiveAttr (src : Name) (val : ValueType) : AttrM Unit := do
676707
if let some tgt' := findTranslation? (← getEnv) src then
677708
throwError "{src} already has a to_additive translation {tgt'}."
678709
insertTranslation src tgt
679-
if let some firstMultArg ← (MetaM.run' <| firstMultiplicativeArg src) then
710+
let firstMultArg ← MetaM.run' <| firstMultiplicativeArg src
711+
if firstMultArg != 0 then
680712
trace[to_additive_detail] "Setting relevant_arg for {src} to be {firstMultArg}."
681713
relevantArgAttr.add src firstMultArg
682714
if (← getEnv).contains tgt then

test/toAdditive.lean

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,14 @@ theorem bar1_works : bar1 3 4 = 3 * 4 := by decide
4141

4242
infix:80 " ^ " => my_has_pow.pow
4343

44-
instance dummy_pow : my_has_pow ℕ $ PLift ℤ := ⟨fun _ _ => 0
45-
instance dummy_smul : my_has_scalar (PLift ℤ) ℕ := ⟨fun _ _ => 0
46-
attribute [to_additive dummy_smul] dummy_pow
44+
instance dummy_pow : my_has_pow ℕ $ PLift ℤ := ⟨fun _ _ => 5
4745

4846
set_option pp.universes true
4947
@[to_additive bar2]
5048
def foo2 {α} [my_has_pow α ℕ] (x : α) (n : ℕ) (m : PLift ℤ) : α := x ^ (n ^ m)
5149

52-
theorem foo2_works : foo2 2 3 (PLift.up 2) = Nat.pow 2 0 := by decide
53-
-- [todo] should it still be using dummy?
54-
theorem bar2_works : bar2 2 3 (PLift.up 2) = 2 * (dummy_smul.1 (PLift.up 2) 3) := by decide
50+
theorem foo2_works : foo2 2 3 (PLift.up 2) = Nat.pow 2 5 := by decide
51+
theorem bar2_works : bar2 2 3 (PLift.up 2) = 2 * 5 := by decide
5552

5653
@[to_additive bar3]
5754
def foo3 {α} [my_has_pow α ℕ] (x : α) : ℕ → α := @my_has_pow.pow α ℕ _ x
@@ -77,6 +74,28 @@ def foo7 := @my_has_pow.pow
7774
theorem foo7_works : foo7 2 3 = Nat.pow 2 3 := by decide
7875
theorem bar7_works : bar7 2 3 = 2 * 3 := by decide
7976

77+
/-- Check that we don't additivize `Nat` expressions. -/
78+
@[to_additive bar8]
79+
def foo8 (a b : ℕ) := a * b
80+
81+
theorem bar8_works : bar8 2 3 = 6 := by decide
82+
83+
/-- Check that we don't additivize `Nat` numerals. -/
84+
@[to_additive bar9]
85+
def foo9 := 1
86+
87+
theorem bar9_works : bar9 = 1 := by decide
88+
89+
@[to_additive bar10]
90+
def foo10 (n m : ℕ) := HPow.hPow n m + n * m * 2 + 1 * 0 + 37 * 1 + 2
91+
92+
theorem bar10_works : bar10 = foo10 := by rfl
93+
94+
@[to_additive bar11]
95+
def foo11 (n : ℕ) (m : ℤ) := n * m * 2 + 1 * 0 + 37 * 1 + 2
96+
97+
theorem bar11_works : bar11 = foo11 := by rfl
98+
8099
/- test the eta-expansion applied on `foo6`. -/
81100
run_cmd do
82101
let c ← getConstInfo `Test.foo6
@@ -106,9 +125,9 @@ run_cmd do
106125
Elab.Command.liftCoreM <| successIfFail (getConstInfo `Test.add_some_def.in_namespace)
107126

108127
-- [todo] currently this test breaks.
109-
-- example : (add_units.mk_of_add_eq_zero 0 0 (by simp) : ℕ)
110-
-- = (add_units.mk_of_add_eq_zero 0 0 (by simp) : ℕ) :=
111-
-- by normCast
128+
-- example : (AddUnits.mk_of_add_eq_zero 0 0 (by simp) : ℕ)
129+
-- = (AddUnits.mk_of_add_eq_zero 0 0 (by simp) : ℕ) :=
130+
-- by norm_cast
112131

113132
section
114133

@@ -126,10 +145,10 @@ instance pi.has_one {I : Type} {f : I → Type} [(i : I) → One $ f i] : One ((
126145
run_cmd do
127146
let n ← (Elab.Command.liftCoreM <| Lean.Meta.MetaM.run' <| ToAdditive.firstMultiplicativeArg
128147
`Test.pi.has_one)
129-
if n != some 1 then throwError "{n} != 1"
148+
if n != 1 then throwError "{n} != 1"
130149
let n ← (Elab.Command.liftCoreM <| Lean.Meta.MetaM.run' <| ToAdditive.firstMultiplicativeArg
131150
`Test.foo_mul)
132-
if n != some 4 then throwError "{n} != 4"
151+
if n != 4 then throwError "{n} != 4"
133152

134153
end
135154

@@ -139,6 +158,8 @@ def nat_pi_has_one {α : Type} [One α] : One ((x : Nat) → α) := by infer_ins
139158
@[to_additive]
140159
def pi_nat_has_one {I : Type} : One ((x : I) → Nat) := pi.has_one
141160

161+
example : @pi_nat_has_one = @pi_nat_has_zero := rfl
162+
142163
section noncomputablee
143164

144165
@[to_additive Bar.bar]
@@ -147,7 +168,7 @@ noncomputable def Foo.foo (h : ∃ _ : α, True) : α := Classical.choose h
147168
@[to_additive Bar.bar']
148169
def Foo.foo' : ℕ := 2
149170

150-
#eval Bar.bar'
171+
theorem Bar.bar'_works : Bar.bar' = 2 := by decide
151172

152173
run_cmd (do
153174
if !isNoncomputable (← getEnv) `Bar.bar then throwError "bar shouldn't be computable"

0 commit comments

Comments
 (0)