Skip to content

Commit

Permalink
feat: add ctx[Function] and refactor contrib to take into account ove…
Browse files Browse the repository at this point in the history
…rlays
  • Loading branch information
vberlier committed Oct 7, 2023
1 parent 1297472 commit 6f0cec4
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 57 deletions.
5 changes: 2 additions & 3 deletions beet/contrib/function_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from typing import List, Optional

from beet import Context, PluginOptions, configurable
from beet import Context, Function, PluginOptions, configurable


class FunctionHeaderOptions(PluginOptions):
Expand All @@ -27,8 +27,7 @@ def function_header(ctx: Context, opts: FunctionHeaderOptions):
if not opts.template:
return

for path in ctx.data.functions.match(*opts.match):
for function, (_, path) in ctx.select(match=opts.match, extend=Function).items():
with ctx.override(render_path=path, render_group="functions"):
header = ctx.template.render(opts.template)
function = ctx.data.functions[path]
function.text = header + function.text
4 changes: 2 additions & 2 deletions beet/contrib/minify_function.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Plugin that minifies function files."""


from beet import Context
from beet import Context, Function


def beet_default(ctx: Context):
for function in ctx.data.functions.values():
for _, function in ctx[Function]:
function.text = "".join(
stripped + "\n"
for line in function.lines
Expand Down
4 changes: 2 additions & 2 deletions beet/contrib/relative_function_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import re
from pathlib import PurePosixPath

from beet import Context
from beet import Context, Function

REGEX_RELATIVE_PATH = re.compile(r"^(|.*\s)function\s+(\.\.?/\S+)(\s*)$")


def beet_default(ctx: Context):
for path, function in ctx.data.functions.items():
for path, function in ctx[Function]:
namespace, _, original_path = path.partition(":")
current_dir = PurePosixPath(original_path).parent

Expand Down
2 changes: 1 addition & 1 deletion beet/contrib/rename_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def handle_filename_for_namespace_file(

logger.warning(
'Invalid %s destination "%s".',
snake_case(file_type.__name__),
file_type.snake_name,
dest,
extra={"annotate": filename},
)
Expand Down
42 changes: 15 additions & 27 deletions beet/contrib/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
]


from typing import Dict, Union
from typing import Any

from beet import Context, ListOption, PluginOptions, configurable
from beet.core.utils import snake_case
from beet import Context, PluginOptions, TextFileBase, configurable
from beet.toolchain.select import PackMatchOption, PackSelector


class RenderOptions(PluginOptions):
resource_pack: Union[Dict[str, ListOption[str]], ListOption[str]] = {}
data_pack: Union[Dict[str, ListOption[str]], ListOption[str]] = {}
resource_pack: PackMatchOption = PackMatchOption()
data_pack: PackMatchOption = PackMatchOption()


def beet_default(ctx: Context):
Expand All @@ -26,25 +26,13 @@ def beet_default(ctx: Context):
def render(ctx: Context, opts: RenderOptions):
"""Plugin that processes the data pack and the resource pack with Jinja."""
for groups, pack in zip([opts.resource_pack, opts.data_pack], ctx.packs):
file_types = set(pack.resolve_scope_map().values())
group_map = {
snake_case(file_type.__name__): file_type for file_type in file_types
}

if isinstance(groups, ListOption):
groups = {k: groups for k in group_map}
else:
for singular in list(group_map):
group_map.setdefault(f"{singular}s", group_map[singular])

for group, render_options in groups.items():
try:
file_type = group_map[group]
proxy = pack[file_type]
file_paths = proxy.match(*render_options.entries())
except:
raise ValueError(f"Invalid render group {group!r}.") from None
else:
for path in file_paths:
with ctx.override(render_path=path, render_group=group):
ctx.template.render_file(proxy[path]) # type: ignore
for file_instance, (_, path) in (
PackSelector.from_options(groups, template=ctx.template)
.select_files(pack, extend=TextFileBase[Any])
.items()
):
with ctx.override(
render_path=path,
render_group=f"{file_instance.snake_name}s",
):
ctx.template.render_file(file_instance)
6 changes: 6 additions & 0 deletions beet/core/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ class File(Generic[ValueType, SerializeType]):

original: "File[ValueType, SerializeType]" = extra_field(default=None)

snake_name: ClassVar[str] = "file"

def __init_subclass__(cls):
super().__init_subclass__()
cls.snake_name = snake_case(cls.__name__)

def __post_init__(self):
if self._content is self.source_path is None:
self._content = self.default()
Expand Down
2 changes: 2 additions & 0 deletions beet/library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class NamespaceFile(Protocol):
scope: ClassVar[Tuple[str, ...]]
extension: ClassVar[str]

snake_name: ClassVar[str]

def __init__(
self,
_content: Optional[Any] = None,
Expand Down
7 changes: 6 additions & 1 deletion beet/toolchain/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Protocol,
Expand Down Expand Up @@ -54,7 +55,7 @@

from .generator import Generator
from .pipeline import GenericPipeline, GenericPlugin, GenericPluginSpec
from .select import PackSelection, select_files
from .select import PackSelection, select_all, select_files
from .template import TemplateManager
from .tree import generate_tree
from .worker import WorkerPoolHandle
Expand Down Expand Up @@ -365,6 +366,10 @@ def select(

return result

def __getitem__(self, extend: Type[T]) -> Iterable[Tuple[str, T]]:
for pack in self.packs:
yield from select_all(pack, extend)


@overload
def configurable(
Expand Down
62 changes: 41 additions & 21 deletions beet/toolchain/select.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
__all__ = [
"select_files",
"select_all",
"PackSelector",
"PackSelection",
"PackSelectOption",
"PackMatchOption",
"PathSpecOption",
"RegexOption",
"RegexFlagsOption",
Expand Down Expand Up @@ -30,7 +32,6 @@
from pydantic import BaseModel

from beet.core.file import File
from beet.core.utils import snake_case
from beet.library.base import NamespaceFile, Pack

from .config import ListOption
Expand Down Expand Up @@ -93,9 +94,26 @@ def compile_spec(
return PathSpec.from_lines("gitwildmatch", patterns)


class PackMatchOption(BaseModel):
__root__: Union[PathSpecOption, Dict[str, PathSpecOption]] = PathSpecOption()

def compile_match(
self,
template: Optional[TemplateManager] = None,
) -> Optional[Union[PathSpec, Dict[str, PathSpec]]]:
if isinstance(self.__root__, PathSpecOption):
return self.__root__.compile_spec(template)
else:
return {
group_name: spec
for group_name, match in self.__root__.items()
if (spec := match.compile_spec(template))
}


class PackSelectOption(BaseModel):
files: RegexOption = RegexOption()
match: Union[PathSpecOption, Dict[str, PathSpecOption]] = PathSpecOption()
match: PackMatchOption = PackMatchOption()

class Config:
extra = "forbid"
Expand All @@ -107,22 +125,7 @@ def compile(
Optional["re.Pattern[str]"],
Optional[Union[PathSpec, Dict[str, PathSpec]]],
]:
files_regex = None
match_spec = None

if self.files:
files_regex = self.files.compile_regex(template)

if isinstance(self.match, PathSpecOption):
match_spec = self.match.compile_spec(template)
else:
match_spec = {
group_name: spec
for group_name, match in self.match.items()
if (spec := match.compile_spec(template))
}

return files_regex, match_spec
return self.files.compile_regex(template), self.match.compile_match(template)


@dataclass(frozen=True)
Expand All @@ -133,14 +136,20 @@ class PackSelector:
@classmethod
def from_options(
cls,
select_options: Optional[PackSelectOption] = None,
select_options: Optional[
Union[PackSelectOption, RegexOption, PackMatchOption]
] = None,
*,
files: Optional[Any] = None,
match: Optional[Any] = None,
template: Optional[TemplateManager] = None,
) -> "PackSelector":
if select_options:
if isinstance(select_options, PackSelectOption):
return PackSelector(*select_options.compile(template))
if isinstance(select_options, RegexOption):
return cls.from_options(PackSelectOption(files=select_options))
if isinstance(select_options, PackMatchOption):
return cls.from_options(PackSelectOption(match=select_options))
values = {}
if files:
values["files"] = files
Expand Down Expand Up @@ -194,7 +203,7 @@ def select_files(
file_types = {t for t in file_types if issubclass(t, extend)}

if isinstance(self.match_spec, dict):
group_map = {snake_case(t.__name__): t for t in file_types}
group_map = {t.snake_name: t for t in file_types}
for singular in list(group_map):
group_map.setdefault(f"{singular}s", group_map[singular])

Expand Down Expand Up @@ -248,6 +257,17 @@ def select_files(
)


def select_all(pack: Pack[Any], extend: Type[T]) -> Iterable[Tuple[str, T]]:
file_types = set(pack.resolve_scope_map().values())
if extend:
file_types = {t for t in file_types if issubclass(t, extend)}
for file_type in file_types:
yield from pack[file_type].items() # type: ignore
if pack.overlay_parent is None:
for overlay in pack.overlays.values():
yield from overlay[file_type].items() # type: ignore


def _gather_from_pack(
pack: Pack[Any], file_type: Type[NamespaceFile], spec: PathSpec
) -> Iterable[Tuple[NamespaceFile, Tuple[Optional[str], Optional[str]]]]:
Expand Down

0 comments on commit 6f0cec4

Please sign in to comment.