From b53705b1e8000ead53bdb582705928fdc5932fbd Mon Sep 17 00:00:00 2001 From: Travis Hathaway Date: Thu, 24 Aug 2023 16:17:43 -0700 Subject: [PATCH] Auth handler plugin hook (#12911) * replace CondaSession with get_session function * fixing some issues with the name of the context object property for channel settings --------- Co-authored-by: Jannis Leidel Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bianca Henderson --- conda/gateways/connection/download.py | 6 +- conda/gateways/connection/session.py | 84 ++++++++- conda/gateways/repodata/__init__.py | 4 +- conda/gateways/repodata/jlap/interface.py | 6 +- conda/notices/fetch.py | 4 +- conda/plugins/__init__.py | 1 + conda/plugins/hookspec.py | 38 ++++ conda/plugins/manager.py | 13 ++ conda/plugins/types.py | 39 +++++ conda/testing/notices/fixtures.py | 9 +- conda/testing/notices/helpers.py | 2 +- conda/trust/signature_verification.py | 14 +- .../dev-guide/plugins/auth_handlers.rst | 27 +++ docs/source/dev-guide/plugins/index.rst | 1 + news/12911-auth-handler-plugin-hook | 19 ++ tests/cli/test_main_notices.py | 34 ++-- tests/gateways/test_connection.py | 164 +++++++++++++++++- tests/gateways/test_jlap.py | 20 ++- tests/notices/test_core.py | 23 +-- tests/notices/test_fetch.py | 8 +- tests/plugins/test_auth_handlers.py | 81 +++++++++ ..._post_command.py => test_post_commands.py} | 0 ...st_pre_command.py => test_pre_commands.py} | 0 23 files changed, 525 insertions(+), 72 deletions(-) create mode 100644 docs/source/dev-guide/plugins/auth_handlers.rst create mode 100644 news/12911-auth-handler-plugin-hook create mode 100644 tests/plugins/test_auth_handlers.py rename tests/plugins/{test_post_command.py => test_post_commands.py} (100%) rename tests/plugins/{test_pre_command.py => test_pre_commands.py} (100%) diff --git a/conda/gateways/connection/download.py b/conda/gateways/connection/download.py index 5616161f20b..dc1e52463e0 100644 --- a/conda/gateways/connection/download.py +++ b/conda/gateways/connection/download.py @@ -30,7 +30,7 @@ RequestsProxyError, SSLError, ) -from .session import CondaSession +from .session import get_session log = getLogger(__name__) @@ -55,7 +55,7 @@ def download( try: timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs - session = CondaSession() + session = get_session(url) resp = session.get(url, stream=True, proxies=session.proxies, timeout=timeout) if log.isEnabledFor(DEBUG): log.debug(stringify(resp, content_max_len=256)) @@ -214,7 +214,7 @@ def download_text(url): disable_ssl_verify_warning() try: timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs - session = CondaSession() + session = get_session(url) response = session.get( url, stream=True, proxies=session.proxies, timeout=timeout ) diff --git a/conda/gateways/connection/session.py b/conda/gateways/connection/session.py index a0bba8911f9..48f7ddc17bd 100644 --- a/conda/gateways/connection/session.py +++ b/conda/gateways/connection/session.py @@ -1,6 +1,9 @@ # Copyright (C) 2012 Anaconda, Inc # SPDX-License-Identifier: BSD-3-Clause """Requests session configured with all accepted scheme adapters.""" +from __future__ import annotations + +from functools import lru_cache from logging import getLogger from threading import local @@ -14,6 +17,7 @@ urlparse, ) from ...exceptions import ProxyError +from ...models.channel import Channel from ..anaconda_client import read_binstar_tokens from . import ( AuthBase, @@ -60,6 +64,61 @@ def close(self): raise NotImplementedError() +def get_channel_name_from_url(url: str) -> str | None: + """ + Given a URL, determine the channel it belongs to and return its name. + """ + return Channel.from_url(url).canonical_name + + +@lru_cache(maxsize=None) +def get_session(url: str): + """ + Function that determines the correct Session object to be returned + based on the URL that is passed in. + """ + channel_name = get_channel_name_from_url(url) + + # If for whatever reason a channel name can't be determined, (should be unlikely) + # we just return the default session object. + if channel_name is None: + return CondaSession() + + # We ensure here if there are duplicates defined, we choose the last one + channel_settings = {} + for settings in context.channel_settings: + if settings.get("channel") == channel_name: + channel_settings = settings + + auth_handler = channel_settings.get("auth", "").strip() or None + + # Return default session object + if auth_handler is None: + return CondaSession() + + auth_handler_cls = context.plugin_manager.get_auth_handler(auth_handler) + + if not auth_handler_cls: + return CondaSession() + + return CondaSession(auth=auth_handler_cls(channel_name)) + + +def get_session_storage_key(auth) -> str: + """ + Function that determines which storage key to use for our CondaSession object caching + """ + if auth is None: + return "default" + + if isinstance(auth, tuple): + return hash(auth) + + auth_type = type(auth) + + return f"{auth_type.__module__}.{auth_type.__qualname__}::{auth.channel_name}" + + class CondaSessionType(type): """ Takes advice from https://github.com/requests/requests/issues/1871#issuecomment-33327847 @@ -70,21 +129,30 @@ def __new__(mcs, name, bases, dct): dct["_thread_local"] = local() return super().__new__(mcs, name, bases, dct) - def __call__(cls): + def __call__(cls, **kwargs): + storage_key = get_session_storage_key(kwargs.get("auth")) + try: - return cls._thread_local.session + return cls._thread_local.sessions[storage_key] except AttributeError: - session = cls._thread_local.session = super().__call__() - return session + session = super().__call__(**kwargs) + cls._thread_local.sessions = {storage_key: session} + except KeyError: + session = cls._thread_local.sessions[storage_key] = super().__call__( + **kwargs + ) + + return session class CondaSession(Session, metaclass=CondaSessionType): - def __init__(self): + def __init__(self, auth: AuthBase | tuple[str, str] | None = None): + """ + :param auth: Optionally provide ``requests.AuthBase`` compliant objects + """ super().__init__() - self.auth = ( - CondaHttpAuth() - ) # TODO: should this just be for certain protocol adapters? + self.auth = auth or CondaHttpAuth() self.proxies.update(context.proxy_servers) diff --git a/conda/gateways/repodata/__init__.py b/conda/gateways/repodata/__init__.py index 8b750ba2779..f450aa2a437 100644 --- a/conda/gateways/repodata/__init__.py +++ b/conda/gateways/repodata/__init__.py @@ -45,7 +45,7 @@ Response, SSLError, ) -from ..connection.session import CondaSession +from ..connection.session import get_session from ..disk import mkdir_p_sudo_safe from .lock import lock @@ -127,7 +127,7 @@ def repodata(self, state: RepodataState) -> str | None: if not context.ssl_verify: warnings.simplefilter("ignore", InsecureRequestWarning) - session = CondaSession() + session = get_session(self._url) headers = {} etag = state.etag diff --git a/conda/gateways/repodata/jlap/interface.py b/conda/gateways/repodata/jlap/interface.py index a5a4306f171..063e01aa869 100644 --- a/conda/gateways/repodata/jlap/interface.py +++ b/conda/gateways/repodata/jlap/interface.py @@ -8,7 +8,7 @@ from conda.base.context import context from conda.gateways.connection.download import disable_ssl_verify_warning -from conda.gateways.connection.session import CondaSession +from conda.gateways.connection.session import get_session from .. import ( CACHE_CONTROL_KEY, @@ -64,11 +64,11 @@ def repodata_parsed(self, state: dict | RepodataState) -> dict | None: When repodata is not updated, it doesn't matter whether this function or the caller reads from a file. """ + session = get_session(self._url) + if not context.ssl_verify: disable_ssl_verify_warning() - session = CondaSession() - repodata_url = f"{self._url}/{self._repodata_fn}" # XXX won't modify caller's state dict diff --git a/conda/notices/fetch.py b/conda/notices/fetch.py index 7b620a5c085..c562d05f166 100644 --- a/conda/notices/fetch.py +++ b/conda/notices/fetch.py @@ -10,7 +10,7 @@ import requests from ..common.io import Spinner -from ..gateways.connection.session import CondaSession +from ..gateways.connection.session import get_session from .cache import cached_response from .types import ChannelNoticeResponse @@ -56,7 +56,7 @@ def get_channel_notice_response(url: str, name: str) -> ChannelNoticeResponse | additional channel information to use. If the response was invalid we suppress/log and error message. """ - session = CondaSession() + session = get_session(url) try: resp = session.get( url, allow_redirects=False, timeout=5 diff --git a/conda/plugins/__init__.py b/conda/plugins/__init__.py index 6c15fe21a45..ca021133641 100644 --- a/conda/plugins/__init__.py +++ b/conda/plugins/__init__.py @@ -27,6 +27,7 @@ from .hookspec import hookimpl # noqa: F401 from .types import ( # noqa: F401 + CondaAuthHandler, CondaPostCommand, CondaPreCommand, CondaSolver, diff --git a/conda/plugins/hookspec.py b/conda/plugins/hookspec.py index fb9459bf456..ad9978fccf4 100644 --- a/conda/plugins/hookspec.py +++ b/conda/plugins/hookspec.py @@ -14,6 +14,7 @@ import pluggy from .types import ( + CondaAuthHandler, CondaPostCommand, CondaPreCommand, CondaSolver, @@ -170,3 +171,40 @@ def conda_post_commands(): run_for={"install", "create"}, ) """ + + @_hookspec + def conda_auth_handlers(self) -> Iterable[CondaAuthHandler]: + """ + Register a conda auth handler derived from the requests API. + + This plugin hook allows attaching requests auth handler subclasses, + e.g. when authenticating requests against individual channels hosted + at HTTP/HTTPS services. + + **Example:** + + .. code-block:: python + + import os + from conda import plugins + from requests.auth import AuthBase + + + class EnvironmentHeaderAuth(AuthBase): + def __init__(self, *args, **kwargs): + self.username = os.environ["EXAMPLE_CONDA_AUTH_USERNAME"] + self.password = os.environ["EXAMPLE_CONDA_AUTH_PASSWORD"] + + def __call__(self, request): + request.headers["X-Username"] = self.username + request.headers["X-Password"] = self.password + return request + + + @plugins.hookimpl + def conda_auth_handlers(): + yield plugins.CondaAuthHandler( + name="environment-header-auth", + auth_handler=EnvironmentHeaderAuth, + ) + """ diff --git a/conda/plugins/manager.py b/conda/plugins/manager.py index 9ef94c3ca58..7b194b07150 100644 --- a/conda/plugins/manager.py +++ b/conda/plugins/manager.py @@ -15,6 +15,7 @@ from inspect import getmodule, isclass import pluggy +from requests.auth import AuthBase from ..auxlib.ish import dals from ..base.context import context @@ -202,6 +203,18 @@ def get_solver_backend(self, name: str | None = None) -> type[Solver]: return backend + def get_auth_handler(self, name: str) -> type[AuthBase] | None: + """ + Get the auth handler with the given name or None + """ + auth_handlers = self.get_hook_results("auth_handlers") + matches = tuple( + item for item in auth_handlers if item.name.lower() == name.lower().strip() + ) + + if len(matches) > 0: + return matches[0].handler + def invoke_pre_commands(self, command: str) -> None: """ Invokes ``CondaPreCommand.action`` functions registered with ``conda_pre_commands``. diff --git a/conda/plugins/types.py b/conda/plugins/types.py index a17dcc53e71..c85a9faf8cd 100644 --- a/conda/plugins/types.py +++ b/conda/plugins/types.py @@ -12,6 +12,8 @@ from dataclasses import dataclass, field from typing import Callable, NamedTuple +from requests.auth import AuthBase + from ..core.solve import Solver @@ -102,3 +104,40 @@ class CondaPostCommand(NamedTuple): name: str action: Callable[[str], None] run_for: set[str] + + +class ChannelNameMixin: + """ + Class mixin to make all plugin implementations compatible, e.g. when they + use an existing (e.g. 3rd party) requests authentication handler. + + Please use the concrete :class:`~conda.plugins.types.ChannelAuthBase` + in case you're creating an own implementation. + """ + + def __init__(self, channel_name: str, *args, **kwargs): + self.channel_name = channel_name + super().__init__(*args, **kwargs) + + +class ChannelAuthBase(ChannelNameMixin, AuthBase): + """ + Base class that we require all plugin implementations to use to be compatible. + + Authentication is tightly coupled with individual channels. Therefore, an additional + ``channel_name`` property must be set on the ``requests.auth.AuthBase`` based class. + """ + + +class CondaAuthHandler(NamedTuple): + """ + Return type to use when the defining the conda auth handlers hook. + + :param name: Name (e.g., ``basic-auth``). This name should be unique + and only one may be registered at a time. + :param handler: Type that will be used as the authentication handler + during network requests. + """ + + name: str + handler: type[ChannelAuthBase] diff --git a/conda/testing/notices/fixtures.py b/conda/testing/notices/fixtures.py index 42eb9f8e9cb..98a3108b76c 100644 --- a/conda/testing/notices/fixtures.py +++ b/conda/testing/notices/fixtures.py @@ -25,11 +25,10 @@ def notices_cache_dir(tmpdir): @pytest.fixture(scope="function") -def notices_mock_http_session_get(): - with mock.patch( - "conda.gateways.connection.session.CondaSession.get" - ) as session_get: - yield session_get +def notices_mock_fetch_get_session(): + with mock.patch("conda.notices.fetch.get_session") as mock_get_session: + mock_get_session.return_value = mock.MagicMock() + yield mock_get_session @pytest.fixture(scope="function") diff --git a/conda/testing/notices/helpers.py b/conda/testing/notices/helpers.py index 1ed79aeee21..a3d6cb1d49d 100644 --- a/conda/testing/notices/helpers.py +++ b/conda/testing/notices/helpers.py @@ -60,7 +60,7 @@ def one_200(): yield MockResponse(status_code, messages_json, raise_exc=raise_exc) chn = chain(one_200(), forever_404()) - mock_session.side_effect = tuple(next(chn) for _ in range(100)) + mock_session().get.side_effect = tuple(next(chn) for _ in range(100)) def create_notice_cache_files( diff --git a/conda/trust/signature_verification.py b/conda/trust/signature_verification.py index 76394723f44..4925c2388f8 100644 --- a/conda/trust/signature_verification.py +++ b/conda/trust/signature_verification.py @@ -26,7 +26,7 @@ class SignatureError(Exception): from ..base.context import context from ..common.url import join_url from ..gateways.connection import HTTPError, InsecureRequestWarning -from ..gateways.connection.session import CondaSession +from ..gateways.connection.session import get_session from .constants import INITIAL_TRUST_ROOT, KEY_MGR_FILE log = getLogger(__name__) @@ -171,15 +171,11 @@ def key_mgr(self): return trusted - # FUTURE: Python 3.8+, replace with functools.cached_property - @property - @lru_cache(maxsize=None) - def session(self): - return CondaSession() - def _fetch_channel_signing_data( self, signing_data_url, filename, etag=None, mod_stamp=None ): + session = get_session(signing_data_url) + if not context.ssl_verify: warnings.simplefilter("ignore", InsecureRequestWarning) @@ -200,10 +196,10 @@ def _fetch_channel_signing_data( # # TODO: Figure how to handle authn for obtaining trust metadata, # independently of the authn used to access package repositories. - resp = self.session.get( + resp = session.get( join_url(signing_data_url, filename), headers=headers, - proxies=self.session.proxies, + proxies=session.proxies, auth=lambda r: r, timeout=( context.remote_connect_timeout_secs, diff --git a/docs/source/dev-guide/plugins/auth_handlers.rst b/docs/source/dev-guide/plugins/auth_handlers.rst new file mode 100644 index 00000000000..2f6ec976bb2 --- /dev/null +++ b/docs/source/dev-guide/plugins/auth_handlers.rst @@ -0,0 +1,27 @@ +============= +Auth Handlers +============= + +The auth handlers plugin hook allows plugin authors to enable new modes +of authentication within conda. Registered auth handlers will be +available to configure on a per channel basis via the ``channel_settings`` +configuration option in the ``.condarc`` file. + +Auth handlers are subclasses of the :class:`~conda.plugins.types.ChannelAuthBase` class, +which is itself a subclass of `requests.auth.AuthBase`_. +The :class:`~conda.plugins.types.ChannelAuthBase` class adds an additional ``channel_name`` +property to the `requests.auth.AuthBase`_ class. This is necessary for appropriate handling of +channel based authentication in conda. + +For more information on how to implement your own auth handlers, please read the requests +documentation on `Custom Authentication`_. + + +.. autoapiclass:: conda.plugins.types.CondaAuthHandler + :members: + :undoc-members: + +.. autoapifunction:: conda.plugins.hookspec.CondaSpecs.conda_auth_handlers + +.. _requests.auth.AuthBase: https://docs.python-requests.org/en/latest/api/#requests.auth.AuthBase +.. _Custom Authentication: https://docs.python-requests.org/en/latest/user/advanced/#custom-authentication diff --git a/docs/source/dev-guide/plugins/index.rst b/docs/source/dev-guide/plugins/index.rst index 1122fd58960..71affca4b14 100644 --- a/docs/source/dev-guide/plugins/index.rst +++ b/docs/source/dev-guide/plugins/index.rst @@ -97,6 +97,7 @@ For examples of how to use other plugin hooks, please read their respective docu .. toctree:: :maxdepth: 1 + auth_handlers post_commands pre_commands solvers diff --git a/news/12911-auth-handler-plugin-hook b/news/12911-auth-handler-plugin-hook new file mode 100644 index 00000000000..291029612a7 --- /dev/null +++ b/news/12911-auth-handler-plugin-hook @@ -0,0 +1,19 @@ +### Enhancements + +* Adds a new "auth handler" plugin hook for conda (#12911) + +### Bug fixes + +* + +### Deprecations + +* + +### Docs + +* + +### Other + +* diff --git a/tests/cli/test_main_notices.py b/tests/cli/test_main_notices.py index 27d1bc07483..d11315139ad 100644 --- a/tests/cli/test_main_notices.py +++ b/tests/cli/test_main_notices.py @@ -42,7 +42,7 @@ def test_main_notices( capsys, conda_notices_args_n_parser, notices_cache_dir, - notices_mock_http_session_get, + notices_mock_fetch_get_session, ): """ Test the full working path through the code. We vary the test based on the status code @@ -54,7 +54,7 @@ def test_main_notices( args, parser = conda_notices_args_n_parser messages = ("Test One", "Test Two") messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, status_code, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, status_code, messages_json) notices.execute(args, parser) @@ -74,7 +74,7 @@ def test_main_notices_reads_from_cache( capsys, conda_notices_args_n_parser, notices_cache_dir, - notices_mock_http_session_get, + notices_mock_fetch_get_session, ): """ Test the full working path through the code when reading from cache instead of making @@ -105,7 +105,7 @@ def test_main_notices_reads_from_expired_cache( capsys, conda_notices_args_n_parser, notices_cache_dir, - notices_mock_http_session_get, + notices_mock_fetch_get_session, ): """ Test the full working path through the code when reading from cache instead of making @@ -133,7 +133,7 @@ def test_main_notices_reads_from_expired_cache( # different messages messages_different_json = get_test_notices(messages_different) add_resp_to_mock( - notices_mock_http_session_get, + notices_mock_fetch_get_session, status_code=200, messages_json=messages_different_json, ) @@ -153,7 +153,7 @@ def test_main_notices_handles_bad_expired_at_field( capsys, conda_notices_args_n_parser, notices_cache_dir, - notices_mock_http_session_get, + notices_mock_fetch_get_session, ): """ This test ensures that an incorrectly defined `notices.json` file doesn't completely break @@ -177,7 +177,9 @@ def test_main_notices_handles_bad_expired_at_field( ] } add_resp_to_mock( - notices_mock_http_session_get, status_code=200, messages_json=bad_notices_json + notices_mock_fetch_get_session, + status_code=200, + messages_json=bad_notices_json, ) create_notice_cache_files(notices_cache_dir, [cache_file], [bad_notices_json]) @@ -213,7 +215,7 @@ def test_cache_names_appear_as_expected( capsys, conda_notices_args_n_parser, notices_cache_dir, - notices_mock_http_session_get, + notices_mock_fetch_get_session, ): """This is a test to make sure the cache filenames appear as we expect them to.""" with mock.patch( @@ -228,7 +230,7 @@ def test_cache_names_appear_as_expected( args, parser = conda_notices_args_n_parser messages = ("Test One", "Test Two") messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, 200, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, 200, messages_json) notices.execute(args, parser) @@ -297,24 +299,26 @@ def test_notices_appear_once_when_running_decorated_commands( fetch_mock.assert_not_called() -def test_notices_work_with_s3_channel(notices_cache_dir, notices_mock_http_session_get): +def test_notices_work_with_s3_channel( + notices_cache_dir, notices_mock_fetch_get_session +): """As a user, I want notices to be correctly retrieved from channels with s3 URLs.""" s3_channel = "s3://conda-org" messages = ("Test One", "Test Two") messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, 200, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, 200, messages_json) run(f"conda notices -c {s3_channel} --override-channels") - notices_mock_http_session_get.assert_called_once() - args, kwargs = notices_mock_http_session_get.call_args + notices_mock_fetch_get_session().get.assert_called_once() + args, kwargs = notices_mock_fetch_get_session().get.call_args arg_1, *_ = args assert arg_1 == "s3://conda-org/notices.json" def test_notices_does_not_interrupt_command_on_failure( - notices_cache_dir, notices_mock_http_session_get + notices_cache_dir, notices_mock_fetch_get_session ): """ As a user, when I run conda in an environment where notice cache files might not be readable or @@ -343,7 +347,7 @@ def test_notices_does_not_interrupt_command_on_failure( def test_notices_cannot_read_cache_files( - notices_cache_dir, notices_mock_http_session_get + notices_cache_dir, notices_mock_fetch_get_session ): """ As a user, when I run `conda notices` and the cache file cannot be read or written, I want diff --git a/tests/gateways/test_connection.py b/tests/gateways/test_connection.py index 68e8f77d68e..e63e59a666f 100644 --- a/tests/gateways/test_connection.py +++ b/tests/gateways/test_connection.py @@ -8,12 +8,20 @@ from requests import HTTPError from conda.auxlib.compat import Utf8NamedTemporaryFile +from conda.base.context import reset_context from conda.common.compat import ensure_binary from conda.common.url import path_to_url from conda.exceptions import CondaExitZero from conda.gateways.anaconda_client import remove_binstar_token, set_binstar_token -from conda.gateways.connection.session import CondaHttpAuth, CondaSession +from conda.gateways.connection.session import ( + CondaHttpAuth, + CondaSession, + get_channel_name_from_url, + get_session, + get_session_storage_key, +) from conda.gateways.disk.delete import rm_rf +from conda.plugins.types import ChannelAuthBase from conda.testing.gateways.fixtures import MINIO_EXE from conda.testing.integration import env_var, make_temp_env @@ -110,3 +118,157 @@ def test_s3_server(minio_s3_server): ): # we just want to run make_temp_env and cleanup after pass + + +def test_get_session_returns_default(): + """ + Tests to make sure that our session manager returns a regular + CondaSession object when no other session classes are registered. + """ + url = "https://localhost/test" + session_obj = get_session(url) + get_session.cache_clear() # ensuring cleanup + + assert type(session_obj) is CondaSession + + +def test_get_session_with_channel_settings(mocker): + """ + Tests to make sure the get_session function works when ``channel_settings`` + have been set on the context object. + """ + mocker.patch( + "conda.gateways.connection.session.get_channel_name_from_url", + return_value="defaults", + ) + mock_context = mocker.patch("conda.gateways.connection.session.context") + mock_context.channel_settings = ({"channel": "defaults", "auth": "dummy_one"},) + + url = "https://localhost/test1" + + session_obj = get_session(url) + get_session.cache_clear() # ensuring cleanup + + assert type(session_obj) is CondaSession + + # For session objects with a custom auth handler it will not be set to CondaHttpAuth + assert type(session_obj.auth) is not CondaHttpAuth + + # Make sure we tried to retrieve our auth handler in this function + assert ( + mocker.call("dummy_one") + in mock_context.plugin_manager.get_auth_handler.mock_calls + ) + + +def test_get_session_with_channel_settings_multiple(mocker): + """ + Tests to make sure the get_session function works when ``channel_settings`` + have been set on the context object and there exists more than one channel + configured using the same type of auth handler. + + It's important that our cache keys are set up so that we do not return the + same CondaSession object for these two different channels. + """ + mocker.patch( + "conda.gateways.connection.session.get_channel_name_from_url", + side_effect=["channel_one", "channel_two"], + ) + mock_context = mocker.patch("conda.gateways.connection.session.context") + mock_context.channel_settings = ( + {"channel": "channel_one", "auth": "dummy_one"}, + {"channel": "channel_two", "auth": "dummy_one"}, + ) + mock_context.plugin_manager.get_auth_handler.return_value = ChannelAuthBase + + url_one = "https://localhost/test1" + url_two = "https://localhost/test2" + + session_obj_one = get_session(url_one) + session_obj_two = get_session(url_two) + + get_session.cache_clear() # ensuring cleanup + + assert session_obj_one is not session_obj_two + + storage_key_one = get_session_storage_key(session_obj_one.auth) + storage_key_two = get_session_storage_key(session_obj_two.auth) + + assert storage_key_one in session_obj_one._thread_local.sessions + assert storage_key_two in session_obj_one._thread_local.sessions + + assert type(session_obj_one) is CondaSession + assert type(session_obj_two) is CondaSession + + # For session objects with a custom auth handler it will not be set to CondaHttpAuth + assert type(session_obj_one.auth) is not CondaHttpAuth + assert type(session_obj_two.auth) is not CondaHttpAuth + + # Make sure we tried to retrieve our auth handler in this function + assert ( + mocker.call("dummy_one") + in mock_context.plugin_manager.get_auth_handler.mock_calls + ) + + +def test_get_session_with_channel_settings_no_handler(mocker): + """ + Tests to make sure the get_session function works when ``channel_settings`` + have been set on the context objet. This test does not find a matching auth + handler. + """ + mocker.patch( + "conda.gateways.connection.session.get_channel_name_from_url", + return_value="defaults", + ) + mock_context = mocker.patch("conda.gateways.connection.session.context") + mock_context.plugin_manager.get_auth_handler.return_value = None + mock_context.channel_settings = ({"channel": "defaults", "auth": "dummy_two"},) + + url = "https://localhost/test2" + + session_obj = get_session(url) + get_session.cache_clear() # ensuring cleanup + + assert type(session_obj) is CondaSession + + # For sessions without a custom auth handler, this will be the default auth handler + assert type(session_obj.auth) is CondaHttpAuth + + # Make sure we tried to retrieve our auth handler in this function + assert ( + mocker.call("dummy_two") + in mock_context.plugin_manager.get_auth_handler.mock_calls + ) + + +@pytest.mark.parametrize( + "url, channels, expected", + ( + ( + "https://repo.anaconda.com/pkgs/main/linux-64/test-package-0.1.0.conda", + ("defaults",), + "defaults", + ), + ( + "https://conda.anaconda.org/conda-forge/linux-64/test-package-0.1.0.tar.bz2", + ("conda-forge", "defaults"), + "conda-forge", + ), + ( + "http://localhost/noarch/test-package-0.1.0.conda", + ("defaults", "http://localhost"), + "http://localhost", + ), + ("http://localhost", ("defaults",), "http://localhost"), + ), +) +def test_get_channel_name_from_url(url, channels, expected, monkeypatch): + """ + Makes sure we return the correct value from the ``get_channel_name_from_url`` function. + """ + monkeypatch.setenv("CONDA_CHANNELS", ",".join(channels)) + reset_context() + channel_name = get_channel_name_from_url(url) + + assert expected == channel_name diff --git a/tests/gateways/test_jlap.py b/tests/gateways/test_jlap.py index f32891cacb0..be5f2e86af8 100644 --- a/tests/gateways/test_jlap.py +++ b/tests/gateways/test_jlap.py @@ -15,11 +15,11 @@ import requests import zstandard -from conda.base.context import conda_tests_ctxt_mgmt_def_pol, context +from conda.base.context import conda_tests_ctxt_mgmt_def_pol, context, reset_context from conda.common.io import env_vars from conda.core.subdir_data import SubdirData from conda.exceptions import CondaHTTPError, CondaSSLError -from conda.gateways.connection.session import CondaSession +from conda.gateways.connection.session import CondaSession, get_session from conda.gateways.repodata import ( CACHE_CONTROL_KEY, CACHE_STATE_SUFFIX, @@ -127,7 +127,7 @@ def test_jlap_fetch_file(package_repository_base: Path, tmp_path: Path, mocker): @pytest.mark.parametrize("verify_ssl", [True, False]) def test_jlap_fetch_ssl( - package_server_ssl: socket, tmp_path: Path, mocker, verify_ssl: bool + package_server_ssl: socket, tmp_path: Path, monkeypatch, verify_ssl: bool ): """Check that JlapRepoInterface doesn't raise exceptions.""" host, port = package_server_ssl.getsockname() @@ -148,23 +148,25 @@ def test_jlap_fetch_ssl( # clear session cache to avoid leftover wrong-ssl-verify Session() try: - del CondaSession._thread_local.session + CondaSession._thread_local.sessions = {} except AttributeError: pass state = {} - with env_vars( - {"CONDA_SSL_VERIFY": str(verify_ssl).lower()}, - stack_callback=conda_tests_ctxt_mgmt_def_pol, - ), pytest.raises(expected_exception), pytest.warns() as record: + with pytest.raises(expected_exception), pytest.warns() as record: + monkeypatch.setenv("CONDA_SSL_VERIFY", str(verify_ssl).lower()) + reset_context() repo.repodata(state) + # Clear lru_cache from the `get_session` function + get_session.cache_clear() + # If we didn't disable warnings, we will see two 'InsecureRequestWarning' assert len(record) == 0, f"Unexpected warning {record[0]._category_name}" # clear session cache to avoid leftover wrong-ssl-verify Session() try: - del CondaSession._thread_local.session + CondaSession._thread_local.sessions = {} except AttributeError: pass diff --git a/tests/notices/test_core.py b/tests/notices/test_core.py index 0bd8fcdfef8..496af5235b1 100644 --- a/tests/notices/test_core.py +++ b/tests/notices/test_core.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("status_code", (200, 404, 500)) def test_display_notices_happy_path( - status_code, capsys, notices_cache_dir, notices_mock_http_session_get + status_code, capsys, notices_cache_dir, notices_mock_fetch_get_session ): """ Happy path for displaying notices. We test two error codes to make sure we get @@ -23,7 +23,7 @@ def test_display_notices_happy_path( """ messages = ("Test One", "Test Two") messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, status_code, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, status_code, messages_json) channel_notice_set = notices.retrieve_notices() notices.display_notices(channel_notice_set) @@ -48,14 +48,14 @@ def test_display_notices_happy_path( assert message not in captured.out -def test_notices_decorator(capsys, notices_cache_dir, notices_mock_http_session_get): +def test_notices_decorator(capsys, notices_cache_dir, notices_mock_fetch_get_session): """ Create a dummy function to wrap with our notices decorator and test it with two test messages. """ messages = ("Test One", "Test Two") messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, 200, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, 200, messages_json) dummy_mesg = "Dummy mesg" offset_cache_file_mtime(NOTICES_DECORATOR_DISPLAY_INTERVAL + 100) @@ -77,7 +77,7 @@ def dummy(args, parser): def test__conda_user_story__only_see_once( - capsys, notices_cache_dir, notices_mock_http_session_get + capsys, notices_cache_dir, notices_mock_fetch_get_session ): """ As a conda user, I only want to see a channel notice once while running @@ -86,7 +86,7 @@ def test__conda_user_story__only_see_once( messages = ("Test One",) dummy_mesg = "Dummy Mesg" messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, 200, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, 200, messages_json) offset_cache_file_mtime(NOTICES_DECORATOR_DISPLAY_INTERVAL + 100) @@ -110,7 +110,10 @@ def dummy(args, parser): def test__conda_user_story__disable_notices( - capsys, notices_cache_dir, notices_mock_http_session_get, disable_channel_notices + capsys, + notices_cache_dir, + notices_mock_fetch_get_session, + disable_channel_notices, ): """ As a conda user, if I disable channel notifications in my .condarc file, @@ -120,7 +123,7 @@ def test__conda_user_story__disable_notices( messages = ("Test One", "Test Two") dummy_mesg = "Dummy Mesg" messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, 200, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, 200, messages_json) @notices.notices def dummy(args, parser): @@ -136,7 +139,7 @@ def dummy(args, parser): def test__conda_user_story__more_notices_message( - capsys, notices_cache_dir, notices_mock_http_session_get + capsys, notices_cache_dir, notices_mock_fetch_get_session ): """ As a conda user, I want to see a message telling me there are more notices @@ -144,7 +147,7 @@ def test__conda_user_story__more_notices_message( """ messages = tuple(f"Test {idx}" for idx in range(1, 11, 1)) messages_json = get_test_notices(messages) - add_resp_to_mock(notices_mock_http_session_get, 200, messages_json) + add_resp_to_mock(notices_mock_fetch_get_session, 200, messages_json) offset_cache_file_mtime(NOTICES_DECORATOR_DISPLAY_INTERVAL + 100) diff --git a/tests/notices/test_fetch.py b/tests/notices/test_fetch.py index 49f351080da..aad8140293e 100644 --- a/tests/notices/test_fetch.py +++ b/tests/notices/test_fetch.py @@ -10,11 +10,11 @@ def test_get_channel_notice_response_timeout_error( - notices_cache_dir, notices_mock_http_session_get + notices_cache_dir, notices_mock_fetch_get_session ): """Tests the timeout error case for the get_channel_notice_response function.""" with patch("conda.notices.fetch.logger") as mock_logger: - notices_mock_http_session_get.side_effect = requests.exceptions.Timeout + notices_mock_fetch_get_session().get.side_effect = requests.exceptions.Timeout channel_notice_set = retrieve_notices() display_notices(channel_notice_set) @@ -24,12 +24,12 @@ def test_get_channel_notice_response_timeout_error( def test_get_channel_notice_response_malformed_json( - notices_cache_dir, notices_mock_http_session_get + notices_cache_dir, notices_mock_fetch_get_session ): """Tests malformed json error case for the get_channel_notice_response function.""" messages = ("hello", "hello 2") with patch("conda.notices.fetch.logger") as mock_logger: - add_resp_to_mock(notices_mock_http_session_get, 200, messages, raise_exc=True) + add_resp_to_mock(notices_mock_fetch_get_session, 200, messages, raise_exc=True) channel_notice_set = retrieve_notices() display_notices(channel_notice_set) diff --git a/tests/plugins/test_auth_handlers.py b/tests/plugins/test_auth_handlers.py new file mode 100644 index 00000000000..731247829be --- /dev/null +++ b/tests/plugins/test_auth_handlers.py @@ -0,0 +1,81 @@ +# Copyright (C) 2012 Anaconda, Inc +# SPDX-License-Identifier: BSD-3-Clause +import re + +import pytest +from requests.auth import HTTPBasicAuth + +from conda import plugins +from conda.exceptions import PluginError + +PLUGIN_NAME = "custom_auth" +PLUGIN_NAME_ALT = "custom_auth_alt" + + +class CustomCondaAuth(HTTPBasicAuth): + def __init__(self): + username = "user_one" + password = "pass_one" + super().__init__(username, password) + + +class CustomAltCondaAuth(HTTPBasicAuth): + def __init__(self): + username = "user_two" + password = "pass_two" + super().__init__(username, password) + + +class CustomAuthPlugin: + @plugins.hookimpl + def conda_auth_handlers(self): + yield plugins.CondaAuthHandler(handler=CustomCondaAuth, name=PLUGIN_NAME) + + +class CustomAltAuthPlugin: + @plugins.hookimpl + def conda_auth_handlers(self): + yield plugins.CondaAuthHandler(handler=CustomAltCondaAuth, name=PLUGIN_NAME_ALT) + + +def test_get_auth_handler(plugin_manager): + """ + Return the correct auth backend class or return ``None`` + """ + plugin = CustomAuthPlugin() + plugin_manager.register(plugin) + + auth_handler_cls = plugin_manager.get_auth_handler(PLUGIN_NAME) + assert auth_handler_cls is CustomCondaAuth + + auth_handler_cls = plugin_manager.get_auth_handler("DOES_NOT_EXIST") + assert auth_handler_cls is None + + +def test_get_auth_handler_multiple(plugin_manager): + """ + Tests to make sure we can retrieve auth backends when there are multiple hooks registered. + """ + plugin_one = CustomAuthPlugin() + plugin_two = CustomAltAuthPlugin() + plugin_manager.register(plugin_one) + plugin_manager.register(plugin_two) + + auth_class = plugin_manager.get_auth_handler(PLUGIN_NAME) + assert auth_class is CustomCondaAuth + + auth_class = plugin_manager.get_auth_handler(PLUGIN_NAME_ALT) + assert auth_class is CustomAltCondaAuth + + +def test_duplicated(plugin_manager): + """ + Make sure that a PluginError is raised if we register the same auth backend twice. + """ + plugin_manager.register(CustomAuthPlugin()) + plugin_manager.register(CustomAuthPlugin()) + + with pytest.raises( + PluginError, match=re.escape("Conflicting `auth_handlers` plugins found") + ): + plugin_manager.get_auth_handler(PLUGIN_NAME) diff --git a/tests/plugins/test_post_command.py b/tests/plugins/test_post_commands.py similarity index 100% rename from tests/plugins/test_post_command.py rename to tests/plugins/test_post_commands.py diff --git a/tests/plugins/test_pre_command.py b/tests/plugins/test_pre_commands.py similarity index 100% rename from tests/plugins/test_pre_command.py rename to tests/plugins/test_pre_commands.py