Skip to content

Commit

Permalink
feat: make flat/non-flat ext lemma configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
digama0 committed May 24, 2023
1 parent 60f19be commit 62f319c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
27 changes: 18 additions & 9 deletions Std/Tactic/Ext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,24 @@ Calls the continuation `k` with the list of parameters to the structure,
two structure variables `x` and `y`, and a list of pairs `(field, ty)`
where `ty` is `x.field = y.field` or `HEq x.field y.field`.
-/
def withExtHyps (struct : Name)
def withExtHyps (struct : Name) (flat : Term)
(k : Array Expr → (x y : Expr) → Array (Name × Expr) → MetaM α) : MetaM α := do
let flat ← match flat with
| `(true) => pure true
| `(false) => pure false
| _ => throwErrorAt flat "expected 'true' or 'false'"
unless isStructure (← getEnv) struct do throwError "not a structure: {struct}"
let structC ← mkConstWithLevelParams struct
forallTelescope (← inferType structC) fun params _ => do
withNewBinderInfos (params.map (·.fvarId!, BinderInfo.implicit)) do
withLocalDeclD `x (mkAppN structC params) fun x => do
withLocalDeclD `y (mkAppN structC params) fun y => do
let mut hyps := #[]
for field in getStructureFieldsFlattened (← getEnv) struct (includeSubobjectFields := false) do
let fields := if flat then
getStructureFieldsFlattened (← getEnv) struct (includeSubobjectFields := false)
else
getStructureFields (← getEnv) struct
for field in fields do
let x_f ← mkProjection x field
let y_f ← mkProjection y field
if ← isProof x_f then
Expand All @@ -41,8 +49,8 @@ def withExtHyps (struct : Name)
Creates the type of the extensionality lemma for the given structure,
elaborating to `x.1 = y.1 → x.2 = y.2 → x = y`, for example.
-/
scoped elab "ext_type%" struct:ident : term => do
withExtHyps (← resolveGlobalConstNoOverloadWithInfo struct) fun params x y hyps => do
scoped elab "ext_type%" flat:term:max struct:ident : term => do
withExtHyps (← resolveGlobalConstNoOverloadWithInfo struct) flat fun params x y hyps => do
let ty := hyps.foldr (init := ← mkEq x y) fun (f, h) ty =>
mkForall f BinderInfo.default h ty
mkForallFVars (params |>.push x |>.push y) ty
Expand All @@ -60,22 +68,23 @@ def mkAndN : List Expr → Expr
Creates the type of the iff-variant of the extensionality lemma for the given structure,
elaborating to `x = y ↔ x.1 = y.1 ∧ x.2 = y.2`, for example.
-/
scoped elab "ext_iff_type%" struct:ident : term => do
withExtHyps (← resolveGlobalConstNoOverloadWithInfo struct) fun params x y hyps => do
scoped elab "ext_iff_type%" flat:term:max struct:ident : term => do
withExtHyps (← resolveGlobalConstNoOverloadWithInfo struct) flat fun params x y hyps => do
mkForallFVars (params |>.push x |>.push y) <|
mkIff (← mkEq x y) <| mkAndN (hyps.map (·.2)).toList

macro_rules | `(declare_ext_theorems_for $struct:ident $[$prio]?) => do
macro_rules | `(declare_ext_theorems_for $[(flat := $f)]? $struct:ident $(prio)?) => do
let flat := f.getD (mkIdent `true)
let names ← Macro.resolveGlobalName struct.getId.eraseMacroScopes
let name ← match names.filter (·.2.isEmpty) with
| [] => Macro.throwError s!"unknown constant {struct}"
| [(name, _)] => pure name
| _ => Macro.throwError s!"ambiguous name {struct}"
let extName := mkIdentFrom struct (canonical := true) <| name.mkStr "ext"
let extIffName := mkIdentFrom struct (canonical := true) <| name.mkStr "ext_iff"
`(@[ext $[$prio]?] protected theorem $extName:ident : ext_type% $struct:ident :=
`(@[ext $(prio)?] protected theorem $extName:ident : ext_type% $flat $struct:ident :=
fun {..} {..} => by intros; subst_eqs; rfl
protected theorem $extIffName:ident : ext_iff_type% $struct:ident :=
protected theorem $extIffName:ident : ext_iff_type% $flat $struct:ident :=
fun {..} {..} =>
fun h => by cases h; split_ands <;> rfl,
fun _ => by (repeat cases ‹_ ∧ _›); subst_eqs; rfl⟩)
Expand Down
13 changes: 8 additions & 5 deletions Std/Tactic/Ext/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace Std.Tactic.Ext
open Lean Meta

/-- `declare_ext_theorems_for A` declares the extensionality theorems for the structure `A`. -/
syntax "declare_ext_theorems_for" ident prio ? : command
syntax "declare_ext_theorems_for" ("(" &"flat" " := " term ")")? ident (prio)? : command

/-- Information about an extensionality theorem, stored in the environment extension. -/
structure ExtTheorem where
Expand All @@ -33,7 +33,7 @@ initialize extExtension :
/-- Get the list of `@[ext]` lemmas corresponding to the key `ty`. -/
@[inline] def getExtLemmas (ty : Expr) : MetaM (Array ExtTheorem) :=
return (← (extExtension.getState (← getEnv)).getMatch ty)
|>.insertionSort fun a b => a.priority > b.priority
|>.qsort fun a b => a.priority > b.priority

/-- Registers an extensionality lemma.
Expand All @@ -45,17 +45,20 @@ When `@[ext]` is applied to a theorem,
the theorem is registered for the `ext` tactic.
You can use `@[ext 9000]` to specify a priority for the attribute. -/
syntax (name := ext) "ext" prio ? : attr
syntax (name := ext) "ext" ("(" &"flat" " := " term ")")? (prio)? : attr

initialize registerBuiltinAttribute {
name := `ext
descr := "Marks a lemma as extensionality lemma"
add := fun declName stx kind => do
let `(attr| ext $[$prio]?) := stx | throwError "unexpected @[ext] attribute {stx}"
let `(attr| ext $[(flat := $f)]? $(prio)?) := stx
| throwError "unexpected @[ext] attribute {stx}"
if isStructure (← getEnv) declName then
liftCommandElabM <| Elab.Command.elabCommand <|
← `(declare_ext_theorems_for $(mkCIdentFrom stx declName) $[$prio]?)
← `(declare_ext_theorems_for $[(flat := $f)]? $(mkCIdentFrom stx declName) $[$prio]?)
else MetaM.run' do
if let some flat := f then
throwErrorAt flat "unexpected 'flat' config on @[ext] lemma"
let declTy := (← getConstInfo declName).type
let (_, _, declTy) ← withDefault <| forallMetaTelescopeReducing declTy
let failNotEq := throwError
Expand Down
19 changes: 14 additions & 5 deletions test/ext.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import Std.Tactic.Ext
import Std.Logic

set_option linter.missingDocs false
axiom mySorry {α : Sort _} : α

structure A (n : Nat) where
a : Nat
Expand All @@ -21,15 +22,23 @@ structure B (n) extends A n where

example (a b : C n) : a = b := by
ext
guard_target = a.a = b.a; admit
guard_target = a.b = b.b; admit
guard_target = HEq a.i b.i; admit
guard_target = a.c = b.c; admit
guard_target = a.a = b.a; exact mySorry
guard_target = a.b = b.b; exact mySorry
guard_target = HEq a.i b.i; exact mySorry
guard_target = a.c = b.c; exact mySorry

@[ext (flat := false)] structure C' (n) extends B n where
c : Nat

example (a b : C' n) : a = b := by
ext
guard_target = a.toB = b.toB; exact mySorry
guard_target = a.c = b.c; exact mySorry

open Std.Tactic.Ext
example (f g : Nat × Nat → Nat) : f = g := by
ext ⟨x, y⟩
guard_target = f (x, y) = g (x, y); admit
guard_target = f (x, y) = g (x, y); exact mySorry

-- allow more specific ext theorems
declare_ext_theorems_for Fin
Expand Down

0 comments on commit 62f319c

Please sign in to comment.