Skip to content

Commit

Permalink
Verify we're connected to Elasticsearch before requests
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson committed Jun 30, 2021
1 parent a2bacb2 commit 801a839
Show file tree
Hide file tree
Showing 9 changed files with 929 additions and 6 deletions.
106 changes: 105 additions & 1 deletion elasticsearch/_async/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@
import asyncio
import logging
import sys
import warnings
from itertools import chain

from ..exceptions import (
AuthenticationException,
AuthorizationException,
ConnectionError,
ConnectionTimeout,
ElasticsearchWarning,
NotElasticsearchError,
SerializationError,
TransportError,
)
from ..transport import Transport
from ..transport import Transport, _verify_elasticsearch
from .compat import get_running_loop
from .http_aiohttp import AIOHttpConnection

Expand Down Expand Up @@ -113,6 +118,10 @@ async def _async_init(self):
self.loop = get_running_loop()
self.kwargs["loop"] = self.loop

# Set our 'verified_once' implementation to one that
# works with 'asyncio' instead of 'threading'
self._verified_once = Once()

# Now that we have a loop we can create all our HTTP connections...
self.set_connections(self.hosts)
self.seed_connections = list(self.connection_pool.connections[:])
Expand Down Expand Up @@ -327,6 +336,19 @@ async def perform_request(self, method, url, headers=None, params=None, body=Non
method, headers, params, body
)

# Before we make the actual API call we verify the Elasticsearch instance.
if self._verified_elasticsearch is None:
await self._verified_once.call(
self._do_verify_elasticsearch, headers=headers, timeout=timeout
)

# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
if self._verified_elasticsearch is False:
raise NotElasticsearchError(
"The client noticed that the server is not Elasticsearch "
"and we do not support this unknown product"
)

for attempt in range(self.max_retries + 1):
connection = self.get_connection()

Expand Down Expand Up @@ -398,3 +420,85 @@ async def close(self):

for connection in self.connection_pool.connections:
await connection.close()

async def _do_verify_elasticsearch(self, headers, timeout):
"""Verifies that we're connected to an Elasticsearch cluster.
This is done at least once before the first actual API call
and makes a single request to the 'GET /' API endpoint and
check version along with other details of the response.
If we're unable to verify we're talking to Elasticsearch
but we're also unable to rule it out due to a permission
error we instead emit an 'ElasticsearchWarning'.
"""
# Product check has already been done, no need to do again.
if self._verified_elasticsearch:
return

headers = {header.lower(): value for header, value in (headers or {}).items()}
# We know we definitely want JSON so request it via 'accept'
headers.setdefault("accept", "application/json")

info_headers = {}
info_response = {}
error = None

for conn in chain(self.connection_pool.connections, self.seed_connections):
try:
_, info_headers, info_response = await conn.perform_request(
"GET", "/", headers=headers, timeout=timeout
)

# Lowercase all the header names for consistency in accessing them.
info_headers = {
header.lower(): value for header, value in info_headers.items()
}

info_response = self.deserializer.loads(
info_response, mimetype="application/json"
)
break

# Previous versions of 7.x Elasticsearch required a specific
# permission so if we receive HTTP 401/403 we should warn
# instead of erroring out.
except (AuthenticationException, AuthorizationException):
warnings.warn(
(
"The client is unable to verify that the server is "
"Elasticsearch due security privileges on the server side"
),
ElasticsearchWarning,
stacklevel=4,
)
self._verified_elasticsearch = True
return

# This connection didn't work, we'll try another.
except (ConnectionError, SerializationError) as err:
if error is None:
error = err

# If we received a connection error and weren't successful
# anywhere then we reraise the more appropriate error.
if error and not info_response:
raise error

# Check the information we got back from the index request.
self._verified_elasticsearch = _verify_elasticsearch(
info_headers, info_response
)


class Once:
"""Simple class which forces an async function to only execute once."""

def __init__(self):
self._lock = asyncio.Lock()
self._called = False

async def call(self, func, *args, **kwargs):
async with self._lock:
if not self._called:
self._called = True
await func(*args, **kwargs)
11 changes: 11 additions & 0 deletions elasticsearch/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def to_bytes(x, encoding="ascii"):
except (ImportError, AttributeError):
pass

try:
from threading import Lock
except ImportError: # Python <3.7 isn't guaranteed to have threading support.

class Lock:
def __enter__(self):
pass

def __exit__(self, *_):
pass


__all__ = [
"string_types",
Expand Down
6 changes: 6 additions & 0 deletions elasticsearch/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class SerializationError(ElasticsearchException):
"""


class NotElasticsearchError(ElasticsearchException):
"""Error which is raised when the client detects
it's not connected to an Elasticsearch cluster.
"""


class TransportError(ElasticsearchException):
"""
Exception raised when ES returns a non-OK (>=400) HTTP status code. Or when
Expand Down
1 change: 1 addition & 0 deletions elasticsearch/exceptions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ from typing import Any, Dict, Union
class ImproperlyConfigured(Exception): ...
class ElasticsearchException(Exception): ...
class SerializationError(ElasticsearchException): ...
class NotElasticsearchError(ElasticsearchException): ...

class TransportError(ElasticsearchException):
@property
Expand Down
166 changes: 166 additions & 0 deletions elasticsearch/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,23 @@
# specific language governing permissions and limitations
# under the License.

import re
import time
import warnings
from itertools import chain
from platform import python_version

from ._version import __versionstr__
from .compat import Lock
from .connection import Urllib3HttpConnection
from .connection_pool import ConnectionPool, DummyConnectionPool, EmptyConnectionPool
from .exceptions import (
AuthenticationException,
AuthorizationException,
ConnectionError,
ConnectionTimeout,
ElasticsearchWarning,
NotElasticsearchError,
SerializationError,
TransportError,
)
Expand Down Expand Up @@ -198,6 +205,23 @@ def __init__(
if http_client_meta:
self._client_meta += (http_client_meta,)

# Tri-state flag that describes what state the verification
# of whether we're connected to an Elasticsearch cluster or not.
# The three states are:
# - 'None': Means we've either not started the verification process
# or that the verification is in progress. '_verified_once' ensures
# that multiple requests don't kick off multiple verification processes.
# - 'True': Means we've verified that we're talking to Elasticsearch or
# that we can't rule out Elasticsearch due to auth issues. A warning
# will be raised if we receive 401/403.
# - 'False': Means we've discovered we're not talking to Elasticsearch,
# should raise an error in this case for every request.
self._verified_elasticsearch = None

# Ensures that the ES verification request only fires once and that
# all requests block until this request returns back.
self._verified_once = Once()

def add_connection(self, host):
"""
Create a new :class:`~elasticsearch.Connection` instance and add it to the pool.
Expand Down Expand Up @@ -380,6 +404,19 @@ def perform_request(self, method, url, headers=None, params=None, body=None):
method, headers, params, body
)

# Before we make the actual API call we verify the Elasticsearch instance.
if self._verified_elasticsearch is None:
self._verified_once.call(
self._do_verify_elasticsearch, headers=headers, timeout=timeout
)

# If '_verified_elasticsearch' is False we know we're not connected to Elasticsearch.
if self._verified_elasticsearch is False:
raise NotElasticsearchError(
"The client noticed that the server is not Elasticsearch "
"and we do not support this unknown product"
)

for attempt in range(self.max_retries + 1):
connection = self.get_connection()

Expand Down Expand Up @@ -488,3 +525,132 @@ def _resolve_request_args(self, method, headers, params, body):
)

return method, headers, params, body, ignore, timeout

def _do_verify_elasticsearch(self, headers, timeout):
"""Verifies that we're connected to an Elasticsearch cluster.
This is done at least once before the first actual API call
and makes a single request to the 'GET /' API endpoint to
check the version along with other details of the response.
If we're unable to verify we're talking to Elasticsearch
but we're also unable to rule it out due to a permission
error we instead emit an 'ElasticsearchWarning'.
"""
# Product check has already been done, no need to do again.
if self._verified_elasticsearch is not None:
return

headers = {header.lower(): value for header, value in (headers or {}).items()}
# We know we definitely want JSON so request it via 'accept'
headers.setdefault("accept", "application/json")

info_headers = {}
info_response = {}
error = None

for conn in chain(self.connection_pool.connections, self.seed_connections):
try:
_, info_headers, info_response = conn.perform_request(
"GET", "/", headers=headers, timeout=timeout
)

# Lowercase all the header names for consistency in accessing them.
info_headers = {
header.lower(): value for header, value in info_headers.items()
}

info_response = self.deserializer.loads(
info_response, mimetype="application/json"
)
break

# Previous versions of 7.x Elasticsearch required a specific
# permission so if we receive HTTP 401/403 we should warn
# instead of erroring out.
except (AuthenticationException, AuthorizationException):
warnings.warn(
(
"The client is unable to verify that the server is "
"Elasticsearch due security privileges on the server side"
),
ElasticsearchWarning,
stacklevel=5,
)
self._verified_elasticsearch = True
return

# This connection didn't work, we'll try another.
except (ConnectionError, SerializationError) as err:
if error is None:
error = err

# If we received a connection error and weren't successful
# anywhere then we reraise the more appropriate error.
if error and not info_response:
raise error

# Check the information we got back from the index request.
self._verified_elasticsearch = _verify_elasticsearch(
info_headers, info_response
)


def _verify_elasticsearch(headers, response):
"""Verifies that the server we're talking to is Elasticsearch.
Does this by checking HTTP headers and the deserialized
response to the 'info' API. Returns 'True' if we're verified
against Elasticsearch, 'False' otherwise.
"""
try:
version = response.get("version", {})
version_number = tuple(
int(x) if x is not None else 999
for x in re.search(
r"^([0-9]+)\.([0-9]+)(?:\.([0-9]+))?", version["number"]
).groups()
)
except (KeyError, TypeError, ValueError, AttributeError):
# No valid 'version.number' field, effectively 0.0.0
version = {}
version_number = (0, 0, 0)

# Check all of the fields and headers for missing/valid values.
try:
bad_tagline = response.get("tagline", None) != "You Know, for Search"
bad_build_flavor = version.get("build_flavor", None) != "default"
bad_product_header = headers.get("x-elastic-product", None) != "Elasticsearch"
except (AttributeError, TypeError):
bad_tagline = True
bad_build_flavor = True
bad_product_header = True

if (
# No version or version less than 6.x
version_number < (6, 0, 0)
# 6.x and there's a bad 'tagline'
or ((6, 0, 0) <= version_number < (7, 0, 0) and bad_tagline)
# 7.0-7.13 and there's a bad 'tagline' or 'build_flavor'
or (
(7, 0, 0) <= version_number < (7, 14, 0)
and (bad_tagline or bad_build_flavor)
)
# 7.14+ and there's a bad 'X-Elastic-Product' HTTP header
or ((7, 14, 0) <= version_number and bad_product_header)
):
return False

return True


class Once:
"""Simple class which forces a function to only execute once."""

def __init__(self):
self._lock = Lock()
self._called = False

def call(self, func, *args, **kwargs):
with self._lock:
if not self._called:
self._called = True
func(*args, **kwargs)
8 changes: 7 additions & 1 deletion test_elasticsearch/test_async/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from mock import patch
from multidict import CIMultiDict

from elasticsearch import AIOHttpConnection, __versionstr__
from elasticsearch import AIOHttpConnection, AsyncElasticsearch, __versionstr__
from elasticsearch.compat import reraise_exceptions
from elasticsearch.exceptions import ConnectionError

Expand Down Expand Up @@ -410,3 +410,9 @@ async def test_aiohttp_connection_error(self):
conn = AIOHttpConnection("not.a.host.name")
with pytest.raises(ConnectionError):
await conn.perform_request("GET", "/")

async def test_elasticsearch_connection_error(self):
es = AsyncElasticsearch("http://not.a.host.name")

with pytest.raises(ConnectionError):
await es.search()

0 comments on commit 801a839

Please sign in to comment.