175 changes: 94 additions & 81 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
Any,
FrozenDictOf,
Function,
NoMatch,
Option,
Pattern,
TupleOf,
Validator,
)
from ibis.common.typing import get_type_hints

Expand All @@ -22,64 +23,73 @@
VAR_POSITIONAL = inspect.Parameter.VAR_POSITIONAL


class ValidationError(Exception):
...


class Annotation:
"""Base class for all annotations.
Annotations are used to mark fields in a class and to validate them.
Parameters
----------
validator : Validator, default noop
Validator to validate the field.
pattern : Pattern, default noop
Pattern to validate the field.
default : Any, default EMPTY
Default value of the field.
typehint : type, default EMPTY
Type of the field, not used for validation.
"""

__slots__ = ("_validator", "_default", "_typehint")
__slots__ = ("_pattern", "_default", "_typehint")

def __init__(self, validator=None, default=EMPTY, typehint=EMPTY):
if validator is None or isinstance(validator, Validator):
def __init__(self, pattern=None, default=EMPTY, typehint=EMPTY):
if pattern is None or isinstance(pattern, Pattern):
pass
elif callable(validator):
validator = Function(validator)
elif callable(pattern):
pattern = Function(pattern)
else:
raise TypeError(f"Unsupported validator {validator!r}")
raise TypeError(f"Unsupported pattern {pattern!r}")
self._pattern = pattern
self._default = default
self._typehint = typehint
self._validator = validator

def __eq__(self, other):
return (
type(self) is type(other)
and self._pattern == other._pattern
and self._default == other._default
and self._typehint == other._typehint
and self._validator == other._validator
)

def __repr__(self):
return (
f"{self.__class__.__name__}(validator={self._validator!r}, "
f"{self.__class__.__name__}(pattern={self._pattern!r}, "
f"default={self._default!r}, typehint={self._typehint!r})"
)

def validate(self, arg, context=None):
if self._validator is None:
if self._pattern is None:
return arg
return self._validator.validate(arg, context)

result = self._pattern.match(arg, context)
if result is NoMatch:
raise ValidationError(f"{arg!r} doesn't match {self._pattern!r}")

return result


class Attribute(Annotation):
"""Annotation to mark a field in a class.
An optional validator can be provider to validate the field every time it
An optional pattern can be provider to validate the field every time it
is set.
Parameters
----------
validator : Validator, default noop
Validator to validate the field.
pattern : Pattern, default noop
Pattern to validate the field.
default : Callable, default EMPTY
Callable to compute the default value of the field.
"""
Expand All @@ -105,8 +115,8 @@ class Argument(Annotation):
Parameters
----------
validator
Optional validator to validate the argument.
pattern
Optional pattern to validate the argument.
default
Optional default value of the argument.
typehint
Expand All @@ -120,47 +130,47 @@ class Argument(Annotation):

def __init__(
self,
validator: Validator | None = None,
pattern: Pattern | None = None,
default: AnyType = EMPTY,
typehint: type | None = None,
kind: int = POSITIONAL_OR_KEYWORD,
):
super().__init__(validator, default, typehint)
super().__init__(pattern, default, typehint)
self._kind = kind

@classmethod
def required(cls, validator=None, **kwargs):
def required(cls, pattern=None, **kwargs):
"""Annotation to mark a mandatory argument."""
return cls(validator, **kwargs)
return cls(pattern, **kwargs)

@classmethod
def default(cls, default, validator=None, **kwargs):
def default(cls, default, pattern=None, **kwargs):
"""Annotation to allow missing arguments with a default value."""
return cls(validator, default, **kwargs)
return cls(pattern, default, **kwargs)

@classmethod
def optional(cls, validator=None, default=None, **kwargs):
def optional(cls, pattern=None, default=None, **kwargs):
"""Annotation to allow and treat `None` values as missing arguments."""
if validator is None:
validator = Option(Any(), default=default)
if pattern is None:
pattern = Option(Any(), default=default)
else:
validator = Option(validator, default=default)
return cls(validator, default=None, **kwargs)
pattern = Option(pattern, default=default)
return cls(pattern, default=None, **kwargs)

@classmethod
def varargs(cls, validator=None, **kwargs):
def varargs(cls, pattern=None, **kwargs):
"""Annotation to mark a variable length positional argument."""
validator = None if validator is None else TupleOf(validator)
return cls(validator, kind=VAR_POSITIONAL, **kwargs)
pattern = None if pattern is None else TupleOf(pattern)
return cls(pattern, kind=VAR_POSITIONAL, **kwargs)

@classmethod
def varkwargs(cls, validator=None, **kwargs):
validator = None if validator is None else FrozenDictOf(Any(), validator)
return cls(validator, kind=VAR_KEYWORD, **kwargs)
def varkwargs(cls, pattern=None, **kwargs):
pattern = None if pattern is None else FrozenDictOf(Any(), pattern)
return cls(pattern, kind=VAR_KEYWORD, **kwargs)


class Parameter(inspect.Parameter):
"""Augmented Parameter class to additionally hold a validator object."""
"""Augmented Parameter class to additionally hold a pattern object."""

__slots__ = ()

Expand Down Expand Up @@ -241,18 +251,18 @@ def merge(cls, *signatures, **annotations):
)

@classmethod
def from_callable(cls, fn, validators=None, return_validator=None):
def from_callable(cls, fn, patterns=None, return_pattern=None):
"""Create a validateable signature from a callable.
Parameters
----------
fn : Callable
Callable to create a signature from.
validators : list or dict, default None
Pass validators to add missing or override existing argument type
patterns : list or dict, default None
Pass patterns to add missing or override existing argument type
annotations.
return_validator : Validator, default None
Validator for the return value of the callable.
return_pattern : Pattern, default None
Pattern for the return value of the callable.
Returns
-------
Expand All @@ -261,15 +271,13 @@ def from_callable(cls, fn, validators=None, return_validator=None):
sig = super().from_callable(fn)
typehints = get_type_hints(fn)

if validators is None:
validators = {}
elif isinstance(validators, (list, tuple)):
# create a mapping of parameter name to validator
validators = dict(zip(sig.parameters.keys(), validators))
elif not isinstance(validators, dict):
raise TypeError(
f"validators must be a list or dict, got {type(validators)}"
)
if patterns is None:
patterns = {}
elif isinstance(patterns, (list, tuple)):
# create a mapping of parameter name to pattern
patterns = dict(zip(sig.parameters.keys(), patterns))
elif not isinstance(patterns, dict):
raise TypeError(f"patterns must be a list or dict, got {type(patterns)}")

parameters = []
for param in sig.parameters.values():
Expand All @@ -278,36 +286,36 @@ def from_callable(cls, fn, validators=None, return_validator=None):
default = param.default
typehint = typehints.get(name)

if name in validators:
validator = validators[name]
if name in patterns:
pattern = patterns[name]
elif typehint is not None:
validator = Validator.from_typehint(typehint)
pattern = Pattern.from_typehint(typehint)
else:
validator = None
pattern = None

if kind is VAR_POSITIONAL:
annot = Argument.varargs(validator, typehint=typehint)
annot = Argument.varargs(pattern, typehint=typehint)
elif kind is VAR_KEYWORD:
annot = Argument.varkwargs(validator, typehint=typehint)
annot = Argument.varkwargs(pattern, typehint=typehint)
elif default is EMPTY:
annot = Argument.required(validator, kind=kind, typehint=typehint)
annot = Argument.required(pattern, kind=kind, typehint=typehint)
else:
annot = Argument.default(
default, validator, kind=param.kind, typehint=typehint
default, pattern, kind=param.kind, typehint=typehint
)

parameters.append(Parameter(param.name, annot))

if return_validator is not None:
return_annotation = return_validator
if return_pattern is not None:
return_annotation = return_pattern
elif (typehint := typehints.get("return")) is not None:
return_annotation = Validator.from_typehint(typehint)
return_annotation = Pattern.from_typehint(typehint)
else:
return_annotation = EMPTY

return cls(parameters, return_annotation=return_annotation)

def unbind(self, this: AnyType):
def unbind(self, this: dict[str, Any]) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""Reverse bind of the parameters.
Attempts to reconstructs the original arguments as keyword only arguments.
Expand Down Expand Up @@ -355,7 +363,7 @@ def validate(self, *args, **kwargs):
validated : dict
Dictionary of validated arguments.
"""
# bind the signature to the passed arguments and apply the validators
# bind the signature to the passed arguments and apply the patterns
# before passing the arguments, so self.__init__() receives already
# validated arguments as keywords
bound = self.bind(*args, **kwargs)
Expand Down Expand Up @@ -396,7 +404,12 @@ def validate_return(self, value, context):
"""
if self.return_annotation is EMPTY:
return value
return self.return_annotation.validate(value, context)

result = self.return_annotation.match(value, context)
if result is NoMatch:
raise ValidationError(f"{value!r} doesn't match {self}")

return result


# aliases for convenience
Expand All @@ -409,7 +422,7 @@ def validate_return(self, value, context):
varkwargs = Argument.varkwargs


# TODO(kszucs): try to cache validator objects
# TODO(kszucs): try to cache pattern objects
# TODO(kszucs): try a quicker curry implementation


Expand All @@ -424,20 +437,20 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
... def foo(x: int, y: str) -> float:
... return float(x) + float(y)
2. With argument validators passed as keyword arguments
2. With argument patterns passed as keyword arguments
>>> from ibis.common.patterns import InstanceOf as instance_of
>>> @annotated(x=instance_of(int), y=instance_of(str))
... def foo(x, y):
... return float(x) + float(y)
3. With mixing type annotations and validators where the latter takes precedence
3. With mixing type annotations and patterns where the latter takes precedence
>>> @annotated(x=instance_of(float))
... def foo(x: int, y: str) -> float:
... return float(x) + float(y)
4. With argument validators passed as a list and/or an optional return validator
4. With argument patterns passed as a list and/or an optional return pattern
>>> @annotated([instance_of(int), instance_of(str)], instance_of(float))
... def foo(x, y):
Expand All @@ -447,18 +460,18 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
----------
*args : Union[
tuple[Callable],
tuple[list[Validator], Callable],
tuple[list[Validator], Validator, Callable]
tuple[list[Pattern], Callable],
tuple[list[Pattern], Pattern, Callable]
]
Positional arguments.
- If a single callable is passed, it's wrapped with the signature
- If two arguments are passed, the first one is a list of validators for the
- If two arguments are passed, the first one is a list of patterns for the
arguments and the second one is the callable to wrap
- If three arguments are passed, the first one is a list of validators for the
arguments, the second one is a validator for the return value and the third
- If three arguments are passed, the first one is a list of patterns for the
arguments, the second one is a pattern for the return value and the third
one is the callable to wrap
**kwargs : dict[str, Validator]
Validators for the arguments.
**kwargs : dict[str, Pattern]
Patterns for the arguments.
Returns
-------
Expand All @@ -468,19 +481,19 @@ def annotated(_1=None, _2=None, _3=None, **kwargs):
return functools.partial(annotated, **kwargs)
elif _2 is None:
if callable(_1):
func, validators, return_validator = _1, None, None
func, patterns, return_pattern = _1, None, None
else:
return functools.partial(annotated, _1, **kwargs)
elif _3 is None:
if not isinstance(_2, Validator):
func, validators, return_validator = _2, _1, None
if not isinstance(_2, Pattern):
func, patterns, return_pattern = _2, _1, None
else:
return functools.partial(annotated, _1, _2, **kwargs)
else:
func, validators, return_validator = _3, _1, _2
func, patterns, return_pattern = _3, _1, _2

sig = Signature.from_callable(
func, validators=validators or kwargs, return_validator=return_validator
func, patterns=patterns or kwargs, return_pattern=return_pattern
)

@functools.wraps(func)
Expand Down
12 changes: 6 additions & 6 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from ibis.common.caching import WeakCache
from ibis.common.collections import FrozenDict
from ibis.common.patterns import Validator
from ibis.common.patterns import Pattern
from ibis.common.typing import evaluate_annotations


Expand Down Expand Up @@ -60,7 +60,7 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
with contextlib.suppress(AttributeError):
signatures.append(parent.__signature__)

# collection type annotations and convert them to validators
# collection type annotations and convert them to patterns
module = dct.get("__module__")
qualname = dct.get("__qualname__") or clsname
annotations = dct.get("__annotations__", {})
Expand All @@ -70,17 +70,17 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
for name, typehint in typehints.items():
if get_origin(typehint) is ClassVar:
continue
validator = Validator.from_typehint(typehint)
pattern = Pattern.from_typehint(typehint)
if name in dct:
dct[name] = Argument.default(dct[name], validator, typehint=typehint)
dct[name] = Argument.default(dct[name], pattern, typehint=typehint)
else:
dct[name] = Argument.required(validator, typehint=typehint)
dct[name] = Argument.required(pattern, typehint=typehint)

# collect the newly defined annotations
slots = list(dct.pop("__slots__", []))
namespace, arguments = {}, {}
for name, attrib in dct.items():
if isinstance(attrib, Validator):
if isinstance(attrib, Pattern):
attrib = Argument.required(attrib)

if isinstance(attrib, Argument):
Expand Down
48 changes: 9 additions & 39 deletions ibis/common/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ class CoercionError(Exception):
...


class ValidationError(Exception):
...


class MatchError(Exception):
...

Expand All @@ -63,7 +59,14 @@ def __coerce__(cls, value: Any, **kwargs: Any) -> Self:
...


class Validator(ABC):
class NoMatch(metaclass=Sentinel):
"""Marker to indicate that a pattern didn't match."""


# TODO(kszucs): have an As[int] or Coerced[int] type in ibis.common.typing which
# would be used to annotate an argument as coercible to int or to a certain type
# without needing for the type to inherit from Coercible
class Pattern(Hashable):
__slots__ = ()

@classmethod
Expand Down Expand Up @@ -194,15 +197,6 @@ def from_typehint(cls, annot: type, allow_coercion: bool = True) -> Pattern:
f"Cannot create validator from annotation {annot!r} {origin!r}"
)


class NoMatch(metaclass=Sentinel):
"""Marker to indicate that a pattern didn't match."""


# TODO(kszucs): have an As[int] or Coerced[int] type in ibis.common.typing which
# would be used to annotate an argument as coercible to int or to a certain type
# without needing for the type to inherit from Coercible
class Pattern(Validator, Hashable):
@abstractmethod
def match(self, value: AnyType, context: dict[str, AnyType]) -> AnyType:
"""Match a value against the pattern.
Expand Down Expand Up @@ -241,30 +235,6 @@ def __rshift__(self, name: str) -> Pattern:
def __rmatmul__(self, name: str) -> Pattern:
return Capture(self, name)

def validate(
self, value: AnyType, context: Optional[dict[str, AnyType]] = None
) -> Any:
"""Validate a value against the pattern.
If the pattern doesn't match the value, then it raises a `ValidationError`.
Parameters
----------
value
The value to match the pattern against.
context
A dictionary providing arbitrary context for the pattern matching.
Returns
-------
match
The matched / validated value.
"""
result = self.match(value, context=context)
if result is NoMatch:
raise ValidationError(f"{value!r} doesn't match {self}")
return result


class Matcher(Pattern):
"""A lightweight alternative to `ibis.common.grounds.Concrete`.
Expand Down Expand Up @@ -293,7 +263,7 @@ def __hash__(self) -> int:
return self.__precomputed_hash__

def __setattr__(self, name, value) -> None:
raise AttributeError("Can't set attributes on immutable ENode instance")
raise AttributeError("Can't set attributes on immutable instance")

def __repr__(self):
fields = {k: getattr(self, k) for k in self.__slots__}
Expand Down
24 changes: 15 additions & 9 deletions ibis/common/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@
import pytest
from typing_extensions import Annotated # noqa: TCH002

from ibis.common.annotations import Argument, Attribute, Parameter, Signature, annotated
from ibis.common.annotations import (
Argument,
Attribute,
Parameter,
Signature,
ValidationError,
annotated,
)
from ibis.common.patterns import (
Any,
CoercedTo,
InstanceOf,
NoMatch,
Option,
TupleOf,
ValidationError,
pattern,
)

Expand All @@ -24,13 +30,13 @@
def test_argument_repr():
argument = Argument(is_int, typehint=int, default=None)
assert repr(argument) == (
"Argument(validator=InstanceOf(type=<class 'int'>), default=None, "
"Argument(pattern=InstanceOf(type=<class 'int'>), default=None, "
"typehint=<class 'int'>)"
)


def test_default_argument():
annotation = Argument.default(validator=lambda x, context: int(x), default=3)
annotation = Argument.default(pattern=lambda x, context: int(x), default=3)
assert annotation.validate(1) == 1
with pytest.raises(TypeError):
annotation.validate(None)
Expand Down Expand Up @@ -82,7 +88,7 @@ class Foo:

assert field.initialize(Foo) == 20

field2 = Attribute(validator=lambda x, this: str(x), default=lambda self: self.a)
field2 = Attribute(pattern=lambda x, this: str(x), default=lambda self: self.a)
assert field != field2
assert field2.initialize(Foo) == "10"

Expand All @@ -103,7 +109,7 @@ def fn(x, this):

ofn = Argument.optional(fn)
op = Parameter("test", annotation=ofn)
assert op.annotation._validator == Option(fn, default=None)
assert op.annotation._pattern == Option(fn, default=None)
assert op.default is None
assert op.annotation.validate(None, {"other": 1}) is None

Expand Down Expand Up @@ -218,12 +224,12 @@ def add_other(x, this):

a = Parameter("a", annotation=Argument.required(CoercedTo(float)))
b = Parameter("b", annotation=Argument.required(CoercedTo(float)))
c = Parameter("c", annotation=Argument.default(default=0, validator=CoercedTo(float)))
c = Parameter("c", annotation=Argument.default(default=0, pattern=CoercedTo(float)))
d = Parameter(
"d",
annotation=Argument.default(default=tuple(), validator=TupleOf(CoercedTo(float))),
annotation=Argument.default(default=tuple(), pattern=TupleOf(CoercedTo(float))),
)
e = Parameter("e", annotation=Argument.optional(validator=CoercedTo(float)))
e = Parameter("e", annotation=Argument.optional(pattern=CoercedTo(float)))
sig = Signature(parameters=[a, b, c, d, e])


Expand Down
6 changes: 3 additions & 3 deletions ibis/common/tests/test_grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ibis.common.annotations import (
Parameter,
Signature,
ValidationError,
argument,
attribute,
optional,
Expand All @@ -36,7 +37,6 @@
Option,
Pattern,
TupleOf,
ValidationError,
)
from ibis.tests.util import assert_pickle_roundtrip

Expand Down Expand Up @@ -333,7 +333,7 @@ def test_annotable_with_recursive_generic_type_annotations():
# testing cons list
pattern = Pattern.from_typehint(List[Integer])
values = ["1", 2.0, 3]
result = pattern.validate(values, {})
result = pattern.match(values, {})
expected = ConsList(1, ConsList(2, ConsList(3, EmptyList())))
assert result == expected
assert result[0] == 1
Expand All @@ -346,7 +346,7 @@ def test_annotable_with_recursive_generic_type_annotations():
# testing cons map
pattern = Pattern.from_typehint(Map[Integer, Float])
values = {"1": 2, 3: "4.0", 5: 6.0}
result = pattern.validate(values, {})
result = pattern.match(values, {})
expected = ConsMap((1, 2.0), ConsMap((3, 4.0), ConsMap((5, 6.0), EmptyMap())))
assert result == expected
assert result[1] == 2.0
Expand Down
5 changes: 1 addition & 4 deletions ibis/common/tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from typing_extensions import Annotated

from ibis.common.annotations import ValidationError
from ibis.common.collections import FrozenDict
from ibis.common.graph import Node
from ibis.common.patterns import (
Expand Down Expand Up @@ -63,7 +64,6 @@
Topmost,
TupleOf,
TypeOf,
ValidationError,
match,
pattern,
)
Expand Down Expand Up @@ -638,7 +638,6 @@ def test_matching_mapping():
)
def test_various_patterns(pattern, value, expected):
assert pattern.match(value, context={}) == expected
assert pattern.validate(value, context={}) == expected


@pytest.mark.parametrize(
Expand All @@ -663,8 +662,6 @@ def test_various_patterns(pattern, value, expected):
)
def test_various_not_matching_patterns(pattern, value):
assert pattern.match(value, context={}) is NoMatch
with pytest.raises(ValidationError):
pattern.validate(value, context={})


@pattern
Expand Down
10 changes: 5 additions & 5 deletions ibis/common/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_interval_units(singular, plural, short):
def test_interval_unit_coercions(singular, plural, short):
u = IntervalUnit[singular.upper()]
v = CoercedTo(IntervalUnit)
assert v.validate(singular, {}) == u
assert v.validate(plural, {}) == u
assert v.validate(short, {}) == u
assert v.match(singular, {}) == u
assert v.match(plural, {}) == u
assert v.match(short, {}) == u


@pytest.mark.parametrize(
Expand All @@ -70,7 +70,7 @@ def test_interval_unit_coercions(singular, plural, short):
)
def test_interval_unit_aliases(alias, expected):
v = CoercedTo(IntervalUnit)
assert v.validate(alias, {}) == IntervalUnit(expected)
assert v.match(alias, {}) == IntervalUnit(expected)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_normalize_timedelta_invalid(value, unit):
def test_interval_unit_compatibility():
v = CoercedTo(IntervalUnit)
for unit in itertools.chain(DateUnit, TimeUnit):
interval = v.validate(unit, {})
interval = v.match(unit, {})
assert isinstance(interval, IntervalUnit)
assert unit.value == interval.value

Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import ibis.expr.operations.relations as rels
import ibis.expr.types as ir
from ibis import util
from ibis.common.annotations import ValidationError
from ibis.common.exceptions import IbisTypeError, IntegrityError
from ibis.common.patterns import ValidationError

# ---------------------------------------------------------------------
# Some expression metaprogramming / graph transformations to support
Expand Down
3 changes: 2 additions & 1 deletion ibis/expr/datatypes/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from typing_extensions import Annotated

import ibis.expr.datatypes as dt
from ibis.common.patterns import As, Attrs, NoMatch, Pattern, ValidationError
from ibis.common.annotations import ValidationError
from ibis.common.patterns import As, Attrs, NoMatch, Pattern
from ibis.common.temporal import TimestampUnit, TimeUnit


Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/datatypes/tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import ibis.expr.datatypes as dt
from ibis.common.patterns import ValidationError
from ibis.common.annotations import ValidationError


@pytest.mark.parametrize(
Expand Down
3 changes: 1 addition & 2 deletions ibis/expr/operations/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
from ibis.common.annotations import attribute
from ibis.common.patterns import ValidationError
from ibis.common.annotations import ValidationError, attribute
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Column, Value

Expand Down
3 changes: 1 addition & 2 deletions ibis/expr/operations/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.annotations import ValidationError, attribute
from ibis.common.exceptions import IbisTypeError
from ibis.common.patterns import ValidationError
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Binary, Column, Unary, Value
from ibis.expr.operations.generic import _Negatable
Expand Down
41 changes: 17 additions & 24 deletions ibis/expr/operations/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.patterns import (
CoercedTo,
GenericCoercedTo,
Pattern,
ValidationError,
)
from ibis.common.patterns import CoercedTo, GenericCoercedTo, NoMatch, Pattern


@pytest.mark.parametrize(
Expand All @@ -32,58 +27,56 @@ def test_literal_coercion_type_inference(value, dtype):
def test_coerced_to_literal():
p = CoercedTo(ops.Literal)
one = ops.Literal(1, dt.int8)
assert p.validate(ops.Literal(1, dt.int8), {}) == one
assert p.validate(1, {}) == one
assert p.validate(False, {}) == ops.Literal(False, dt.boolean)
assert p.match(ops.Literal(1, dt.int8), {}) == one
assert p.match(1, {}) == one
assert p.match(False, {}) == ops.Literal(False, dt.boolean)

p = GenericCoercedTo(ops.Literal[dt.Int8])
assert p.validate(ops.Literal(1, dt.int8), {}) == one
assert p.match(ops.Literal(1, dt.int8), {}) == one

p = Pattern.from_typehint(ops.Literal[dt.Int8])
assert p == GenericCoercedTo(ops.Literal[dt.Int8])

one = ops.Literal(1, dt.int16)
with pytest.raises(ValidationError):
p.validate(one, {})
assert p.match(one, {}) is NoMatch


def test_coerced_to_value():
one = ops.Literal(1, dt.int8)

p = Pattern.from_typehint(ops.Value)
assert p.validate(1, {}) == one
assert p.match(1, {}) == one

p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Any])
assert p.validate(1, {}) == one
assert p.match(1, {}) == one

p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Scalar])
assert p.validate(1, {}) == one
assert p.match(1, {}) == one

p = Pattern.from_typehint(ops.Value[dt.Int8, ds.Columnar])
with pytest.raises(ValidationError):
p.validate(1, {})
assert p.match(1, {}) is NoMatch

# dt.Integer is not instantiable so it will be only used for checking
# that the produced literal has any integer datatype
p = Pattern.from_typehint(ops.Value[dt.Integer, ds.Any])
assert p.validate(1, {}) == one
assert p.match(1, {}) == one

# same applies here, the coercion itself will use only the inferred datatype
# but then the result is checked against the given typehint
p = Pattern.from_typehint(ops.Value[dt.Int8 | dt.Int16, ds.Any])
assert p.validate(1, {}) == one
assert p.validate(128, {}) == ops.Literal(128, dt.int16)
assert p.match(1, {}) == one
assert p.match(128, {}) == ops.Literal(128, dt.int16)

p1 = Pattern.from_typehint(ops.Value[dt.Int8, ds.Any])
p2 = Pattern.from_typehint(ops.Value[dt.Int16, ds.Scalar])
assert p1.validate(1, {}) == one
assert p1.match(1, {}) == one
# this is actually supported by creating an explicit dtype
# in Value.__coerce__ based on the `T` keyword argument
assert p2.validate(1, {}) == ops.Literal(1, dt.int16)
assert p2.validate(128, {}) == ops.Literal(128, dt.int16)
assert p2.match(1, {}) == ops.Literal(1, dt.int16)
assert p2.match(128, {}) == ops.Literal(128, dt.int16)

p = p1 | p2
assert p.validate(1, {}) == one
assert p.match(1, {}) == one


@pytest.mark.pandas
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/operations/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def make_node(

arg = rlz.ValueOf(dt.dtype(raw_dtype))
if (default := param.default) is EMPTY:
fields[name] = Argument.required(validator=arg)
fields[name] = Argument.required(pattern=arg)
else:
fields[name] = Argument.default(validator=arg, default=default)
fields[name] = Argument.default(pattern=arg, default=default)

fields["dtype"] = dt.dtype(return_annotation)

Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class ObjectWithSchema(Annotable):

def test_schema_is_coercible():
s = sch.Schema({"a": dt.int64, "b": dt.Array(dt.int64)})
assert CoercedTo(sch.Schema).validate(PreferenceA, {}) == s
assert CoercedTo(sch.Schema).match(PreferenceA, {}) == s

o = ObjectWithSchema(schema=PreferenceA)
assert o.schema == s
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ibis.common.grounds import Immutable
from ibis.config import _default_backend, options
from ibis.util import experimental
from ibis.common.patterns import ValidationError, Coercible, CoercionError
from ibis.common.annotations import ValidationError
from rich.jupyter import JupyterMixin
from ibis.common.patterns import Coercible, CoercionError

Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import ibis
import ibis.expr.types as ir
from ibis.common.patterns import ValidationError
from ibis.common.annotations import ValidationError
from ibis.tests.expr.mocks import MockBackend
from ibis.tests.util import assert_equal

Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ibis
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis.common.patterns import ValidationError
from ibis.common.annotations import ValidationError
from ibis.expr import api


Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
import ibis.expr.types as ir
from ibis.common.patterns import ValidationError
from ibis.common.annotations import ValidationError

t = ibis.table([("a", "int64")], name="t")

Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import ibis.selectors as s
from ibis import _
from ibis import literal as L
from ibis.common.annotations import ValidationError
from ibis.common.exceptions import RelationError
from ibis.common.patterns import ValidationError
from ibis.expr import api
from ibis.expr.types import Column, Table
from ibis.tests.expr.mocks import MockAlchemyBackend, MockBackend
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.common.patterns import ValidationError
from ibis.common.annotations import ValidationError


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import _, literal
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict
from ibis.common.exceptions import IbisTypeError
from ibis.common.patterns import ValidationError
from ibis.expr import api
from ibis.tests.util import assert_equal

Expand Down
15 changes: 7 additions & 8 deletions ibis/tests/expr/test_window_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import ValidationError
from ibis.common.exceptions import IbisInputError, IbisTypeError
from ibis.common.patterns import Pattern, ValidationError
from ibis.common.patterns import NoMatch, Pattern


def test_window_boundary():
Expand All @@ -35,21 +36,19 @@ def test_window_boundary_typevars():

p = Pattern.from_typehint(ops.WindowBoundary[dt.Integer, ds.Any])
b = ops.WindowBoundary(5, preceding=False)
assert p.validate(b, {}) == b
with pytest.raises(ValidationError):
p.validate(ops.WindowBoundary(5.0, preceding=False), {})
with pytest.raises(ValidationError):
p.validate(ops.WindowBoundary(lit, preceding=True), {})
assert p.match(b, {}) == b
assert p.match(ops.WindowBoundary(5.0, preceding=False), {}) is NoMatch
assert p.match(ops.WindowBoundary(lit, preceding=True), {}) is NoMatch

p = Pattern.from_typehint(ops.WindowBoundary[dt.Interval, ds.Any])
b = ops.WindowBoundary(lit, preceding=True)
assert p.validate(b, {}) == b
assert p.match(b, {}) == b


def test_window_boundary_coercions():
RowsWindowBoundary = ops.WindowBoundary[dt.Integer, ds.Any]
p = Pattern.from_typehint(RowsWindowBoundary)
assert p.validate(1, {}) == RowsWindowBoundary(ops.Literal(1, dtype=dt.int8), False)
assert p.match(1, {}) == RowsWindowBoundary(ops.Literal(1, dtype=dt.int8), False)


def test_window_builder_rows():
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from ibis.common.patterns import ValidationError
from ibis.common.annotations import ValidationError
from ibis.config import options


Expand Down