Skip to content

Commit

Permalink
Fix mapping of types in get_key() to work w/py3k.
Browse files Browse the repository at this point in the history
Fixes #120
  • Loading branch information
coleifer committed Apr 3, 2020
1 parent b75faa1 commit d51e735
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
10 changes: 6 additions & 4 deletions walrus/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def __init__(self, *args, **kwargs):
script_dir = kwargs.pop('script_dir', None)
super(Database, self).__init__(*args, **kwargs)
self.__mapping = {
'list': self.List,
'set': self.Set,
'zset': self.ZSet,
'hash': self.Hash}
b'list': self.List,
b'set': self.Set,
b'zset': self.ZSet,
b'hash': self.Hash}
self._transaction_local = TransactionLocal()
self._transaction_lock = threading.RLock()
self.init_scripts(script_dir=script_dir)
Expand Down Expand Up @@ -193,6 +193,8 @@ def get_key(self, key):
a hash key is requested, then a :py:class:`Hash` will be
returned.
Note: only works for Hash, List, Set and ZSet.
:param str key: Key to retrieve.
:returns: A hash, set, list, zset or array.
"""
Expand Down
30 changes: 30 additions & 0 deletions walrus/tests/database.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,38 @@
from walrus.containers import *
from walrus.tests.base import WalrusTestCase
from walrus.tests.base import db


class TestWalrus(WalrusTestCase):
def test_get_key(self):
h = db.Hash('h1')
h['hk1'] = 'v1'

l = db.List('l1')
l.append('i1')

s = db.Set('s1')
s.add('k1')

zs = db.ZSet('z1')
zs.add({'i1': 1., 'i2': 2.})

h_db = db.get_key('h1')
self.assertTrue(isinstance(h_db, Hash))
self.assertEqual(h_db['hk1'], b'v1')

l_db = db.get_key('l1')
self.assertTrue(isinstance(l_db, List))
self.assertEqual(l_db[0], b'i1')

s_db = db.get_key('s1')
self.assertTrue(isinstance(s_db, Set))
self.assertEqual(s_db.members(), set((b'k1',)))

z_db = db.get_key('z1')
self.assertTrue(isinstance(z_db, ZSet))
self.assertEqual(z_db.score('i1'), 1.)

def test_atomic(self):
def assertDepth(depth):
self.assertEqual(len(db._transaction_local.pipes), depth)
Expand Down

0 comments on commit d51e735

Please sign in to comment.