Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Digest auth middleware #305

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
e2c07d8
Split middlewares to a package
Sep 2, 2019
61d711b
Add digest auth middleware and HTTPDigestAuth model
Sep 2, 2019
85d910e
Move safe_encode to utils
Sep 2, 2019
9f307b8
Handle nonce count
Sep 2, 2019
3fabceb
Fix linting issues
Sep 2, 2019
cee7ff2
Add test for Digest auth
Sep 2, 2019
f75f450
Parametrize Digest auth test with all supported algorithms
Sep 2, 2019
6f7f659
Raise when Digest header cannot be parsed
Sep 2, 2019
e35b8a4
Add HTTPDigestAuth to AuthTypes
Sep 2, 2019
b499ce8
Fix linting
Sep 2, 2019
b06ffda
Add type annotation for class variable
Sep 3, 2019
e88b4c3
Do not use `safe_encode` on known types
Sep 3, 2019
2d19616
Remove unnecessary encoding argument
Sep 3, 2019
75fd9bc
Inline function
Sep 3, 2019
9831ad6
Remove unnecessary UTF-8 specific encodings
Sep 3, 2019
672f80b
Added `to_str` util function
Sep 3, 2019
2e9a221
Allow space separated qop values
Sep 3, 2019
2e70dee
Add DigestAuthChallenge helper class
Sep 3, 2019
1cf0df1
Encode username and password on init
Sep 3, 2019
bfd8a71
Refactor parsing of the digest header into DigestAuthChallenge
Sep 3, 2019
4067181
Fix calculation of HA1 for `-sess` prefix algorithms
Sep 3, 2019
8e47ca6
Use !r in exception message
Sep 3, 2019
63589d1
Refactor nonce count tracking
Sep 3, 2019
6b11934
Rename variable
Sep 3, 2019
2ee35ab
Refactor client nonce generation
Sep 3, 2019
0ec47ca
Add typing to intermediate variable
Sep 3, 2019
49d8723
Remove unhelpful comments
Sep 3, 2019
2d327bc
Update README and documentation
Sep 3, 2019
cd5f6aa
Handle failed auth
Sep 3, 2019
22b0451
Raise ProtocolError on malformed requests
Sep 4, 2019
83b76a4
Use parse_http_list to correctly extract challenge fields
Sep 4, 2019
88dba56
Add tests covering digest auth malformed headers and missing qop
Sep 4, 2019
0af458a
Fix linting
Sep 4, 2019
82d5312
Add extra cases of malformed headers
Sep 4, 2019
2f95e02
Rename HTTPDigestAuth > DigestAuth
Sep 4, 2019
4e080f9
Remove DigestAuth model and accept BaseMiddleware instance as auth
Sep 4, 2019
c7eefa1
Use utf-8 encoding on BasicAuth
Sep 4, 2019
d580f97
Use to_bytes on basic auth
Sep 5, 2019
99b13de
Remove unnecessary type check
Sep 5, 2019
9bb823b
Avoid circular import on type checking
Sep 5, 2019
e436bab
Add helper class holding per-request Digest auth state
Sep 5, 2019
b2b7078
Hold nonce count globally
Sep 6, 2019
ed89de7
Do not quote algorithm, qop, and nc
Sep 6, 2019
36d8a33
Fix variable name
Sep 6, 2019
6f486f9
Handle ', ' separated qop values
Sep 6, 2019
33c072f
Add LRUDict util and use it for the global nonce count
Sep 7, 2019
5ee7ade
Fix merge conflicts with master
florimondmanca Sep 7, 2019
8a4ec2e
Merge pull request #1 from encode/digest-auth-middleware
yeraydiazdiaz Sep 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -56,7 +56,7 @@ Plus all the standard features of `requests`...
* Keep-Alive & Connection Pooling
* Sessions with Cookie Persistence
* Browser-style SSL Verification
* Basic/Digest Authentication *(Digest is still TODO)*
* Basic/Digest Authentication
* Elegant Key/Value Cookies
* Automatic Decompression
* Automatic Content Decoding
Expand Down
22 changes: 22 additions & 0 deletions docs/quickstart.md
Expand Up @@ -379,3 +379,25 @@ value to be more or less strict:
```python
>>> httpx.get('https://github.com/', timeout=0.001)
```

## Authentication

HTTPX supports Basic and Digest HTTP authentication.

To provide Basic authentication credentials you can provide a tuple of
plaintext `str` or `bytes` objects as the `auth` argument to the request
functions:

```python
>>> httpx.get("https://example.com", auth=("my_user", "password123"))
```

To provide credentials for Digest authentication you need to instantiate
an `DigestAuth` object providing, again, the plaintext username and
password as arguments. This object can be then passed as the `auth` argument
to the request methods as above:

```python
>>> auth = httpx.DigestAuth("my_user", "password123")
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
>>> httpx.get("https://example.com", auth=auth)
```
4 changes: 4 additions & 0 deletions httpx/__init__.py
Expand Up @@ -40,6 +40,8 @@
TooManyRedirects,
WriteTimeout,
)
from .middleware.basic_auth import BasicAuthMiddleware as BasicAuth
from .middleware.digest_auth import DigestAuthMiddleware as DigestAuth
from .models import (
URL,
AsyncRequest,
Expand Down Expand Up @@ -125,6 +127,8 @@
"CookieTypes",
"Headers",
"HeaderTypes",
"BasicAuth",
"DigestAuth",
"Origin",
"QueryParams",
"QueryParamTypes",
Expand Down
6 changes: 4 additions & 2 deletions httpx/client.py
Expand Up @@ -211,14 +211,16 @@ def _get_auth_middleware(
) -> typing.Optional[BaseMiddleware]:
if isinstance(auth, tuple):
return BasicAuthMiddleware(username=auth[0], password=auth[1])

if callable(auth):
elif isinstance(auth, BaseMiddleware):
return auth
elif callable(auth):
return CustomAuthMiddleware(auth=auth)

if auth is not None:
raise TypeError(
'When specified, "auth" must be a (username, password) tuple or '
"a callable with signature (AsyncRequest) -> AsyncRequest "
"or a subclass of BaseMiddleware "
f"(got {auth!r})"
)

Expand Down
226 changes: 226 additions & 0 deletions httpx/middleware/digest_auth.py
@@ -0,0 +1,226 @@
import hashlib
import os
import re
import time
import typing
from urllib.request import parse_http_list

from ..exceptions import ProtocolError
from ..models import AsyncRequest, AsyncResponse, StatusCode
from ..utils import DefaultLRUDict, to_bytes, to_str, unquote
from .base import BaseMiddleware


class DigestAuthMiddleware(BaseMiddleware):
per_nonce_count: typing.Dict[bytes, int] = DefaultLRUDict(1_000, lambda: 0)

def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> None:
self.username = to_bytes(username)
self.password = to_bytes(password)

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
request_middleware = _RequestDigestAuth(
username=self.username,
password=self.password,
per_nonce_count=self.per_nonce_count,
)
return await request_middleware(request=request, get_response=get_response)


class _RequestDigestAuth(BaseMiddleware):
ALGORITHM_TO_HASH_FUNCTION: typing.Dict[str, typing.Callable] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
"SHA": hashlib.sha1,
"SHA-SESS": hashlib.sha1,
"SHA-256": hashlib.sha256,
"SHA-256-SESS": hashlib.sha256,
"SHA-512": hashlib.sha512,
"SHA-512-SESS": hashlib.sha512,
}

def __init__(
self, username: bytes, password: bytes, per_nonce_count: typing.Dict[bytes, int]
) -> None:
self.username = username
self.password = password
self.per_nonce_count = per_nonce_count
self.num_401_responses = 0

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
response = await get_response(request)
if not (
StatusCode.is_client_error(response.status_code)
and "www-authenticate" in response.headers
):
self.num_401_responses = 0
return response

header = response.headers["www-authenticate"]
try:
challenge = DigestAuthChallenge.from_header(header)
except ValueError:
raise ProtocolError("Malformed Digest authentication header")

if self._previous_auth_failed(challenge):
return response

request.headers["Authorization"] = self._build_auth_header(request, challenge)
return await self(request, get_response)

def _previous_auth_failed(self, challenge: "DigestAuthChallenge") -> bool:
"""Returns whether the previous auth failed.

This is fairly subtle as the server may return a 401 with the *same* nonce that
that of our first attempt. If it is different, however, we know the server has
rejected our credentials.
"""
self.num_401_responses += 1

return (
challenge.nonce not in self.per_nonce_count and self.num_401_responses > 1
)

def _build_auth_header(
self, request: AsyncRequest, challenge: "DigestAuthChallenge"
) -> str:
hash_func = self.ALGORITHM_TO_HASH_FUNCTION[challenge.algorithm]

def digest(data: bytes) -> bytes:
return hash_func(data).hexdigest().encode()

A1 = b":".join((self.username, challenge.realm, self.password))

path = request.url.full_path.encode("utf-8")
A2 = b":".join((request.method.encode(), path))
# TODO: implement auth-int
HA2 = digest(A2)

nonce_count, nc_value = self._get_nonce_count(challenge.nonce)
cnonce = self._get_client_nonce(nonce_count, challenge.nonce)

HA1 = digest(A1)
if challenge.algorithm.lower().endswith("-sess"):
HA1 = digest(b":".join((HA1, challenge.nonce, cnonce)))

qop = self._resolve_qop(challenge.qop)
if qop is None:
digest_data = [HA1, challenge.nonce, HA2]
else:
digest_data = [challenge.nonce, nc_value, cnonce, qop, HA2]
key_digest = b":".join(digest_data)

format_args = {
"username": self.username,
"realm": challenge.realm,
"nonce": challenge.nonce,
"uri": path,
"response": digest(b":".join((HA1, key_digest))),
"algorithm": challenge.algorithm.encode(),
}
if challenge.opaque:
format_args["opaque"] = challenge.opaque
if qop:
format_args["qop"] = b"auth"
format_args["nc"] = nc_value
format_args["cnonce"] = cnonce

return "Digest " + self._get_header_value(format_args)

def _get_client_nonce(self, nonce_count: int, nonce: bytes) -> bytes:
s = str(nonce_count).encode()
s += nonce
s += time.ctime().encode()
s += os.urandom(8)

return hashlib.sha1(s).hexdigest()[:16].encode()

def _get_nonce_count(self, nonce: bytes) -> typing.Tuple[int, bytes]:
"""Returns the number of requests made with the same server provided
nonce value along with its 8-digit hex representation."""
self.per_nonce_count[nonce] += 1
return self.per_nonce_count[nonce], b"%08x" % self.per_nonce_count[nonce]

def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str:
NON_QUOTED_FIELDS = ("algorithm", "qop", "nc")
QUOTED_TEMPLATE = '{}="{}"'
NON_QUOTED_TEMPLATE = "{}={}"

header_value = ""
for i, (field, value) in enumerate(header_fields.items()):
if i > 0:
header_value += ", "
template = (
QUOTED_TEMPLATE
if field not in NON_QUOTED_FIELDS
else NON_QUOTED_TEMPLATE
)
header_value += template.format(field, to_str(value))

return header_value

def _resolve_qop(self, qop: typing.Optional[bytes]) -> typing.Optional[bytes]:
if qop is None:
return None
qops = re.split(b", ?", qop)
if b"auth" in qops:
return b"auth"

if qops == [b"auth-int"]:
raise NotImplementedError("Digest auth-int support is not yet implemented")

raise ProtocolError(f'Unexpected qop value "{qop!r}" in digest auth')


class DigestAuthChallenge:
def __init__(
self,
realm: bytes,
nonce: bytes,
algorithm: str = None,
opaque: typing.Optional[bytes] = None,
qop: typing.Optional[bytes] = None,
) -> None:
self.realm = realm
self.nonce = nonce
self.algorithm = algorithm or "MD5"
self.opaque = opaque
self.qop = qop

@classmethod
def from_header(cls, header: str) -> "DigestAuthChallenge":
"""Returns a challenge from a Digest WWW-Authenticate header.

These take the form of:
`Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"`
"""
scheme, _, fields = header.partition(" ")
if scheme.lower() != "digest":
raise ValueError("Header does not start with 'Digest'")

header_dict: typing.Dict[str, str] = {}
for field in parse_http_list(fields):
key, value = field.strip().split("=")
header_dict[key] = unquote(value)

try:
return cls.from_header_dict(header_dict)
except KeyError as exc:
raise ValueError("Malformed Digest WWW-Authenticate header") from exc

@classmethod
def from_header_dict(cls, header_dict: dict) -> "DigestAuthChallenge":
realm = header_dict["realm"].encode()
nonce = header_dict["nonce"].encode()
qop = header_dict["qop"].encode() if "qop" in header_dict else None
opaque = header_dict["opaque"].encode() if "opaque" in header_dict else None
algorithm = header_dict.get("algorithm")
return cls(
realm=realm, nonce=nonce, qop=qop, opaque=opaque, algorithm=algorithm
)
4 changes: 4 additions & 0 deletions httpx/models.py
Expand Up @@ -39,6 +39,9 @@
str_query_param,
)

if typing.TYPE_CHECKING:
from .middleware.base import BaseMiddleware # noqa: F401

PrimitiveData = typing.Optional[typing.Union[str, int, float, bool]]

URLTypes = typing.Union["URL", str]
Expand All @@ -61,6 +64,7 @@
AuthTypes = typing.Union[
typing.Tuple[typing.Union[str, bytes], typing.Union[str, bytes]],
typing.Callable[["AsyncRequest"], "AsyncRequest"],
"BaseMiddleware",
]

AsyncRequestData = typing.Union[dict, str, bytes, typing.AsyncIterator[bytes]]
Expand Down
51 changes: 51 additions & 0 deletions httpx/utils.py
Expand Up @@ -173,3 +173,54 @@ def get_logger(name: str) -> logging.Logger:

def to_bytes(value: typing.Union[str, bytes], encoding: str = "utf-8") -> bytes:
return value.encode(encoding) if isinstance(value, str) else value


def to_str(str_or_bytes: typing.Union[str, bytes], encoding: str = "utf-8") -> str:
return (
str_or_bytes if isinstance(str_or_bytes, str) else str_or_bytes.decode(encoding)
)


def unquote(value: str) -> str:
return value[1:-1] if value[0] == value[-1] == '"' else value


class LRUDict(dict):
"""Subclass of dict keeping only the N last items inserted.

When setting a new item the oldest element on the dict is deleted.
"""

def __init__(
self,
max_size: int,
*args: typing.Iterable[typing.Tuple[typing.Any, typing.Any]],
**kwargs: typing.Any,
) -> None:
if len(args) + len(kwargs) > max_size:
raise ValueError("Cannot initialize with more elements than the maximum")
self.max_size = max_size
super().__init__(*args, **kwargs)

def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
super().__setitem__(key, value)
if len(self) > self.max_size:
del self[list(self)[0]]


class DefaultLRUDict(LRUDict):
def __init__(
self,
max_size: int,
default_factory: typing.Callable = None,
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
*args: typing.Iterable[typing.Tuple[typing.Any, typing.Any]],
**kwargs: typing.Any,
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.default_factory = default_factory
super().__init__(max_size, *args, **kwargs)

def __missing__(self, key: typing.Any) -> typing.Any:
if self.default_factory is None:
raise KeyError(key)
self[key] = self.default_factory()
return self[key]