Skip to content

Commit

Permalink
feat: refactor beet exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Apr 26, 2022
1 parent cfcba0f commit 65745c3
Show file tree
Hide file tree
Showing 16 changed files with 265 additions and 119 deletions.
6 changes: 4 additions & 2 deletions beet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
__version__ = "0.59.0"


from .core.cache import *
from .core.container import *
from .core.error import *
from .core.file import *
from .core.watch import *
from .library.base import *
Expand All @@ -14,5 +18,3 @@
from .toolchain.template import *
from .toolchain.tree import *
from .toolchain.worker import *

__version__ = "0.59.0"
24 changes: 24 additions & 0 deletions beet/core/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
__all__ = [
"BeetException",
"BubbleException",
"WrappedException",
]


class BeetException(Exception):
"""Base class for beet exceptions."""


class BubbleException(BeetException):
"""Exceptions inheriting from this class will bubble up through exception wrappers."""


class WrappedException(BubbleException):
"""Raised to wrap an underlying exception."""

__cause__: Exception
hide_wrapped_exception: bool

def __init__(self, *args: object) -> None:
super().__init__(*args)
self.hide_wrapped_exception = False
97 changes: 90 additions & 7 deletions beet/core/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"YamlFileBase",
"YamlFile",
"PngFile",
"SerializationError",
"DeserializationError",
"InvalidDataModel",
]


Expand All @@ -27,7 +30,9 @@
from zipfile import ZipFile

import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError

from .error import BubbleException, WrappedException

try:
from PIL.Image import Image
Expand All @@ -43,7 +48,14 @@ def open_image(*args: Any, **kwargs: Any) -> Any:
raise RuntimeError("Please install Pillow to edit images programmatically")


from .utils import FileSystemPath, JsonDict, dump_json, extra_field
from .utils import (
FileSystemPath,
JsonDict,
dump_json,
extra_field,
format_validation_error,
snake_case,
)

ValueType = TypeVar("ValueType", bound=Any)
SerializeType = TypeVar("SerializeType", bound=Any)
Expand Down Expand Up @@ -278,6 +290,37 @@ def __set__(self, obj: File[ValueType, Any], value: ValueType):
obj.set_content(value)


class SerializationError(WrappedException):
"""Raised when serialization fails."""

file: File[Any, Any]

def __init__(self, file: File[Any, Any]):
super().__init__(file)
self.file = file

def __str__(self) -> str:
if self.file.original.source_path:
return f'Couldn\'t serialize "{self.file.original.source_path}".'
return f"Couldn't serialize file of type {type(self.file)}."


class DeserializationError(WrappedException):
"""Raised when deserialization fails."""

file: File[Any, Any]
message: str

def __init__(self, file: File[Any, Any]):
super().__init__(file)
self.file = file

def __str__(self) -> str:
if self.file.original.source_path:
return f'Couldn\'t deserialize "{self.file.original.source_path}".'
return f"Couldn't deserialize file of type {type(self.file)}."


class TextFileBase(File[ValueType, str]):
"""Base class for files that get serialized to strings."""

Expand All @@ -289,10 +332,20 @@ def __post_init__(self):
self.deserializer = self.from_str

def serialize(self, content: Union[ValueType, str]) -> str:
return content if isinstance(content, str) else self.serializer(content)
try:
return content if isinstance(content, str) else self.serializer(content)
except BubbleException:
raise
except Exception as exc:
raise SerializationError(self) from exc

def deserialize(self, content: Union[ValueType, str]) -> ValueType:
return self.deserializer(content) if isinstance(content, str) else content
try:
return self.deserializer(content) if isinstance(content, str) else content
except BubbleException:
raise
except Exception as exc:
raise DeserializationError(self) from exc

@classmethod
def from_zip(cls, origin: ZipFile, name: str) -> str:
Expand Down Expand Up @@ -339,10 +392,20 @@ def __post_init__(self):
self.deserializer = self.from_bytes

def serialize(self, content: Union[ValueType, bytes]) -> bytes:
return content if isinstance(content, bytes) else self.serializer(content)
try:
return content if isinstance(content, bytes) else self.serializer(content)
except BubbleException:
raise
except Exception as exc:
raise SerializationError(self) from exc

def deserialize(self, content: Union[ValueType, bytes]) -> ValueType:
return self.deserializer(content) if isinstance(content, bytes) else content
try:
return self.deserializer(content) if isinstance(content, bytes) else content
except BubbleException:
raise
except Exception as exc:
raise DeserializationError(self) from exc

@classmethod
def from_zip(cls, origin: ZipFile, name: str) -> bytes:
Expand Down Expand Up @@ -378,6 +441,22 @@ def default(cls) -> bytes:
return b""


class InvalidDataModel(DeserializationError):
"""Raised when data model deserialization fails."""

explanation: str

def __init__(self, file: File[Any, Any], explanation: str):
super().__init__(file)
self.explanation = explanation
self.hide_wrapped_exception = True

def __str__(self) -> str:
if self.file.original.source_path:
return f'Validation error for "{self.file.original.source_path}".\n\n{self.explanation}'
return f"Validation error for file of type {type(self.file)}.\n\n{self.explanation}"


@dataclass(eq=False, repr=False)
class DataModelBase(TextFileBase[ValueType]):
"""Base class for data models."""
Expand All @@ -403,7 +482,11 @@ def to_str(self, content: ValueType) -> str:
def from_str(self, content: str) -> ValueType:
value = self.decoder(content)
if self.model and issubclass(self.model, BaseModel):
value = self.model.parse_obj(value)
try:
value = self.model.parse_obj(value)
except ValidationError as exc:
message = format_validation_error(snake_case(self.model.__name__), exc)
raise InvalidDataModel(self, message) from exc
return value # type: ignore

@classmethod
Expand Down
40 changes: 39 additions & 1 deletion beet/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"local_import_path",
"log_time",
"remove_path",
"format_obj",
"format_exc",
"format_validation_error",
]


Expand All @@ -29,6 +32,7 @@
from importlib import import_module
from importlib.util import find_spec
from pathlib import Path
from traceback import format_exception
from typing import (
Any,
Dict,
Expand All @@ -42,7 +46,7 @@
runtime_checkable,
)

from pydantic import PydanticTypeError
from pydantic import PydanticTypeError, ValidationError
from pydantic.validators import _VALIDATORS # type: ignore

T = TypeVar("T")
Expand Down Expand Up @@ -209,6 +213,40 @@ def remove_path(*paths: FileSystemPath):
path.unlink(missing_ok=True)


def format_exc(exc: BaseException) -> str:
return "".join(format_exception(exc.__class__, exc, exc.__traceback__))


def format_obj(obj: Any) -> str:
module = getattr(obj, "__module__", None)
name = getattr(obj, "__qualname__", getattr(obj, "__name__", None))
return f'"{module}.{name}"' if module and name else repr(obj)


def format_validation_error(prefix: str, exc: ValidationError) -> str:
errors = [
(
prefix
+ "".join(
json.dumps([item]) for item in error["loc"] if item != "__root__"
),
error["msg"]
if error["msg"][0].isupper()
else error["msg"][0].capitalize() + error["msg"][1:],
)
for error in exc.errors()
]
width = max(len(loc) for loc, _ in errors) + 1
return "\n".join(
"{loc:<{width}} => {msg}".format(
loc=loc,
width=width,
msg=msg + "." * (not msg.endswith(".")),
)
for loc, msg in errors
)


class PathObjectError(PydanticTypeError):
msg_template = "value is not a valid path object"

Expand Down
12 changes: 8 additions & 4 deletions beet/toolchain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from click_help_colors import HelpColorsCommand, HelpColorsGroup

from beet import __version__
from beet.core.error import BeetException, WrappedException
from beet.core.utils import format_exc

from .pipeline import FormattedPipelineException
from .project import Project
from .utils import format_exc


def format_error(
Expand All @@ -49,8 +49,12 @@ def error_handler(should_exit: bool = False, format_padding: int = 0) -> Iterato

try:
yield
except FormattedPipelineException as exc:
message, exception = exc.message, exc.__cause__ if exc.format_cause else None
except WrappedException as exc:
message = str(exc)
if not exc.hide_wrapped_exception:
exception = exc.__cause__
except BeetException as exc:
message = str(exc)
except (click.Abort, KeyboardInterrupt):
click.echo()
message = "Aborted."
Expand Down
18 changes: 12 additions & 6 deletions beet/toolchain/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@
from pydantic import BaseModel, ValidationError, validator
from pydantic.generics import GenericModel

from beet.core.error import BubbleException
from beet.core.utils import (
FileSystemPath,
JsonDict,
TextComponent,
format_validation_error,
local_import_path,
resolve_packageable_path,
)

from .pipeline import FormattedPipelineException
from .utils import apply_option, eval_option, format_validation_error
from .utils import apply_option, eval_option

DETECT_CONFIG_FILES: Tuple[str, ...] = (
"beet.json",
Expand All @@ -54,12 +55,17 @@
)


class InvalidProjectConfig(FormattedPipelineException):
class InvalidProjectConfig(BubbleException):
"""Raised when trying to load an invalid project config."""

def __init__(self, *args: Any):
super().__init__(*args)
self.message = f"Couldn't load project config.\n\n{self}"
explanation: str

def __init__(self, explanation: str):
super().__init__(explanation)
self.explanation = explanation

def __str__(self) -> str:
return f"Couldn't load project config.\n\n{self.explanation}"


ItemType = TypeVar("ItemType")
Expand Down

0 comments on commit 65745c3

Please sign in to comment.