Skip to content

Commit

Permalink
feat: merge and default overloads for context generator
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Sep 18, 2022
1 parent 520de01 commit 810a859
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 17 deletions.
32 changes: 32 additions & 0 deletions beet/library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class NamespaceFile(Protocol):
scope: ClassVar[Tuple[str, ...]]
extension: ClassVar[str]

def merge(self, other: Any) -> bool:
...

def bind(self, pack: Any, path: str) -> Any:
...

Expand All @@ -120,6 +123,10 @@ def ensure_deserialized(
) -> Any:
...

@classmethod
def default(cls) -> Any:
...

@classmethod
def load(cls: Type[T], origin: FileOrigin, path: FileSystemPath) -> T:
...
Expand Down Expand Up @@ -319,6 +326,23 @@ def bind(self, namespace: "Namespace", file_type: Type[NamespaceFileType]):
except Drop:
del self[key]

def setdefault(
self,
key: str,
default: Optional[NamespaceFileType] = None,
) -> NamespaceFileType:
if value := self.get(key):
return value
if default:
self[key] = default
else:
if not self.file_type:
raise ValueError(
"File type associated to the namespace container is not available."
)
self[key] = self.file_type()
return self[key]

def merge(self, other: Mapping[Any, SupportsMerge]) -> bool:
if (
self.namespace is not None
Expand Down Expand Up @@ -609,6 +633,14 @@ def split_key(self, key: str) -> Tuple[str, str]:
def join_key(self, key1: str, key2: str) -> str:
return f"{key1}:{key2}"

def setdefault(
self,
key: str,
default: Optional[NamespaceFileType] = None,
) -> NamespaceFileType:
key1, key2 = self.split_key(key)
return self.proxy[key1][self.proxy_key].setdefault(key2, default) # type: ignore

def walk(self) -> Iterator[Tuple[str, Set[str], Dict[str, NamespaceFileType]]]:
"""Walk over the file hierarchy."""
for prefix, namespace in self.proxy.items():
Expand Down
82 changes: 66 additions & 16 deletions beet/toolchain/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
Iterator,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)

Expand All @@ -37,6 +40,7 @@

T = TypeVar("T", contravariant=True)
GeneratorType = TypeVar("GeneratorType", bound="Generator")
NamespaceFileType = TypeVar("NamespaceFileType", bound="NamespaceFile")


@dataclass
Expand Down Expand Up @@ -131,6 +135,26 @@ def __call__(
) -> str:
...

@overload
def __call__(
self,
fmt: str,
*,
merge: NamespaceFile,
hash: Any = None,
) -> str:
...

@overload
def __call__(
self,
fmt: str,
*,
default: Union[Type[NamespaceFileType], NamespaceFileType],
hash: Any = None,
) -> NamespaceFileType:
...

@overload
def __call__(
self,
Expand All @@ -154,35 +178,61 @@ def __call__(
self,
*args: Any,
render: Optional[TextFileBase[Any]] = None,
merge: Optional[NamespaceFile] = None,
default: Optional[Union[Type[NamespaceFile], NamespaceFile]] = None,
hash: Any = None,
**kwargs: Any,
) -> Any:
file_instance: NamespaceFile
default_file = None

if render:
file_instance = render # type: ignore
fmt = args[0] if args else None
elif len(args) == 2:
fmt, file_instance = args
else:
file_instance = args[0]
fmt = None
if default:
if isinstance(default, type):
file_type = default
else:
file_type = type(default)
default_file = default

if hash is None and not render:
hash = lambda: file_instance.ensure_serialized()
file_instance = None
fmt = args[0]

file_type = type(file_instance)
key = (
self[file_type].path(fmt, hash) if fmt else self[file_type].path(hash=hash)
)
else:
if render:
file_instance = cast(NamespaceFile, render)
fmt = args[0] if args else None
elif merge:
file_instance = merge
fmt = args[0]
elif len(args) == 2:
fmt, file_instance = args
else:
file_instance = args[0]
fmt = None

if hash is None and not render:
hash = lambda: file_instance.ensure_serialized()

file_type = type(file_instance)

pack = (
self.data
if file_type in self.data.namespace_type.field_map
else self.assets
)

pack[key] = file_instance
if not fmt:
key = self[file_type].path(hash=hash)
elif ":" in fmt:
key = self[file_type].format(fmt, hash)
else:
key = self[file_type].path(fmt, hash)

if file_instance:
if merge:
pack[file_type].merge({key: file_instance})
else:
pack[key] = file_instance
elif default:
return pack[file_type].setdefault(key, default_file)

if render:
with self.ctx.override(
Expand Down
12 changes: 11 additions & 1 deletion examples/code_context_generator/demo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from beet import Context, Function
from beet import Context, Function, FunctionTag


def beet_default(ctx: Context):
Expand Down Expand Up @@ -40,3 +40,13 @@ def beet_default(ctx: Context):
tag2 = generate.id("foo")
obj2 = generate.hash("foo")
generate(Function([f"scoreboard players set @s[tag={tag2}] {obj2} 1"]))

func = ctx.generate(
"hoisted:{namespace}/{path}{hash}",
Function(["say hoisted"], tags=["minecraft:load"]),
)

for i in range(3):
ctx.generate("hoisted:foo", merge=FunctionTag({"values": [f"demo:foo{i}"]}))
ctx.generate("hoisted:abc", default=Function).append(f"function {func}")
ctx.generate("hoisted:def", default=Function("say init")).prepend(f"say {i+1}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function hoisted:code_context_generator/3n96dmzzfmn4w
function hoisted:code_context_generator/3n96dmzzfmn4w
function hoisted:code_context_generator/3n96dmzzfmn4w
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
say hoisted
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
say 3
say 2
say 1
say init
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"values": [
"demo:foo0",
"demo:foo1",
"demo:foo2"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"values": [
"hoisted:code_context_generator/3n96dmzzfmn4w"
]
}

0 comments on commit 810a859

Please sign in to comment.