Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(self,
metrics_enabled=False,
connection_class=None,
sockopts=None,
cql_version=None,
executor_threads=2,
max_schema_agreement_wait=10):
"""
Expand Down Expand Up @@ -236,6 +237,7 @@ def __init__(self,

self.metrics_enabled = metrics_enabled
self.sockopts = sockopts
self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait

# let Session objects be GC'ed (and shutdown) when the user no longer
Expand Down Expand Up @@ -316,6 +318,7 @@ def connection_factory(self, address, *args, **kwargs):
kwargs['port'] = self.port
kwargs['compression'] = self.compression
kwargs['sockopts'] = self.sockopts
kwargs['cql_version'] = self.cql_version

return self.connection_class.factory(address, *args, **kwargs)

Expand All @@ -326,6 +329,7 @@ def _make_connection_factory(self, host, *args, **kwargs):
kwargs['port'] = self.port
kwargs['compression'] = self.compression
kwargs['sockopts'] = self.sockopts
kwargs['cql_version'] = self.cql_version

return partial(self.connection_class.factory, host.address, *args, **kwargs)

Expand Down
7 changes: 7 additions & 0 deletions cassandra/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,12 @@ def cql_encode_str(val):
return cql_quote(val)


def cql_encode_bytes(val):
hex_val = ''.join('%02x' % byte for byte in val)
hex_val = '0x' + hex_val
return hex_val


def cql_encode_object(val):
return str(val)

Expand Down Expand Up @@ -815,6 +821,7 @@ def cql_encode_set_collection(val):

cql_encoders = {
float: cql_encode_object,
bytearray: cql_encode_bytes,
str: cql_encode_str,
unicode: cql_encode_unicode,
types.NoneType: cql_encode_none,
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,31 @@
os.mkdir(path)


class BaseTestCase(unittest.TestCase):
def _get_cass_and_cql_version(self):
"""
Probe system.local table to determine Cassandra and CQL version.
"""
c = Cluster()
s = c.connect()
s.set_keyspace('system')
row = s.execute('SELECT cql_version, release_version FROM local')[0]

cass_version = self._get_version_as_tuple(row.release_version)
cql_version = self._get_version_as_tuple(row.cql_version)

c.shutdown()

result = {'cass_version': cass_version, 'cql_version': cql_version}
return result

def _get_version_as_tuple(self, version):
version = version.split('.')
version = [int(p) for p in version]
version = tuple(version)
return version


def get_cluster():
return CCM_CLUSTER

Expand Down
108 changes: 97 additions & 11 deletions tests/integration/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,93 @@

from blist import sortedset

from cassandra.cluster import Cluster
from cassandra import InvalidRequest
from cassandra.cluster import Cluster, NoHostAvailable

class TypeTests(unittest.TestCase):
from tests.integration import BaseTestCase


class TypeTests(BaseTestCase):
def setUp(self):
super(TypeTests, self).setUp()
self._versions = self._get_cass_and_cql_version()

def test_blob_type_as_string(self):
c = Cluster()
s = c.connect()

s.execute("""
CREATE KEYSPACE typetests_blob1
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}
""")
s.set_keyspace("typetests_blob1")
s.execute("""
CREATE TABLE mytable (
a ascii,
b blob,
PRIMARY KEY (a)
)
""")

params = [
'key1',
'blobyblob'.encode('hex')
]

query = 'INSERT INTO mytable (a, b) VALUES (%s, %s)'

if self._versions['cql_version'] >= (3, 1, 0):
# Blob values can't be specified using string notation in CQL 3.1.0 and
# above which is used by default in Cassandra 2.0.
msg = r'.*Invalid STRING constant \(.*?\) for b of type blob.*'
self.assertRaisesRegexp(InvalidRequest, msg, s.execute, query, params)
return

s.execute(query, params)
expected_vals = [
'key1',
'blobyblob'
]

results = s.execute("SELECT * FROM mytable")

for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)

def test_blob_type_as_bytearray(self):
c = Cluster()
s = c.connect()

s.execute("""
CREATE KEYSPACE typetests_blob2
WITH replication = { 'class' : 'SimpleStrategy', 'replication_factor': '1'}
""")
s.set_keyspace("typetests_blob2")
s.execute("""
CREATE TABLE mytable (
a ascii,
b blob,
PRIMARY KEY (a)
)
""")

params = [
'key1',
bytearray('blob1', 'hex')
]

query = 'INSERT INTO mytable (a, b) VALUES (%s, %s);'
s.execute(query, params)

expected_vals = [
'key1',
bytearray('blob1', 'hex')
]

results = s.execute("SELECT * FROM mytable")

for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)

def test_basic_types(self):
c = Cluster()
Expand All @@ -27,7 +111,6 @@ def test_basic_types(self):
b text,
c ascii,
d bigint,
e blob,
f boolean,
g decimal,
h double,
Expand Down Expand Up @@ -55,7 +138,6 @@ def test_basic_types(self):
"sometext",
"ascii", # ascii
12345678923456789, # bigint
"blob".encode('hex'), # blob
True, # boolean
Decimal('1.234567890123456789'), # decimal
0.000244140625, # double
Expand All @@ -77,7 +159,6 @@ def test_basic_types(self):
"sometext",
"ascii", # ascii
12345678923456789, # bigint
"blob", # blob
True, # boolean
Decimal('1.234567890123456789'), # decimal
0.000244140625, # double
Expand All @@ -95,8 +176,8 @@ def test_basic_types(self):
)

s.execute("""
INSERT INTO mytable (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", params)

results = s.execute("SELECT * FROM mytable")
Expand All @@ -106,11 +187,10 @@ def test_basic_types(self):

# try the same thing with a prepared statement
prepared = s.prepare("""
INSERT INTO mytable (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO mytable (a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""")

params[4] = 'blob'
s.execute(prepared.bind(params))

results = s.execute("SELECT * FROM mytable")
Expand All @@ -120,7 +200,7 @@ def test_basic_types(self):

# query with prepared statement
prepared = s.prepare("""
SELECT a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s FROM mytable
SELECT a, b, c, d, f, g, h, i, j, k, l, m, n, o, p, q, r, s FROM mytable
""")
results = s.execute(prepared.bind(()))

Expand All @@ -133,3 +213,9 @@ def test_basic_types(self):

for expected, actual in zip(expected_vals, results[0]):
self.assertEquals(expected, actual)

def _get_version_as_tuple(self, version):
version = version.split('.')
version = [int(p) for p in version]
version = tuple(version)
return version