Skip to content

Commit

Permalink
feat: user-defined attributes
Browse files Browse the repository at this point in the history
See new test for an example.

closes #513
  • Loading branch information
leodemoura committed Jul 27, 2021
1 parent 0bea52d commit a77598f
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 131 deletions.
274 changes: 145 additions & 129 deletions src/Lean/Attributes.lean
Expand Up @@ -47,137 +47,9 @@ builtin_initialize attributeMapRef : IO.Ref (PersistentHashMap Name AttributeImp
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
let m ← attributeMapRef.get
if m.contains attr.name then throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attr.name ++ "' has already been used"))
unless (← IO.initializing) do throw (IO.userError "failed to register attribute, attributes can only be registered during initialization")
unless (← IO.initializing) || (← importing) do throw (IO.userError "failed to register attribute, attributes can only be registered during initialization")
attributeMapRef.modify fun m => m.insert attr.name attr

abbrev AttributeImplBuilder := List DataValue → Except String AttributeImpl
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder

builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {}

def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuilder) : IO Unit := do
let table ← attributeImplBuilderTableRef.get
if table.contains builderId then throw (IO.userError ("attribute implementation builder '" ++ toString builderId ++ "' has already been declared"))
attributeImplBuilderTableRef.modify fun table => table.insert builderId builder

def mkAttributeImplOfBuilder (builderId : Name) (args : List DataValue) : IO AttributeImpl := do
let table ← attributeImplBuilderTableRef.get
match table.find? builderId with
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
| some builder => IO.ofExcept $ builder args

inductive AttributeExtensionOLeanEntry where
| decl (declName : Name) -- `declName` has type `AttributeImpl`
| builder (builderId : Name) (args : List DataValue)

structure AttributeExtensionState where
newEntries : List AttributeExtensionOLeanEntry := []
map : PersistentHashMap Name AttributeImpl
deriving Inhabited

abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState

private def AttributeExtension.mkInitial : IO AttributeExtensionState := do
let map ← attributeMapRef.get
pure { map := map }

unsafe def mkAttributeImplOfConstantUnsafe (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl :=
match env.find? declName with
| none => throw ("unknow constant '" ++ toString declName ++ "'")
| some info =>
match info.type with
| Expr.const `Lean.AttributeImpl _ _ => env.evalConst AttributeImpl opts declName
| _ => throw ("unexpected attribute implementation type at '" ++ toString declName ++ "' (`AttributeImpl` expected")

@[implementedBy mkAttributeImplOfConstantUnsafe]
constant mkAttributeImplOfConstant (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl

def mkAttributeImplOfEntry (env : Environment) (opts : Options) (e : AttributeExtensionOLeanEntry) : IO AttributeImpl :=
match e with
| AttributeExtensionOLeanEntry.decl declName => IO.ofExcept $ mkAttributeImplOfConstant env opts declName
| AttributeExtensionOLeanEntry.builder builderId args => mkAttributeImplOfBuilder builderId args

private def AttributeExtension.addImported (es : Array (Array AttributeExtensionOLeanEntry)) : ImportM AttributeExtensionState := do
let ctx ← read
let map ← attributeMapRef.get
let map ← es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : PersistentHashMap Name AttributeImpl) entry => do
let attrImpl ← liftM $ mkAttributeImplOfEntry ctx.env ctx.opts entry
pure $ map.insert attrImpl.name attrImpl)
map)
map
pure { map := map }

private def addAttrEntry (s : AttributeExtensionState) (e : AttributeExtensionOLeanEntry × AttributeImpl) : AttributeExtensionState :=
{ s with map := s.map.insert e.2.name e.2, newEntries := e.1 :: s.newEntries }

builtin_initialize attributeExtension : AttributeExtension ←
registerPersistentEnvExtension {
name := `attrExt,
mkInitial := AttributeExtension.mkInitial,
addImportedFn := AttributeExtension.addImported,
addEntryFn := addAttrEntry,
exportEntriesFn := fun s => s.newEntries.reverse.toArray,
statsFn := fun s => format "number of local entries: " ++ format s.newEntries.length
}

/- Return true iff `n` is the name of a registered attribute. -/
@[export lean_is_attribute]
def isBuiltinAttribute (n : Name) : IO Bool := do
let m ← attributeMapRef.get; pure (m.contains n)

/- Return the name of all registered attributes. -/
def getBuiltinAttributeNames : IO (List Name) := do
let m ← attributeMapRef.get; pure $ m.foldl (fun r n _ => n::r) []

def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
let m ← attributeMapRef.get
match m.find? attrName with
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))

@[export lean_attribute_application_time]
def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do
let attr ← getBuiltinAttributeImpl n
pure attr.applicationTime

def isAttribute (env : Environment) (attrName : Name) : Bool :=
(attributeExtension.getState env).map.contains attrName

def getAttributeNames (env : Environment) : List Name :=
let m := (attributeExtension.getState env).map
m.foldl (fun r n _ => n::r) []

def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map
match m.find? attrName with
| some attr => pure attr
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")

def registerAttributeOfDecl (env : Environment) (opts : Options) (attrDeclName : Name) : Except String Environment := do
let attrImpl ← mkAttributeImplOfConstant env opts attrDeclName
if isAttribute env attrImpl.name then
throw ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used")
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.decl attrDeclName, attrImpl)

def registerAttributeOfBuilder (env : Environment) (builderId : Name) (args : List DataValue) : IO Environment := do
let attrImpl ← mkAttributeImplOfBuilder builderId args
if isAttribute env attrImpl.name then
throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used"))
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.builder builderId args, attrImpl)

def Attribute.add (declName : Name) (attrName : Name) (stx : Syntax) (kind := AttributeKind.global) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.add declName stx kind

def Attribute.erase (declName : Name) (attrName : Name) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.erase declName

/-
Helper methods for decoding the parameters of builtin attributes that are defined before `Lean.Parser`.
We have the following ones:
Expand Down Expand Up @@ -406,4 +278,148 @@ def setValue {α : Type} (attrs : EnumAttributes α) (env : Environment) (decl :

end EnumAttributes

/-
Attribute extension and builders. We use builders to implement attribute factories for parser categories.
-/

abbrev AttributeImplBuilder := List DataValue → Except String AttributeImpl
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder

builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {}

def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuilder) : IO Unit := do
let table ← attributeImplBuilderTableRef.get
if table.contains builderId then throw (IO.userError ("attribute implementation builder '" ++ toString builderId ++ "' has already been declared"))
attributeImplBuilderTableRef.modify fun table => table.insert builderId builder

def mkAttributeImplOfBuilder (builderId : Name) (args : List DataValue) : IO AttributeImpl := do
let table ← attributeImplBuilderTableRef.get
match table.find? builderId with
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
| some builder => IO.ofExcept $ builder args

inductive AttributeExtensionOLeanEntry where
| decl (declName : Name) -- `declName` has type `AttributeImpl`
| builder (builderId : Name) (args : List DataValue)

structure AttributeExtensionState where
newEntries : List AttributeExtensionOLeanEntry := []
map : PersistentHashMap Name AttributeImpl
deriving Inhabited

abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState

private def AttributeExtension.mkInitial : IO AttributeExtensionState := do
let map ← attributeMapRef.get
pure { map := map }

unsafe def mkAttributeImplOfConstantUnsafe (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl :=
match env.find? declName with
| none => throw ("unknow constant '" ++ toString declName ++ "'")
| some info =>
match info.type with
| Expr.const `Lean.AttributeImpl _ _ => env.evalConst AttributeImpl opts declName
| _ => throw ("unexpected attribute implementation type at '" ++ toString declName ++ "' (`AttributeImpl` expected")

@[implementedBy mkAttributeImplOfConstantUnsafe]
constant mkAttributeImplOfConstant (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl

def mkAttributeImplOfEntry (env : Environment) (opts : Options) (e : AttributeExtensionOLeanEntry) : IO AttributeImpl :=
match e with
| AttributeExtensionOLeanEntry.decl declName => IO.ofExcept $ mkAttributeImplOfConstant env opts declName
| AttributeExtensionOLeanEntry.builder builderId args => mkAttributeImplOfBuilder builderId args

private def AttributeExtension.addImported (es : Array (Array AttributeExtensionOLeanEntry)) : ImportM AttributeExtensionState := do
let ctx ← read
let map ← attributeMapRef.get
let map ← es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : PersistentHashMap Name AttributeImpl) entry => do
let attrImpl ← liftM $ mkAttributeImplOfEntry ctx.env ctx.opts entry
pure $ map.insert attrImpl.name attrImpl)
map)
map
pure { map := map }

private def addAttrEntry (s : AttributeExtensionState) (e : AttributeExtensionOLeanEntry × AttributeImpl) : AttributeExtensionState :=
{ s with map := s.map.insert e.2.name e.2, newEntries := e.1 :: s.newEntries }

builtin_initialize attributeExtension : AttributeExtension ←
registerPersistentEnvExtension {
name := `attrExt,
mkInitial := AttributeExtension.mkInitial,
addImportedFn := AttributeExtension.addImported,
addEntryFn := addAttrEntry,
exportEntriesFn := fun s => s.newEntries.reverse.toArray,
statsFn := fun s => format "number of local entries: " ++ format s.newEntries.length
}

/- Return true iff `n` is the name of a registered attribute. -/
@[export lean_is_attribute]
def isBuiltinAttribute (n : Name) : IO Bool := do
let m ← attributeMapRef.get; pure (m.contains n)

/- Return the name of all registered attributes. -/
def getBuiltinAttributeNames : IO (List Name) := do
let m ← attributeMapRef.get; pure $ m.foldl (fun r n _ => n::r) []

def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
let m ← attributeMapRef.get
match m.find? attrName with
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))

@[export lean_attribute_application_time]
def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do
let attr ← getBuiltinAttributeImpl n
pure attr.applicationTime

def isAttribute (env : Environment) (attrName : Name) : Bool :=
(attributeExtension.getState env).map.contains attrName

def getAttributeNames (env : Environment) : List Name :=
let m := (attributeExtension.getState env).map
m.foldl (fun r n _ => n::r) []

def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map
match m.find? attrName with
| some attr => pure attr
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")

def registerAttributeOfDecl (env : Environment) (opts : Options) (attrDeclName : Name) : Except String Environment := do
let attrImpl ← mkAttributeImplOfConstant env opts attrDeclName
if isAttribute env attrImpl.name then
throw ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used")
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.decl attrDeclName, attrImpl)

def registerAttributeOfBuilder (env : Environment) (builderId : Name) (args : List DataValue) : IO Environment := do
let attrImpl ← mkAttributeImplOfBuilder builderId args
if isAttribute env attrImpl.name then
throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used"))
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.builder builderId args, attrImpl)

def Attribute.add (declName : Name) (attrName : Name) (stx : Syntax) (kind := AttributeKind.global) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.add declName stx kind

def Attribute.erase (declName : Name) (attrName : Name) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.erase declName

builtin_initialize
-- See comment at `updateEnvAttributesRef`
updateEnvAttributesRef.set fun env => do
let map ← attributeMapRef.get
let s ← attributeExtension.getState env
let s := map.foldl (init := s) fun s attrName attrImpl =>
if s.map.contains attrName then
s
else
{ s with map := s.map.insert attrName attrImpl }
return attributeExtension.setState env s

end Lean
5 changes: 3 additions & 2 deletions src/Lean/Compiler/InitAttr.lean
Expand Up @@ -49,8 +49,9 @@ unsafe def registerInitAttrUnsafe (attrName : Name) (runAfterImport : Bool) : IO
for modEntries in entries do
for (decl, initDecl) in modEntries do
if initDecl.isAnonymous then
_ ← IO.ofExcept $ ctx.env.evalConst (IO Unit) ctx.opts decl
else runInit ctx.env ctx.opts decl initDecl
discard <| IO.ofExcept $ ctx.env.evalConst (IO Unit) ctx.opts decl
else
runInit ctx.env ctx.opts decl initDecl
}

@[implementedBy registerInitAttrUnsafe]
Expand Down
13 changes: 13 additions & 0 deletions src/Lean/Environment.lean
Expand Up @@ -570,6 +570,17 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
env ← extDescr.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.push entries }
return env

/-
"Forward declaration" needed for updating the attribute table with user-defined attributes.
User-defined attributes are declared using the `initialize` command. The `initialize` command is just syntax sugar for the `init` attribute.
The `init` attribute is initialized after the `attributeExtension` is initialized. We cannot change the order since the `init` attribute is an attribute,
and requires this extension.
The `attributeExtension` initializer uses `attributeMapRef` to initialize the attribute mapping.
When we a new user-defined attribute declaration is imported, `attributeMapRef` is updated.
Later, we set this method with code that adds the user-defined attributes that were imported after we initialized `attributeExtension`.
-/
builtin_initialize updateEnvAttributesRef : IO.Ref (Environment → IO Environment) ← IO.mkRef (fun env => pure env)

private partial def finalizePersistentExtensions (env : Environment) (mods : Array ModuleData) (opts : Options) : IO Environment := do
loop 0 env
where
Expand All @@ -588,6 +599,8 @@ where
-- a user-defined persistent extension is imported.
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
env ← setImportedEntries env mods prevSize
-- See comment at `updateEnvAttributesRef`
env ← (← updateEnvAttributesRef.get) env
loop (i + 1) env
else
return env
Expand Down
8 changes: 8 additions & 0 deletions src/shell/CMakeLists.txt
Expand Up @@ -187,3 +187,11 @@ add_test(NAME leanpkgtest_user_ext
export PATH=${LEAN_BIN}:$PATH
find . -name '*.olean' -delete
leanpkg build | grep 'world, hello, test'")

add_test(NAME leanpkgtest_user_attr
WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/leanpkg/user_attr"
COMMAND bash -c "
set -eu
export PATH=${LEAN_BIN}:$PATH
find . -name '*.olean' -delete
leanpkg build")
1 change: 1 addition & 0 deletions tests/leanpkg/user_attr/.gitignore
@@ -0,0 +1 @@
/build
12 changes: 12 additions & 0 deletions tests/leanpkg/user_attr/UserAttr.lean
@@ -0,0 +1,12 @@
import UserAttr.Tst

open Lean

def tst : MetaM Unit := do
let env ← getEnv
assert! (blaAttr.hasTag env `f)
assert! (blaAttr.hasTag env `g)
assert! !(blaAttr.hasTag env `id)
pure ()

#eval tst
5 changes: 5 additions & 0 deletions tests/leanpkg/user_attr/UserAttr/BlaAttr.lean
@@ -0,0 +1,5 @@
import Lean

open Lean

initialize blaAttr : TagAttribute ← registerTagAttribute `bla "simple user defined attribute"
4 changes: 4 additions & 0 deletions tests/leanpkg/user_attr/UserAttr/Tst.lean
@@ -0,0 +1,4 @@
import UserAttr.BlaAttr

@[bla] def f (x : Nat) := x + 2
@[bla] def g (x : Nat) := x + 1
3 changes: 3 additions & 0 deletions tests/leanpkg/user_attr/leanpkg.toml
@@ -0,0 +1,3 @@
[package]
name = "UserAttr"
version = "0.1"

0 comments on commit a77598f

Please sign in to comment.