Skip to content

Commit

Permalink
Merge pull request pycassa#198 from david-huber/master
Browse files Browse the repository at this point in the history
Adding column_count and column_reversed support to ColumnFamilyStub.get and ColumnFamityStub.multiget
  • Loading branch information
thobbs committed May 3, 2013
2 parents f63a435 + 1327356 commit 0dbea64
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 23 deletions.
4 changes: 2 additions & 2 deletions doc/api/pycassa/contrib/stubs.rst
Expand Up @@ -5,9 +5,9 @@

.. autoclass:: pycassa.contrib.stubs.ColumnFamilyStub(pool=None, column_family=None, rows=None)

.. automethod:: get(key[, columns][, column_start][, column_finish][, include_timestamp])
.. automethod:: get(key[, columns][, column_start][, column_finish][, column_reversed][, column_count][, include_timestamp])

.. automethod:: multiget(keys[, columns][, column_start][, column_finish][, include_timestamp])
.. automethod:: multiget(keys[, columns][, column_start][, column_finish][, column_reversed][, column_count][, include_timestamp])

.. automethod:: get_range([columns][, include_timestamp])

Expand Down
66 changes: 46 additions & 20 deletions pycassa/contrib/stubs.py
Expand Up @@ -8,8 +8,7 @@

import operator

from functools import partial

from collections import MutableMapping
from pycassa import NotFoundException
from pycassa.util import OrderedDict
from pycassa.columnfamily import gm_timestamp
Expand All @@ -18,17 +17,30 @@

__all__ = ['ConnectionPoolStub', 'ColumnFamilyStub', 'SystemManagerStub']

class OrderedDictWithTime(OrderedDict):

class DictWithTime(MutableMapping):
def __init__(self, *args, **kwargs):
self.__timestamp = kwargs.pop('timestamp', None)
super(OrderedDictWithTime, self).__init__(*args, **kwargs)
self.store = dict()
self.update(dict(*args, **kwargs))

def __getitem__(self, key):
return self.store[key]

def __setitem__(self, key, value, timestamp=None):
if timestamp is None:
timestamp = self.__timestamp or gm_timestamp()

super(OrderedDictWithTime, self).__setitem__(key, (value, timestamp))
self.store[key] = (value, timestamp)

def __delitem__(self, key):
del self.store[key]

def __iter__(self):
return iter(self.store)

def __len__(self):
return len(self.store)

operator_dict = {
EQ: operator.eq,
Expand All @@ -38,7 +50,6 @@ def __setitem__(self, key, value, timestamp=None):
LTE: operator.le,
}


class ConnectionPoolStub(object):
"""Connection pool stub.
Expand Down Expand Up @@ -112,8 +123,8 @@ class ColumnFamilyStub(object):
def __init__(self, pool=None, column_family=None, rows=None, **kwargs):
rows = rows or OrderedDict()
for r in rows.itervalues():
if not isinstance(r, OrderedDictWithTime):
r = OrderedDictWithTime(r)
if not isinstance(r, DictWithTime):
r = DictWithTime(r)
self.rows = rows

if pool is not None:
Expand All @@ -125,7 +136,8 @@ def __len__(self):
def __contains__(self, obj):
return self.rows.__contains__(obj)

def get(self, key, columns=None, column_start=None, column_finish=None, include_timestamp=False, **kwargs):
def get(self, key, columns=None, column_start=None, column_finish=None,
column_reversed=False, column_count=100, include_timestamp=False, **kwargs):
"""Get a value from the column family stub."""

my_columns = self.rows.get(key)
Expand All @@ -136,26 +148,40 @@ def get(self, key, columns=None, column_start=None, column_finish=None, include_
if not my_columns:
raise NotFoundException()

return OrderedDict((k, get_value(v)) for (k, v)
in my_columns.iteritems()
if self._is_column_in_range(k, columns, column_start, column_finish))
items = my_columns.items()
items.sort()

if column_reversed:
items.reverse()

sliced_items = [(k, get_value(v)) for (k, v) in items
if self._is_column_in_range(k, columns,
column_start, column_finish, column_reversed)][:column_count]

return OrderedDict(sliced_items)

def _is_column_in_range(self, k, columns, column_start, column_finish, column_reversed):
lower_bound = column_start if not column_reversed else column_finish
upper_bound = column_finish if not column_reversed else column_start

def _is_column_in_range(self, k, columns, column_start, column_finish):
if columns:
return k in columns
return (not column_start or k >= column_start) and (not column_finish or k <= column_finish)
return (not lower_bound or k >= lower_bound) and (not upper_bound or k <= upper_bound)


def multiget(self, keys, columns=None, column_start=None, column_finish=None, include_timestamp=False, **kwargs):
def multiget(self, keys, columns=None, column_start=None, column_finish=None,
column_reversed=False, column_count=100, include_timestamp=False, **kwargs):
"""Get multiple key values from the column family stub."""

return OrderedDict(
(key, self.get(
key,
columns,
column_start,
column_finish,
include_timestamp,
columns=columns,
column_start=column_start,
column_finish=column_finish,
column_reversed=column_reversed,
column_count=column_count,
include_timestamp=include_timestamp,
)) for key in keys if key in self.rows)

def batch(self, **kwargs):
Expand All @@ -169,7 +195,7 @@ def insert(self, key, columns, timestamp=None, **kwargs):
"""Insert data to the column family stub."""

if key not in self.rows:
self.rows[key] = OrderedDictWithTime([], timestamp=timestamp)
self.rows[key] = DictWithTime([], timestamp=timestamp)

for column in columns:
self.rows[key].__setitem__(column, columns[column], timestamp)
Expand Down
84 changes: 83 additions & 1 deletion tests/contrib/stubs.py
Expand Up @@ -8,6 +8,8 @@
from pycassa.contrib.stubs import ColumnFamilyStub, ConnectionPoolStub, \
SystemManagerStub

from pycassa.util import OrderedDict

pool = cf = None
pool_stub = cf_stub = None

Expand Down Expand Up @@ -66,6 +68,15 @@ def test_insert_get(self):
assert_true(isinstance(ts, (int, long)))
assert_equal(test_cf.get(key), columns)

def test_insert_get_column_start_and_finish_reversed(self):
key = 'TestColumnFamily.test_insert_get_reversed'
columns = {'1': 'val1', '2': 'val2'}
for test_cf in (cf, cf_stub):
assert_raises(NotFoundException, test_cf.get, key)
ts = test_cf.insert(key, columns)
assert_true(isinstance(ts, (int, long)))
row = test_cf.get(key, column_reversed=True)

def test_insert_get_column_start_and_finish(self):
key = 'TestColumnFamily.test_insert_get_column_start_and_finish'
columns = {'a': 'val1', 'b': 'val2', 'c': 'val3', 'd': 'val4'}
Expand All @@ -75,6 +86,34 @@ def test_insert_get_column_start_and_finish(self):
assert_true(isinstance(ts, (int, long)))
assert_equal(test_cf.get(key, column_start='b', column_finish='c'), {'b': 'val2', 'c': 'val3'})

def test_insert_get_column_start_and_reversed(self):
key = 'TestColumnFamily.test_insert_get_column_start_and_finish_reversed'
columns = {'a': 'val1', 'b': 'val2', 'c': 'val3', 'd': 'val4'}
for test_cf in (cf, cf_stub):
assert_raises(NotFoundException, test_cf.get, key)
ts = test_cf.insert(key, columns)
assert_true(isinstance(ts, (int, long)))
assert_equal(test_cf.get(key, column_start='b', column_reversed=True), {'b': 'val2', 'a': 'val1'})

def test_insert_get_column_count(self):
key = 'TestColumnFamily.test_insert_get_column_count'
columns = {'a': 'val1', 'b': 'val2', 'c': 'val3', 'd': 'val4'}
for test_cf in (cf, cf_stub):
assert_raises(NotFoundException, test_cf.get, key)
ts = test_cf.insert(key, columns)
assert_true(isinstance(ts, (int, long)))
assert_equal(test_cf.get(key, column_count=3), {'a': 'val1', 'b': 'val2', 'c': 'val3'})

def test_insert_get_default_column_count(self):
keys = [str(i) for i in range(1000)]
keys.sort()
keys_and_values = [(key, key) for key in keys]
key = 'TestColumnFamily.test_insert_get_default_column_count'

for test_cf in (cf, cf_stub):
assert_raises(NotFoundException, test_cf.get, key)
ts = test_cf.insert(key, dict(key_value for key_value in keys_and_values))
assert_equal(test_cf.get(key), dict([key_value for key_value in keys_and_values][:100]))

def test_insert_multiget(self):
key1 = 'TestColumnFamily.test_insert_multiget1'
Expand Down Expand Up @@ -108,6 +147,50 @@ def test_insert_multiget_column_start_and_finish(self):
assert_equal(rows[key2], {'3': 'val1'})
assert_true(missing_key not in rows)

def test_insert_multiget_column_finish_and_reversed(self):
key1 = 'TestColumnFamily.test_insert_multiget_column_finish_and_reversed1'
columns1 = {'1': 'val1', '3': 'val2'}
key2 = 'TestColumnFamily.test_insert_multiget_column_finish_and_reversed2'
columns2 = {'5': 'val1', '7': 'val2'}
missing_key = 'key3'

for test_cf in (cf, cf_stub):
test_cf.insert(key1, columns1)
test_cf.insert(key2, columns2)
rows = test_cf.multiget([key1, key2, missing_key], column_finish='3', column_reversed=True)
assert_equal(len(rows), 2)
assert_equal(rows[key1], {'3': 'val2'})
assert_equal(rows[key2], {'5': 'val1', '7': 'val2'})
assert_true(missing_key not in rows)

def test_insert_multiget_column_start_column_count(self):
key1 = 'TestColumnFamily.test_insert_multiget_column_start_column_count'
columns1 = {'1': 'val1', '2': 'val2'}
key2 = 'test_insert_multiget1'
columns2 = {'3': 'val1', '4': 'val2'}
missing_key = 'key3'

for test_cf in (cf, cf_stub):
test_cf.insert(key1, columns1)
test_cf.insert(key2, columns2)
rows = test_cf.multiget([key1, key2, missing_key], column_count=1, column_start='2')
assert_equal(len(rows), 2)
assert_equal(rows[key1], {'2': 'val2'})
assert_equal(rows[key2], {'3': 'val1'})
assert_true(missing_key not in rows)

def test_insert_multiget_default_column_count(self):
keys = [str(i) for i in range(1000)]
keys.sort()
keys_and_values = [(key, key) for key in keys]
key = 'TestColumnFamily.test_insert_multiget_default_column_count'

for test_cf in (cf, cf_stub):
test_cf.insert(key, dict(key_value for key_value in keys_and_values))
rows = test_cf.multiget([key])
assert_equal(len(rows), 1)
assert_equal(rows[key], dict([key_value for key_value in keys_and_values][:100]))

def insert_insert_get_indexed_slices(self):
columns = {'birthdate': 1L}

Expand All @@ -128,7 +211,6 @@ def insert_insert_get_indexed_slices(self):
count += 1
assert_equal(count, 3)


def test_remove(self):
key = 'TestColumnFamily.test_remove'
for test_cf in (cf, cf_stub):
Expand Down

0 comments on commit 0dbea64

Please sign in to comment.