Skip to content

Commit

Permalink
feat: user-defined environment extensions
Browse files Browse the repository at this point in the history
New test demonstrates how to use them.
The user-defined extensions cannot be used in the same file where they
were declared because the `initialize` commands are only executed when
we import the modules containing them.

TODO: user-defined attributes.
  • Loading branch information
leodemoura committed Jul 26, 2021
1 parent 42561bb commit cdd1dbb
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 68 deletions.
196 changes: 128 additions & 68 deletions src/Lean/Environment.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ import Lean.Util.FindExpr
import Lean.Util.Profile

namespace Lean
builtin_initialize importingRef : IO.Ref Bool ← IO.mkRef false

/- True while modules are being imported. We use this flag to test check whether environment extensions are registered only
during initialization (builtin ones), and importing (user defined ones). -/
def importing : IO Bool :=
importingRef.get

/- Opaque environment extension state. -/
constant EnvExtensionStateSpec : PointedType.{0}
def EnvExtensionState : Type := EnvExtensionStateSpec.type
Expand Down Expand Up @@ -145,18 +152,20 @@ structure EnvExtensionInterface where
registerExt {σ} (mkInitial : IO σ) : IO (ext σ)
setState {σ} (e : ext σ) (env : Environment) : σ → Environment
modifyState {σ} (e : ext σ) (env : Environment) : (σ → σ) → Environment
getState {σ} (e : ext σ) (env : Environment) : σ
getState {σ} [Inhabited σ] (e : ext σ) (env : Environment) : σ
mkInitialExtStates : IO (Array EnvExtensionState)
ensureExtensionsSize : Environment → IO Environment

instance : Inhabited EnvExtensionInterface where
default := {
ext := id,
inhabitedExt := id,
registerExt := fun mk => mk,
setState := fun _ env _ => env,
modifyState := fun _ env _ => env,
getState := fun ext _ => ext,
mkInitialExtStates := pure #[]
ext := id
inhabitedExt := id
ensureExtensionsSize := fun env => pure env
registerExt := fun mk => mk
setState := fun _ env _ => env
modifyState := fun _ env _ => env
getState := fun ext _ => ext
mkInitialExtStates := pure #[]
}

/- Unsafe implementation of `EnvExtensionInterface` -/
Expand All @@ -170,23 +179,53 @@ structure Ext (σ : Type) where
private def mkEnvExtensionsRef : IO (IO.Ref (Array (Ext EnvExtensionState))) := IO.mkRef #[]
@[builtinInit mkEnvExtensionsRef] private constant envExtensionsRef : IO.Ref (Array (Ext EnvExtensionState))

/--
User-defined environment extensions are declared using the `initialize` command.
This command is just syntax sugar for the `init` attribute.
When we `import` lean modules, the vector stored at `envExtensionsRef` may increase in size because of
user-defined environment extensions. When this happens, we must adjust the size of the `env.extensions`.
This method is invoked when processing `import`s.
-/
partial def ensureExtensionsArraySize (env : Environment) : IO Environment := do
loop env.extensions.size env
where
loop (i : Nat) (env : Environment) : IO Environment := do
let envExtensions ← envExtensionsRef.get
if h : i < envExtensions.size then
let s ← envExtensions[i].mkInitial
let env := { env with extensions := env.extensions.push s }
loop (i + 1) env
else
return env

private def invalidExtMsg := "invalid environment extension has been accessed"

unsafe def setState {σ} (ext : Ext σ) (env : Environment) (s : σ) : Environment :=
{ env with extensions := env.extensions.set! ext.idx (unsafeCast s) }
if h : ext.idx < env.extensions.size then
{ env with extensions := env.extensions.set ⟨ext.idx, h⟩ (unsafeCast s) }
else
panic! invalidExtMsg

@[inline] unsafe def modifyState {σ : Type} (ext : Ext σ) (env : Environment) (f : σ → σ) : Environment :=
{ env with
extensions := env.extensions.modify ext.idx fun s =>
let s : σ := unsafeCast s;
let s : σ := f s;
unsafeCast s }
if ext.idx < env.extensions.size then
{ env with
extensions := env.extensions.modify ext.idx fun s =>
let s : σ := unsafeCast s
let s : σ := f s
unsafeCast s }
else
panic! invalidExtMsg

unsafe def getState {σ} (ext : Ext σ) (env : Environment) : σ :=
let s : EnvExtensionState := env.extensions.get! ext.idx
unsafeCast s
unsafe def getState {σ} [Inhabited σ] (ext : Ext σ) (env : Environment) : σ :=
if h : ext.idx < env.extensions.size then
let s : EnvExtensionState := env.extensions.get ⟨ext.idx, h⟩
unsafeCast s
else
panic! invalidExtMsg

unsafe def registerExt {σ} (mkInitial : IO σ) : IO (Ext σ) := do
let initializing ← IO.initializing
unless initializing do throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
unless (← IO.initializing) || (← importing) do
throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
let exts ← envExtensionsRef.get
let idx := exts.size
let ext : Ext σ := {
Expand All @@ -201,13 +240,14 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do
exts.mapM fun ext => ext.mkInitial

unsafe def imp : EnvExtensionInterface := {
ext := Ext,
inhabitedExt := fun _ => ⟨arbitrary⟩,
registerExt := registerExt,
setState := setState,
modifyState := modifyState,
getState := getState,
mkInitialExtStates := mkInitialExtStates
ext := Ext
ensureExtensionsSize := ensureExtensionsArraySize
inhabitedExt := fun _ => ⟨arbitrary⟩
registerExt := registerExt
setState := setState
modifyState := modifyState
getState := getState
mkInitialExtStates := mkInitialExtStates
}

end EnvExtensionInterfaceUnsafe
Expand All @@ -217,11 +257,14 @@ constant EnvExtensionInterfaceImp : EnvExtensionInterface

def EnvExtension (σ : Type) : Type := EnvExtensionInterfaceImp.ext σ

private def ensureExtensionsArraySize (env : Environment) : IO Environment :=
EnvExtensionInterfaceImp.ensureExtensionsSize env

namespace EnvExtension
instance {σ} [s : Inhabited σ] : Inhabited (EnvExtension σ) := EnvExtensionInterfaceImp.inhabitedExt s
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment := EnvExtensionInterfaceImp.setState ext env s
def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ → σ) : Environment := EnvExtensionInterfaceImp.modifyState ext env f
def getState {σ : Type} (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env
def getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env
end EnvExtension

/- Environment extensions can only be registered during initialization.
Expand Down Expand Up @@ -297,15 +340,15 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)

namespace PersistentEnvExtension

def getModuleEntries {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
(ext.toEnvExtension.getState env).importedEntries.get! m

def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
ext.toEnvExtension.modifyState env fun s =>
let state := ext.addEntryFn s.state b;
{ s with state := state }

def getState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) : σ :=
def getState {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) : σ :=
(ext.toEnvExtension.getState env).state

def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (s : σ) : Environment :=
Expand Down Expand Up @@ -379,10 +422,10 @@ namespace SimplePersistentEnvExtension
instance {α σ : Type} [Inhabited σ] : Inhabited (SimplePersistentEnvExtension α σ) :=
inferInstanceAs (Inhabited (PersistentEnvExtension α α (List α × σ)))

def getEntries {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α :=
def getEntries {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α :=
(PersistentEnvExtension.getState ext env).1

def getState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ :=
def getState {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ :=
(PersistentEnvExtension.getState ext env).2

def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (s : σ) : Environment :=
Expand Down Expand Up @@ -518,23 +561,36 @@ private partial def getEntriesFor (mod : ModuleData) (extId : Name) (i : Nat) :
else
#[]

private def setImportedEntries (env : Environment) (mods : Array ModuleData) : IO Environment := do
private def setImportedEntries (env : Environment) (mods : Array ModuleData) (startingAt : Nat := 0) : IO Environment := do
let mut env := env
let pExtDescrs ← persistentEnvExtensionsRef.get
for mod in mods do
for extDescr in pExtDescrs do
for extDescr in pExtDescrs[startingAt:] do
let entries := getEntriesFor mod extDescr.name 0
env ← extDescr.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.push entries }
return env

private def finalizePersistentExtensions (env : Environment) (opts : Options) : IO Environment := do
let mut env := env
let pExtDescrs ← persistentEnvExtensionsRef.get
for extDescr in pExtDescrs do
let s := extDescr.toEnvExtension.getState env
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
env ← extDescr.toEnvExtension.setState env { s with state := newState }
return env
private partial def finalizePersistentExtensions (env : Environment) (mods : Array ModuleData) (opts : Options) : IO Environment := do
loop 0 env
where
loop (i : Nat) (env : Environment) : IO Environment := do
-- Recall that the size of the array stored `persistentEnvExtensionRef` may increase when we import user-defined environment extensions.
let pExtDescrs ← persistentEnvExtensionsRef.get
if h : i < pExtDescrs.size then
let extDescr := pExtDescrs[i]
let s := extDescr.toEnvExtension.getState env
let prevSize := (← persistentEnvExtensionsRef.get).size
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
let mut env ← extDescr.toEnvExtension.setState env { s with state := newState }
env ← ensureExtensionsArraySize env
if (← persistentEnvExtensionsRef.get).size > prevSize then
-- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and
-- a user-defined persistent extension is imported.
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
env ← setImportedEntries env mods prevSize
loop (i + 1) env
else
return env

structure ImportState where
moduleNameSet : NameSet := {}
Expand All @@ -544,34 +600,38 @@ structure ImportState where

@[export lean_import_modules]
partial def importModules (imports : List Import) (opts : Options) (trustLevel : UInt32 := 0) : IO Environment := profileitIO "import" opts do
let (_, s) ← importMods imports |>.run {}
-- (moduleNames, mods, regions)
let mut modIdx : Nat := 0
let mut const2ModIdx : HashMap Name ModuleIdx := {}
let mut constants : ConstMap := SMap.empty
for mod in s.moduleData do
for cinfo in mod.constants do
const2ModIdx := const2ModIdx.insert cinfo.name modIdx
if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'")
constants := constants.insert cinfo.name cinfo
modIdx := modIdx + 1
constants := constants.switch
let exts ← mkInitialExtensionStates
let env : Environment := {
const2ModIdx := const2ModIdx,
constants := constants,
extensions := exts,
header := {
quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module
trustLevel := trustLevel,
imports := imports.toArray,
regions := s.regions,
moduleNames := s.moduleNames
try
importingRef.set true
let (_, s) ← importMods imports |>.run {}
-- (moduleNames, mods, regions)
let mut modIdx : Nat := 0
let mut const2ModIdx : HashMap Name ModuleIdx := {}
let mut constants : ConstMap := SMap.empty
for mod in s.moduleData do
for cinfo in mod.constants do
const2ModIdx := const2ModIdx.insert cinfo.name modIdx
if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'")
constants := constants.insert cinfo.name cinfo
modIdx := modIdx + 1
constants := constants.switch
let exts ← mkInitialExtensionStates
let env : Environment := {
const2ModIdx := const2ModIdx,
constants := constants,
extensions := exts,
header := {
quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module
trustLevel := trustLevel,
imports := imports.toArray,
regions := s.regions,
moduleNames := s.moduleNames
}
}
}
let env ← setImportedEntries env s.moduleData
let env ← finalizePersistentExtensions env opts
pure env
let env ← setImportedEntries env s.moduleData
let env ← finalizePersistentExtensions env s.moduleData opts
pure env
finally
importingRef.set false
where
importMods : List Import → StateRefT ImportState IO Unit
| [] => pure ()
Expand Down
8 changes: 8 additions & 0 deletions src/shell/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,11 @@ add_test(NAME leanpkgtest_cyclic
set -eu
export PATH=${LEAN_BIN}:$PATH
leanpkg build 2>&1 | grep 'import cycle'")

add_test(NAME leanpkgtest_user_ext
WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/leanpkg/user_ext"
COMMAND bash -c "
set -eu
export PATH=${LEAN_BIN}:$PATH
find . -name '*.olean' -delete
leanpkg build | grep 'world, hello, test'")
1 change: 1 addition & 0 deletions tests/leanpkg/user_ext/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/build
5 changes: 5 additions & 0 deletions tests/leanpkg/user_ext/UserExt.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import UserExt.Tst1
import UserExt.Tst2

show_foo_set
show_bla_set
23 changes: 23 additions & 0 deletions tests/leanpkg/user_ext/UserExt/BlaExt.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import Lean

open Lean

initialize blaExtension : SimplePersistentEnvExtension Name NameSet ←
registerSimplePersistentEnvExtension {
name := `blaExt
addEntryFn := NameSet.insert
addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es
}

syntax (name := insertBla) "insert_bla " ident : command
syntax (name := showBla) "show_bla_set" : command

open Lean.Elab
open Lean.Elab.Command

@[commandElab insertBla] def elabInsertBla : CommandElab := fun stx => do
IO.println s!"inserting {stx[1].getId}"
modifyEnv fun env => blaExtension.addEntry env stx[1].getId

@[commandElab showBla] def elabShowBla : CommandElab := fun stx => do
IO.println s!"bla set: {blaExtension.getState (← getEnv) |>.toList}"
23 changes: 23 additions & 0 deletions tests/leanpkg/user_ext/UserExt/FooExt.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import Lean

open Lean

initialize fooExtension : SimplePersistentEnvExtension Name NameSet ←
registerSimplePersistentEnvExtension {
name := `fooExt
addEntryFn := NameSet.insert
addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es
}

syntax (name := insertFoo) "insert_foo " ident : command
syntax (name := showFoo) "show_foo_set" : command

open Lean.Elab
open Lean.Elab.Command

@[commandElab insertFoo] def elabInsertFoo : CommandElab := fun stx => do
IO.println s!"inserting {stx[1].getId}"
modifyEnv fun env => fooExtension.addEntry env stx[1].getId

@[commandElab showFoo] def elabShowFoo : CommandElab := fun stx => do
IO.println s!"foo set: {fooExtension.getState (← getEnv) |>.toList}"
9 changes: 9 additions & 0 deletions tests/leanpkg/user_ext/UserExt/Tst1.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import UserExt.FooExt
import UserExt.BlaExt

insert_foo hello
insert_foo world
show_foo_set

insert_bla abc
show_bla_set
6 changes: 6 additions & 0 deletions tests/leanpkg/user_ext/UserExt/Tst2.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import UserExt.BlaExt
import UserExt.FooExt

insert_foo test
insert_bla defg
insert_bla hij
3 changes: 3 additions & 0 deletions tests/leanpkg/user_ext/leanpkg.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[package]
name = "UserExt"
version = "0.1"

0 comments on commit cdd1dbb

Please sign in to comment.