diff --git a/mecha/api.py b/mecha/api.py index 50fafbe..f21cb84 100644 --- a/mecha/api.py +++ b/mecha/api.py @@ -36,7 +36,7 @@ Context, DataPack, Function, - NamespaceFile, + ResourcePack, TextFileBase, ) from beet.core.utils import ( @@ -51,7 +51,12 @@ from .ast import AstLiteral, AstNode, AstRoot from .config import CommandTree -from .database import CompilationDatabase, CompilationUnit +from .database import ( + CompilationDatabase, + CompilationUnit, + CompilationUnitProvider, + FileTypeCompilationUnitProvider, +) from .diagnostic import ( Diagnostic, DiagnosticCollection, @@ -66,6 +71,7 @@ AstNodeType = TypeVar("AstNodeType", bound=AstNode) TextFileType = TypeVar("TextFileType", bound=TextFileBase[Any]) +PackType = TypeVar("PackType", bound=Union[ResourcePack, DataPack]) logger = logging.getLogger("mecha") @@ -156,9 +162,7 @@ class Mecha: spec: CommandSpec = extra_field(default=None) - providers: List[Type[NamespaceFile]] = extra_field( - default_factory=lambda: [Function] - ) + providers: List[CompilationUnitProvider] = extra_field(init=False) preprocessor: Preprocessor = extra_field(default=wrap_backslash_continuation) @@ -234,6 +238,8 @@ def __post_init__( parsers=get_parsers(version), ) + self.providers = [FileTypeCompilationUnitProvider([Function], self.directory)] + self.serialize = Serializer( spec=self.spec, database=self.database, @@ -375,7 +381,7 @@ def parse( @overload def compile( self, - source: DataPack, + source: PackType, *, match: Optional[List[str]] = None, multiline: Optional[bool] = None, @@ -383,7 +389,7 @@ def compile( readonly: Optional[bool] = None, initial_step: int = 0, report: Optional[DiagnosticCollection] = None, - ) -> DataPack: + ) -> PackType: ... @overload @@ -393,7 +399,7 @@ def compile( *, filename: Optional[FileSystemPath] = None, resource_location: Optional[str] = None, - within: Optional[DataPack] = None, + within: Optional[Union[ResourcePack, DataPack]] = None, multiline: Optional[bool] = None, formatting: Optional[JsonDict] = None, readonly: Optional[bool] = None, @@ -409,7 +415,7 @@ def compile( *, filename: Optional[FileSystemPath] = None, resource_location: Optional[str] = None, - within: Optional[DataPack] = None, + within: Optional[Union[ResourcePack, DataPack]] = None, multiline: Optional[bool] = None, formatting: Optional[JsonDict] = None, readonly: Optional[bool] = None, @@ -420,18 +426,20 @@ def compile( def compile( self, - source: Union[DataPack, TextFileBase[Any], List[str], str, AstRoot], + source: Union[ + Union[ResourcePack, DataPack], TextFileBase[Any], List[str], str, AstRoot + ], *, match: Optional[List[str]] = None, filename: Optional[FileSystemPath] = None, resource_location: Optional[str] = None, - within: Optional[DataPack] = None, + within: Optional[Union[ResourcePack, DataPack]] = None, multiline: Optional[bool] = None, formatting: Optional[JsonDict] = None, readonly: Optional[bool] = None, initial_step: int = 0, report: Optional[DiagnosticCollection] = None, - ) -> Union[DataPack, TextFileBase[Any]]: + ) -> Union[Union[ResourcePack, DataPack], TextFileBase[Any]]: """Apply all compilation steps.""" self.database.setup_compilation() @@ -440,32 +448,16 @@ def compile( if readonly is None: readonly = self.readonly - if isinstance(source, DataPack): + if isinstance(source, (ResourcePack, DataPack)): result = source if match is None: match = self.match - packs = [source] - if source.overlay_parent is None: - packs.extend(source.overlays.values()) - - for file_type in self.providers: - if not issubclass(file_type, TextFileBase): - continue - for pack in packs: - for key in pack[file_type].match(*match or ["*"]): - value = pack[file_type][key] - self.database[value] = CompilationUnit( - resource_location=key, - filename=( - os.path.relpath(value.source_path, self.directory) - if value.source_path - else None - ), - pack=pack, - ) - self.database.enqueue(value) + for provider in self.providers: + for file_instance, compilation_unit in provider(source, match): + self.database[file_instance] = compilation_unit + self.database.enqueue(file_instance) else: if isinstance(source, (list, str)): source = Function(source) @@ -477,31 +469,35 @@ def compile( else: result = Function() - self.database[result] = CompilationUnit( - ast=source if isinstance(source, AstRoot) else None, + compilation_unit = CompilationUnit( resource_location=resource_location, filename=str(filename) if filename else None, pack=within, ) + + if isinstance(source, AstRoot): + compilation_unit.ast = source + + self.database[result] = compilation_unit self.database.enqueue(result) - for step, function in self.database.process_queue(): - compilation_unit = self.database[function] + for step, file_instance in self.database.process_queue(): + compilation_unit = self.database[file_instance] start_time = perf_counter_ns() if step < 0: if compilation_unit.ast: - self.database.enqueue(function, initial_step) + self.database.enqueue(file_instance, initial_step) continue try: - compilation_unit.source = function.text + compilation_unit.source = file_instance.text compilation_unit.ast = self.parse( - function, + file_instance, filename=compilation_unit.filename, resource_location=compilation_unit.resource_location, multiline=multiline, ) - self.database.enqueue(function, initial_step) + self.database.enqueue(file_instance, initial_step) except DiagnosticError as exc: compilation_unit.diagnostics.extend(exc.diagnostics) @@ -513,27 +509,30 @@ def compile( if not compilation_unit.diagnostics.error: compilation_unit.ast = ast self.database.enqueue( - key=function, + key=file_instance, step=step + 1, priority=compilation_unit.priority, ) - - elif not readonly: + else: + if readonly: + continue if not compilation_unit.ast: continue with self.serialize.use_diagnostics(compilation_unit.diagnostics): - function.text = self.serialize(compilation_unit.ast, **formatting) + file_instance.text = self.serialize( + compilation_unit.ast, **formatting + ) compilation_unit.perf[step] = (perf_counter_ns() - start_time) * 1e-06 - sorted_functions = sorted( + sorted_source_files = sorted( self.database.session, key=lambda f: self.database[f].resource_location or "", ) if self.perf_report is not None: - for function in sorted_functions: - compilation_unit = self.database[function] + for file_instance in sorted_source_files: + compilation_unit = self.database[file_instance] self.perf_report.append( ( compilation_unit.filename or "", @@ -549,8 +548,8 @@ def compile( diagnostics = DiagnosticCollection( [ exc - for function in sorted_functions - for exc in self.database[function].diagnostics.exceptions + for file_instance in sorted_source_files + for exc in self.database[file_instance].diagnostics.exceptions ] ) diff --git a/mecha/database.py b/mecha/database.py index a1cdc80..3f40fea 100644 --- a/mecha/database.py +++ b/mecha/database.py @@ -1,17 +1,41 @@ __all__ = [ "CompilationDatabase", "CompilationUnit", + "CompilationUnitProvider", + "FileTypeCompilationUnitProvider", ] +import os from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from heapq import heappop, heappush -from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple - -from beet import Container, DataPack, Generator, TextFile, TextFileBase -from beet.core.utils import extra_field +from typing import ( + Any, + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Optional, + Protocol, + Set, + Tuple, + Type, + Union, +) + +from beet import ( + Container, + DataPack, + Generator, + NamespaceFile, + ResourcePack, + TextFile, + TextFileBase, +) +from beet.core.utils import FileSystemPath, extra_field from .ast import AstRoot from .diagnostic import DiagnosticCollection @@ -25,7 +49,7 @@ class CompilationUnit: source: Optional[str] = None filename: Optional[str] = None resource_location: Optional[str] = None - pack: Optional[DataPack] = None + pack: Optional[Union[ResourcePack, DataPack]] = None priority: int = 0 diagnostics: DiagnosticCollection = extra_field(init=False) @@ -42,7 +66,9 @@ def __post_init__(self): class CompilationDatabase(Container[TextFileBase[Any], CompilationUnit]): """Compilation database.""" - indices: DefaultDict[Optional[DataPack], Dict[str, TextFileBase[Any]]] + indices: DefaultDict[ + Optional[Union[ResourcePack, DataPack]], Dict[str, TextFileBase[Any]] + ] session: Set[TextFileBase[Any]] queue: List[Tuple[int, int, str, int, TextFileBase[Any]]] step: int @@ -127,3 +153,49 @@ def process_context(self, file_instance: TextFileBase[Any]): yield return yield + + +class CompilationUnitProvider(Protocol): + """Provide source files for compilation.""" + + def __call__( + self, + pack: Union[ResourcePack, DataPack], + match: Optional[List[str]] = None, + ) -> Iterable[Tuple[TextFileBase[Any], CompilationUnit]]: + ... + + +@dataclass +class FileTypeCompilationUnitProvider: + """Provide source files based on their type.""" + + file_types: List[Type[NamespaceFile]] + directory: FileSystemPath + + def __call__( + self, + pack: Union[ResourcePack, DataPack], + match: Optional[List[str]] = None, + ) -> Iterable[Tuple[TextFileBase[Any], CompilationUnit]]: + packs = [pack] + if pack.overlay_parent is None: + packs.extend(pack.overlays.values()) + + for file_type in self.file_types: + if not issubclass(file_type, TextFileBase): + continue + + for pack in packs: + for resource_location in pack[file_type].match(*match or ["*"]): + file_instance = pack[file_type][resource_location] + + yield file_instance, CompilationUnit( + resource_location=resource_location, + filename=( + os.path.relpath(file_instance.source_path, self.directory) + if file_instance.source_path + else None + ), + pack=pack, + ) diff --git a/tests/test_compile.py b/tests/test_compile.py index 73a0d88..5352141 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -1,19 +1,22 @@ from dataclasses import replace -from typing import Any +from typing import Any, ClassVar, Tuple import pytest -from beet import DataPack, Function, TextFile +from beet import DataPack, Function, ResourcePack, TextFile from mecha import ( + AstBool, AstChildren, AstCommand, AstJsonValue, AstMessage, AstMessageText, + AstRoot, AstSelector, Diagnostic, DiagnosticCollection, DiagnosticError, + FileTypeCompilationUnitProvider, Mecha, rule, ) @@ -179,3 +182,72 @@ def test_lint_error_report(mc: Mecha, dummy_transform: Any, dummy_lint_error: An d = mc.database[function].diagnostics.exceptions[0] assert d.level == "error" assert d.format_message() == "Really don't. (really_do_not_use_say)" + + +class DummySourceFile(TextFile): + scope: ClassVar[Tuple[str, ...]] = ("dummy",) + extension: ClassVar[str] = ".txt" + + +@pytest.fixture +def dummy_source_file_provider(mc: Mecha): + previous_providers = mc.providers + mc.providers = [FileTypeCompilationUnitProvider([DummySourceFile], mc.directory)] + yield + mc.providers = previous_providers + + +def test_dummy_source_file_provider(mc: Mecha, dummy_source_file_provider: Any): + p = DataPack() + p.extend_namespace.append(DummySourceFile) + + a = Function("say fold \\\n this") + b = DummySourceFile("# some comment") + c = DummySourceFile("gamerule keepInventory true") + + p["demo:a"] = a + p["demo:b"] = b + p["demo:c"] = c + + diagnostics = DiagnosticCollection() + mc.compile(p, report=diagnostics) + + assert a == Function("say fold \\\n this") + assert b == DummySourceFile("") + assert c == DummySourceFile("gamerule keepInventory true\n") + + assert a not in mc.database + assert mc.database[b].ast == AstRoot(commands=AstChildren()) + assert mc.database[c].ast == AstRoot( + commands=AstChildren( + [ + AstCommand( + identifier="gamerule:keepInventory:value", + arguments=AstChildren([AstBool(value=True)]), + ) + ] + ) + ) + + +def test_assets_dummy_source_file_provider(mc: Mecha, dummy_source_file_provider: Any): + r = ResourcePack() + r.extend_namespace.append(DummySourceFile) + + a = DummySourceFile("gamerule keep\\\nInventory true") + r["demo:a"] = a + + diagnostics = DiagnosticCollection() + mc.compile(r, report=diagnostics) + + assert a == DummySourceFile("gamerule keepInventory true\n") + assert mc.database[a].ast == AstRoot( + commands=AstChildren( + [ + AstCommand( + identifier="gamerule:keepInventory:value", + arguments=AstChildren([AstBool(value=True)]), + ) + ] + ) + )