Skip to content

Commit

Permalink
adding support for column_count. ColumnFamilyStub.get needs to sort c…
Browse files Browse the repository at this point in the history
…olumns before slicing them. Insert order doesn't matter, so changing the dict object used.
  • Loading branch information
david-huber committed May 2, 2013
1 parent af25e70 commit fd9b9c0
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 18 deletions.
52 changes: 34 additions & 18 deletions pycassa/contrib/stubs.py
Expand Up @@ -8,8 +8,7 @@

import operator

from functools import partial

from collections import MutableMapping
from pycassa import NotFoundException
from pycassa.util import OrderedDict
from pycassa.columnfamily import gm_timestamp
Expand All @@ -18,17 +17,30 @@

__all__ = ['ConnectionPoolStub', 'ColumnFamilyStub', 'SystemManagerStub']

class OrderedDictWithTime(OrderedDict):

class DictWithTime(MutableMapping):
def __init__(self, *args, **kwargs):
self.__timestamp = kwargs.pop('timestamp', None)
super(OrderedDictWithTime, self).__init__(*args, **kwargs)
self.store = dict()
self.update(dict(*args, **kwargs))

def __getitem__(self, key):
return self.store[key]

def __setitem__(self, key, value, timestamp=None):
if timestamp is None:
timestamp = self.__timestamp or gm_timestamp()

super(OrderedDictWithTime, self).__setitem__(key, (value, timestamp))
self.store[key] = (value, timestamp)

def __delitem__(self, key):
del self.store[key]

def __iter__(self):
return iter(self.store)

def __len__(self):
return len(self.store)

operator_dict = {
EQ: operator.eq,
Expand All @@ -38,7 +50,6 @@ def __setitem__(self, key, value, timestamp=None):
LTE: operator.le,
}


class ConnectionPoolStub(object):
"""Connection pool stub.
Expand Down Expand Up @@ -112,8 +123,8 @@ class ColumnFamilyStub(object):
def __init__(self, pool=None, column_family=None, rows=None, **kwargs):
rows = rows or OrderedDict()
for r in rows.itervalues():
if not isinstance(r, OrderedDictWithTime):
r = OrderedDictWithTime(r)
if not isinstance(r, DictWithTime):
r = DictWithTime(r)
self.rows = rows

if pool is not None:
Expand All @@ -125,7 +136,8 @@ def __len__(self):
def __contains__(self, obj):
return self.rows.__contains__(obj)

def get(self, key, columns=None, column_start=None, column_finish=None, include_timestamp=False, **kwargs):
def get(self, key, columns=None, column_start=None, column_finish=None,
column_count=100, include_timestamp=False, **kwargs):
"""Get a value from the column family stub."""

my_columns = self.rows.get(key)
Expand All @@ -136,26 +148,30 @@ def get(self, key, columns=None, column_start=None, column_finish=None, include_
if not my_columns:
raise NotFoundException()

return OrderedDict((k, get_value(v)) for (k, v)
in my_columns.iteritems()
if self._is_column_in_range(k, columns, column_start, column_finish))
items = my_columns.items()
items.sort()

return OrderedDict([(k, get_value(v)) for (k, v) in items
if self._is_column_in_range(k, columns, column_start, column_finish)][:column_count])

def _is_column_in_range(self, k, columns, column_start, column_finish):
if columns:
return k in columns
return (not column_start or k >= column_start) and (not column_finish or k <= column_finish)


def multiget(self, keys, columns=None, column_start=None, column_finish=None, include_timestamp=False, **kwargs):
def multiget(self, keys, columns=None, column_start=None, column_finish=None,
column_count=100, include_timestamp=False, **kwargs):
"""Get multiple key values from the column family stub."""

return OrderedDict(
(key, self.get(
key,
columns,
column_start,
column_finish,
include_timestamp,
columns=columns,
column_start=column_start,
column_finish=column_finish,
column_count=column_count,
include_timestamp=include_timestamp,
)) for key in keys if key in self.rows)

def batch(self, **kwargs):
Expand All @@ -169,7 +185,7 @@ def insert(self, key, columns, timestamp=None, **kwargs):
"""Insert data to the column family stub."""

if key not in self.rows:
self.rows[key] = OrderedDictWithTime([], timestamp=timestamp)
self.rows[key] = DictWithTime([], timestamp=timestamp)

for column in columns:
self.rows[key].__setitem__(column, columns[column], timestamp)
Expand Down
47 changes: 47 additions & 0 deletions tests/contrib/stubs.py
Expand Up @@ -75,6 +75,25 @@ def test_insert_get_column_start_and_finish(self):
assert_true(isinstance(ts, (int, long)))
assert_equal(test_cf.get(key, column_start='b', column_finish='c'), {'b': 'val2', 'c': 'val3'})

def test_insert_get_column_count(self):
key = 'TestColumnFamily.test_insert_get_column_count'
columns = {'a': 'val1', 'b': 'val2', 'c': 'val3', 'd': 'val4'}
for test_cf in (cf, cf_stub):
assert_raises(NotFoundException, test_cf.get, key)
ts = test_cf.insert(key, columns)
assert_true(isinstance(ts, (int, long)))
assert_equal(test_cf.get(key, column_count=3), {'a': 'val1', 'b': 'val2', 'c': 'val3'})

def test_insert_get_default_column_count(self):
keys = [str(i) for i in range(1000)]
keys.sort()
keys_and_values = [(key, key) for key in keys]
key = 'TestColumnFamily.test_insert_get_default_column_count'

for test_cf in (cf, cf_stub):
assert_raises(NotFoundException, test_cf.get, key)
ts = test_cf.insert(key, dict(key_value for key_value in keys_and_values))
assert_equal(test_cf.get(key), dict([key_value for key_value in keys_and_values][:100]))

def test_insert_multiget(self):
key1 = 'TestColumnFamily.test_insert_multiget1'
Expand Down Expand Up @@ -108,6 +127,34 @@ def test_insert_multiget_column_start_and_finish(self):
assert_equal(rows[key2], {'3': 'val1'})
assert_true(missing_key not in rows)

def test_insert_multiget_column_start_column_count(self):
key1 = 'TestColumnFamily.test_insert_multiget_column_start_column_count'
columns1 = {'1': 'val1', '2': 'val2'}
key2 = 'test_insert_multiget1'
columns2 = {'3': 'val1', '4': 'val2'}
missing_key = 'key3'

for test_cf in (cf, cf_stub):
test_cf.insert(key1, columns1)
test_cf.insert(key2, columns2)
rows = test_cf.multiget([key1, key2, missing_key], column_count=1, column_start='2')
assert_equal(len(rows), 2)
assert_equal(rows[key1], {'2': 'val2'})
assert_equal(rows[key2], {'3': 'val1'})
assert_true(missing_key not in rows)

def test_insert_multiget_default_column_count(self):
keys = [str(i) for i in range(1000)]
keys.sort()
keys_and_values = [(key, key) for key in keys]
key = 'TestColumnFamily.test_insert_multiget_default_column_count'

for test_cf in (cf, cf_stub):
test_cf.insert(key, dict(key_value for key_value in keys_and_values))
rows = test_cf.multiget([key])
assert_equal(len(rows), 1)
assert_equal(rows[key], dict([key_value for key_value in keys_and_values][:100]))

def insert_insert_get_indexed_slices(self):
columns = {'birthdate': 1L}

Expand Down

0 comments on commit fd9b9c0

Please sign in to comment.