Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DigestAuth as middleware - No nonce count #332

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions httpx/__init__.py
Expand Up @@ -40,6 +40,7 @@
TooManyRedirects,
WriteTimeout,
)
from .middleware.digest_auth import DigestAuth
from .models import (
URL,
AsyncRequest,
Expand Down Expand Up @@ -133,4 +134,5 @@
"Response",
"ResponseContent",
"RequestFiles",
"DigestAuth",
]
5 changes: 3 additions & 2 deletions httpx/client.py
Expand Up @@ -211,8 +211,9 @@ def _get_auth_middleware(
) -> typing.Optional[BaseMiddleware]:
if isinstance(auth, tuple):
return BasicAuthMiddleware(username=auth[0], password=auth[1])

if callable(auth):
elif isinstance(auth, BaseMiddleware):
return auth
elif callable(auth):
return CustomAuthMiddleware(auth=auth)

if auth is not None:
Expand Down
188 changes: 188 additions & 0 deletions httpx/middleware/digest_auth.py
@@ -0,0 +1,188 @@
import hashlib
import os
import re
import time
import typing
from urllib.request import parse_http_list

from ..exceptions import ProtocolError
from ..models import AsyncRequest, AsyncResponse, StatusCode
from ..utils import to_bytes, to_str, unquote
from .base import BaseMiddleware


class DigestAuth(BaseMiddleware):
ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
"SHA": hashlib.sha1,
"SHA-SESS": hashlib.sha1,
"SHA-256": hashlib.sha256,
"SHA-256-SESS": hashlib.sha256,
"SHA-512": hashlib.sha512,
"SHA-512-SESS": hashlib.sha512,
}

def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> None:
self.username = to_bytes(username)
self.password = to_bytes(password)
self._num_401_responses = 0
yeraydiazdiaz marked this conversation as resolved.
Show resolved Hide resolved

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
response = await get_response(request)
if not (
StatusCode.is_client_error(response.status_code)
and "www-authenticate" in response.headers
):
self._num_401_responses = 0
return response
else:
self._num_401_responses += 1

header = response.headers["www-authenticate"]
try:
challenge = DigestAuthChallenge.from_header(header)
except ValueError:
raise ProtocolError("Malformed Digest authentication header")

if self._num_401_responses > 1:
return response

request.headers["Authorization"] = self._build_auth_header(request, challenge)
return await self(request, get_response)

def _build_auth_header(
self, request: AsyncRequest, challenge: "DigestAuthChallenge"
) -> str:
hash_func = self.ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm]

def digest(data: bytes) -> bytes:
return hash_func(data).hexdigest().encode()

A1 = b":".join((self.username, challenge.realm, self.password))

path = request.url.full_path.encode("utf-8")
A2 = b":".join((request.method.encode(), path))
# TODO: implement auth-int
HA2 = digest(A2)

nonce_count = 1 # TODO: implement nonce counting
nc_value = b"%08x" % nonce_count
cnonce = self._get_client_nonce(nonce_count, challenge.nonce)

HA1 = digest(A1)
if challenge.algorithm.lower().endswith("-sess"):
HA1 = digest(b":".join((HA1, challenge.nonce, cnonce)))

qop = self._resolve_qop(challenge.qop)
if qop is None:
digest_data = [HA1, challenge.nonce, HA2]
else:
digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
key_digest = b":".join(digest_data)

format_args = {
"username": self.username,
"realm": challenge.realm,
"nonce": challenge.nonce,
"uri": path,
"response": digest(b":".join((HA1, key_digest))),
"algorithm": challenge.algorithm.encode(),
}
if challenge.opaque:
format_args["opaque"] = challenge.opaque
if qop:
format_args["qop"] = b"auth"
format_args["nc"] = nc_value
format_args["cnonce"] = cnonce

return "Digest " + self._get_header_value(format_args)

def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes:
s = str(nonce_count).encode()
s += nonce
s += time.ctime().encode()
s += os.urandom(8)

return hashlib.sha1(s).hexdigest()[:16].encode()

def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
QUOTED_TEMPLATE = '{}="{}"'
NON_QUOTED_TEMPLATE = "{}={}"

header_value = ""
for i, (field, value) in enumerate(header_fields.items()):
if i > 0:
header_value += ", "
template = (
QUOTED_TEMPLATE
if field not in NON_QUOTED_FIELDS
else NON_QUOTED_TEMPLATE
)
header_value += template.format(field, to_str(value))

return header_value

def _resolve_qop(self, qop: typing.Optional[bytes]) -> typing.Optional[bytes]:
if qop is None:
return None
qops = re.split(b", ?", qop)
if b"auth" in qops:
return b"auth"

if qops == [b"auth-int"]:
raise NotImplementedError("Digest auth-int support is not yet implemented")

raise ProtocolError(f'Unexpected qop value "{qop!r}" in digest auth')


class DigestAuthChallenge:
def __init__(
self,
realm: bytes,
nonce: bytes,
algorithm: str = None,
opaque: typing.Optional[bytes] = None,
qop: typing.Optional[bytes] = None,
) -> None:
self.realm = realm
self.nonce = nonce
self.algorithm = algorithm or "MD5"
self.opaque = opaque
self.qop = qop

@classmethod
def from_header(cls, header: str) -> "DigestAuthChallenge":
"""Returns a challenge from a Digest WWW-Authenticate header.
These take the form of:
`Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
"""
scheme, _, fields = header.partition(" ")
if scheme.lower() != "digest":
raise ValueError("Header does not start with 'Digest'")

header_dict: typing.Dict[str, str] = {}
for field in parse_http_list(fields):
key, value = field.strip().split("=")
yeraydiazdiaz marked this conversation as resolved.
Show resolved Hide resolved
header_dict[key] = unquote(value)

try:
return cls.from_header_dict(header_dict)
except KeyError as exc:
raise ValueError("Malformed Digest WWW-Authenticate header") from exc

@classmethod
def from_header_dict(cls, header_dict: dict) -> "DigestAuthChallenge":
realm = header_dict["realm"].encode()
nonce = header_dict["nonce"].encode()
qop = header_dict["qop"].encode() if "qop" in header_dict else None
opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
algorithm = header_dict.get("algorithm")
return cls(
realm=realm, nonce=nonce, qop=qop, opaque=opaque, algorithm=algorithm
)
4 changes: 4 additions & 0 deletions httpx/models.py
Expand Up @@ -39,6 +39,9 @@
str_query_param,
)

if typing.TYPE_CHECKING:
from .middleware.base import BaseMiddleware # noqa: F401

PrimitiveData = typing.Optional[typing.Union[str, int, float, bool]]

URLTypes = typing.Union["URL", str]
Expand All @@ -61,6 +64,7 @@
AuthTypes = typing.Union[
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
typing.Callable[["AsyncRequest"], "AsyncRequest"],
"BaseMiddleware",
]

AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
Expand Down
10 changes: 10 additions & 0 deletions httpx/utils.py
Expand Up @@ -173,3 +173,13 @@ def get_logger(name: str) -> logging.Logger:

def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
return value.encode(encoding) if isinstance(value, str) else value


def to_str(str_or_bytes: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
return (
str_or_bytes if isinstance(str_or_bytes, str) else str_or_bytes.decode(encoding)
)


def unquote(value: str) -> str:
return value[1:-1] if value[0] == value[-1] == '"' else value