Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : derive DecidableEq for mutual inductives #2591

Merged
merged 4 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Please check the [releases](https://github.com/leanprover/lean4/releases) page f
v4.3.0 (development in progress)
---------

* The derive handler for `DecidableEq` [now handles](https://github.com/leanprover/lean4/pull/2591) mutual inductive types.
* [Fix linker warnings on macOS](https://github.com/leanprover/lean4/pull/2598).

v4.2.0
Expand Down
39 changes: 23 additions & 16 deletions src/Lean/Elab/Deriving/DecEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ open Meta
def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `DecidableEq 2 indVal

def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkSameCtorRhs : List (Ident × Ident × Bool × Bool) → TermElabM Term
mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) → TermElabM Term
| [] => ``(isTrue rfl)
| (a, b, recField, isProof) :: todo => withFreshMacroScope do
let rhs ← if isProof then
Expand All @@ -30,7 +30,7 @@ where
by subst h; exact $(← mkSameCtorRhs todo):term
else
isFalse (by intro n; injection n; apply h _; assumption))
if recField then
if let some auxFunName := recField then
-- add local instance for `a = b` using the function being defined `auxFunName`
`(let inst := $(mkIdent auxFunName) $a $b; $rhs)
else
Expand Down Expand Up @@ -67,8 +67,11 @@ where
let b := mkIdent (← mkFreshUserName `b)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let recField := (← inferType x).isAppOf indVal.name
let isProof := (← inferType (← inferType x)).isProp
let indValNum :=
ctx.typeInfos.findIdx?
((← inferType x).isAppOf ∘ ConstantVal.name ∘ InductiveVal.toConstantVal)
let recField := indValNum.map (ctx.auxFunNames[·]!)
let isProof := (← inferType (← inferType x)).isProp
todo := todo.push (a, b, recField, isProof)
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
Expand All @@ -81,18 +84,24 @@ where
alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
return alts

def mkAuxFunction (ctx : Context) : TermElabM Syntax := do
let auxFunName := ctx.auxFunNames[0]!
let indVal :=ctx.typeInfos[0]!
let header ← mkDecEqHeader indVal
let mut body ← mkMatch header indVal auxFunName
let binders := header.binders
let type ← `(Decidable ($(mkIdent header.targetNames[0]!) = $(mkIdent header.targetNames[1]!)))
def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do
let header ← mkDecEqHeader indVal
let body ← mkMatch ctx header indVal
let binders := header.binders
let type ← `(Decidable ($(mkIdent header.targetNames[0]!) = $(mkIdent header.targetNames[1]!)))
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term)

def mkAuxFunctions (ctx : Context) : TermElabM (TSyntax `command) := do
let mut res : Array (TSyntax `command) := #[]
for i in [:ctx.auxFunNames.size] do
let auxFunName := ctx.auxFunNames[i]!
let indVal := ctx.typeInfos[i]!
res := res.push (← mkAuxFunction ctx auxFunName indVal)
`(command| mutual $[$res:command]* end)

def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
let ctx ← mkContext "decEq" indVal.name
let cmds := #[← mkAuxFunction ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
let cmds := #[← mkAuxFunctions ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
trace[Elab.Deriving.decEq] "\n{cmds}"
return cmds

Expand Down Expand Up @@ -174,9 +183,7 @@ def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
elabCommand cmd

def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size != 1 then
return false -- mutually inductive types are not supported yet
else if (← isEnumType declNames[0]!) then
if (← isEnumType declNames[0]!) then
mkDecEqEnum declNames[0]!
return true
else
Expand Down
27 changes: 27 additions & 0 deletions tests/lean/decEqMutualInductives.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/-! Verify that the derive handler for `DecidableEq` handles mutual inductive types-/

-- Print the generated derivations
set_option trace.Elab.Deriving.decEq true

mutual
inductive Tree : Type :=
| node : ListTree → Tree

inductive ListTree : Type :=
| nil : ListTree
| cons : Tree → ListTree → ListTree
deriving DecidableEq
end

mutual
inductive Foo₁ : Type :=
| foo₁₁ : Foo₁
| foo₁₂ : Foo₂ → Foo₁
deriving DecidableEq

inductive Foo₂ : Type :=
| foo₂ : Foo₃ → Foo₂

inductive Foo₃ : Type :=
| foo₃ : Foo₁ → Foo₃
end
50 changes: 50 additions & 0 deletions tests/lean/decEqMutualInductives.lean.expected.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
[Elab.Deriving.decEq]
[mutual
private def decEqTree✝ (x✝ : @Tree✝) (x✝¹ : @Tree✝) : Decidable✝ (x✝ = x✝¹) :=
match x✝, x✝¹ with
| @Tree.node a✝, @Tree.node b✝ =>
let inst✝ := decEqListTree✝ a✝ b✝;
if h✝ : a✝ = b✝ then by subst h✝; exact isTrue✝ rfl✝
else isFalse✝ (by intro n✝; injection n✝; apply h✝ _; assumption)
private def decEqListTree✝ (x✝² : @ListTree✝) (x✝³ : @ListTree✝) : Decidable✝ (x✝² = x✝³) :=
match x✝², x✝³ with
| @ListTree.nil, @ListTree.nil => isTrue✝¹ rfl✝¹
| ListTree.nil .., ListTree.cons .. => isFalse✝¹ (by intro h✝¹; injection h✝¹)
| ListTree.cons .., ListTree.nil .. => isFalse✝¹ (by intro h✝¹; injection h✝¹)
| @ListTree.cons a✝¹ a✝², @ListTree.cons b✝¹ b✝² =>
let inst✝¹ := decEqTree✝ a✝¹ b✝¹;
if h✝² : a✝¹ = b✝¹ then by subst h✝²;
exact
let inst✝² := decEqListTree✝ a✝² b✝²;
if h✝³ : a✝² = b✝² then by subst h✝³; exact isTrue✝² rfl✝²
else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝³ _; assumption)
else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝² _; assumption)
end,
instance : DecidableEq✝ (@ListTree✝) :=
decEqListTree✝]
[Elab.Deriving.decEq]
[mutual
private def decEqFoo₁✝ (x✝ : @Foo₁✝) (x✝¹ : @Foo₁✝) : Decidable✝ (x✝ = x✝¹) :=
match x✝, x✝¹ with
| @Foo₁.foo₁₁, @Foo₁.foo₁₁ => isTrue✝ rfl✝
| Foo₁.foo₁₁ .., Foo₁.foo₁₂ .. => isFalse✝ (by intro h✝; injection h✝)
| Foo₁.foo₁₂ .., Foo₁.foo₁₁ .. => isFalse✝ (by intro h✝; injection h✝)
| @Foo₁.foo₁₂ a✝, @Foo₁.foo₁₂ b✝ =>
let inst✝ := decEqFoo₂✝ a✝ b✝;
if h✝¹ : a✝ = b✝ then by subst h✝¹; exact isTrue✝¹ rfl✝¹
else isFalse✝¹ (by intro n✝; injection n✝; apply h✝¹ _; assumption)
private def decEqFoo₂✝ (x✝² : @Foo₂✝) (x✝³ : @Foo₂✝) : Decidable✝ (x✝² = x✝³) :=
match x✝², x✝³ with
| @Foo₂.foo₂ a✝¹, @Foo₂.foo₂ b✝¹ =>
let inst✝¹ := decEqFoo₃✝ a✝¹ b✝¹;
if h✝² : a✝¹ = b✝¹ then by subst h✝²; exact isTrue✝² rfl✝²
else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝² _; assumption)
private def decEqFoo₃✝ (x✝⁴ : @Foo₃✝) (x✝⁵ : @Foo₃✝) : Decidable✝ (x✝⁴ = x✝⁵) :=
match x✝⁴, x✝⁵ with
| @Foo₃.foo₃ a✝², @Foo₃.foo₃ b✝² =>
let inst✝² := decEqFoo₁✝ a✝² b✝²;
if h✝³ : a✝² = b✝² then by subst h✝³; exact isTrue✝³ rfl✝³
else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝³ _; assumption)
end,
instance : DecidableEq✝ (@Foo₁✝) :=
decEqFoo₁✝]
Loading