Skip to content

Commit 7756012

Browse files
committed
fix(to_dual): also store reorder for the reverse translation (#31716)
This PR fixes the bug in `to_dual` that `@[to_dual]` forgets to add the `reorder := ...` flag to the reverse translation. The correct behaviour is that it adds the inverse of the `reorder` permutation for the reverse translation. To fix this, I refactor `insertTranslation`/`insertTranslationAndInfo`. Since `insertTranslationAndInfo` with an empty `ArgInfo` does the same as `insertTranslation`, I decided to merge these two into one function `insertTranslation`. I also merged the `reorderAttr` and `relevantArgAttr` environment extensions into a single `argInfoAttr` environment extension, because this is more convenient. I add a new `ToDual` test file which includes a test for this fix. I replace the implementation of `insert_to_additive_translation` with something more principled, instead of using `insertTranslation`.
1 parent f710fe1 commit 7756012

File tree

5 files changed

+101
-114
lines changed

5 files changed

+101
-114
lines changed

Mathlib/Tactic/Translate/Core.lean

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,22 @@ initialize changeNumeralAttr : NameMapExtension (List Nat) ←
135135
pure <| arg.map (·.1.isNatLit?.get!.pred) |>.toList
136136
| _, _ => throwUnsupportedSyntax }
137137

138+
/-- `ArgInfo` stores information about how a constant should be translated. -/
139+
structure ArgInfo where
140+
/-- The arguments that should be reordered when translating, using cycle notation. -/
141+
reorder : List (List Nat) := []
142+
/-- The argument used to determine whether this constant should be translated. -/
143+
relevantArg : Nat := 0
144+
138145
/-- `TranslateData` is a structure that holds all data required for a translation attribute. -/
139146
structure TranslateData : Type where
140147
/-- An attribute that tells that certain arguments of this definition are not
141148
involved when translating.
142149
This helps the translation heuristic by also transforming definitions if `ℕ` or another
143150
fixed type occurs as one of these arguments. -/
144151
ignoreArgsAttr : NameMapExtension (List Nat)
145-
/-- `reorderAttr` stores the declarations that need their arguments reordered when translating.
146-
This is specified using the `(reorder := ...)` syntax. -/
147-
reorderAttr : NameMapExtension (List <| List Nat)
148-
relevantArgAttr : NameMapExtension Nat
152+
/-- `argInfoAttr` stores the declarations that need some extra information to be translated. -/
153+
argInfoAttr : NameMapExtension ArgInfo
149154
/-- The global `dont_translate` attribute specifies that operations on the given type
150155
should not be translated. This can be either for types that are translated,
151156
such as `MonoidAlgebra` -> `AddMonoidAlgebra`, or for fixed types, such as `Fin n`/`ZMod n`.
@@ -165,7 +170,6 @@ structure TranslateData : Type where
165170
isDual : Bool
166171
guessNameData : GuessName.GuessNameData
167172

168-
attribute [inherit_doc relevantArgOption] TranslateData.relevantArgAttr
169173
attribute [inherit_doc GuessName.GuessNameData] TranslateData.guessNameData
170174

171175
/-- Get the translation for the given name. -/
@@ -178,44 +182,36 @@ This allows translating automatically generated declarations such as `IsRegular.
178182
def findPrefixTranslation (env : Environment) (nm : Name) (t : TranslateData) : Name :=
179183
nm.mapPrefix (findTranslation? env t)
180184

181-
/-- Add a name translation to the translations map. -/
182-
def insertTranslation (t : TranslateData) (src tgt : Name) (failIfExists := true) : CoreM Unit := do
183-
if let some tgt' := findTranslation? (← getEnv) t src then
184-
if failIfExists then
185-
throwError "The translation {src} ↦ {tgt'} already exists"
186-
else
187-
trace[translate] "The translation {src} ↦ {tgt'} already exists"
188-
return
189-
modifyEnv (t.translations.addEntry · (src, tgt))
190-
trace[translate] "Added translation {src} ↦ {tgt}"
191-
-- For an attribute like `to_dual`, we also insert the reverse direction of the translation
185+
/-- Compute the `ArgInfo` for the reverse translation. The `reorder` permutation is inverted.
186+
In practice, `relevantArg` does not overlap with `reorder` for dual translations,
187+
so we don't bother applying the permutation to `relevantArg`. -/
188+
def ArgInfo.reverse (i : ArgInfo) : ArgInfo where
189+
reorder := i.reorder.map (·.reverse)
190+
relevantArg := i.relevantArg
191+
192+
/-- Add a name translation to the translations map and add the `argInfo` information to `src`.
193+
If the translation attribute is dual, also add the reverse translation. -/
194+
def insertTranslation (t : TranslateData) (src tgt : Name) (argInfo : ArgInfo)
195+
(failIfExists := true) : CoreM Unit := do
196+
insertTranslationAux t src tgt failIfExists argInfo
192197
if t.isDual && src != tgt then
193-
if let some src' := findTranslation? (← getEnv) t tgt then
198+
insertTranslationAux t tgt src failIfExists argInfo.reverse
199+
where
200+
/-- Insert only one direction of a translation. -/
201+
insertTranslationAux (t : TranslateData) (src tgt : Name) (failIfExists : Bool)
202+
(argInfo : ArgInfo) : CoreM Unit := do
203+
if let some tgt' := findTranslation? (← getEnv) t src then
194204
if failIfExists then
195-
throwError "The translation {tgt} ↦ {src'} already exists"
205+
throwError "The translation {src} ↦ {tgt'} already exists"
196206
else
197-
trace[translate] "The translation {tgt} ↦ {src'} already exists"
198-
return
199-
modifyEnv (t.translations.addEntry · (tgt, src))
200-
trace[translate] "Also added translation {tgt} ↦ {src}"
201-
202-
/-- `ArgInfo` stores information about how a constant should be translated. -/
203-
structure ArgInfo where
204-
/-- The arguments that should be reordered when translating, using cycle notation. -/
205-
reorder : List (List Nat) := []
206-
/-- The argument used to determine whether this constant should be translated. -/
207-
relevantArg : Nat := 0
208-
209-
/-- Add a name translation to the translations map and add the `argInfo` information to `src`. -/
210-
def insertTranslationAndInfo (t : TranslateData) (src tgt : Name) (argInfo : ArgInfo)
211-
(failIfExists := true) : CoreM Unit := do
212-
insertTranslation t src tgt failIfExists
213-
if argInfo.reorder != [] then
214-
trace[translate] "@[{t.attrName}] will reorder the arguments of {tgt} by {argInfo.reorder}."
215-
t.reorderAttr.add src argInfo.reorder
216-
if argInfo.relevantArg != 0 then
217-
trace[translate_detail] "Setting relevant_arg for {src} to be {argInfo.relevantArg}."
218-
t.relevantArgAttr.add src argInfo.relevantArg
207+
trace[translate] "The translation {src} ↦ {tgt'} already exists"
208+
else
209+
modifyEnv (t.translations.addEntry · (src, tgt))
210+
trace[translate] "Added translation {src} ↦ {tgt}"
211+
unless argInfo matches {} do
212+
trace[translate] "@[{t.attrName}] will reorder the arguments of {src} by {argInfo.reorder}."
213+
trace[translate_detail] "Setting relevant_arg for {src} to be {argInfo.relevantArg}."
214+
modifyEnv (t.argInfoAttr.addEntry · (src, argInfo))
219215

220216
/-- `Config` is the type of the arguments that can be provided to `to_additive`. -/
221217
structure Config : Type where
@@ -266,7 +262,6 @@ They are expanded until they are applied to one more argument than the maximum i
266262
It also expands all kernel projections that have as head a constant `n` in `reorder`. -/
267263
def expand (t : TranslateData) (e : Expr) : MetaM Expr := do
268264
let env ← getEnv
269-
let reorderFn : Name → List (List ℕ) := fun nm ↦ (t.reorderAttr.find? env nm |>.getD [])
270265
let e₂ ← Lean.Meta.transform (input := e) (skipConstInApp := true)
271266
(post := fun e => return .done e) fun e ↦
272267
e.withApp fun f args ↦ do
@@ -281,11 +276,11 @@ def expand (t : TranslateData) (e : Expr) : MetaM Expr := do
281276
return .visit <| (← whnfD (← inferType s)).withApp fun sf sargs ↦
282277
mkAppN (mkApp (mkAppN (.const projName sf.constLevels!) sargs) s) args
283278
| .const c _ =>
284-
let reorder := reorderFn c
285-
if reorder.isEmpty then
279+
let some info := t.argInfoAttr.find? env c | return .continue
280+
if info.reorder.isEmpty then
286281
-- no need to expand if nothing needs reordering
287282
return .continue
288-
let needed_n := reorder.flatten.foldr Nat.max 0 + 1
283+
let needed_n := info.reorder.flatten.foldr Nat.max 0 + 1
289284
if needed_n ≤ args.size then
290285
return .continue
291286
else
@@ -387,19 +382,18 @@ def applyReplacementFun (t : TranslateData) (e : Expr) (dontTranslate : Array FV
387382
return e'
388383
where /-- Implementation of `applyReplacementFun`. -/
389384
aux (env : Environment) (trace : Bool) : Expr → Expr :=
390-
let reorderFn : Name → List (List ℕ) := fun nm ↦ (t.reorderAttr.find? env nm |>.getD [])
391-
let relevantArg : Name → ℕ := fun nm ↦ (t.relevantArgAttr.find? env nm).getD 0
392385
Lean.Expr.replaceRec fun r e ↦ Id.run do
393386
if trace then
394387
dbg_trace s!"replacing at {e}"
395388
match e with
396389
| .const n₀ ls₀ => do
397390
let n₁ := findPrefixTranslation env n₀ t
398-
let ls₁ : List Level := if 0 ∈ (reorderFn n₀).flatten then ls₀.swapFirstTwo else ls₀
391+
let swapUniv := (t.argInfoAttr.find? env n₀).elim false (·.reorder.any (·.contains 0))
392+
let ls₁ : List Level := if swapUniv then ls₀.swapFirstTwo else ls₀
399393
if trace then
400394
if n₀ != n₁ then
401395
dbg_trace s!"changing {n₀} to {n₁}"
402-
if 0 ∈ (reorderFn n₀).flatten then
396+
if swapUniv then
403397
dbg_trace s!"reordering the universe variables from {ls₀} to {ls₁}"
404398
return some <| .const n₁ ls₁
405399
| .app g x => do
@@ -412,9 +406,9 @@ where /-- Implementation of `applyReplacementFun`. -/
412406
let some nm := gf.constName? | return mkAppN (← r gf) (← gAllArgs.mapM r)
413407
-- e = `(nm y₁ .. yₙ x)
414408
/- Test if the head should not be replaced. -/
415-
let relevantArgId := relevantArg nm
416-
if h : relevantArgId < gAllArgs.size then
417-
if let some fxd := shouldTranslate env t gAllArgs[relevantArgId] dontTranslate then
409+
let { reorder, relevantArg } := t.argInfoAttr.find? env nm |>.getD {}
410+
if h : relevantArg < gAllArgs.size then
411+
if let some fxd := shouldTranslate env t gAllArgs[relevantArg] dontTranslate then
418412
if trace then
419413
match fxd with
420414
| .inl fxd => dbg_trace s!"The application of {nm} contains the fixed type \
@@ -424,7 +418,6 @@ where /-- Implementation of `applyReplacementFun`. -/
424418
else
425419
gf ← r gf
426420
/- Test if arguments should be reordered. -/
427-
let reorder := reorderFn nm
428421
if !reorder.isEmpty then
429422
gAllArgs := gAllArgs.permute! reorder
430423
if trace then
@@ -653,7 +646,7 @@ partial def transformDeclAux (t : TranslateData) (cfg : Config) (pre tgt_pre : N
653646
-- if the auxiliary declaration doesn't have prefix `pre`, then we have to add this declaration
654647
-- to the translation dictionary, since otherwise we cannot translate the name.
655648
if !pre.isPrefixOf src then
656-
insertTranslation t src tgt
649+
insertTranslation t src tgt {}
657650
-- now transform the source declaration
658651
let trgDecl : ConstantInfo ← MetaM.run' <|
659652
if src == pre then
@@ -761,7 +754,7 @@ def translateLemmas {m : Type → Type} [Monad m] [MonadError m] [MonadLiftT Cor
761754
throwError "{names[0]!} and {nm} do not generate the same number of {desc}."
762755
for (srcLemmas, tgtLemmas) in auxLemmas.zip <| auxLemmas.eraseIdx! 0 do
763756
for (srcLemma, tgtLemma) in srcLemmas.zip tgtLemmas do
764-
insertTranslationAndInfo t srcLemma tgtLemma argInfo
757+
insertTranslation t srcLemma tgtLemma argInfo
765758

766759
/--
767760
Find the argument of `nm` that appears in the first translatable (type-class) argument.
@@ -778,7 +771,7 @@ def findRelevantArg (t : TranslateData) (nm : Name) : MetaM Nat := do
778771
let relevantArg? (tgt : Expr) : Option Nat := do
779772
let c ← tgt.getAppFn.constName?
780773
guard (findTranslation? env t c).isSome
781-
let relevantArg := (t.relevantArgAttr.find? env c).getD 0
774+
let relevantArg := (t.argInfoAttr.find? env c).elim 0 (·.relevantArg)
782775
let arg ← tgt.getArg? relevantArg
783776
xs.findIdx? (arg.containsFVar ·.fvarId!)
784777
-- run the above check on all hypotheses and on the conclusion
@@ -825,7 +818,7 @@ def proceedFieldsAux (t : TranslateData) (src tgt : Name) (argInfo : ArgInfo)
825818
throwError "Failed to map fields of {src}, {tgt} with {srcFields} ↦ {tgtFields}.\n \
826819
Lengths do not match."
827820
for srcField in srcFields, tgtField in tgtFields do
828-
insertTranslationAndInfo t srcField tgtField argInfo
821+
insertTranslation t srcField tgtField argInfo
829822

830823
/-- Add the structure fields of `src` to the translations dictionary
831824
so that they will be translated correctly. -/
@@ -1068,15 +1061,10 @@ partial def addTranslationAttr (t : TranslateData) (src : Name) (cfg : Config)
10681061
-- If `tgt` is not in the environment, the translation to `tgt` was added only for
10691062
-- translating the namespace, and `src` wasn't actually tagged.
10701063
if (← getEnv).contains tgt then
1071-
let mut updated := false
1072-
if cfg.reorder != [] then
1073-
modifyEnv (t.reorderAttr.addEntry · (src, cfg.reorder))
1074-
updated := true
1075-
if let some relevantArg := cfg.relevantArg? then
1076-
modifyEnv (t.relevantArgAttr.addEntry · (src, relevantArg))
1077-
updated := true
1078-
if updated then
1064+
if cfg.reorder != [] || cfg.relevantArg?.isSome then
10791065
MetaM.run' <| checkExistingType t src tgt cfg.reorder cfg.dontTranslate
1066+
let argInfo := { reorder := cfg.reorder, relevantArg := cfg.relevantArg?.getD 0 }
1067+
insertTranslation t src tgt argInfo false
10801068
return #[tgt]
10811069
throwError
10821070
"Cannot apply attribute @[{t.attrName}] to '{src}': it is already translated to '{tgt}'. \n\
@@ -1095,7 +1083,7 @@ partial def addTranslationAttr (t : TranslateData) (src : Name) (cfg : Config)
10951083
MetaM.run' <| checkExistingType t src tgt cfg.reorder cfg.dontTranslate
10961084
let relevantArg ← cfg.relevantArg?.getDM <| MetaM.run' <| findRelevantArg t src
10971085
let argInfo := { reorder := cfg.reorder, relevantArg }
1098-
insertTranslationAndInfo t src tgt argInfo alreadyExists
1086+
insertTranslation t src tgt argInfo alreadyExists
10991087
let nestedNames ←
11001088
if alreadyExists then
11011089
-- since `tgt` already exists, we just need to copy metadata and

Mathlib/Tactic/Translate/ToAdditive.lean

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -258,28 +258,8 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ←
258258
| _ => throwUnsupportedSyntax
259259
return ids.toList }
260260

261-
/-- An extension that stores all the declarations that need their arguments reordered when
262-
applying `@[to_additive]`. It is applied using the `to_additive (reorder := ...)` syntax. -/
263-
initialize reorderAttr : NameMapExtension (List (List Nat)) ←
264-
registerNameMapExtension _
265-
266-
/-- Linter to check that the `relevant_arg` attribute is not given manually -/
267-
register_option linter.toAdditiveRelevantArg : Bool := {
268-
defValue := true
269-
descr := "Linter to check that the `relevant_arg` attribute is not given manually." }
270-
271-
@[inherit_doc to_additive_relevant_arg]
272-
initialize relevantArgAttr : NameMapExtension Nat ←
273-
registerNameMapAttribute {
274-
name := `to_additive_relevant_arg
275-
descr := "Auxiliary attribute for `to_additive` stating \
276-
which arguments are the types with a multiplicative structure."
277-
add := fun
278-
| _, stx@`(attr| to_additive_relevant_arg $id) => do
279-
Linter.logLintIf linter.toAdditiveRelevantArg stx
280-
m!"This attribute is deprecated. Use `@[to_additive (relevant_arg := ...)]` instead."
281-
pure <| id.getNat.pred
282-
| _, _ => throwUnsupportedSyntax }
261+
@[inherit_doc TranslateData.argInfoAttr]
262+
initialize argInfoAttr : NameMapExtension ArgInfo ← registerNameMapExtension _
283263

284264
@[inherit_doc to_additive_dont_translate]
285265
initialize dontTranslateAttr : NameMapExtension Unit ←
@@ -389,11 +369,7 @@ def abbreviationDict : Std.HashMap String String := .ofList [
389369

390370
/-- The bundle of environment extensions for `to_additive` -/
391371
def data : TranslateData where
392-
ignoreArgsAttr := ignoreArgsAttr
393-
reorderAttr := reorderAttr
394-
relevantArgAttr := relevantArgAttr
395-
dontTranslateAttr := dontTranslateAttr
396-
translations := translations
372+
ignoreArgsAttr; argInfoAttr; dontTranslateAttr; translations
397373
attrName := `to_additive
398374
changeNumeral := true
399375
isDual := false
@@ -412,6 +388,6 @@ initialize registerBuiltinAttribute {
412388
into the `to_additive` dictionary. This is useful for translating namespaces that don't (yet)
413389
have a corresponding translated declaration. -/
414390
elab "insert_to_additive_translation" src:ident tgt:ident : command => do
415-
Command.liftCoreM <| insertTranslation data src.getId tgt.getId
391+
translations.add src.getId tgt.getId
416392

417393
end Mathlib.Tactic.ToAdditive

Mathlib/Tactic/Translate/ToDual.lean

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,8 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ←
9696
| _ => throwUnsupportedSyntax
9797
return ids.toList }
9898

99-
/-- An extension that stores all the declarations that need their arguments reordered when
100-
applying `@[to_dual]`. It is applied using the `to_dual (reorder := ...)` syntax. -/
101-
initialize reorderAttr : NameMapExtension (List (List Nat)) ←
102-
registerNameMapExtension _
103-
104-
@[inherit_doc to_dual_relevant_arg]
105-
initialize relevantArgAttr : NameMapExtension Nat ←
106-
registerNameMapAttribute {
107-
name := `to_dual_relevant_arg
108-
descr := "Auxiliary attribute for `to_dual` stating \
109-
which arguments are the types with a dual structure."
110-
add := fun
111-
| _, `(attr| to_dual_relevant_arg $id) => pure <| id.1.isNatLit?.get!.pred
112-
| _, _ => throwUnsupportedSyntax }
99+
@[inherit_doc TranslateData.argInfoAttr]
100+
initialize argInfoAttr : NameMapExtension ArgInfo ← registerNameMapExtension _
113101

114102
@[inherit_doc to_dual_dont_translate]
115103
initialize dontTranslateAttr : NameMapExtension Unit ←
@@ -177,11 +165,7 @@ def abbreviationDict : Std.HashMap String String := .ofList []
177165

178166
/-- The bundle of environment extensions for `to_dual` -/
179167
def data : TranslateData where
180-
ignoreArgsAttr := ignoreArgsAttr
181-
reorderAttr := reorderAttr
182-
relevantArgAttr := relevantArgAttr
183-
dontTranslateAttr := dontTranslateAttr
184-
translations := translations
168+
ignoreArgsAttr; argInfoAttr; dontTranslateAttr; translations
185169
attrName := `to_dual
186170
changeNumeral := false
187171
isDual := true

MathlibTest/ToDual.lean

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import Mathlib.Order.Defs.PartialOrder
2+
import Mathlib.Order.Notation
3+
4+
-- test that we can translate between structures, reordering the arguments of the fields
5+
class SemilatticeInf (α : Type) extends PartialOrder α, Min α where
6+
le_inf : ∀ a b c : α, a ≤ b → a ≤ c → a ≤ b ⊓ c
7+
8+
class SemilatticeSup (α : Type) extends PartialOrder α, Max α where
9+
protected sup_le : ∀ a b c : α, a ≤ c → b ≤ c → a ⊔ b ≤ c
10+
11+
attribute [to_dual] SemilatticeInf
12+
attribute [to_dual (reorder := 3 4 5)] SemilatticeSup.sup_le
13+
14+
@[to_dual]
15+
lemma SemilatticeInf.le_inf' {α : Type} [SemilatticeInf α] (a b c : α) : a ≤ b → a ≤ c → a ≤ b ⊓ c :=
16+
SemilatticeInf.le_inf a b c
17+
18+
@[to_dual]
19+
lemma SemilatticeSup.sup_le' {α : Type} [SemilatticeSup α] (a b c : α) : a ≤ c → b ≤ c → a ⊔ b ≤ c :=
20+
SemilatticeSup.sup_le a b c
21+
22+
-- we still cannot reorder arguments of arguments, so `SemilatticeInf.mk` is not tranlatable
23+
/--
24+
error: @[to_dual] failed. The translated value is not type correct. For help, see the docstring of `to_additive`, section `Troubleshooting`. Failed to add declaration
25+
instSemilatticeSupOfForallLeForallMax:
26+
Application type mismatch: The argument
27+
le_inf
28+
has type
29+
∀ (a b c : α), b ≤ a → c ≤ a → b ⊔ c ≤ a
30+
but is expected to have type
31+
∀ (a b c : α), a ≤ c → b ≤ c → a ⊔ b ≤ c
32+
in the application
33+
{ toPartialOrder := inst✝¹, toMax := inst✝, sup_le := le_inf }
34+
-/
35+
#guard_msgs in
36+
@[to_dual]
37+
instance {α : Type} [PartialOrder α] [Min α]
38+
(le_inf : ∀ a b c : α, a ≤ b → a ≤ c → a ≤ b ⊓ c) : SemilatticeInf α where
39+
le_inf

MathlibTest/toAdditive.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,14 @@ example {x} (h : 1 = x) : baz20 = x := by simp; guard_target = 1 = x; exact h
156156
@[to_additive bar21]
157157
def foo21 {N} {A} [Pow A N] (a : A) (n : N) : A := a ^ n
158158

159-
run_cmd liftCoreM <| MetaM.run' <| guard <| relevantArgAttr.find? (← getEnv) `Test.foo21 == some 1
159+
run_meta guard <| argInfoAttr.find? (← getEnv) `Test.foo21 matches some ⟨[], 1
160160

161161
@[to_additive bar22]
162162
abbrev foo22 {α} [Monoid α] (a : α) : ℕ → α
163163
| 0 => 1
164164
| _ => a
165165

166-
run_cmd liftCoreM <| MetaM.run' <| do
166+
run_meta do
167167
-- make `abbrev` definition `reducible` automatically
168168
guard <| (← getReducibilityStatus `Test.bar22) == .reducible
169169
-- make `abbrev` definition `inline` automatically

0 commit comments

Comments
 (0)