diff --git a/tests/test_autopacking.py b/tests/test_autopacking.py index d7431aff..1429db3a 100644 --- a/tests/test_autopacking.py +++ b/tests/test_autopacking.py @@ -10,7 +10,7 @@ from nose.tools import (assert_raises, assert_equal, assert_almost_equal, assert_true) -from datetime import datetime +from datetime import date, datetime from uuid import uuid1 import uuid import unittest @@ -1036,3 +1036,156 @@ def test_constructor_lambdas(self): self.cf.key_validation_class = TestCustomTypes.IntString2() self.cf.insert(1234, {'col': 'val'}) assert_equal(self.cf.get(1234), {'col': 'val'}) + +class TestCustomComposite(unittest.TestCase): + """ + Test CompositeTypes with custom inner types. + """ + + # Some contrived scenarios + class IntDateType(types.CassandraType): + """ + Represent a date as an integer. E.g.: March 05, 2012 = 20120305 + """ + @staticmethod + def pack(v, *args, **kwargs): + assert type(v) in (datetime, date), "Invalid arg" + str_date = v.strftime("%Y%m%d") + return marshal.encode_int(int(str_date)) + + @staticmethod + def unpack(v, *args, **kwargs): + int_date = marshal.decode_int(v) + return date(*time.strptime(str(int_date), "%Y%m%d")[0:3]) + + class IntString(types.CassandraType): + + @staticmethod + def pack(intval): + return str(intval) + + @staticmethod + def unpack(strval): + return int(strval) + + class IntString2(types.CassandraType): + + def __init__(self, *args, **kwargs): + self.pack = lambda val: str(val) + self.unpack = lambda val: int(val) + + @classmethod + def setup_class(cls): + sys = SystemManager() + have_composites = sys._conn.version != CASSANDRA_07 + if not have_composites: + raise SkipTest("Cassandra < 0.8 does not composite types") + + sys.create_column_family( + TEST_KS, + 'CustomComposite1', + comparator_type=CompositeType( + IntegerType(), + UTF8Type())) + + @classmethod + def teardown_class(cls): + sys = SystemManager() + sys.drop_column_family(TEST_KS, 'CustomComposite1') + + def test_static_composite_basic(self): + cf = ColumnFamily(pool, 'CustomComposite1') + colname = (20120305, '12345') + cf.insert('key', {colname: 'val1'}) + assert_equal(cf.get('key'), {colname: 'val1'}) + + def test_insert_with_custom_composite(self): + cf_std = ColumnFamily(pool, 'CustomComposite1') + cf_cust = ColumnFamily(pool, 'CustomComposite1') + cf_cust.column_name_class = CompositeType( + TestCustomComposite.IntDateType(), + TestCustomComposite.IntString()) + + std_col = (20120311, '321') + cust_col = (date(2012, 3, 11), 321) + cf_cust.insert('cust_insert_key_1', {cust_col: 'cust_insert_val_1'}) + assert_equal(cf_std.get('cust_insert_key_1'), + {std_col: 'cust_insert_val_1'}) + + def test_retrieve_with_custom_composite(self): + cf_std = ColumnFamily(pool, 'CustomComposite1') + cf_cust = ColumnFamily(pool, 'CustomComposite1') + cf_cust.column_name_class = CompositeType( + TestCustomComposite.IntDateType(), + TestCustomComposite.IntString()) + + std_col = (20120312, '321') + cust_col = (date(2012, 3, 12), 321) + cf_std.insert('cust_insert_key_2', {std_col: 'cust_insert_val_2'}) + assert_equal(cf_cust.get('cust_insert_key_2'), + {cust_col: 'cust_insert_val_2'}) + + def test_composite_slicing(self): + cf_std = ColumnFamily(pool, 'CustomComposite1') + cf_cust = ColumnFamily(pool, 'CustomComposite1') + cf_cust.column_name_class = CompositeType( + TestCustomComposite.IntDateType(), + TestCustomComposite.IntString2()) + + col0 = (20120101, '123') + col1 = (20120102, '123') + col2 = (20120102, '456') + col3 = (20120102, '789') + col4 = (20120103, '123') + + dt0 = date(2012, 1, 1) + dt1 = date(2012, 1, 2) + dt2 = date(2012, 1, 3) + + col0_cust = (dt0, 123) + col1_cust = (dt1, 123) + col2_cust = (dt1, 456) + col3_cust = (dt1, 789) + col4_cust = (dt2, 123) + + cf_std.insert('key2', {col0: '', col1: '', col2: '', col3: '', col4: ''}) + + result = cf_cust.get('key2', + column_start=((dt1, True),), + column_finish=((dt1, True),)) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''}) + + result = cf_cust.get('key2', + column_start=(dt1,), + column_finish=((dt2, False), )) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''}) + + result = cf_cust.get('key2', + column_start=((dt1, True),), + column_finish=((dt2, False), )) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''}) + + result = cf_cust.get('key2', + column_start=(dt1, ), + column_finish=((dt2, False), )) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''}) + + result = cf_cust.get('key2', + column_start=((dt0, False), ), + column_finish=((dt2, False), )) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''}) + + result = cf_cust.get('key2', + column_start=(dt1, 123), + column_finish=(dt1, 789)) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''}) + + result = cf_cust.get('key2', + column_start=(dt1, 123), + column_finish=(dt1, (789, True))) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''}) + + result = cf_cust.get('key2', + column_start=(dt1, (123, True)), + column_finish=((dt2, False), )) + assert_equal(result, {col1_cust: '', col2_cust: '', col3_cust: ''})