Skip to content

Commit

Permalink
feat: compilation unit provider
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Oct 30, 2023
1 parent a8c8a8d commit 2738c9b
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 58 deletions.
99 changes: 49 additions & 50 deletions mecha/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
Context,
DataPack,
Function,
NamespaceFile,
ResourcePack,
TextFileBase,
)
from beet.core.utils import (
Expand All @@ -51,7 +51,12 @@

from .ast import AstLiteral, AstNode, AstRoot
from .config import CommandTree
from .database import CompilationDatabase, CompilationUnit
from .database import (
CompilationDatabase,
CompilationUnit,
CompilationUnitProvider,
FileTypeCompilationUnitProvider,
)
from .diagnostic import (
Diagnostic,
DiagnosticCollection,
Expand All @@ -66,6 +71,7 @@

AstNodeType = TypeVar("AstNodeType", bound=AstNode)
TextFileType = TypeVar("TextFileType", bound=TextFileBase[Any])
PackType = TypeVar("PackType", bound=Union[ResourcePack, DataPack])


logger = logging.getLogger("mecha")
Expand Down Expand Up @@ -156,9 +162,7 @@ class Mecha:

spec: CommandSpec = extra_field(default=None)

providers: List[Type[NamespaceFile]] = extra_field(
default_factory=lambda: [Function]
)
providers: List[CompilationUnitProvider] = extra_field(init=False)

preprocessor: Preprocessor = extra_field(default=wrap_backslash_continuation)

Expand Down Expand Up @@ -234,6 +238,8 @@ def __post_init__(
parsers=get_parsers(version),
)

self.providers = [FileTypeCompilationUnitProvider([Function], self.directory)]

self.serialize = Serializer(
spec=self.spec,
database=self.database,
Expand Down Expand Up @@ -375,15 +381,15 @@ def parse(
@overload
def compile(
self,
source: DataPack,
source: PackType,
*,
match: Optional[List[str]] = None,
multiline: Optional[bool] = None,
formatting: Optional[JsonDict] = None,
readonly: Optional[bool] = None,
initial_step: int = 0,
report: Optional[DiagnosticCollection] = None,
) -> DataPack:
) -> PackType:
...

@overload
Expand All @@ -393,7 +399,7 @@ def compile(
*,
filename: Optional[FileSystemPath] = None,
resource_location: Optional[str] = None,
within: Optional[DataPack] = None,
within: Optional[Union[ResourcePack, DataPack]] = None,
multiline: Optional[bool] = None,
formatting: Optional[JsonDict] = None,
readonly: Optional[bool] = None,
Expand All @@ -409,7 +415,7 @@ def compile(
*,
filename: Optional[FileSystemPath] = None,
resource_location: Optional[str] = None,
within: Optional[DataPack] = None,
within: Optional[Union[ResourcePack, DataPack]] = None,
multiline: Optional[bool] = None,
formatting: Optional[JsonDict] = None,
readonly: Optional[bool] = None,
Expand All @@ -420,18 +426,20 @@ def compile(

def compile(
self,
source: Union[DataPack, TextFileBase[Any], List[str], str, AstRoot],
source: Union[
Union[ResourcePack, DataPack], TextFileBase[Any], List[str], str, AstRoot
],
*,
match: Optional[List[str]] = None,
filename: Optional[FileSystemPath] = None,
resource_location: Optional[str] = None,
within: Optional[DataPack] = None,
within: Optional[Union[ResourcePack, DataPack]] = None,
multiline: Optional[bool] = None,
formatting: Optional[JsonDict] = None,
readonly: Optional[bool] = None,
initial_step: int = 0,
report: Optional[DiagnosticCollection] = None,
) -> Union[DataPack, TextFileBase[Any]]:
) -> Union[Union[ResourcePack, DataPack], TextFileBase[Any]]:
"""Apply all compilation steps."""
self.database.setup_compilation()

Expand All @@ -440,32 +448,16 @@ def compile(
if readonly is None:
readonly = self.readonly

if isinstance(source, DataPack):
if isinstance(source, (ResourcePack, DataPack)):
result = source

if match is None:
match = self.match

packs = [source]
if source.overlay_parent is None:
packs.extend(source.overlays.values())

for file_type in self.providers:
if not issubclass(file_type, TextFileBase):
continue
for pack in packs:
for key in pack[file_type].match(*match or ["*"]):
value = pack[file_type][key]
self.database[value] = CompilationUnit(
resource_location=key,
filename=(
os.path.relpath(value.source_path, self.directory)
if value.source_path
else None
),
pack=pack,
)
self.database.enqueue(value)
for provider in self.providers:
for file_instance, compilation_unit in provider(source, match):
self.database[file_instance] = compilation_unit
self.database.enqueue(file_instance)
else:
if isinstance(source, (list, str)):
source = Function(source)
Expand All @@ -477,31 +469,35 @@ def compile(
else:
result = Function()

self.database[result] = CompilationUnit(
ast=source if isinstance(source, AstRoot) else None,
compilation_unit = CompilationUnit(
resource_location=resource_location,
filename=str(filename) if filename else None,
pack=within,
)

if isinstance(source, AstRoot):
compilation_unit.ast = source

self.database[result] = compilation_unit
self.database.enqueue(result)

for step, function in self.database.process_queue():
compilation_unit = self.database[function]
for step, file_instance in self.database.process_queue():
compilation_unit = self.database[file_instance]
start_time = perf_counter_ns()

if step < 0:
if compilation_unit.ast:
self.database.enqueue(function, initial_step)
self.database.enqueue(file_instance, initial_step)
continue
try:
compilation_unit.source = function.text
compilation_unit.source = file_instance.text
compilation_unit.ast = self.parse(
function,
file_instance,
filename=compilation_unit.filename,
resource_location=compilation_unit.resource_location,
multiline=multiline,
)
self.database.enqueue(function, initial_step)
self.database.enqueue(file_instance, initial_step)
except DiagnosticError as exc:
compilation_unit.diagnostics.extend(exc.diagnostics)

Expand All @@ -513,27 +509,30 @@ def compile(
if not compilation_unit.diagnostics.error:
compilation_unit.ast = ast
self.database.enqueue(
key=function,
key=file_instance,
step=step + 1,
priority=compilation_unit.priority,
)

elif not readonly:
else:
if readonly:
continue
if not compilation_unit.ast:
continue
with self.serialize.use_diagnostics(compilation_unit.diagnostics):
function.text = self.serialize(compilation_unit.ast, **formatting)
file_instance.text = self.serialize(
compilation_unit.ast, **formatting
)

compilation_unit.perf[step] = (perf_counter_ns() - start_time) * 1e-06

sorted_functions = sorted(
sorted_source_files = sorted(
self.database.session,
key=lambda f: self.database[f].resource_location or "<unknown>",
)

if self.perf_report is not None:
for function in sorted_functions:
compilation_unit = self.database[function]
for file_instance in sorted_source_files:
compilation_unit = self.database[file_instance]
self.perf_report.append(
(
compilation_unit.filename or "",
Expand All @@ -549,8 +548,8 @@ def compile(
diagnostics = DiagnosticCollection(
[
exc
for function in sorted_functions
for exc in self.database[function].diagnostics.exceptions
for file_instance in sorted_source_files
for exc in self.database[file_instance].diagnostics.exceptions
]
)

Expand Down
84 changes: 78 additions & 6 deletions mecha/database.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
__all__ = [
"CompilationDatabase",
"CompilationUnit",
"CompilationUnitProvider",
"FileTypeCompilationUnitProvider",
]


import os
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from heapq import heappop, heappush
from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple

from beet import Container, DataPack, Generator, TextFile, TextFileBase
from beet.core.utils import extra_field
from typing import (
Any,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
Optional,
Protocol,
Set,
Tuple,
Type,
Union,
)

from beet import (
Container,
DataPack,
Generator,
NamespaceFile,
ResourcePack,
TextFile,
TextFileBase,
)
from beet.core.utils import FileSystemPath, extra_field

from .ast import AstRoot
from .diagnostic import DiagnosticCollection
Expand All @@ -25,7 +49,7 @@ class CompilationUnit:
source: Optional[str] = None
filename: Optional[str] = None
resource_location: Optional[str] = None
pack: Optional[DataPack] = None
pack: Optional[Union[ResourcePack, DataPack]] = None
priority: int = 0

diagnostics: DiagnosticCollection = extra_field(init=False)
Expand All @@ -42,7 +66,9 @@ def __post_init__(self):
class CompilationDatabase(Container[TextFileBase[Any], CompilationUnit]):
"""Compilation database."""

indices: DefaultDict[Optional[DataPack], Dict[str, TextFileBase[Any]]]
indices: DefaultDict[
Optional[Union[ResourcePack, DataPack]], Dict[str, TextFileBase[Any]]
]
session: Set[TextFileBase[Any]]
queue: List[Tuple[int, int, str, int, TextFileBase[Any]]]
step: int
Expand Down Expand Up @@ -127,3 +153,49 @@ def process_context(self, file_instance: TextFileBase[Any]):
yield
return
yield


class CompilationUnitProvider(Protocol):
"""Provide source files for compilation."""

def __call__(
self,
pack: Union[ResourcePack, DataPack],
match: Optional[List[str]] = None,
) -> Iterable[Tuple[TextFileBase[Any], CompilationUnit]]:
...


@dataclass
class FileTypeCompilationUnitProvider:
"""Provide source files based on their type."""

file_types: List[Type[NamespaceFile]]
directory: FileSystemPath

def __call__(
self,
pack: Union[ResourcePack, DataPack],
match: Optional[List[str]] = None,
) -> Iterable[Tuple[TextFileBase[Any], CompilationUnit]]:
packs = [pack]
if pack.overlay_parent is None:
packs.extend(pack.overlays.values())

for file_type in self.file_types:
if not issubclass(file_type, TextFileBase):
continue

for pack in packs:
for resource_location in pack[file_type].match(*match or ["*"]):
file_instance = pack[file_type][resource_location]

yield file_instance, CompilationUnit(
resource_location=resource_location,
filename=(
os.path.relpath(file_instance.source_path, self.directory)
if file_instance.source_path
else None
),
pack=pack,
)

0 comments on commit 2738c9b

Please sign in to comment.