Skip to content

Commit

Permalink
feat: optimized noConfusion for enumeration types
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Sep 6, 2021
1 parent d43a1c7 commit a7c6218
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 4 deletions.
89 changes: 85 additions & 4 deletions src/Lean/Meta/Constructions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Environment
import Lean.MonadEnv
import Lean.AuxRecursor
import Lean.Meta.AppBuilder

namespace Lean

@[extern "lean_mk_cases_on"] constant mkCasesOnImp (env : Environment) (declName : @& Name) : Except KernelException Environment
@[extern "lean_mk_rec_on"] constant mkRecOnImp (env : Environment) (declName : @& Name) : Except KernelException Environment
@[extern "lean_mk_no_confusion"] constant mkNoConfusionImp (env : Environment) (declName : @& Name) : Except KernelException Environment
@[extern "lean_mk_no_confusion"] constant mkNoConfusionCoreImp (env : Environment) (declName : @& Name) : Except KernelException Environment
@[extern "lean_mk_below"] constant mkBelowImp (env : Environment) (declName : @& Name) : Except KernelException Environment
@[extern "lean_mk_ibelow"] constant mkIBelowImp (env : Environment) (declName : @& Name) : Except KernelException Environment
@[extern "lean_mk_brec_on"] constant mkBRecOnImp (env : Environment) (declName : @& Name) : Except KernelException Environment
Expand All @@ -25,10 +25,91 @@ variable {m} [Monad m] [MonadEnv m] [MonadError m] [MonadOptions m]

def mkCasesOn (declName : Name) : m Unit := adaptFn mkCasesOnImp declName
def mkRecOn (declName : Name) : m Unit := adaptFn mkRecOnImp declName
def mkNoConfusion (declName : Name) : m Unit := adaptFn mkNoConfusionImp declName
def mkNoConfusionCore (declName : Name) : m Unit := adaptFn mkNoConfusionCoreImp declName
def mkBelow (declName : Name) : m Unit := adaptFn mkBelowImp declName
def mkIBelow (declName : Name) : m Unit := adaptFn mkIBelowImp declName
def mkBRecOn (declName : Name) : m Unit := adaptFn mkBRecOnImp declName
def mkBInductionOn (declName : Name) : m Unit := adaptFn mkBInductionOnImp declName

open Meta

def mkNoConfusionEnum (enumName : Name) : MetaM Unit := do
if (← getEnv).contains ``noConfusionEnum then
mkToCtorIdx
mkNoConfusionType
mkNoConfusion
else
-- `noConfusionEnum` was not defined yet, so we use `mkNoConfusionCore`
mkNoConfusionCore enumName
where

mkToCtorIdx : MetaM Unit := do
let ConstantInfo.inductInfo info ← getConstInfo enumName | unreachable!
let numCtors := info.ctors.length
let declName := Name.mkStr enumName "toCtorIdx"
let enumType := mkConst enumName
let natType := mkConst ``Nat
let declType ← mkArrow enumType natType
let mut minors := #[]
for i in [:numCtors] do
minors := minors.push <| mkNatLit i
withLocalDeclD `x enumType fun x => do
let motive ← mkLambdaFVars #[x] natType
let declValue ← mkLambdaFVars #[x] <| mkAppN (mkApp2 (mkConst (mkCasesOnName enumName) [levelOne]) motive x) minors
addAndCompile <| Declaration.defnDecl {
name := declName
levelParams := []
type := declType
value := declValue
safety := DefinitionSafety.safe
hints := ReducibilityHints.abbrev
}

mkNoConfusionType : MetaM Unit := do
let enumType := mkConst enumName
let sortU := mkSort (mkLevelParam `u)
let toCtorIdx := mkConst (Name.mkStr enumName "toCtorIdx")
withLocalDeclD `P sortU fun P =>
withLocalDeclD `x enumType fun x =>
withLocalDeclD `y enumType fun y => do
let declType ← mkForallFVars #[P, x, y] sortU
let declValue ← mkLambdaFVars #[P, x, y] (← mkAppM ``noConfusionTypeEnum #[toCtorIdx, P, x, y])
let declName := Name.mkStr enumName "noConfusionType"
addAndCompile <| Declaration.defnDecl {
name := declName
levelParams := [`u]
type := declType
value := declValue
safety := DefinitionSafety.safe
hints := ReducibilityHints.abbrev
}

mkNoConfusion : MetaM Unit := do
let enumType := mkConst enumName
let u := mkLevelParam `u
let sortU := mkSort u
let toCtorIdx := mkConst (Name.mkStr enumName "toCtorIdx")
let noConfusionType := mkConst (Name.mkStr enumName "noConfusionType") [u]
withLocalDecl `P BinderInfo.implicit sortU fun P =>
withLocalDecl `x BinderInfo.implicit enumType fun x =>
withLocalDecl `y BinderInfo.implicit enumType fun y => do
withLocalDeclD `h (← mkEq x y) fun h => do
let declType ← mkForallFVars #[P, x, y, h] (mkApp3 noConfusionType P x y)
let declValue ← mkLambdaFVars #[P, x, y, h] (← mkAppOptM ``noConfusionEnum #[none, none, none, toCtorIdx, P, x, y, h])
let declName := Name.mkStr enumName "noConfusion"
addAndCompile <| Declaration.defnDecl {
name := declName
levelParams := [`u]
type := declType
value := declValue
safety := DefinitionSafety.safe
hints := ReducibilityHints.abbrev
}

def mkNoConfusion (declName : Name) : MetaM Unit := do
if (← isEnumType declName) then
mkNoConfusionEnum declName
else
mkNoConfusionCore declName

end Lean
11 changes: 11 additions & 0 deletions src/Lean/MonadEnv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,15 @@ def findModuleOf? [Monad m] [MonadEnv m] [MonadError m] (declName : Name) : m (O
| none => return none
| some modIdx => return some ((← getEnv).allImportedModuleNames[modIdx])

def isEnumType [Monad m] [MonadEnv m] [MonadError m] (declName : Name) : m Bool := do
if let ConstantInfo.inductInfo info ← getConstInfo declName then
if info.all.length == 1 && info.numIndices == 0 && info.numParams == 0 && !info.isRec && !info.isNested && !info.isUnsafe then
info.ctors.allM fun ctorName => do
let ConstantInfo.ctorInfo info ← getConstInfo ctorName | return false
return info.numFields == 0
else
return false
else
return false

end Lean

0 comments on commit a7c6218

Please sign in to comment.