Skip to content

Commit

Permalink
Backport aiohttp conditional HEAD bug workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
pquentin committed Nov 24, 2022
1 parent ebad38c commit 9b0d43b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
16 changes: 13 additions & 3 deletions elasticsearch/_async/http_aiohttp.py
Expand Up @@ -17,6 +17,7 @@

import asyncio
import os
import re
import ssl
import warnings

Expand Down Expand Up @@ -49,6 +50,15 @@
except ImportError:
pass

_version_parts = []
for _version_part in aiohttp.__version__.split(".")[:3]:
try:
_version_parts.append(int(re.search(r"^([0-9]+)", _version_part).group(1))) # type: ignore[union-attr]
except (AttributeError, ValueError):
break
_AIOHTTP_SEMVER_VERSION = tuple(_version_parts)
_AIOHTTP_FIXED_HEAD_BUG = _AIOHTTP_SEMVER_VERSION >= (3, 7, 0)


class AsyncConnection(Connection):
"""Base class for Async HTTP connection implementations"""
Expand Down Expand Up @@ -247,11 +257,11 @@ async def perform_request(
query_string = ""
url_target = url_path

# There is a bug in aiohttp that disables the re-use
is_head = False
# There is a bug in aiohttp<3.7 that disables the re-use
# of the connection in the pool when method=HEAD.
# See: aio-libs/aiohttp#1769
is_head = False
if method == "HEAD":
if method == "HEAD" and not _AIOHTTP_FIXED_HEAD_BUG:
method = "GET"
is_head = True

Expand Down
19 changes: 19 additions & 0 deletions test_elasticsearch/test_async/test_connection.py
Expand Up @@ -29,6 +29,7 @@
from mock import patch
from multidict import CIMultiDict

import elasticsearch._async.http_aiohttp
from elasticsearch import AIOHttpConnection, AsyncElasticsearch, __versionstr__
from elasticsearch.compat import reraise_exceptions
from elasticsearch.exceptions import ConnectionError, NotFoundError
Expand Down Expand Up @@ -56,6 +57,9 @@ async def __aenter__(self, *_, **__):
async def __aexit__(self, *_, **__):
pass

async def release(self):
pass

async def text(self):
return response_body.decode("utf-8", "surrogatepass")

Expand Down Expand Up @@ -421,6 +425,21 @@ def request_raise(*_, **__):
await conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"

@pytest.mark.parametrize("aiohttp_fixed_head_bug", [True, False])
async def test_head_workaround(self, aiohttp_fixed_head_bug, monkeypatch):
monkeypatch.setattr(
elasticsearch._async.http_aiohttp,
"_AIOHTTP_FIXED_HEAD_BUG",
aiohttp_fixed_head_bug,
)

con = await self._get_mock_connection()
await con.perform_request("HEAD", "/anything")

method, url = con.session.request.call_args[0]
assert method == "HEAD" if aiohttp_fixed_head_bug else "GET"
assert url.human_repr() == "http://localhost:9200/anything"


class TestConnectionHttpbin:
"""Tests the HTTP connection implementations against a live server E2E"""
Expand Down

0 comments on commit 9b0d43b

Please sign in to comment.