diff --git a/elasticsearch_async/connection.py b/elasticsearch_async/connection.py index f0dc630..1ae43ba 100644 --- a/elasticsearch_async/connection.py +++ b/elasticsearch_async/connection.py @@ -11,7 +11,7 @@ class AIOHttpConnection(Connection): def __init__(self, host='localhost', port=9200, http_auth=None, use_ssl=False, verify_certs=False, ca_certs=None, client_cert=None, - client_key=None, loop=None, use_dns_cache=True, **kwargs): + client_key=None, loop=None, use_dns_cache=True, headers=None, **kwargs): super().__init__(host=host, port=port, **kwargs) self.loop = asyncio.get_event_loop() if loop is None else loop @@ -23,6 +23,9 @@ def __init__(self, host='localhost', port=9200, http_auth=None, if isinstance(http_auth, (tuple, list)): http_auth = aiohttp.BasicAuth(*http_auth) + headers = headers or {} + headers.setdefault('content-type', 'application/json') + self.session = aiohttp.ClientSession( auth=http_auth, conn_timeout=self.timeout, @@ -30,7 +33,8 @@ def __init__(self, host='localhost', port=9200, http_auth=None, loop=self.loop, verify_ssl=verify_certs, use_dns_cache=use_dns_cache, - ) + ), + headers=headers ) self.base_url = 'http%s://%s:%d%s' % ( @@ -42,7 +46,7 @@ def close(self): return self.session.close() @asyncio.coroutine - def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=()): + def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None): url_path = url if params: url_path = '%s?%s' % (url, urlencode(params or {})) @@ -52,7 +56,7 @@ def perform_request(self, method, url, params=None, body=None, timeout=None, ign response = None try: with aiohttp.Timeout(timeout or self.timeout): - response = yield from self.session.request(method, url, data=body) + response = yield from self.session.request(method, url, data=body, headers=headers) raw_data = yield from response.text() duration = self.loop.time() - start diff --git a/elasticsearch_async/transport.py b/elasticsearch_async/transport.py index 69e060f..0d1b46c 100644 --- a/elasticsearch_async/transport.py +++ b/elasticsearch_async/transport.py @@ -131,13 +131,13 @@ def sniff_hosts(self, initial=False): yield from c.close() @asyncio.coroutine - def main_loop(self, method, url, params, body, ignore=(), timeout=None): + def main_loop(self, method, url, params, body, headers=None, ignore=(), timeout=None): for attempt in range(self.max_retries + 1): connection = self.get_connection() try: status, headers, data = yield from connection.perform_request( - method, url, params, body, ignore=ignore, timeout=timeout) + method, url, params, body, headers=headers, ignore=ignore, timeout=timeout) except TransportError as e: if method == 'HEAD' and e.status_code == 404: return False @@ -169,7 +169,7 @@ def main_loop(self, method, url, params, body, ignore=(), timeout=None): data = self.deserializer.loads(data, headers.get('content-type')) return data - def perform_request(self, method, url, params=None, body=None): + def perform_request(self, method, url, headers=None, params=None, body=None): if body is not None: body = self.serializer.dumps(body) @@ -202,6 +202,7 @@ def perform_request(self, method, url, params=None, body=None): ignore = (ignore, ) return ensure_future(self.main_loop(method, url, params, body, + headers=headers, ignore=ignore, timeout=timeout), loop=self.loop)