Skip to content

Commit

Permalink
Merge pull request #426 from mjcaley/client-cert-password
Browse files Browse the repository at this point in the history
Better handling of client private key
  • Loading branch information
mjcaley committed Oct 10, 2023
2 parents 9fc6606 + c7837c2 commit 74ad8fa
Show file tree
Hide file tree
Showing 9 changed files with 357 additions and 20 deletions.
6 changes: 4 additions & 2 deletions aiospamc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ssl
import sys
from enum import Enum
from getpass import getuser
from getpass import getpass, getuser
from io import BufferedReader
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -176,7 +176,9 @@ def add_client_cert(
if self._ssl is False:
self._ssl = True
self.add_verify(True)
self._ssl_builder.add_client(cert, key, password)
self._ssl_builder.add_client(
cert, key, lambda: password or getpass("Private key password")
)

return self

Expand Down
10 changes: 7 additions & 3 deletions aiospamc/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import asyncio
import ssl
from enum import Enum, auto
from getpass import getpass
from pathlib import Path
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import certifi
import loguru
Expand Down Expand Up @@ -499,13 +500,16 @@ def add_default_ca(self) -> SSLContextBuilder:
return self

def add_client(
self, file: Path, key: Optional[Path] = None, password: Optional[str] = None
self,
file: Path,
key: Optional[Path] = None,
password: Optional[Callable[[], Union[str, bytes, bytearray]]] = None,
) -> SSLContextBuilder:
"""Add client certificate.
:param file: Path to the client certificate.
:param key: Path to the key.
:param password: Password of the key.
:param password: Callable that returns the password, if any.
"""

self._context.load_cert_chain(file, key, password)
Expand Down
20 changes: 17 additions & 3 deletions aiospamc/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import ssl
from functools import partial
from pathlib import Path
from typing import Any, Dict, Optional, SupportsBytes, Tuple, Union, cast

Expand Down Expand Up @@ -118,16 +119,29 @@ def add_client_cert(
if not self._ssl:
self.add_verify(True)

def pwd_check(password: Optional[str] = None) -> str:
"""Return the password, otherwise throw an exception.
:return: The password.
:raises ValueError: When the password is `None`.
"""

if password is None:
raise ValueError("Private key password not provided")
return password

if isinstance(cert, Path):
self._ssl_builder.add_client(cert)
self._ssl_builder.add_client(cert, password=partial(pwd_check, None))
elif isinstance(cert, tuple) and len(cert) == 2:
client, key = cast(Tuple[Path, Optional[Path]], cert)
self._ssl_builder.add_client(client, key)
self._ssl_builder.add_client(client, key, password=partial(pwd_check, None))
elif isinstance(cert, tuple) and len(cert) == 3:
client, key, password = cast(
Tuple[Path, Optional[Path], Optional[str]], cert
)
self._ssl_builder.add_client(client, key, password)
self._ssl_builder.add_client(
client, key, password=partial(pwd_check, password)
)
else:
raise TypeError("Unexepected value")

Expand Down
38 changes: 35 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
import sys
import threading
from asyncio import StreamReader, StreamWriter
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
from shutil import which
from socket import gethostbyname
from subprocess import PIPE, STDOUT, Popen, TimeoutExpired

import pytest
import trustme
from cryptography.hazmat.primitives.serialization import (
BestAvailableEncryption,
Encoding,
PrivateFormat,
load_pem_private_key,
)
from pytest_mock import MockerFixture

from aiospamc.header_values import ContentLengthValue
Expand Down Expand Up @@ -321,22 +327,43 @@ def server_cert_and_key(server_cert, tmp_path_factory: pytest.TempdirFactory):
yield cert_file, key_file


@pytest.fixture(scope="session")
def client_private_key_password():
yield b"password"


@pytest.fixture(scope="session")
def client_cert_and_key(
ca, hostname, ip_address, tmp_path_factory: pytest.TempdirFactory
ca,
hostname,
ip_address,
tmp_path_factory: pytest.TempdirFactory,
client_private_key_password,
):
tmp_path = tmp_path_factory.mktemp("client_certs")
cert_file = tmp_path / "client.cert"
key_file = tmp_path / "client.key"
cert_key_file = tmp_path / "client_cert_key.pem"
enc_key_file = tmp_path / "client_enc_key.pem"

cert: trustme.LeafCert = ca.issue_cert(hostname, ip_address)

cert.private_key_and_cert_chain_pem.write_to_path(cert_key_file)
cert_file.write_bytes(b"".join([blob.bytes() for blob in cert.cert_chain_pems]))
cert.private_key_pem.write_to_path(key_file)

yield cert_file, key_file, cert_key_file
client_private_key = load_pem_private_key(
cert.private_key_pem.bytes(),
None,
)
client_enc_key_bytes = client_private_key.private_bytes(
Encoding.PEM,
PrivateFormat.PKCS8,
BestAvailableEncryption(client_private_key_password),
)
enc_key_file.write_bytes(client_enc_key_bytes)

yield cert_file, key_file, cert_key_file, enc_key_file


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -364,6 +391,11 @@ def client_key_path(client_cert_and_key):
yield client_cert_and_key[1]


@pytest.fixture(scope="session")
def client_encrypted_key_path(client_cert_and_key):
yield client_cert_and_key[3]


@dataclass
class ServerResponse:
response: bytes = b""
Expand Down
14 changes: 11 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from loguru import logger
from pytest_mock import MockerFixture
from pytest_mock import MockerFixture, MockFixture
from typer.testing import CliRunner

import aiospamc
Expand Down Expand Up @@ -97,11 +97,19 @@ def test_cli_builder_add_ca_cert_not_found():
CliClientBuilder().with_connection().add_ca_cert(Path("doesnt_exist")).build()


def test_cli_builder_add_ca_client(client_cert_path, client_key_path):
def test_cli_builder_add_ca_client(
mocker: MockFixture,
client_cert_path,
client_encrypted_key_path,
client_private_key_password,
):
mocker.patch("getpass.getpass", return_value=client_private_key_password)
c = (
CliClientBuilder()
.with_connection()
.add_client_cert(client_cert_path, client_key_path, "password")
.add_client_cert(
client_cert_path, client_encrypted_key_path, client_private_key_password
)
.build()
)

Expand Down
16 changes: 12 additions & 4 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,18 @@ def test_ssl_context_builder_add_ca_path_not_found():


def test_ssl_context_builder_add_client_cert(
mocker: MockerFixture, client_cert_path, client_key_path
mocker: MockerFixture,
client_cert_path,
client_key_path,
client_private_key_password,
):
builder = SSLContextBuilder()
certs_spy = mocker.spy(builder._context, "load_cert_chain")
s = builder.add_client(client_cert_path, client_key_path, "password").build()

assert (client_cert_path, client_key_path, "password") == certs_spy.call_args.args
password_call = lambda: client_private_key_password
s = builder.add_client(client_cert_path, client_key_path, password_call).build()

assert (
client_cert_path,
client_key_path,
password_call,
) == certs_spy.call_args.args
43 changes: 41 additions & 2 deletions tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,14 @@ def test_frontend_builder_add_client_cert_and_key(client_cert_path, client_key_p


def test_frontend_builder_add_client_cert_key_and_password(
client_cert_path, client_key_path
client_cert_path, client_encrypted_key_path, client_private_key_password
):
f = (
FrontendClientBuilder()
.with_connection()
.add_client_cert((client_cert_path, client_key_path, "password"))
.add_client_cert(
(client_cert_path, client_encrypted_key_path, client_private_key_password)
)
.build()
)

Expand Down Expand Up @@ -313,6 +315,43 @@ async def test_ping_returns_response_ssl_client(
assert isinstance(result, Response)


async def test_ping_returns_response_ssl_client_encrypted_private_key(
fake_tcp_ssl_client,
spam,
ca_cert_path,
client_cert_path,
client_encrypted_key_path,
client_private_key_password,
):
_, host, port = fake_tcp_ssl_client
result = await ping(
host=host,
port=port,
verify=ca_cert_path,
cert=(client_cert_path, client_encrypted_key_path, client_private_key_password),
)

assert isinstance(result, Response)


async def test_ping_returns_response_ssl_client_encrypted_private_key_raises_error(
fake_tcp_ssl_client,
spam,
ca_cert_path,
client_cert_path,
client_encrypted_key_path,
):
_, host, port = fake_tcp_ssl_client

with pytest.raises(ValueError):
await ping(
host=host,
port=port,
verify=ca_cert_path,
cert=(client_cert_path, client_encrypted_key_path),
)


async def test_tell_request_with_default_parameters(fake_tcp_server, spam, mocker):
_, host, port = fake_tcp_server
req_spy = mocker.spy(Client, "request")
Expand Down

0 comments on commit 74ad8fa

Please sign in to comment.