Skip to content

Commit

Permalink
Add socket_factory and SSL socket support
Browse files Browse the repository at this point in the history
  • Loading branch information
iamaleksey committed Nov 5, 2012
1 parent e03e591 commit c3e4c07
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
31 changes: 28 additions & 3 deletions pycassa/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from thrift.transport import TTransport
from thrift.transport import TSocket
from thrift.transport import TSocket, TSSLSocket
from thrift.protocol import TBinaryProtocol

from pycassa.cassandra import Cassandra
Expand All @@ -8,11 +8,19 @@
DEFAULT_SERVER = 'localhost:9160'
DEFAULT_PORT = 9160


def default_socket_factory(host, port):
"""
Returns a normal :class:`TSocket` instance.
"""
return TSocket.TSocket(host, port)


class Connection(Cassandra.Client):
"""Encapsulation of a client session."""

def __init__(self, keyspace, server, framed_transport=True, timeout=None,
credentials=None, api_version=None):
credentials=None, socket_factory=default_socket_factory):
self.keyspace = None
self.server = server
server = server.split(':')
Expand All @@ -21,7 +29,7 @@ def __init__(self, keyspace, server, framed_transport=True, timeout=None,
else:
port = server[1]
host = server[0]
socket = TSocket.TSocket(host, int(port))
socket = socket_factory(host, int(port))
if timeout is not None:
socket.setTimeout(timeout * 1000.0)
if framed_transport:
Expand All @@ -45,3 +53,20 @@ def set_keyspace(self, keyspace):

def close(self):
self.transport.close()


def make_ssl_socket_factory(ca_certs, validate=True):
"""
A convenience function for creating an SSL socket factory.
`ca_certs` should contain the path to the certificate file,
`validate` determines whether or not SSL certificate validation will be performed.
"""

def ssl_socket_factory(host, port):
"""
Returns a :class:`TSSLSocket` instance.
"""
return TSSLSocket.TSSLSocket(host, port, ca_certs=ca_certs, validate=validate)

return ssl_socket_factory
19 changes: 15 additions & 4 deletions pycassa/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from thrift import Thrift
from thrift.transport.TTransport import TTransportException
from connection import Connection
from connection import Connection, default_socket_factory
from logging.pool_logger import PoolLogger
from util import as_interface
from cassandra.ttypes import TimedOutException, UnavailableException
Expand Down Expand Up @@ -229,7 +229,7 @@ def _set_max_overflow(self, max_overflow):
up to `pool_timeout` seconds for a connection to be returned to the
pool before giving up. Note that this setting is only meaningful when you
are accessing the pool concurrently, such as with multiple threads.
This may be set to 0 to fail immediately or -1 to wait forever.
This may be set to 0 to fail immediately or -1 to wait forever.
The default value is 30. """

recycle = 10000
Expand All @@ -242,21 +242,30 @@ def _set_max_overflow(self, max_overflow):
or :exc:`~.UnavailableException`, which tend to indicate single or
multiple node failure, the operation will be retried on different nodes
up to `max_retries` times before an :exc:`~.MaximumRetryException` is raised.
Setting this to 0 disables retries and setting to -1 allows unlimited retries.
Setting this to 0 disables retries and setting to -1 allows unlimited retries.
The default value is 5. """

logging_name = None
""" By default, each pool identifies itself in the logs using ``id(self)``.
If multiple pools are in use for different purposes, setting `logging_name` will
help individual pools to be identified in the logs. """

socket_factory = default_socket_factory
""" A function that creates the socket for each connection in the pool.
This function should take two arguments: `host`, the host the connection is
being made to, and `port`, the destination port.
By default, this is function is :func:`~connection.default_socket_factory`.
"""

def __init__(self, keyspace,
server_list=['localhost:9160'],
credentials=None,
timeout=0.5,
use_threadlocal=True,
pool_size=5,
prefill=True,
socket_factory=default_socket_factory,
**kwargs):
"""
All connections in the pool will be opened to `keyspace`.
Expand Down Expand Up @@ -315,6 +324,7 @@ def __init__(self, keyspace,
self.keyspace = keyspace
self.credentials = credentials
self.timeout = timeout
self.socket_factory = socket_factory
if use_threadlocal:
self._tlocal = threading.local()

Expand Down Expand Up @@ -431,7 +441,8 @@ def _get_new_wrapper(self, server):
self.keyspace, server,
framed_transport=True,
timeout=self.timeout,
credentials=self.credentials)
credentials=self.credentials,
socket_factory=self.socket_factory)

def _replace_wrapper(self):
"""Try to replace the connection."""
Expand Down
6 changes: 3 additions & 3 deletions pycassa/system_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time

from pycassa.connection import Connection
from pycassa.connection import Connection, default_socket_factory
from pycassa.cassandra.ttypes import IndexType, KsDef, CfDef, ColumnDef,\
SchemaDisagreementException
import pycassa.marshal as marshal
Expand Down Expand Up @@ -66,8 +66,8 @@ class SystemManager(object):
"""

def __init__(self, server='localhost:9160', credentials=None, framed_transport=True,
timeout=_DEFAULT_TIMEOUT):
self._conn = Connection(None, server, framed_transport, timeout, credentials)
timeout=_DEFAULT_TIMEOUT, socket_factory=default_socket_factory):
self._conn = Connection(None, server, framed_transport, timeout, credentials, socket_factory)

def close(self):
""" Closes the underlying connection """
Expand Down

0 comments on commit c3e4c07

Please sign in to comment.