diff --git a/elasticsearch/connection/thrift.py b/elasticsearch/connection/thrift.py index d7c9b9a37..1260626bf 100644 --- a/elasticsearch/connection/thrift.py +++ b/elasticsearch/connection/thrift.py @@ -5,7 +5,7 @@ from .esthrift import Rest from .esthrift.ttypes import Method, RestRequest - from thrift.transport import TTransport, TSocket + from thrift.transport import TTransport, TSocket, TSSLSocket from thrift.protocol import TBinaryProtocol from thrift.Thrift import TException THRIFT_AVAILABLE = True @@ -23,7 +23,7 @@ class ThriftConnection(PoolingConnection): """ transport_schema = 'thrift' - def __init__(self, host='localhost', port=9500, framed_transport=False, **kwargs): + def __init__(self, host='localhost', port=9500, framed_transport=False, thrift_ssl=False, thrift_socket_extra_args=None, thrift_headers=None, **kwargs): """ :arg framed_transport: use `TTransport.TFramedTransport` instead of `TTransport.TBufferedTransport` @@ -33,10 +33,15 @@ def __init__(self, host='localhost', port=9500, framed_transport=False, **kwargs super(ThriftConnection, self).__init__(host=host, port=port, **kwargs) self._framed_transport = framed_transport + self._tsocket_class = TSocket.TSocket + if thrift_ssl: + self._tsocket_class = TSSLSocket.TSSLSocket self._tsocket_args = (host, port) + self._tsocket_kwargs = thrift_socket_extra_args or dict() + self._thrift_headers = thrift_headers or dict() def _make_connection(self): - socket = TSocket.TSocket(*self._tsocket_args) + socket = self._tsocket_class(*self._tsocket_args, **self._tsocket_kwargs) socket.setTimeout(self.timeout * 1000.0) if self._framed_transport: transport = TTransport.TFramedTransport(socket) @@ -50,7 +55,7 @@ def _make_connection(self): def perform_request(self, method, url, params=None, body=None, timeout=None): request = RestRequest(method=Method._NAMES_TO_VALUES[method.upper()], uri=url, - parameters=params, body=body) + parameters=params, headers=self._thrift_headers, body=body) start = time.time() tclient = self._get_connection()