diff --git a/elasticsearch/connection/base.py b/elasticsearch/connection/base.py index 7841fd02e..bda71c58e 100644 --- a/elasticsearch/connection/base.py +++ b/elasticsearch/connection/base.py @@ -17,13 +17,17 @@ class Connection(object): """ transport_schema = 'http' - def __init__(self, host='localhost', port=9200, url_prefix='', timeout=10, **kwargs): + def __init__(self, host='localhost', port=9200, url_prefix='', timeout=10, transport_schema='http', auth=None, **kwargs): """ :arg host: hostname of the node (default: localhost) :arg port: port to use (default: 9200) :arg url_prefix: optional url prefix for elasticsearch :arg timeout: default timeout in seconds (default: 10) """ + self.transport_schema = transport_schema + + self.auth = auth + self.host = '%s://%s:%s' % (self.transport_schema, host, port) if url_prefix: url_prefix = '/' + url_prefix.strip('/') diff --git a/elasticsearch/connection/http.py b/elasticsearch/connection/http.py index c55be9f23..bf25cf2ec 100644 --- a/elasticsearch/connection/http.py +++ b/elasticsearch/connection/http.py @@ -13,7 +13,10 @@ class RequestsHttpConnection(Connection): """ Connection using the `requests` library. """ def __init__(self, host='localhost', port=9200, **kwargs): super(RequestsHttpConnection, self).__init__(host=host, port=port, **kwargs) + self.session = requests.session() + if self.auth is not None: + self.session.auth = tuple(self.auth.split(':', 1)) def perform_request(self, method, url, params=None, body=None, timeout=None): url = self.host + self.url_prefix + url @@ -22,7 +25,9 @@ def perform_request(self, method, url, params=None, body=None, timeout=None): request = requests.Request(method, url, params=params or {}, data=body).prepare() start = time.time() try: - response = self.session.send(request, timeout=timeout or self.timeout) + response = self.session.request(method=request.method, url=request.url, headers=request.headers, + data=request.body, hooks=request.hooks, timeout=timeout or self.timeout) + duration = time.time() - start raw_data = response.text except (requests.ConnectionError, requests.Timeout) as e: @@ -44,7 +49,15 @@ class Urllib3HttpConnection(Connection): """ def __init__(self, host='localhost', port=9200, **kwargs): super(Urllib3HttpConnection, self).__init__(host=host, port=port, **kwargs) - self.pool = urllib3.HTTPConnectionPool(host, port=port, timeout=kwargs.get('timeout', None)) + + auth_headers = dict() + if self.auth is not None: + auth_headers.update(urllib3.make_headers(basic_auth=self.auth)) + + if self.transport_schema == 'https': + self.pool = urllib3.HTTPSConnectionPool(host, port=port, timeout=kwargs.get('timeout', None), headers=auth_headers) + else: + self.pool = urllib3.HTTPConnectionPool(host, port=port, timeout=kwargs.get('timeout', None), headers=auth_headers) def perform_request(self, method, url, params=None, body=None, timeout=None): url = self.url_prefix + url