@@ -33,6 +33,8 @@ syntax (name := to_additive_ignore_args) "to_additive_ignore_args" num* : attr
33
33
syntax (name := to_additive_relevant_arg) "to_additive_relevant_arg" num : attr
34
34
/-- The `to_additive_reorder` attribute. -/
35
35
syntax (name := to_additive_reorder) "to_additive_reorder" num* : attr
36
+ /-- The `to_additive_fixed_numeral` attribute. -/
37
+ syntax (name := to_additive_fixed_numeral) "to_additive_fixed_numeral" "?" ? : attr
36
38
/-- Remaining arguments of `to_additive`. -/
37
39
syntax to_additiveRest := (ppSpace ident)? (ppSpace str)?
38
40
/-- The `to_additive` attribute. -/
@@ -104,7 +106,7 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ←
104
106
add := fun _ stx ↦ do
105
107
let ids ← match stx with
106
108
| `(attr| to_additive_ignore_args $[$ids:num]*) => pure <| ids.map (·.1 .isNatLit?.get!)
107
- | _ => throwError "unexpected to_additive_ignore_args syntax {stx}"
109
+ | _ => throwUnsupportedSyntax
108
110
return ids.toList
109
111
}
110
112
@@ -131,8 +133,7 @@ initialize reorderAttr : NameMapExtension (List Nat) ←
131
133
add := fun
132
134
| _, `(attr| to_additive_reorder $[$ids:num]*) =>
133
135
pure <| Array.toList <| ids.map (·.1 .isNatLit?.get!)
134
- | _, stx => throwError "unexpected to_additive_reorder syntax {stx}"
135
- }
136
+ | _, _ => throwUnsupportedSyntax }
136
137
137
138
/-- Get the reorder list (defined using `@[to_additive_reorder ...]`) for the given declaration. -/
138
139
def getReorder [Functor M] [MonadEnv M]: Name → M (List Nat)
@@ -170,9 +171,8 @@ initialize relevantArgAttr : NameMapExtension Nat ←
170
171
descr := "Auxiliary attribute for `to_additive` stating" ++
171
172
" which arguments are the types with a multiplicative structure."
172
173
add := fun
173
- | _, `(attr| to_additive_relevant_arg $id) => pure <| id.1 .isNatLit?.get!
174
- | _, stx => throwError "unexpected to_additive_relevant_arg syntax {stx}"
175
- }
174
+ | _, `(attr| to_additive_relevant_arg $id) => pure <| id.1 .isNatLit?.get!.pred
175
+ | _, _ => throwUnsupportedSyntax }
176
176
177
177
/-- Given a declaration name and an argument index, determines whether it
178
178
is relevant. This is used in `applyReplacementFun` where more detail on what it
@@ -182,6 +182,25 @@ def isRelevant [Monad M] [MonadEnv M] (n : Name) (i : Nat) : M Bool := do
182
182
| some j => return i == j
183
183
| none => return i == 0
184
184
185
+ /--
186
+ An attribute that stores all the declarations that deal with numeric literals on fixed types.
187
+ * `@[to_additive_fixed_numeral]` should be added to all functions that take a numeral as argument
188
+ that should never be changed by `@[to_additive]` (because it represents a numeral in a fixed
189
+ type).
190
+ * `@[to_additive_fixed_numeral?]` should be added to all functions that take a numeral as argument
191
+ that should only be changed if `additiveTest` succeeds on the first argument, i.e. when the
192
+ numeral is only translated if the first argument is a variable (or consists of variables).
193
+ -/
194
+ initialize fixedNumeralAttr : NameMapExtension Bool ←
195
+ registerNameMapAttribute {
196
+ name := `to_additive_fixed_numeral
197
+ descr :=
198
+ "Auxiliary attribute for `to_additive` that stores functions that have numerals as argument."
199
+ add := fun
200
+ | _, `(attr| to_additive_fixed_numeral $[?%$conditional]?) =>
201
+ pure <| conditional.isSome
202
+ | _, _ => throwUnsupportedSyntax }
203
+
185
204
/-- Maps multiplicative names to their additive counterparts. -/
186
205
initialize translations : NameMapExtension Name ← registerNameMapExtension _
187
206
@@ -238,18 +257,23 @@ def additiveTest (e : Expr) : M Bool := do
238
257
else
239
258
additiveTestAux false e
240
259
260
+ /-- Checks whether a numeral should be translated. -/
261
+ def shouldTranslateNumeral [Monad M] [MonadEnv M] (n : Name) (firstArg : Expr) : M Bool := do
262
+ match fixedNumeralAttr.find? (← getEnv) n with
263
+ | some true => additiveTest firstArg
264
+ | some false => return false
265
+ | none => return true
266
+
241
267
/--
242
- `e. applyReplacementFun f test` applies `f` to each identifier
243
- (inductive type, defined function etc) in an expression, unless
268
+ `applyReplacementFun e` replaces the expression `e` with its additive counterpart.
269
+ It translates each identifier (inductive type, defined function etc) in an expression, unless
244
270
* The identifier occurs in an application with first argument `arg`; and
245
271
* `test arg` is false.
246
272
However, if `f` is in the dictionary `relevant`, then the argument `relevant.find f`
247
273
is tested, instead of the first argument.
248
274
249
- Reorder contains the information about what arguments to reorder:
250
- e.g. `g x₁ x₂ x₃ ... xₙ` becomes `g x₂ x₁ x₃ ... xₙ` if `reorder.find g = some [1]`.
251
- We assume that all functions where we want to reorder arguments are fully applied.
252
- This can be done by applying `etaExpand` first.
275
+ It will also reorder arguments of certain functions, using `shouldReorder`:
276
+ e.g. `g x₁ x₂ x₃ ... xₙ` becomes `g x₂ x₁ x₃ ... xₙ` if `reorderAttr.find? env g = some [1]`.
253
277
-/
254
278
def applyReplacementFun : Expr → MetaM Expr :=
255
279
Lean.Expr.replaceRecMeta fun r e ↦ do
@@ -258,7 +282,8 @@ def applyReplacementFun : Expr → MetaM Expr :=
258
282
| .lit (.natVal 1 ) => pure <| mkRawNatLit 0
259
283
| .const n₀ ls => do
260
284
let n₁ := Name.mapPrefix (findTranslation? <|← getEnv) n₀
261
- trace[to_additive_detail] "applyReplacementFun: {n₀} → {n₁}"
285
+ if n₀ != n₁ then
286
+ trace[to_additive_detail] "applyReplacementFun: {n₀} → {n₁}"
262
287
let ls : List Level ← (do -- [ todo ] just get Lean to figure out the levels?
263
288
if ← shouldReorder n₀ 1 then
264
289
return ls.get! 1 ::ls.head!::ls.drop 2
@@ -270,6 +295,7 @@ def applyReplacementFun : Expr → MetaM Expr :=
270
295
let gArgs := g.getAppArgs
271
296
-- e = `(nm y₁ .. yₙ x)
272
297
trace[to_additive_detail] "applyReplacementFun: app {nm} {gArgs} {x}"
298
+ /- Test if arguments should be reordered. -/
273
299
if h : gArgs.size > 0 then
274
300
let c1 ← shouldReorder nm gArgs.size
275
301
let c2 ← additiveTest gArgs[0 ]
@@ -282,16 +308,25 @@ def applyReplacementFun : Expr → MetaM Expr :=
282
308
trace[to_additive_detail]
283
309
"applyReplacementFun: reordering {nm}: {x} ↔ {ga}\n Before: {e}\n After: {e₂}"
284
310
return some e₂
311
+ /- Test if the head should not be replaced. -/
285
312
let c1 ← isRelevant nm gArgs.size
286
313
let c2 := gf.isConst
287
314
let c3 ← additiveTest x
315
+ if c1 && c2 && c3 then
316
+ trace[to_additive_detail]
317
+ "applyReplacementFun: {x} doesn't contain a fixed type, so we will change {nm}"
288
318
if c1 && c2 && not c3 then
289
319
-- the test failed, so don't update the function body.
290
320
trace[to_additive_detail]
291
- "applyReplacementFun: isRelevant and test failed: {nm} {gArgs} {x} "
321
+ "applyReplacementFun: {x} contains a fixed type, so {nm} is not changed "
292
322
let x ← r x
293
323
let args ← gArgs.mapM r
294
324
return some $ mkApp (mkAppN gf args) x
325
+ /- Do not replace numerals in specific types. -/
326
+ let firstArg := if h : gArgs.size > 0 then gArgs[0 ] else x
327
+ if not (← shouldTranslateNumeral nm firstArg) then
328
+ trace[to_additive_detail] "applyReplacementFun: Do not change numeral {g.app x}"
329
+ return some <| g.app x
295
330
return e.updateApp! (← r g) (← r x)
296
331
| _ => return none
297
332
@@ -366,7 +401,7 @@ partial def transformDeclAux
366
401
if env.contains tgt then
367
402
return
368
403
let srcDecl ← getConstInfo src
369
- -- we first transform all the declarations of the form `pre._proof_i `
404
+ -- we first transform all auxilliary declarations generated when elaborating `pre`
370
405
for n in srcDecl.type.listNamesWithPrefix pre do
371
406
transformDeclAux none pre tgt_pre n
372
407
if let some value := srcDecl.value? then
@@ -450,7 +485,7 @@ Find the first argument of `nm` that has a multiplicative type-class on it.
450
485
Returns 1 if there are no types with a multiplicative class as arguments.
451
486
E.g. `prod.group` returns 1, and `pi.has_one` returns 2.
452
487
-/
453
- def firstMultiplicativeArg (nm : Name) : MetaM (Option Nat) := do
488
+ def firstMultiplicativeArg (nm : Name) : MetaM Nat := do
454
489
forallTelescopeReducing (← getConstInfo nm).type fun xs _ ↦ do
455
490
-- xs are the arguments to the constant
456
491
let xs := xs.toList
@@ -466,8 +501,8 @@ def firstMultiplicativeArg (nm : Name) : MetaM (Option Nat) := do
466
501
xs.findIdx? fun x ↦ Expr.containsFVar tgt_arg x.fvarId!
467
502
trace[to_additive_detail] "firstMultiplicativeArg: {l}"
468
503
match l.join with
469
- | [] => return none
470
- | (head :: tail) => return some <| tail.foldr Nat.min head
504
+ | [] => return 0
505
+ | (head :: tail) => return tail.foldr Nat.min head
471
506
472
507
/-- `ValueType` is the type of the arguments that can be provided to `to_additive`. -/
473
508
structure ValueType : Type where
@@ -633,12 +668,12 @@ def targetName (src tgt : Name) (allowAutoName : Bool) : CoreM Name := do
633
668
throwError "to_additive: can't transport {src} to itself."
634
669
return res
635
670
636
- private def proceedFieldsAux (src tgt : Name) (f : Name → CoreM (List String )) : CoreM Unit := do
671
+ private def proceedFieldsAux (src tgt : Name) (f : Name → CoreM (Array Name )) : CoreM Unit := do
637
672
let srcFields ← f src
638
673
let tgtFields ← f tgt
639
- if srcFields.length != tgtFields.length then
674
+ if srcFields.size != tgtFields.size then
640
675
throwError "Failed to map fields of {src}, {tgt} with {srcFields} ↦ {tgtFields}"
641
- for (srcField, tgtField) in List .zip srcFields tgtFields do
676
+ for (srcField, tgtField) in srcFields .zip tgtFields do
642
677
if srcField != tgtField then
643
678
insertTranslation (src ++ srcField) (tgt ++ tgtField)
644
679
@@ -647,12 +682,9 @@ so that future uses of `to_additive` will map them to the corresponding `tgt` fi
647
682
def proceedFields (src tgt : Name) : CoreM Unit := do
648
683
let env : Environment ← getEnv
649
684
let aux := proceedFieldsAux src tgt
650
- aux (fun n ↦ do
651
- let fields := if isStructure env n then getStructureFieldsFlattened env n else #[]
652
- return fields |> .map Name.toString |> Array.toList
653
- )
654
- -- [ todo ] run to_additive on the constructors of n:
655
- -- aux (fun n ↦ (env.constructorsOf n).mmap $ ...
685
+ aux fun n ↦ pure <| if isStructure env n then getStructureFields env n else #[]
686
+ -- We don't have to run toAdditive on the constructor of a structure, since the use of
687
+ -- `Name.mapPrefix` will do that automatically.
656
688
657
689
private def elabToAdditiveAux (ref : Syntax) (replaceAll trace : Bool) (tgt : Option Syntax)
658
690
(doc : Option Syntax) : ValueType :=
@@ -661,8 +693,7 @@ private def elabToAdditiveAux (ref : Syntax) (replaceAll trace : Bool) (tgt : Op
661
693
tgt := match tgt with | some tgt => tgt.getId | none => Name.anonymous
662
694
doc := doc.bind (·.isStrLit?)
663
695
allowAutoName := false
664
- ref
665
- }
696
+ ref }
666
697
667
698
private def elabToAdditive : Syntax → CoreM ValueType
668
699
| `(attr| to_additive%$tk $[!%$replaceAll]? $[?%$trace]? $[$tgt]? $[$doc]?) =>
@@ -676,7 +707,8 @@ def addToAdditiveAttr (src : Name) (val : ValueType) : AttrM Unit := do
676
707
if let some tgt' := findTranslation? (← getEnv) src then
677
708
throwError "{src} already has a to_additive translation {tgt'}."
678
709
insertTranslation src tgt
679
- if let some firstMultArg ← (MetaM.run' <| firstMultiplicativeArg src) then
710
+ let firstMultArg ← MetaM.run' <| firstMultiplicativeArg src
711
+ if firstMultArg != 0 then
680
712
trace[to_additive_detail] "Setting relevant_arg for {src} to be {firstMultArg}."
681
713
relevantArgAttr.add src firstMultArg
682
714
if (← getEnv).contains tgt then
0 commit comments