Skip to content

Commit

Permalink
feat: add source range
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Mar 1, 2022
1 parent 80095c4 commit 321b410
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
52 changes: 37 additions & 15 deletions beet/core/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ class File(Generic[ValueType, SerializeType]):
content: Union[ValueType, SerializeType, None] = None
source_path: Optional[FileSystemPath] = None

source_start: Optional[int] = extra_field(default=None)
source_stop: Optional[int] = extra_field(default=None)

on_bind: Optional[Callable[[Any, Any, str], Any]] = extra_field(default=None)

serializer: Callable[[ValueType], SerializeType] = extra_field(init=False)
deserializer: Callable[[SerializeType], ValueType] = extra_field(init=False)
reader: Callable[[FileSystemPath, int, int], SerializeType] = extra_field(
init=False
)

def __post_init__(self):
if self.content is self.source_path is None:
self.content = self.default()
self.reader = self.from_path

def merge(self: FileType, other: FileType) -> bool:
"""Merge the given file or return False to indicate no special handling."""
Expand All @@ -81,11 +89,17 @@ def set_content(self, content: Union[ValueType, SerializeType]):
"""Update the internal content."""
self.content = content
self.source_path = None
self.source_start = None
self.source_stop = None

def get_content(self) -> Union[ValueType, SerializeType]:
"""Return the internal content."""
return (
self.decode(Path(self.ensure_source_path()).read_bytes())
self.reader(
self.ensure_source_path(),
0 if self.source_start is None else self.source_start,
-1 if self.source_stop is None else self.source_stop,
)
if self.content is None
else self.content
)
Expand Down Expand Up @@ -163,13 +177,13 @@ def deserialize(self, content: Union[ValueType, SerializeType]) -> ValueType:
raise NotImplementedError()

@classmethod
def decode(cls, raw: bytes) -> SerializeType:
"""Convert bytes to serialized representation."""
def from_path(cls, path: FileSystemPath, start: int, stop: int) -> SerializeType:
"""Read file content from path."""
raise NotImplementedError()

@classmethod
def encode(cls, raw: SerializeType) -> bytes:
"""Convert serialized representation to bytes."""
def from_zip(cls, origin: ZipFile, name: str) -> SerializeType:
"""Read file content from zip."""
raise NotImplementedError()

@classmethod
Expand All @@ -187,7 +201,7 @@ def try_load(
"""Try to load a file from a zipfile or from the filesystem."""
if isinstance(origin, ZipFile):
try:
return cls(cls.decode(origin.read(str(path))))
return cls(cls.from_zip(origin, str(path)))
except KeyError:
return None
path = Path(origin, path)
Expand All @@ -201,9 +215,11 @@ def dump(self, origin: FileOrigin, path: FileSystemPath):
else:
shutil.copyfile(self.ensure_source_path(), str(Path(origin, path)))
else:
raw = self.encode(self.ensure_serialized())
raw = self.ensure_serialized()
if isinstance(origin, ZipFile):
origin.writestr(str(path), raw)
elif isinstance(raw, str):
Path(origin, path).write_text(raw)
else:
Path(origin, path).write_bytes(raw)

Expand Down Expand Up @@ -253,12 +269,15 @@ def deserialize(self, content: Union[ValueType, str]) -> ValueType:
return self.deserializer(content) if isinstance(content, str) else content

@classmethod
def decode(cls, raw: bytes) -> str:
return raw.decode()
def from_zip(cls, origin: ZipFile, name: str) -> str:
return origin.read(name).decode()

@classmethod
def encode(cls, raw: str) -> bytes:
return raw.encode()
def from_path(cls, path: FileSystemPath, start: int, stop: int) -> str:
with open(path, "r") as f:
if start > 0:
f.seek(start)
return f.read(stop - start) if stop >= -1 else f.read()

@classmethod
def to_str(cls, content: ValueType) -> str:
Expand Down Expand Up @@ -304,12 +323,15 @@ def deserialize(self, content: Union[ValueType, bytes]) -> ValueType:
return self.deserializer(content) if isinstance(content, bytes) else content

@classmethod
def decode(cls, raw: bytes) -> bytes:
return raw
def from_zip(cls, origin: ZipFile, name: str) -> bytes:
return origin.read(name)

@classmethod
def encode(cls, raw: bytes) -> bytes:
return raw
def from_path(cls, path: FileSystemPath, start: int, stop: int) -> bytes:
with open(path, "rb") as f:
if start > 0:
f.seek(start)
return f.read() if stop == -1 else f.read(stop - start)

@classmethod
def to_bytes(cls, content: ValueType) -> bytes:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path

from beet import BinaryFile, TextFile


def test_text_range(tmp_path: Path):
p1 = tmp_path / "p1"
p1.write_text("abc")
assert TextFile(source_path=p1, source_start=1).text == "bc"
assert TextFile(source_path=p1, source_stop=2).text == "ab"
assert TextFile(source_path=p1, source_start=1, source_stop=2).text == "b"


def test_binary_range(tmp_path: Path):
p1 = tmp_path / "p1"
p1.write_bytes(b"abc")
assert BinaryFile(source_path=p1, source_start=1).blob == b"bc"
assert BinaryFile(source_path=p1, source_stop=2).blob == b"ab"
assert BinaryFile(source_path=p1, source_start=1, source_stop=2).blob == b"b"

0 comments on commit 321b410

Please sign in to comment.