Skip to content

Commit

Permalink
Auth handler plugin hook (#12911)
Browse files Browse the repository at this point in the history
* replace CondaSession with get_session function
* fixing some issues with the name of the context object property for channel settings

---------

Co-authored-by: Jannis Leidel <jannis@leidel.info>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Bianca Henderson <bhenderson@anaconda.com>
  • Loading branch information
4 people committed Aug 24, 2023
1 parent 1313ef5 commit b53705b
Show file tree
Hide file tree
Showing 23 changed files with 525 additions and 72 deletions.
6 changes: 3 additions & 3 deletions conda/gateways/connection/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
RequestsProxyError,
SSLError,
)
from .session import CondaSession
from .session import get_session

log = getLogger(__name__)

Expand All @@ -55,7 +55,7 @@ def download(

try:
timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs
session = CondaSession()
session = get_session(url)
resp = session.get(url, stream=True, proxies=session.proxies, timeout=timeout)
if log.isEnabledFor(DEBUG):
log.debug(stringify(resp, content_max_len=256))
Expand Down Expand Up @@ -214,7 +214,7 @@ def download_text(url):
disable_ssl_verify_warning()
try:
timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs
session = CondaSession()
session = get_session(url)
response = session.get(
url, stream=True, proxies=session.proxies, timeout=timeout
)
Expand Down
84 changes: 76 additions & 8 deletions conda/gateways/connection/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (C) 2012 Anaconda, Inc
# SPDX-License-Identifier: BSD-3-Clause
"""Requests session configured with all accepted scheme adapters."""
from __future__ import annotations

from functools import lru_cache
from logging import getLogger
from threading import local

Expand All @@ -14,6 +17,7 @@
urlparse,
)
from ...exceptions import ProxyError
from ...models.channel import Channel
from ..anaconda_client import read_binstar_tokens
from . import (
AuthBase,
Expand Down Expand Up @@ -60,6 +64,61 @@ def close(self):
raise NotImplementedError()


def get_channel_name_from_url(url: str) -> str | None:
"""
Given a URL, determine the channel it belongs to and return its name.
"""
return Channel.from_url(url).canonical_name


@lru_cache(maxsize=None)
def get_session(url: str):
"""
Function that determines the correct Session object to be returned
based on the URL that is passed in.
"""
channel_name = get_channel_name_from_url(url)

# If for whatever reason a channel name can't be determined, (should be unlikely)
# we just return the default session object.
if channel_name is None:
return CondaSession()

# We ensure here if there are duplicates defined, we choose the last one
channel_settings = {}
for settings in context.channel_settings:
if settings.get("channel") == channel_name:
channel_settings = settings

auth_handler = channel_settings.get("auth", "").strip() or None

# Return default session object
if auth_handler is None:
return CondaSession()

auth_handler_cls = context.plugin_manager.get_auth_handler(auth_handler)

if not auth_handler_cls:
return CondaSession()

return CondaSession(auth=auth_handler_cls(channel_name))


def get_session_storage_key(auth) -> str:
"""
Function that determines which storage key to use for our CondaSession object caching
"""
if auth is None:
return "default"

if isinstance(auth, tuple):
return hash(auth)

auth_type = type(auth)

return f"{auth_type.__module__}.{auth_type.__qualname__}::{auth.channel_name}"


class CondaSessionType(type):
"""
Takes advice from https://github.com/requests/requests/issues/1871#issuecomment-33327847
Expand All @@ -70,21 +129,30 @@ def __new__(mcs, name, bases, dct):
dct["_thread_local"] = local()
return super().__new__(mcs, name, bases, dct)

def __call__(cls):
def __call__(cls, **kwargs):
storage_key = get_session_storage_key(kwargs.get("auth"))

try:
return cls._thread_local.session
return cls._thread_local.sessions[storage_key]
except AttributeError:
session = cls._thread_local.session = super().__call__()
return session
session = super().__call__(**kwargs)
cls._thread_local.sessions = {storage_key: session}
except KeyError:
session = cls._thread_local.sessions[storage_key] = super().__call__(
**kwargs
)

return session


class CondaSession(Session, metaclass=CondaSessionType):
def __init__(self):
def __init__(self, auth: AuthBase | tuple[str, str] | None = None):
"""
:param auth: Optionally provide ``requests.AuthBase`` compliant objects
"""
super().__init__()

self.auth = (
CondaHttpAuth()
) # TODO: should this just be for certain protocol adapters?
self.auth = auth or CondaHttpAuth()

self.proxies.update(context.proxy_servers)

Expand Down
4 changes: 2 additions & 2 deletions conda/gateways/repodata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
Response,
SSLError,
)
from ..connection.session import CondaSession
from ..connection.session import get_session
from ..disk import mkdir_p_sudo_safe
from .lock import lock

Expand Down Expand Up @@ -127,7 +127,7 @@ def repodata(self, state: RepodataState) -> str | None:
if not context.ssl_verify:
warnings.simplefilter("ignore", InsecureRequestWarning)

session = CondaSession()
session = get_session(self._url)

headers = {}
etag = state.etag
Expand Down
6 changes: 3 additions & 3 deletions conda/gateways/repodata/jlap/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from conda.base.context import context
from conda.gateways.connection.download import disable_ssl_verify_warning
from conda.gateways.connection.session import CondaSession
from conda.gateways.connection.session import get_session

from .. import (
CACHE_CONTROL_KEY,
Expand Down Expand Up @@ -64,11 +64,11 @@ def repodata_parsed(self, state: dict | RepodataState) -> dict | None:
When repodata is not updated, it doesn't matter whether this function or
the caller reads from a file.
"""
session = get_session(self._url)

if not context.ssl_verify:
disable_ssl_verify_warning()

session = CondaSession()

repodata_url = f"{self._url}/{self._repodata_fn}"

# XXX won't modify caller's state dict
Expand Down
4 changes: 2 additions & 2 deletions conda/notices/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import requests

from ..common.io import Spinner
from ..gateways.connection.session import CondaSession
from ..gateways.connection.session import get_session
from .cache import cached_response
from .types import ChannelNoticeResponse

Expand Down Expand Up @@ -56,7 +56,7 @@ def get_channel_notice_response(url: str, name: str) -> ChannelNoticeResponse |
additional channel information to use. If the response was invalid we suppress/log
and error message.
"""
session = CondaSession()
session = get_session(url)
try:
resp = session.get(
url, allow_redirects=False, timeout=5
Expand Down
1 change: 1 addition & 0 deletions conda/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from .hookspec import hookimpl # noqa: F401
from .types import ( # noqa: F401
CondaAuthHandler,
CondaPostCommand,
CondaPreCommand,
CondaSolver,
Expand Down
38 changes: 38 additions & 0 deletions conda/plugins/hookspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pluggy

from .types import (
CondaAuthHandler,
CondaPostCommand,
CondaPreCommand,
CondaSolver,
Expand Down Expand Up @@ -170,3 +171,40 @@ def conda_post_commands():
run_for={"install", "create"},
)
"""

@_hookspec
def conda_auth_handlers(self) -> Iterable[CondaAuthHandler]:
"""
Register a conda auth handler derived from the requests API.
This plugin hook allows attaching requests auth handler subclasses,
e.g. when authenticating requests against individual channels hosted
at HTTP/HTTPS services.
**Example:**
.. code-block:: python
import os
from conda import plugins
from requests.auth import AuthBase
class EnvironmentHeaderAuth(AuthBase):
def __init__(self, *args, **kwargs):
self.username = os.environ["EXAMPLE_CONDA_AUTH_USERNAME"]
self.password = os.environ["EXAMPLE_CONDA_AUTH_PASSWORD"]
def __call__(self, request):
request.headers["X-Username"] = self.username
request.headers["X-Password"] = self.password
return request
@plugins.hookimpl
def conda_auth_handlers():
yield plugins.CondaAuthHandler(
name="environment-header-auth",
auth_handler=EnvironmentHeaderAuth,
)
"""
13 changes: 13 additions & 0 deletions conda/plugins/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from inspect import getmodule, isclass

import pluggy
from requests.auth import AuthBase

from ..auxlib.ish import dals
from ..base.context import context
Expand Down Expand Up @@ -202,6 +203,18 @@ def get_solver_backend(self, name: str | None = None) -> type[Solver]:

return backend

def get_auth_handler(self, name: str) -> type[AuthBase] | None:
"""
Get the auth handler with the given name or None
"""
auth_handlers = self.get_hook_results("auth_handlers")
matches = tuple(
item for item in auth_handlers if item.name.lower() == name.lower().strip()
)

if len(matches) > 0:
return matches[0].handler

def invoke_pre_commands(self, command: str) -> None:
"""
Invokes ``CondaPreCommand.action`` functions registered with ``conda_pre_commands``.
Expand Down
39 changes: 39 additions & 0 deletions conda/plugins/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from dataclasses import dataclass, field
from typing import Callable, NamedTuple

from requests.auth import AuthBase

from ..core.solve import Solver


Expand Down Expand Up @@ -102,3 +104,40 @@ class CondaPostCommand(NamedTuple):
name: str
action: Callable[[str], None]
run_for: set[str]


class ChannelNameMixin:
"""
Class mixin to make all plugin implementations compatible, e.g. when they
use an existing (e.g. 3rd party) requests authentication handler.
Please use the concrete :class:`~conda.plugins.types.ChannelAuthBase`
in case you're creating an own implementation.
"""

def __init__(self, channel_name: str, *args, **kwargs):
self.channel_name = channel_name
super().__init__(*args, **kwargs)


class ChannelAuthBase(ChannelNameMixin, AuthBase):
"""
Base class that we require all plugin implementations to use to be compatible.
Authentication is tightly coupled with individual channels. Therefore, an additional
``channel_name`` property must be set on the ``requests.auth.AuthBase`` based class.
"""


class CondaAuthHandler(NamedTuple):
"""
Return type to use when the defining the conda auth handlers hook.
:param name: Name (e.g., ``basic-auth``). This name should be unique
and only one may be registered at a time.
:param handler: Type that will be used as the authentication handler
during network requests.
"""

name: str
handler: type[ChannelAuthBase]
9 changes: 4 additions & 5 deletions conda/testing/notices/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ def notices_cache_dir(tmpdir):


@pytest.fixture(scope="function")
def notices_mock_http_session_get():
with mock.patch(
"conda.gateways.connection.session.CondaSession.get"
) as session_get:
yield session_get
def notices_mock_fetch_get_session():
with mock.patch("conda.notices.fetch.get_session") as mock_get_session:
mock_get_session.return_value = mock.MagicMock()
yield mock_get_session


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion conda/testing/notices/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def one_200():
yield MockResponse(status_code, messages_json, raise_exc=raise_exc)

chn = chain(one_200(), forever_404())
mock_session.side_effect = tuple(next(chn) for _ in range(100))
mock_session().get.side_effect = tuple(next(chn) for _ in range(100))


def create_notice_cache_files(
Expand Down

0 comments on commit b53705b

Please sign in to comment.