Skip to content

Commit

Permalink
feat: add merge policy
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Sep 1, 2021
1 parent f505b8b commit ff4e104
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 13 deletions.
9 changes: 4 additions & 5 deletions beet/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ def merge(self, other: Mapping[Any, SupportsMerge]) -> bool:
"""Merge values from the given dict-like object."""
for key, value in other.items():
try:
if self[key].merge(value): # type: ignore
continue
except KeyError:
pass
self[key] = value # type: ignore
if key not in self or not self[key].merge(value): # type: ignore
self[key] = value # type: ignore
except Drop:
del self[key] # type: ignore
return True


Expand Down
152 changes: 145 additions & 7 deletions beet/library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@
"NamespacePin",
"NamespaceProxy",
"NamespaceProxyDescriptor",
"MergeCallback",
"MergePolicy",
]


import shutil
from collections import defaultdict
from contextlib import nullcontext
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from itertools import count
from pathlib import Path, PurePosixPath
from typing import (
Any,
Callable,
ClassVar,
DefaultDict,
Dict,
Expand Down Expand Up @@ -71,6 +74,81 @@
PackFile = File[Any, Any]


@dataclass(eq=False)
class NamespaceFile(PackFile):
"""Base class for files that belong in pack namespaces."""

scope: ClassVar[Tuple[str, ...]]
extension: ClassVar[str]


class MergeCallback(Protocol):
"""Protocol for detecting merge callbacks."""

def __call__(self, pack: Any, path: str, current: Any, conflict: Any) -> bool:
...


@dataclass
class MergePolicy:
"""Class holding lists of rules for merging files."""

extra: Dict[str, List[MergeCallback]] = field(default_factory=dict)
namespace: Dict[Type[NamespaceFile], List[MergeCallback]] = field(
default_factory=dict
)
namespace_extra: Dict[str, List[MergeCallback]] = field(default_factory=dict)

def extend(self, other: "MergePolicy"):
for rules, other_rules in [
(self.extra, other.extra),
(self.namespace, other.namespace),
(self.namespace_extra, other.namespace_extra),
]:
for key, value in other_rules.items():
rules.setdefault(key, []).extend(value) # type: ignore

def extend_extra(self, filename: str, rule: MergeCallback):
"""Add rule for merging extra files."""
self.extra.setdefault(filename, []).append(rule)

def extend_namespace(self, file_type: Type[NamespaceFile], rule: MergeCallback):
"""Add rule for merging namespace files."""
self.namespace.setdefault(file_type, []).append(rule)

def extend_namespace_extra(self, filename: str, rule: MergeCallback):
"""Add rule for merging namespace extra files."""
self.namespace_extra.setdefault(filename, []).append(rule)

def merge_with_rules(
self,
pack: Any,
current: MutableMapping[Any, SupportsMerge],
other: Mapping[Any, SupportsMerge],
map_rules: Callable[[str], Tuple[str, List[MergeCallback]]],
) -> bool:
"""Merge values according to the given rules."""
for key, value in other.items():
if key not in current:
current[key] = value
continue

current_value = current[key]
path, rules = map_rules(key)

try:
for rule in rules:
if rule(pack, path, current_value, value):
break
else:
if not current_value.merge(value):
current[key] = value
except Drop:
del current[key]

return True


class ExtraContainer(MatchMixin, MergeMixin, Container[str, PackFile]):
"""Container that stores extra files in a pack or a namespace."""

Expand Down Expand Up @@ -112,6 +190,26 @@ def bind(self, namespace: NamespaceType):
except Drop:
del self[key]

def merge(self, other: Mapping[Any, SupportsMerge]) -> bool:
if (
self.namespace is not None
and self.namespace.pack is not None
and self.namespace.name
):
pack = self.namespace.pack
name = self.namespace.name

return pack.merge_policy.merge_with_rules(
pack=pack,
current=self,
other=other,
map_rules=lambda key: (
f"{name}:{key}",
pack.merge_policy.namespace_extra.get(key, []),
),
)
return super().merge(other)


class PackExtraContainer(ExtraContainer, Generic[PackType]):
"""Pack extra container."""
Expand All @@ -133,13 +231,20 @@ def bind(self, pack: PackType):
except Drop:
del self[key]

def merge(self, other: Mapping[Any, SupportsMerge]) -> bool:
if self.pack is not None:
pack = self.pack

@dataclass(eq=False)
class NamespaceFile(PackFile):
"""Base class for files that belong in pack namespaces."""

scope: ClassVar[Tuple[str, ...]]
extension: ClassVar[str]
return pack.merge_policy.merge_with_rules(
pack=pack,
current=self,
other=other,
map_rules=lambda key: (
key,
pack.merge_policy.extra.get(key, []),
),
)
return super().merge(other)


class NamespaceContainer(MatchMixin, MergeMixin, Container[str, NamespaceFileType]):
Expand Down Expand Up @@ -168,6 +273,28 @@ def bind(self, namespace: "Namespace", file_type: Type[NamespaceFileType]):
except Drop:
del self[key]

def merge(self, other: Mapping[Any, SupportsMerge]) -> bool:
if (
self.namespace is not None
and self.namespace.pack is not None
and self.namespace.name
and self.file_type is not None
):
pack = self.namespace.pack
name = self.namespace.name
file_type = self.file_type

return pack.merge_policy.merge_with_rules(
pack=pack,
current=self, # type: ignore
other=other,
map_rules=lambda key: (
f"{name}:{key}",
pack.merge_policy.namespace.get(file_type, []),
),
)
return super().merge(other)

def generate_tree(self, path: str = "") -> Dict[Any, Any]:
"""Generate a hierarchy of nested dictionaries representing the files and folders."""
prefix = path.split("/") if path else []
Expand Down Expand Up @@ -484,6 +611,8 @@ class Pack(MatchMixin, MergeMixin, Container[str, NamespaceType]):
extend_namespace: List[Type[NamespaceFile]]
extend_namespace_extra: Dict[str, Type[PackFile]]

merge_policy: MergePolicy

namespace_type: ClassVar[Type[NamespaceType]]
default_name: ClassVar[str]
latest_pack_format: ClassVar[int]
Expand All @@ -504,6 +633,7 @@ def __init__(
extend_extra: Optional[Mapping[str, Type[PackFile]]] = None,
extend_namespace: Iterable[Type[NamespaceFile]] = (),
extend_namespace_extra: Optional[Mapping[str, Type[PackFile]]] = None,
merge_policy: Optional[MergePolicy] = None,
):
super().__init__()
self.name = name
Expand All @@ -526,6 +656,10 @@ def __init__(
self.extend_namespace = list(extend_namespace)
self.extend_namespace_extra = dict(extend_namespace_extra or {})

self.merge_policy = MergePolicy()
if merge_policy:
self.merge_policy.extend(merge_policy)

self.load(path or zipfile)

@overload
Expand Down Expand Up @@ -661,12 +795,16 @@ def load(
extend_extra: Optional[Mapping[str, Type[PackFile]]] = None,
extend_namespace: Iterable[Type[NamespaceFile]] = (),
extend_namespace_extra: Optional[Mapping[str, Type[PackFile]]] = None,
merge_policy: Optional[MergePolicy] = None,
):
"""Load pack from a zipfile or from the filesystem."""
self.extend_extra.update(extend_extra or {})
self.extend_namespace.extend(extend_namespace)
self.extend_namespace_extra.update(extend_namespace_extra or {})

if merge_policy:
self.merge_policy.extend(merge_policy)

if origin:
if not isinstance(origin, ZipFile):
origin = Path(origin).resolve()
Expand Down
59 changes: 58 additions & 1 deletion tests/test_data_pack.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from beet import BlockTag, DataPack, Drop, Function, FunctionTag, JsonFile, Structure
from beet import (
BlockTag,
DataPack,
Drop,
Function,
FunctionTag,
JsonFile,
MergePolicy,
Structure,
)


def test_equality():
Expand Down Expand Up @@ -364,3 +374,50 @@ def __call__(self, instance: Function, pack: DataPack, path: str):
"hello": {Function: {"other": Function(["say hello"])}},
"minecraft": {FunctionTag: {"load": FunctionTag({"values": ["hello:other"]})}},
}


def test_merge_rules():
def combine_description(
pack: DataPack,
path: str,
current: JsonFile,
conflict: JsonFile,
) -> bool:
current.data["pack"]["description"] += conflict.data["pack"]["description"]
return True

pack = DataPack(description="hello")
pack.merge_policy.extend_extra("pack.mcmeta", combine_description)

pack.merge(DataPack(description="world"))

assert pack.description == "helloworld"


def test_merge_nuke():
def nuke(*args: Any) -> bool:
raise Drop()

p1 = DataPack(
description="hello",
merge_policy=MergePolicy(extra={"pack.mcmeta": [nuke]}),
)

p1.merge_policy.extend(MergePolicy(namespace={Function: [nuke]}))
p1.merge_policy.extend_namespace_extra("foo.json", nuke)

p1["demo:foo"] = Function()
p1["demo"].extra["foo.json"] = JsonFile()
p1["thing"].extra["foo.json"] = JsonFile()

p2 = DataPack(description="world")
p2["demo:foo"] = Function()
p2["demo:bar"] = Function()
p2["demo"].extra["foo.json"] = JsonFile()

p1.merge(p2)

assert p1.description == ""
assert list(p1.functions) == ["demo:bar"]
assert list(p1["demo"].extra) == []
assert p1["thing"].extra["foo.json"] == JsonFile()

0 comments on commit ff4e104

Please sign in to comment.