From 4b0ee6f79bc12429f1baea098a75e04a8801409c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 2 Aug 2020 12:33:31 +0100 Subject: [PATCH] Add an `UnsupportedProtocol` exception (#128) * Add an UnsupportedProtocol exception * Update httpcore/_async/connection_pool.py Co-authored-by: Florimond Manca * Update tests/async_tests/test_interfaces.py Co-authored-by: Florimond Manca * Update tests/async_tests/test_interfaces.py Co-authored-by: Florimond Manca * Unasync Co-authored-by: Florimond Manca --- httpcore/__init__.py | 2 ++ httpcore/_async/connection_pool.py | 7 +++++-- httpcore/_exceptions.py | 4 ++++ httpcore/_sync/connection_pool.py | 7 +++++-- tests/async_tests/test_interfaces.py | 10 ++++++++++ tests/sync_tests/test_interfaces.py | 10 ++++++++++ 6 files changed, 36 insertions(+), 4 deletions(-) diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 2cf161935..084dd30c5 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -12,6 +12,7 @@ ReadError, ReadTimeout, TimeoutException, + UnsupportedProtocol, WriteError, WriteTimeout, ) @@ -38,5 +39,6 @@ "ReadError", "WriteError", "CloseError", + "UnsupportedProtocol", ] __version__ = "0.9.1" diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index ce976608a..b90e136df 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -2,7 +2,7 @@ from typing import AsyncIterator, Callable, Dict, List, Optional, Set, Tuple from .._backends.auto import AsyncLock, AsyncSemaphore, AutoBackend -from .._exceptions import PoolTimeout +from .._exceptions import PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, origin_to_url_string, url_to_origin @@ -124,7 +124,10 @@ async def request( stream: AsyncByteStream = None, timeout: TimeoutDict = None, ) -> Tuple[bytes, int, bytes, Headers, AsyncByteStream]: - assert url[0] in (b"http", b"https") + if url[0] not in (b"http", b"https"): + scheme = url[0].decode("latin-1") + raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") + origin = url_to_origin(url) if self._keepalive_expiry is not None: diff --git a/httpcore/_exceptions.py b/httpcore/_exceptions.py index b6ab2b2f9..269132393 100644 --- a/httpcore/_exceptions.py +++ b/httpcore/_exceptions.py @@ -13,6 +13,10 @@ def map_exceptions(map: Dict[Type[Exception], Type[Exception]]) -> Iterator[None raise +class UnsupportedProtocol(Exception): + pass + + class ProtocolError(Exception): pass diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 19269f8f8..44a14b6db 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -2,7 +2,7 @@ from typing import Iterator, Callable, Dict, List, Optional, Set, Tuple from .._backends.auto import SyncLock, SyncSemaphore, SyncBackend -from .._exceptions import PoolTimeout +from .._exceptions import PoolTimeout, UnsupportedProtocol from .._threadlock import ThreadLock from .._types import URL, Headers, Origin, TimeoutDict from .._utils import get_logger, origin_to_url_string, url_to_origin @@ -124,7 +124,10 @@ def request( stream: SyncByteStream = None, timeout: TimeoutDict = None, ) -> Tuple[bytes, int, bytes, Headers, SyncByteStream]: - assert url[0] in (b"http", b"https") + if url[0] not in (b"http", b"https"): + scheme = url[0].decode("latin-1") + raise UnsupportedProtocol(f"Unsupported URL protocol {scheme!r}") + origin = url_to_origin(url) if self._keepalive_expiry is not None: diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py index 14f1bae10..e144343c8 100644 --- a/tests/async_tests/test_interfaces.py +++ b/tests/async_tests/test_interfaces.py @@ -50,6 +50,16 @@ async def test_https_request() -> None: assert len(http._connections[url[:3]]) == 1 # type: ignore +@pytest.mark.usefixtures("async_environment") +async def test_request_unsupported_protocol() -> None: + async with httpcore.AsyncConnectionPool() as http: + method = b"GET" + url = (b"ftp", b"example.org", 443, b"/") + headers = [(b"host", b"example.org")] + with pytest.raises(httpcore.UnsupportedProtocol): + await http.request(method, url, headers) + + @pytest.mark.usefixtures("async_environment") async def test_http2_request() -> None: async with httpcore.AsyncConnectionPool(http2=True) as http: diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py index 9ba369ed4..23cd16932 100644 --- a/tests/sync_tests/test_interfaces.py +++ b/tests/sync_tests/test_interfaces.py @@ -51,6 +51,16 @@ def test_https_request() -> None: +def test_request_unsupported_protocol() -> None: + with httpcore.SyncConnectionPool() as http: + method = b"GET" + url = (b"ftp", b"example.org", 443, b"/") + headers = [(b"host", b"example.org")] + with pytest.raises(httpcore.UnsupportedProtocol): + http.request(method, url, headers) + + + def test_http2_request() -> None: with httpcore.SyncConnectionPool(http2=True) as http: method = b"GET"