Skip to content

Commit

Permalink
Support string annotations
Browse files Browse the repository at this point in the history
Fixes #45

Signed-off-by: David Euresti <david@pilot.com>
  • Loading branch information
euresti authored and gaborbernat committed Jul 1, 2020
1 parent 2a93324 commit 23b6e05
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 1 deletion.
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys


def pytest_ignore_collect(path):
return sys.version_info[0] <= 2 and str(path).endswith("__py3.py")
42 changes: 41 additions & 1 deletion src/attrs_strict/_type_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import collections
import typing

import attr

from ._commons import is_newtype
from ._error import (
AttributeTypeError,
Expand Down Expand Up @@ -28,6 +30,15 @@
# silencing type error so mypy doesn't complain about duplicate import
from itertools import izip_longest as zip_longest # type: ignore

try:
from typing import ForwardRef # type: ignore # Not in stubs
except ImportError:
from typing import _ForwardRef as ForwardRef # type: ignore # Not in stubs


class _StringAnnotationError(Exception):
"""Raised when we find string annotations in a class."""


class SimilarTypes:
Dict = {
Expand All @@ -46,6 +57,26 @@ class SimilarTypes:
Callable = {typing.Callable, Callable}


def resolve_types(cls, global_ns=None, local_ns=None):
"""
Resolve any strings and forward annotations in type annotations.
:param type cls: Class to resolve.
:param globalns: Dictionary containing global variables, if needed.
:param localns: Dictionary containing local variables, if needed.
:raise TypeError: If *cls* is not a class.
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
class.
:raise NameError: If types cannot be resolved because of missing variables.
"""
hints = typing.get_type_hints(cls, globalns=global_ns, localns=local_ns)
for field in attr.fields(cls):
if field.name in hints:
# Since fields have been frozen we must work around it.
object.__setattr__(field, "type", hints[field.name])


def type_validator(empty_ok=True):
"""
Validates the attributes using the type argument specified. If the
Expand All @@ -60,7 +91,11 @@ def _validator(instance, attribute, field):
if not empty_ok and not field:
raise EmptyError(field, attribute)

_validate_elements(attribute, field, attribute.type)
try:
_validate_elements(attribute, field, attribute.type)
except _StringAnnotationError:
resolve_types(type(instance))
_validate_elements(attribute, field, attribute.type)

return _validator

Expand All @@ -74,6 +109,11 @@ def _validate_elements(attribute, value, expected_type):
if base_type == typing.Any:
return

if isinstance(base_type, (str, ForwardRef)):
# These base_types happen when you have string annotations and cannot
# be used in isinstance.
raise _StringAnnotationError()

if base_type != typing.Union and not isinstance( # type: ignore
value, base_type
):
Expand Down
5 changes: 5 additions & 0 deletions src/attrs_strict/_type_validation.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import typing
import attr

def resolve_types(
cls: type,
global_ns: typing.Optional[typing.Dict[str, typing.Any]] = None,
local_ns: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> None: ...
def type_validator(
empty_ok: bool = True,
) -> typing.Callable[
Expand Down
54 changes: 54 additions & 0 deletions tests/test_auto_attribs__py3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional

import attr
import pytest

from attrs_strict import type_validator


@pytest.mark.parametrize(
"type_, good_value, bad_value",
[(str, "str", 0xBAD), (int, 7, "bad"), (Optional[str], None, 0xBAD)],
)
def test_real_types(type_, good_value, bad_value):
@attr.s(auto_attribs=True)
class Something:
value: type_ = attr.ib(validator=type_validator())

with pytest.raises(ValueError):
Something(bad_value)

x = Something(good_value)
with pytest.raises(ValueError):
x.value = bad_value
attr.validate(x)


@attr.s(auto_attribs=True)
class Child:
parent: "Parent" = attr.ib(validator=type_validator())


@attr.s(auto_attribs=True)
class Parent:
pass


def test_forward_ref():
Child(Parent())

with pytest.raises(ValueError):
Child(15)


@attr.s(auto_attribs=True)
class Self:
parent: Optional["Self"] = attr.ib(None, validator=type_validator())


def test_recursive():
Self(Self())
Self(Self(None))

with pytest.raises(ValueError):
Self(Self(17))
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ commands = pytest \
--cov-config "{toxinidir}/tox.ini" \
--junitxml {toxworkdir}/junit.{envname}.xml \
{posargs:tests}
rsyncdirs = conftest.py

[testenv:type]
description = try to merge our types against our source
Expand Down

0 comments on commit 23b6e05

Please sign in to comment.