Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions elasticsearch/connection/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .esthrift import Rest
from .esthrift.ttypes import Method, RestRequest

from thrift.transport import TTransport, TSocket
from thrift.transport import TTransport, TSocket, TSSLSocket
from thrift.protocol import TBinaryProtocol
from thrift.Thrift import TException
THRIFT_AVAILABLE = True
Expand All @@ -23,7 +23,7 @@ class ThriftConnection(PoolingConnection):
"""
transport_schema = 'thrift'

def __init__(self, host='localhost', port=9500, framed_transport=False, **kwargs):
def __init__(self, host='localhost', port=9500, framed_transport=False, thrift_ssl=False, thrift_socket_extra_args=None, thrift_headers=None, **kwargs):
"""
:arg framed_transport: use `TTransport.TFramedTransport` instead of
`TTransport.TBufferedTransport`
Expand All @@ -33,10 +33,15 @@ def __init__(self, host='localhost', port=9500, framed_transport=False, **kwargs

super(ThriftConnection, self).__init__(host=host, port=port, **kwargs)
self._framed_transport = framed_transport
self._tsocket_class = TSocket.TSocket
if thrift_ssl:
self._tsocket_class = TSSLSocket.TSSLSocket
self._tsocket_args = (host, port)
self._tsocket_kwargs = thrift_socket_extra_args or dict()
self._thrift_headers = thrift_headers or dict()

def _make_connection(self):
socket = TSocket.TSocket(*self._tsocket_args)
socket = self._tsocket_class(*self._tsocket_args, **self._tsocket_kwargs)
socket.setTimeout(self.timeout * 1000.0)
if self._framed_transport:
transport = TTransport.TFramedTransport(socket)
Expand All @@ -50,7 +55,7 @@ def _make_connection(self):

def perform_request(self, method, url, params=None, body=None, timeout=None):
request = RestRequest(method=Method._NAMES_TO_VALUES[method.upper()], uri=url,
parameters=params, body=body)
parameters=params, headers=self._thrift_headers, body=body)

start = time.time()
tclient = self._get_connection()
Expand Down