diff --git a/sqlalchemy_utils/primitives/country.py b/sqlalchemy_utils/primitives/country.py index 72265425..3a27cd5b 100644 --- a/sqlalchemy_utils/primitives/country.py +++ b/sqlalchemy_utils/primitives/country.py @@ -1,9 +1,12 @@ +from functools import total_ordering + import six from .. import i18n from ..utils import str_coercible +@total_ordering @str_coercible class Country(object): """ @@ -95,6 +98,13 @@ def __hash__(self): def __ne__(self, other): return not (self == other) + def __lt__(self, other): + if isinstance(other, Country): + return self.code < other.code + elif isinstance(other, six.string_types): + return self.code < other + return NotImplemented + def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.code) diff --git a/tests/primitives/test_country.py b/tests/primitives/test_country.py index 24b9df34..7ff0a666 100644 --- a/tests/primitives/test_country.py +++ b/tests/primitives/test_country.py @@ -1,3 +1,5 @@ +import operator + import pytest import six @@ -56,6 +58,33 @@ def test_non_equality_operator(self): assert Country(u'FI') != u'sv' assert not (Country(u'FI') != u'FI') + @pytest.mark.parametrize( + 'op, code_left, code_right, is_', + [ + (operator.lt, u'ES', u'FI', True), + (operator.lt, u'FI', u'ES', False), + (operator.lt, u'ES', u'ES', False), + + (operator.le, u'ES', u'FI', True), + (operator.le, u'FI', u'ES', False), + (operator.le, u'ES', u'ES', True), + + (operator.ge, u'ES', u'FI', False), + (operator.ge, u'FI', u'ES', True), + (operator.ge, u'ES', u'ES', True), + + (operator.gt, u'ES', u'FI', False), + (operator.gt, u'FI', u'ES', True), + (operator.gt, u'ES', u'ES', False), + ] + ) + def test_ordering(self, op, code_left, code_right, is_): + country_left = Country(code_left) + country_right = Country(code_right) + assert op(country_left, country_right) is is_ + assert op(country_left, code_right) is is_ + assert op(code_left, country_right) is is_ + def test_hash(self): return hash(Country('FI')) == hash('FI')