Skip to content
This repository has been archived by the owner on Mar 20, 2023. It is now read-only.

Add default content-type headers to support ES6.0 #29

Merged
merged 1 commit into from Nov 27, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 8 additions & 4 deletions elasticsearch_async/connection.py
Expand Up @@ -11,7 +11,7 @@
class AIOHttpConnection(Connection):
def __init__(self, host='localhost', port=9200, http_auth=None,
use_ssl=False, verify_certs=False, ca_certs=None, client_cert=None,
client_key=None, loop=None, use_dns_cache=True, **kwargs):
client_key=None, loop=None, use_dns_cache=True, headers=None, **kwargs):
super().__init__(host=host, port=port, **kwargs)

self.loop = asyncio.get_event_loop() if loop is None else loop
Expand All @@ -23,14 +23,18 @@ def __init__(self, host='localhost', port=9200, http_auth=None,
if isinstance(http_auth, (tuple, list)):
http_auth = aiohttp.BasicAuth(*http_auth)

headers = headers or {}
headers.setdefault('content-type', 'application/json')

self.session = aiohttp.ClientSession(
auth=http_auth,
conn_timeout=self.timeout,
connector=aiohttp.TCPConnector(
loop=self.loop,
verify_ssl=verify_certs,
use_dns_cache=use_dns_cache,
)
),
headers=headers
)

self.base_url = 'http%s://%s:%d%s' % (
Expand All @@ -42,7 +46,7 @@ def close(self):
return self.session.close()

@asyncio.coroutine
def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=()):
def perform_request(self, method, url, params=None, body=None, timeout=None, ignore=(), headers=None):
url_path = url
if params:
url_path = '%s?%s' % (url, urlencode(params or {}))
Expand All @@ -52,7 +56,7 @@ def perform_request(self, method, url, params=None, body=None, timeout=None, ign
response = None
try:
with aiohttp.Timeout(timeout or self.timeout):
response = yield from self.session.request(method, url, data=body)
response = yield from self.session.request(method, url, data=body, headers=headers)
raw_data = yield from response.text()
duration = self.loop.time() - start

Expand Down
7 changes: 4 additions & 3 deletions elasticsearch_async/transport.py
Expand Up @@ -131,13 +131,13 @@ def sniff_hosts(self, initial=False):
yield from c.close()

@asyncio.coroutine
def main_loop(self, method, url, params, body, ignore=(), timeout=None):
def main_loop(self, method, url, params, body, headers=None, ignore=(), timeout=None):
for attempt in range(self.max_retries + 1):
connection = self.get_connection()

try:
status, headers, data = yield from connection.perform_request(
method, url, params, body, ignore=ignore, timeout=timeout)
method, url, params, body, headers=headers, ignore=ignore, timeout=timeout)
except TransportError as e:
if method == 'HEAD' and e.status_code == 404:
return False
Expand Down Expand Up @@ -169,7 +169,7 @@ def main_loop(self, method, url, params, body, ignore=(), timeout=None):
data = self.deserializer.loads(data, headers.get('content-type'))
return data

def perform_request(self, method, url, params=None, body=None):
def perform_request(self, method, url, headers=None, params=None, body=None):
if body is not None:
body = self.serializer.dumps(body)

Expand Down Expand Up @@ -202,6 +202,7 @@ def perform_request(self, method, url, params=None, body=None):
ignore = (ignore, )

return ensure_future(self.main_loop(method, url, params, body,
headers=headers,
ignore=ignore,
timeout=timeout),
loop=self.loop)
Expand Down