diff --git a/pygtrie.py b/pygtrie.py index 38c89e7..0c41e05 100644 --- a/pygtrie.py +++ b/pygtrie.py @@ -56,6 +56,11 @@ def _sorted_iteritems(d): _iteritems = lambda d: iter(d.items()) # pylint: disable=invalid-name _iterkeys = lambda d: iter(d.keys()) # pylint: disable=invalid-name +try: + _basestring = basestring +except NameError: + _basestring = str + class ShortKeyError(KeyError): """Raised when given key is a prefix of a longer key.""" @@ -1220,7 +1225,12 @@ def __init__(self, *args, **kwargs): named argument is not specified on the function's prototype because of Python's limitations. """ - self._separator = kwargs.pop('separator', '/') + separator = kwargs.pop('separator', '/') + if not isinstance(separator, _basestring): + raise TypeError('separator must be a string') + if not separator: + raise ValueError('separator can not be empty') + self._separator = separator super(StringTrie, self).__init__(*args, **kwargs) @classmethod diff --git a/test.py b/test.py index 934bfa0..77ab0d2 100755 --- a/test.py +++ b/test.py @@ -428,6 +428,19 @@ def path_from_key(cls, key): def key_from_path(cls, path): return '/'.join(path) + def test_valid_separator(self): + t = pygtrie.StringTrie() + t['foo/bar'] = 42 + self.assertTrue(bool(t.has_node('foo') & pygtrie.Trie.HAS_SUBTRIE)) + + t = pygtrie.StringTrie(separator='.') + t['foo.bar'] = 42 + self.assertTrue(bool(t.has_node('foo') & pygtrie.Trie.HAS_SUBTRIE)) + + def test_invalid_separator(self): + self.assertRaises(TypeError, pygtrie.StringTrie, separator=42) + self.assertRaises(ValueError, pygtrie.StringTrie, separator='') + class SortTest(unittest.TestCase):