Skip to content

Commit

Permalink
Enum tests (#332)
Browse files Browse the repository at this point in the history
- Remove unreachable and duplicate code
   - Fixed some logic
  • Loading branch information
davfsa committed Oct 30, 2020
1 parent 915dd5c commit 0a41cb6
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 41 deletions.
7 changes: 2 additions & 5 deletions coverage.ini
Expand Up @@ -7,14 +7,12 @@ source = hikari
omit =
hikari/__main__.py
hikari/cli.py
hikari/_about.py
hikari/utilities/art.py
.nox/*

[report]
precision = 2
show_missing = true
skip_covered = False
skip_covered = false
sort = cover
exclude_lines =
\#\s*pragma: no cover
^\s*raise AssertionError\b
Expand All @@ -26,4 +24,3 @@ exclude_lines =
^\s*\.\.\.$
^\s*@abc.abstractmethod$
^\s*if typing.TYPE_CHECKING:$
sort = Cover
48 changes: 13 additions & 35 deletions hikari/internal/enums.py
Expand Up @@ -44,13 +44,6 @@ def __init__(self, base: typing.Type[typing.Any]) -> None:
self.values_to_names: typing.Dict[str, typing.Any] = {}
self["__doc__"] = "An enumeration."

def __contains__(self, item: typing.Any) -> bool:
try:
_ = self[item]
return True
except KeyError:
return False

def __getitem__(self, name: str) -> typing.Any:
try:
return super().__getitem__(name)
Expand All @@ -60,10 +53,6 @@ def __getitem__(self, name: str) -> typing.Any:
except KeyError:
raise KeyError(name) from None

def __iter__(self) -> typing.Iterator[str]:
yield from super().__iter__()
yield self.names_to_values

def __setitem__(self, name: str, value: typing.Any) -> None:
if name == "" or name == "mro":
raise TypeError(f"Invalid enum member name: {name!r}")
Expand Down Expand Up @@ -97,8 +86,6 @@ def __setitem__(self, name: str, value: typing.Any) -> None:
# We must have defined some alias, so just register the name
self.names_to_values[name] = value
return
if not isinstance(value, self.base):
raise TypeError("Enum values must be an instance of the base type of the enum")

self.names_to_values[name] = value
self.values_to_names[value] = name
Expand All @@ -122,8 +109,11 @@ def __call__(cls, value: typing.Any) -> typing.Any:
def __getitem__(cls, name: str) -> typing.Any:
return cls._name_to_member_map_[name]

def __iter__(cls) -> typing.Iterator[str]:
yield from cls._name_to_member_map_
def __contains__(cls, item: typing.Any) -> bool:
return item in cls._value_to_member_map_

def __iter__(cls) -> typing.Iterator[typing.Any]:
yield from cls._name_to_member_map_.values()

@staticmethod
def __new__(
Expand Down Expand Up @@ -183,25 +173,18 @@ def __prepare__(
if _Enum is NotImplemented:
if name != "Enum":
raise TypeError("First instance of _EnumMeta must be Enum")
return {}
return _EnumNamespace(object)

try:
# Fails if Enum is not defined. We check this in `__new__` properly.
base, enum_type = bases

if isinstance(base, _EnumMeta):
raise TypeError("First base to an enum must be the type to combine with, not _EnumMeta")
if not isinstance(enum_type, _EnumMeta):
raise TypeError("Second base to an enum must be the enum type (derived from _EnumMeta) to be used")

if not issubclass(enum_type, _Enum):
raise TypeError("second base type for enum must be derived from Enum")

return _EnumNamespace(base)
except ValueError:
if name == "Enum" and _Enum is NotImplemented:
return _EnumNamespace(object)
raise TypeError("Expected two base classes for an enum") from None
raise TypeError("Expected exactly two base classes for an enum") from None

def __repr__(cls) -> str:
return f"<enum {cls.__name__}>"
Expand Down Expand Up @@ -365,7 +348,7 @@ def __call__(cls, value: typing.Any = 0) -> typing.Any:
def __getitem__(cls, name: str) -> typing.Any:
return cls._name_to_member_map_[name]

def __iter__(cls) -> typing.Iterator[str]:
def __iter__(cls) -> typing.Iterator[typing.Any]:
yield from cls._name_to_member_map_.values()

@classmethod
Expand All @@ -377,13 +360,10 @@ def __prepare__(
raise TypeError("First instance of _FlagMeta must be Flag")
return _EnumNamespace(object)

try:
# Fails if Enum is not defined.
if len(bases) == 1 and bases[0] == Flag:
return _EnumNamespace(int)
except ValueError:
pass
raise TypeError("Cannot define another Flag base type") from None
# Fails if Flag is not defined.
if len(bases) == 1 and bases[0] == Flag:
return _EnumNamespace(int)
raise TypeError("Cannot define another Flag base type")

@staticmethod
def __new__(
Expand Down Expand Up @@ -761,9 +741,7 @@ def __rsub__(self: _T, other: typing.Union[int, _T]) -> _T:
# This logic has to be reversed to be correct, since order matters for
# a subtraction operator. This also ensures `int - _T -> _T` is a valid
# case for us.
if not isinstance(other, self.__class__):
other = self.__class__(other)
return other - self
return self.__class__(other) - self

def __str__(self) -> str:
return self.name
Expand Down
120 changes: 119 additions & 1 deletion tests/hikari/internal/test_enums.py
Expand Up @@ -18,6 +18,8 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import builtins

import mock
import pytest

Expand Down Expand Up @@ -74,6 +76,18 @@ def test_init_enum_type_with_bases_in_wrong_order_is_TypeError(self, args, kwarg
class Enum(*args, **kwargs):
pass

def test_init_with_more_than_2_types(self):
with pytest.raises(TypeError):

class Enum(enums.Enum, str, int):
pass

def test_init_with_less_than_2_types(self):
with pytest.raises(TypeError):

class Enum(enums.Enum):
pass

def test_init_enum_type_default_docstring_set(self):
class Enum(str, enums.Enum):
pass
Expand Down Expand Up @@ -160,6 +174,26 @@ def p(self):

assert Enum.__members__ == {"foo": 9, "bar": 18, "baz": 27}

def test_init_with_invalid_name(self):
with pytest.raises(TypeError):

class Enum(int, enums.Enum):
mro = 420

def test_init_with_unhashable_value(self):
with mock.patch.object(builtins, "hash", side_effect=TypeError):
with pytest.raises(TypeError):

class Enum(int, enums.Enum):
test = 420

def test_init_with_duplicate(self):
with pytest.raises(TypeError):

class Enum(int, enums.Enum):
test = 123
test = 321

def test_call_when_member(self):
class Enum(int, enums.Enum):
foo = 9
Expand Down Expand Up @@ -190,6 +224,44 @@ class Enum(int, enums.Enum):
assert returned == Enum.foo
assert type(returned) == Enum

def test_contains(self):
class Enum(int, enums.Enum):
foo = 9
bar = 18
baz = 27

assert 9 in Enum
assert 100 not in Enum

def test_name(self):
class Enum(int, enums.Enum):
foo = 9
bar = 18
baz = 27

assert Enum.foo.name == "foo"

def test_iter(self):
class Enum(int, enums.Enum):
foo = 9
bar = 18
baz = 27

a = []
for i in Enum:
a.append(i)

assert a == [Enum.foo, Enum.bar, Enum.baz]

def test_repr(self):
class Enum(int, enums.Enum):
foo = 9
bar = 18
baz = 27

assert repr(Enum) == "<enum Enum>"
assert repr(Enum.foo) == "<Enum.foo: 9>"


class TestIntFlag:
@mock.patch.object(enums, "_Flag", new=NotImplemented)
Expand Down Expand Up @@ -260,7 +332,7 @@ class Flag(enums.Flag):
def foo(self):
return "foo"

assert Flag.foo(12) == "foo"
assert Flag().foo() == "foo"

def test_init_flag_type_allows_classmethods(self):
class Flag(enums.Flag):
Expand Down Expand Up @@ -461,6 +533,29 @@ class Flag(enums.Flag):
assert Flag._temp_members_ == {3: Flag.foo | Flag.bar, 7: Flag.foo | Flag.bar | Flag.baz}
assert Flag._temp_members_ == {3: Flag.foo | Flag.bar, 7: Flag.foo | Flag.bar | Flag.baz}

def test_cache_when_temp_values_over_MAX_CACHED_MEMBERS(self):
class MockDict:
def __getitem__(self, key):
raise KeyError

def __len__(self):
return enums._MAX_CACHED_MEMBERS + 1

def __setitem__(self, k, v):
pass

popitem = mock.Mock()

class Flag(enums.Flag):
foo = 1
bar = 2
baz = 3

Flag._temp_members_ = MockDict()

Flag(4)
Flag._temp_members_.popitem.assert_called_once_with()

def test_bitwise_name(self):
class Flag(enums.Flag):
foo = 1
Expand Down Expand Up @@ -1088,3 +1183,26 @@ class TestFlag(enums.Flag):
assert isinstance(0x1C ^ a, TestFlag)
assert 0x1C ^ a == TestFlag.FOO | TestFlag.BAR | TestFlag.BORK | TestFlag.QUX
assert 0x1C ^ a == 0x1B

def test_getitem(self):
class TestFlag(enums.Flag):
FOO = 0x1
BAR = 0x2
BAZ = 0x4
BORK = 0x8
QUX = 0x10

returned = TestFlag["FOO"]
assert returned == TestFlag.FOO
assert type(returned) == TestFlag

def test_repr(self):
class TestFlag(enums.Flag):
FOO = 0x1
BAR = 0x2
BAZ = 0x4
BORK = 0x8
QUX = 0x10

assert repr(TestFlag) == "<enum TestFlag>"
assert repr(TestFlag.FOO) == "<TestFlag.FOO: 1>"

0 comments on commit 0a41cb6

Please sign in to comment.