@@ -135,17 +135,22 @@ initialize changeNumeralAttr : NameMapExtension (List Nat) ←
135135 pure <| arg.map (·.1 .isNatLit?.get!.pred) |>.toList
136136 | _, _ => throwUnsupportedSyntax }
137137
138+ /-- `ArgInfo` stores information about how a constant should be translated. -/
139+ structure ArgInfo where
140+ /-- The arguments that should be reordered when translating, using cycle notation. -/
141+ reorder : List (List Nat) := []
142+ /-- The argument used to determine whether this constant should be translated. -/
143+ relevantArg : Nat := 0
144+
138145/-- `TranslateData` is a structure that holds all data required for a translation attribute. -/
139146structure TranslateData : Type where
140147 /-- An attribute that tells that certain arguments of this definition are not
141148 involved when translating.
142149 This helps the translation heuristic by also transforming definitions if `ℕ` or another
143150 fixed type occurs as one of these arguments. -/
144151 ignoreArgsAttr : NameMapExtension (List Nat)
145- /-- `reorderAttr` stores the declarations that need their arguments reordered when translating.
146- This is specified using the `(reorder := ...)` syntax. -/
147- reorderAttr : NameMapExtension (List <| List Nat)
148- relevantArgAttr : NameMapExtension Nat
152+ /-- `argInfoAttr` stores the declarations that need some extra information to be translated. -/
153+ argInfoAttr : NameMapExtension ArgInfo
149154 /-- The global `dont_translate` attribute specifies that operations on the given type
150155 should not be translated. This can be either for types that are translated,
151156 such as `MonoidAlgebra` -> `AddMonoidAlgebra`, or for fixed types, such as `Fin n`/`ZMod n`.
@@ -165,7 +170,6 @@ structure TranslateData : Type where
165170 isDual : Bool
166171 guessNameData : GuessName.GuessNameData
167172
168- attribute [inherit_doc relevantArgOption] TranslateData.relevantArgAttr
169173attribute [inherit_doc GuessName.GuessNameData] TranslateData.guessNameData
170174
171175/-- Get the translation for the given name. -/
@@ -178,44 +182,36 @@ This allows translating automatically generated declarations such as `IsRegular.
178182def findPrefixTranslation (env : Environment) (nm : Name) (t : TranslateData) : Name :=
179183 nm.mapPrefix (findTranslation? env t)
180184
181- /-- Add a name translation to the translations map. -/
182- def insertTranslation (t : TranslateData) (src tgt : Name) (failIfExists := true ) : CoreM Unit := do
183- if let some tgt' := findTranslation? (← getEnv) t src then
184- if failIfExists then
185- throwError "The translation {src} ↦ {tgt'} already exists"
186- else
187- trace[translate] "The translation {src} ↦ {tgt'} already exists"
188- return
189- modifyEnv (t.translations.addEntry · (src, tgt))
190- trace[translate] "Added translation {src} ↦ {tgt}"
191- -- For an attribute like `to_dual`, we also insert the reverse direction of the translation
185+ /-- Compute the `ArgInfo` for the reverse translation. The `reorder` permutation is inverted.
186+ In practice, `relevantArg` does not overlap with `reorder` for dual translations,
187+ so we don't bother applying the permutation to `relevantArg`. -/
188+ def ArgInfo.reverse (i : ArgInfo) : ArgInfo where
189+ reorder := i.reorder.map (·.reverse)
190+ relevantArg := i.relevantArg
191+
192+ /-- Add a name translation to the translations map and add the `argInfo` information to `src`.
193+ If the translation attribute is dual, also add the reverse translation. -/
194+ def insertTranslation (t : TranslateData) (src tgt : Name) (argInfo : ArgInfo)
195+ (failIfExists := true ) : CoreM Unit := do
196+ insertTranslationAux t src tgt failIfExists argInfo
192197 if t.isDual && src != tgt then
193- if let some src' := findTranslation? (← getEnv) t tgt then
198+ insertTranslationAux t tgt src failIfExists argInfo.reverse
199+ where
200+ /-- Insert only one direction of a translation. -/
201+ insertTranslationAux (t : TranslateData) (src tgt : Name) (failIfExists : Bool)
202+ (argInfo : ArgInfo) : CoreM Unit := do
203+ if let some tgt' := findTranslation? (← getEnv) t src then
194204 if failIfExists then
195- throwError "The translation {tgt } ↦ {src '} already exists"
205+ throwError "The translation {src } ↦ {tgt '} already exists"
196206 else
197- trace[translate] "The translation {tgt} ↦ {src'} already exists"
198- return
199- modifyEnv (t.translations.addEntry · (tgt, src))
200- trace[translate] "Also added translation {tgt} ↦ {src}"
201-
202- /-- `ArgInfo` stores information about how a constant should be translated. -/
203- structure ArgInfo where
204- /-- The arguments that should be reordered when translating, using cycle notation. -/
205- reorder : List (List Nat) := []
206- /-- The argument used to determine whether this constant should be translated. -/
207- relevantArg : Nat := 0
208-
209- /-- Add a name translation to the translations map and add the `argInfo` information to `src`. -/
210- def insertTranslationAndInfo (t : TranslateData) (src tgt : Name) (argInfo : ArgInfo)
211- (failIfExists := true ) : CoreM Unit := do
212- insertTranslation t src tgt failIfExists
213- if argInfo.reorder != [] then
214- trace[translate] "@[{t.attrName}] will reorder the arguments of {tgt} by {argInfo.reorder}."
215- t.reorderAttr.add src argInfo.reorder
216- if argInfo.relevantArg != 0 then
217- trace[translate_detail] "Setting relevant_arg for {src} to be {argInfo.relevantArg}."
218- t.relevantArgAttr.add src argInfo.relevantArg
207+ trace[translate] "The translation {src} ↦ {tgt'} already exists"
208+ else
209+ modifyEnv (t.translations.addEntry · (src, tgt))
210+ trace[translate] "Added translation {src} ↦ {tgt}"
211+ unless argInfo matches {} do
212+ trace[translate] "@[{t.attrName}] will reorder the arguments of {src} by {argInfo.reorder}."
213+ trace[translate_detail] "Setting relevant_arg for {src} to be {argInfo.relevantArg}."
214+ modifyEnv (t.argInfoAttr.addEntry · (src, argInfo))
219215
220216/-- `Config` is the type of the arguments that can be provided to `to_additive`. -/
221217structure Config : Type where
@@ -266,7 +262,6 @@ They are expanded until they are applied to one more argument than the maximum i
266262It also expands all kernel projections that have as head a constant `n` in `reorder`. -/
267263def expand (t : TranslateData) (e : Expr) : MetaM Expr := do
268264 let env ← getEnv
269- let reorderFn : Name → List (List ℕ) := fun nm ↦ (t.reorderAttr.find? env nm |>.getD [])
270265 let e₂ ← Lean.Meta.transform (input := e) (skipConstInApp := true )
271266 (post := fun e => return .done e) fun e ↦
272267 e.withApp fun f args ↦ do
@@ -281,11 +276,11 @@ def expand (t : TranslateData) (e : Expr) : MetaM Expr := do
281276 return .visit <| (← whnfD (← inferType s)).withApp fun sf sargs ↦
282277 mkAppN (mkApp (mkAppN (.const projName sf.constLevels!) sargs) s) args
283278 | .const c _ =>
284- let reorder := reorderFn c
285- if reorder.isEmpty then
279+ let some info := t.argInfoAttr.find? env c | return .continue
280+ if info. reorder.isEmpty then
286281 -- no need to expand if nothing needs reordering
287282 return .continue
288- let needed_n := reorder.flatten.foldr Nat.max 0 + 1
283+ let needed_n := info. reorder.flatten.foldr Nat.max 0 + 1
289284 if needed_n ≤ args.size then
290285 return .continue
291286 else
@@ -387,19 +382,18 @@ def applyReplacementFun (t : TranslateData) (e : Expr) (dontTranslate : Array FV
387382 return e'
388383where /-- Implementation of `applyReplacementFun`. -/
389384 aux (env : Environment) (trace : Bool) : Expr → Expr :=
390- let reorderFn : Name → List (List ℕ) := fun nm ↦ (t.reorderAttr.find? env nm |>.getD [])
391- let relevantArg : Name → ℕ := fun nm ↦ (t.relevantArgAttr.find? env nm).getD 0
392385 Lean.Expr.replaceRec fun r e ↦ Id.run do
393386 if trace then
394387 dbg_trace s! "replacing at { e} "
395388 match e with
396389 | .const n₀ ls₀ => do
397390 let n₁ := findPrefixTranslation env n₀ t
398- let ls₁ : List Level := if 0 ∈ (reorderFn n₀).flatten then ls₀.swapFirstTwo else ls₀
391+ let swapUniv := (t.argInfoAttr.find? env n₀).elim false (·.reorder.any (·.contains 0 ))
392+ let ls₁ : List Level := if swapUniv then ls₀.swapFirstTwo else ls₀
399393 if trace then
400394 if n₀ != n₁ then
401395 dbg_trace s! "changing { n₀} to { n₁} "
402- if 0 ∈ (reorderFn n₀).flatten then
396+ if swapUniv then
403397 dbg_trace s! "reordering the universe variables from { ls₀} to { ls₁} "
404398 return some <| .const n₁ ls₁
405399 | .app g x => do
@@ -412,9 +406,9 @@ where /-- Implementation of `applyReplacementFun`. -/
412406 let some nm := gf.constName? | return mkAppN (← r gf) (← gAllArgs.mapM r)
413407 -- e = `(nm y₁ .. yₙ x)
414408 /- Test if the head should not be replaced. -/
415- let relevantArgId := relevantArg nm
416- if h : relevantArgId < gAllArgs.size then
417- if let some fxd := shouldTranslate env t gAllArgs[relevantArgId ] dontTranslate then
409+ let { reorder, relevantArg } := t.argInfoAttr.find? env nm |>.getD {}
410+ if h : relevantArg < gAllArgs.size then
411+ if let some fxd := shouldTranslate env t gAllArgs[relevantArg ] dontTranslate then
418412 if trace then
419413 match fxd with
420414 | .inl fxd => dbg_trace s! "The application of { nm} contains the fixed type \
@@ -424,7 +418,6 @@ where /-- Implementation of `applyReplacementFun`. -/
424418 else
425419 gf ← r gf
426420 /- Test if arguments should be reordered. -/
427- let reorder := reorderFn nm
428421 if !reorder.isEmpty then
429422 gAllArgs := gAllArgs.permute! reorder
430423 if trace then
@@ -653,7 +646,7 @@ partial def transformDeclAux (t : TranslateData) (cfg : Config) (pre tgt_pre : N
653646 -- if the auxiliary declaration doesn't have prefix `pre`, then we have to add this declaration
654647 -- to the translation dictionary, since otherwise we cannot translate the name.
655648 if !pre.isPrefixOf src then
656- insertTranslation t src tgt
649+ insertTranslation t src tgt {}
657650 -- now transform the source declaration
658651 let trgDecl : ConstantInfo ← MetaM.run' <|
659652 if src == pre then
@@ -761,7 +754,7 @@ def translateLemmas {m : Type → Type} [Monad m] [MonadError m] [MonadLiftT Cor
761754 throwError "{names[0]!} and {nm} do not generate the same number of {desc}."
762755 for (srcLemmas, tgtLemmas) in auxLemmas.zip <| auxLemmas.eraseIdx! 0 do
763756 for (srcLemma, tgtLemma) in srcLemmas.zip tgtLemmas do
764- insertTranslationAndInfo t srcLemma tgtLemma argInfo
757+ insertTranslation t srcLemma tgtLemma argInfo
765758
766759/--
767760Find the argument of `nm` that appears in the first translatable (type-class) argument.
@@ -778,7 +771,7 @@ def findRelevantArg (t : TranslateData) (nm : Name) : MetaM Nat := do
778771 let relevantArg? (tgt : Expr) : Option Nat := do
779772 let c ← tgt.getAppFn.constName?
780773 guard (findTranslation? env t c).isSome
781- let relevantArg := (t.relevantArgAttr .find? env c).getD 0
774+ let relevantArg := (t.argInfoAttr .find? env c).elim 0 (·.relevantArg)
782775 let arg ← tgt.getArg? relevantArg
783776 xs.findIdx? (arg.containsFVar ·.fvarId!)
784777 -- run the above check on all hypotheses and on the conclusion
@@ -825,7 +818,7 @@ def proceedFieldsAux (t : TranslateData) (src tgt : Name) (argInfo : ArgInfo)
825818 throwError "Failed to map fields of {src}, {tgt} with {srcFields} ↦ {tgtFields}.\n \
826819 Lengths do not match."
827820 for srcField in srcFields, tgtField in tgtFields do
828- insertTranslationAndInfo t srcField tgtField argInfo
821+ insertTranslation t srcField tgtField argInfo
829822
830823/-- Add the structure fields of `src` to the translations dictionary
831824so that they will be translated correctly. -/
@@ -1068,15 +1061,10 @@ partial def addTranslationAttr (t : TranslateData) (src : Name) (cfg : Config)
10681061 -- If `tgt` is not in the environment, the translation to `tgt` was added only for
10691062 -- translating the namespace, and `src` wasn't actually tagged.
10701063 if (← getEnv).contains tgt then
1071- let mut updated := false
1072- if cfg.reorder != [] then
1073- modifyEnv (t.reorderAttr.addEntry · (src, cfg.reorder))
1074- updated := true
1075- if let some relevantArg := cfg.relevantArg? then
1076- modifyEnv (t.relevantArgAttr.addEntry · (src, relevantArg))
1077- updated := true
1078- if updated then
1064+ if cfg.reorder != [] || cfg.relevantArg?.isSome then
10791065 MetaM.run' <| checkExistingType t src tgt cfg.reorder cfg.dontTranslate
1066+ let argInfo := { reorder := cfg.reorder, relevantArg := cfg.relevantArg?.getD 0 }
1067+ insertTranslation t src tgt argInfo false
10801068 return #[tgt]
10811069 throwError
10821070 "Cannot apply attribute @[{t.attrName}] to '{src}': it is already translated to '{tgt}'. \n \
@@ -1095,7 +1083,7 @@ partial def addTranslationAttr (t : TranslateData) (src : Name) (cfg : Config)
10951083 MetaM.run' <| checkExistingType t src tgt cfg.reorder cfg.dontTranslate
10961084 let relevantArg ← cfg.relevantArg?.getDM <| MetaM.run' <| findRelevantArg t src
10971085 let argInfo := { reorder := cfg.reorder, relevantArg }
1098- insertTranslationAndInfo t src tgt argInfo alreadyExists
1086+ insertTranslation t src tgt argInfo alreadyExists
10991087 let nestedNames ←
11001088 if alreadyExists then
11011089 -- since `tgt` already exists, we just need to copy metadata and
0 commit comments