Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
refactor: convert AutopushSettings to attrs
Browse files Browse the repository at this point in the history
Issue #632
  • Loading branch information
pjenvey committed Jun 30, 2017
1 parent 2594f9e commit 1616d24
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 177 deletions.
6 changes: 4 additions & 2 deletions autopush/http.py
Expand Up @@ -41,8 +41,10 @@
from autopush.web.webpush import WebPushHandler
from autopush.websocket import (
NotificationHandler,
PushServerProtocol,
RouterHandler,
)
from autopush.websocket import PushServerProtocol # noqa

APHandlers = Sequence[Tuple[str, Type[BaseHandler]]]
CycloneLogger = Callable[[BaseHandler], None]
Expand Down Expand Up @@ -122,8 +124,8 @@ def for_handler(cls,
handler_cls)) # pragma: nocover

@classmethod
def _for_handler(cls, **kwargs):
# type: (**Any) -> BaseHTTPFactory
def _for_handler(cls, ap_settings, **kwargs):
# type: (AutopushSettings, **Any) -> BaseHTTPFactory
"""Create an instance w/ default kwargs for for_handler"""
raise NotImplementedError # pragma: nocover

Expand Down
2 changes: 1 addition & 1 deletion autopush/main.py
Expand Up @@ -205,7 +205,7 @@ def from_argparse(cls, ns):
endpoint_scheme=ns.endpoint_scheme,
endpoint_hostname=ns.endpoint_hostname or ns.hostname,
endpoint_port=ns.endpoint_port,
enable_cors=not ns.no_cors,
cors=not ns.no_cors,
bear_hash_key=ns.auth_key,
proxy_protocol_port=ns.proxy_protocol_port,
)
Expand Down
2 changes: 1 addition & 1 deletion autopush/main_argparse.py
Expand Up @@ -17,7 +17,7 @@ def add_shared_args(parser):
" Stream Name", default="", env_var="STREAM_NAME",
type=str)
parser.add_argument('--crypto_key', help="Crypto key for tokens",
default=[], env_var="CRYPTO_KEY", type=str,
default=None, env_var="CRYPTO_KEY", type=str,
action="append")
parser.add_argument('--key_hash', help="Key to hash IDs for storage",
default="", env_var="KEY_HASH", type=str)
Expand Down
283 changes: 117 additions & 166 deletions autopush/settings.py
Expand Up @@ -3,18 +3,29 @@
import socket
from argparse import Namespace # noqa
from hashlib import sha256
from typing import Any # noqa
from typing import ( # noqa
Any,
Dict,
List,
Optional,
Union
)

from attr import (
attrs,
attrib,
Factory
)
from cryptography.fernet import Fernet, MultiFernet
from cryptography.hazmat.primitives import constant_time
from twisted.web.client import _HTTP11ClientFactory

import autopush.db as db
from autopush.exceptions import (
InvalidSettings,
InvalidTokenException,
VapidAuthException
)
from autopush.types import JSONDict # noqa
from autopush.utils import (
CLIENT_SHA256_RE,
canonical_url,
Expand All @@ -27,177 +38,131 @@
from autopush.crypto_key import CryptoKey, CryptoKeyException


class QuietClientFactory(_HTTP11ClientFactory):
"""Silence the start/stop factory messages."""
noisy = False
def _init_crypto_key(ck):
# type: (Optional[Union[str, List[str]]]) -> List[str]
"""Provide a default or ensure the provided's a list"""
if ck is None:
return [Fernet.generate_key()]
return ck if isinstance(ck, list) else [ck]


@attrs
class AutopushSettings(object):
"""Main Autopush Settings Object"""
options = ["crypto_key", "hostname", "min_ping_interval",
"max_data"]

def __init__(self,
crypto_key=None,
datadog_api_key=None,
datadog_app_key=None,
datadog_flush_interval=None,
hostname=None,
port=None,
router_scheme=None,
router_hostname=None,
router_port=None,
endpoint_scheme=None,
endpoint_hostname=None,
endpoint_port=None,
proxy_protocol_port=None,
memusage_port=None,
router_conf=None,
router_tablename="router",
router_read_throughput=5,
router_write_throughput=5,
storage_tablename="storage",
storage_read_throughput=5,
storage_write_throughput=5,
message_tablename="message",
message_read_throughput=5,
message_write_throughput=5,
statsd_host="localhost",
statsd_port=8125,
resolve_hostname=False,
max_data=4096,
# Reflected up from UDP Router
wake_timeout=0,
env='development',
enable_cors=False,
hello_timeout=0,
bear_hash_key=None,
preflight_uaid="deadbeef00000000deadbeef00000000",
ami_id=None,
msg_limit=100,
debug=False,
connect_timeout=0.5,
ssl_key=None,
ssl_cert=None,
ssl_dh_param=None,
router_ssl_key=None,
router_ssl_cert=None,
client_certs=None,
auto_ping_interval=None,
auto_ping_timeout=None,
max_connections=None,
close_handshake_timeout=None,
):
"""Initialize the Settings object
Upon creation, the HTTP agent will initialize, all configured routers
will be setup and started, logging will be started, and the database
will have a preflight check done.
"""
self.debug = debug

self.connect_timeout = connect_timeout

if not crypto_key:
crypto_key = [Fernet.generate_key()]
if not isinstance(crypto_key, list):
crypto_key = [crypto_key]
self.update(crypto_key=crypto_key)
self.crypto_key = crypto_key

if bear_hash_key is None:
bear_hash_key = []
if not isinstance(bear_hash_key, list):
bear_hash_key = [bear_hash_key]
self.bear_hash_key = bear_hash_key

self.max_data = max_data

debug = attrib(default=False) # type: bool

fernet = attrib(init=False) # type: MultiFernet
_crypto_key = attrib(
convert=_init_crypto_key, default=None) # type: List[str]

bear_hash_key = attrib(default=Factory(list)) # type: List[str]

hostname = attrib(default=None) # type: Optional[str]
port = attrib(default=None) # type: Optional[int]
_resolve_hostname = attrib(default=False) # type: bool

router_scheme = attrib(default=None) # type: Optional[str]
router_hostname = attrib(default=None) # type: Optional[str]
router_port = attrib(default=None) # type: Optional[int]

endpoint_scheme = attrib(default=None) # type: Optional[str]
endpoint_hostname = attrib(default=None) # type: Optional[str]
endpoint_port = attrib(default=None) # type: Optional[int]

proxy_protocol_port = attrib(default=None) # type: Optional[int]
memusage_port = attrib(default=None) # type: Optional[int]

statsd_host = attrib(default="localhost") # type: str
statsd_port = attrib(default=8125) # type: int

datadog_api_key = attrib(default=None) # type: Optional[str]
datadog_app_key = attrib(default=None) # type: Optional[str]
datadog_flush_interval = attrib(default=None) # type: Optional[int]

router_tablename = attrib(default="router") # type: str
router_read_throughput = attrib(default=5) # type: int
router_write_throughput = attrib(default=5) # type: int
storage_tablename = attrib(default="storage") # type: str
storage_read_throughput = attrib(default=5) # type: int
storage_write_throughput = attrib(default=5) # type: int
message_tablename = attrib(default="message") # type: str
message_read_throughput = attrib(default=5) # type: int
message_write_throughput = attrib(default=5) # type: int
preflight_uaid = attrib(
default="deadbeef00000000deadbeef00000000") # type: str

ssl_key = attrib(default=None) # type: Optional[str]
ssl_cert = attrib(default=None) # type: Optional[str]
ssl_dh_param = attrib(default=None) # type: Optional[str]

router_ssl_key = attrib(default=None) # type: Optional[str]
router_ssl_cert = attrib(default=None) # type: Optional[str]

client_certs = attrib(default=None) # type: Optional[Dict[str, str]]

router_url = attrib(init=False) # type: str
endpoint_url = attrib(init=False) # type: str
ws_url = attrib(init=False) # type: str

router_conf = attrib(default=Factory(dict)) # type: JSONDict

# twisted Agent's connectTimeout
connect_timeout = attrib(default=0.5) # type: float
max_data = attrib(default=4096) # type: int
env = attrib(default='development') # type: str
ami_id = attrib(default=None) # type: Optional[str]
cors = attrib(default=False) # type: bool

hello_timeout = attrib(default=0) # type: int
# Force timeout in idle seconds
wake_timeout = attrib(default=0) # type: int
msg_limit = attrib(default=100) # type: int
auto_ping_interval = attrib(default=None) # type: Optional[int]
auto_ping_timeout = attrib(default=None) # type: Optional[int]
max_connections = attrib(default=None) # type: Optional[int]
close_handshake_timeout = attrib(default=None) # type: Optional[int]

# Generate messages per legacy rules, only used for testing to
# generate legacy data.
_notification_legacy = attrib(default=False) # type: bool

def __attrs_post_init__(self):
"""Initialize the Settings object"""
# Setup hosts/ports/urls
default_hostname = socket.gethostname()
self.hostname = hostname or default_hostname
if resolve_hostname:
if not self.hostname:
self.hostname = socket.gethostname()
if self._resolve_hostname:
self.hostname = resolve_ip(self.hostname)

self.datadog_api_key = datadog_api_key
self.datadog_app_key = datadog_app_key
self.datadog_flush_interval = datadog_flush_interval
self.statsd_host = statsd_host
self.statsd_port = statsd_port

self.port = port
self.router_port = router_port
self.proxy_protocol_port = proxy_protocol_port
self.memusage_port = memusage_port
self.endpoint_hostname = endpoint_hostname or self.hostname
self.router_hostname = router_hostname or self.hostname

if router_conf is None:
router_conf = {}
self.router_conf = router_conf
if not self.endpoint_hostname:
self.endpoint_hostname = self.hostname
if not self.router_hostname:
self.router_hostname = self.hostname

self.router_url = canonical_url(
router_scheme or 'http',
self.router_scheme or 'http',
self.router_hostname,
router_port
self.router_port
)

self.endpoint_url = canonical_url(
endpoint_scheme or 'http',
self.endpoint_scheme or 'http',
self.endpoint_hostname,
endpoint_port
self.endpoint_port
)

# not accurate under autoendpoint (like router_url)
self.ws_url = "{}://{}:{}/".format(
"wss" if ssl_key else "ws",
'wss' if self.ssl_key else 'ws',
self.hostname,
self.port
)

self.ssl_key = ssl_key
self.ssl_cert = ssl_cert
self.ssl_dh_param = ssl_dh_param
self.router_ssl_key = router_ssl_key
self.router_ssl_cert = router_ssl_cert
self.fernet = MultiFernet([Fernet(key) for key in self._crypto_key])

self.enable_tls_auth = client_certs is not None
self.client_certs = client_certs

self.auto_ping_interval = auto_ping_interval
self.auto_ping_timeout = auto_ping_timeout
self.max_connections = max_connections
self.close_handshake_timeout = close_handshake_timeout

self.router_tablename = router_tablename
self.router_read_throughput = router_read_throughput
self.router_write_throughput = router_write_throughput
self.storage_tablename = storage_tablename
self.storage_read_throughput = storage_read_throughput
self.storage_write_throughput = storage_write_throughput
self.message_tablename = message_tablename
self.message_read_throughput = message_read_throughput
self.message_write_throughput = message_write_throughput

self.msg_limit = msg_limit

# CORS
self.cors = enable_cors

# Force timeout in idle seconds
self.wake_timeout = wake_timeout

# Env
self.env = env

self.hello_timeout = hello_timeout

self.ami_id = ami_id

# Generate messages per legacy rules, only used for testing to
# generate legacy data.
self._notification_legacy = False
self.preflight_uaid = preflight_uaid
@property
def enable_tls_auth(self):
"""Whether TLS authentication w/ client certs is enabled"""
return self.client_certs is not None

@classmethod
def from_argparse(cls, ns, **kwargs):
Expand Down Expand Up @@ -315,20 +280,6 @@ def from_argparse(cls, ns, **kwargs):
**kwargs
)

def update(self, **kwargs):
"""Update the arguments, if a ``crypto_key`` is in kwargs then the
``self.fernet`` attribute will be initialized"""
for key, val in kwargs.items():
if key == "crypto_key":
fkeys = []
if not isinstance(val, list):
val = [val]
for v in val:
fkeys.append(Fernet(v))
self.fernet = MultiFernet(fkeys)
else:
setattr(self, key, val)

def make_simplepush_endpoint(self, uaid, chid):
"""Create a simplepush endpoint"""
root = self.endpoint_url + "/spush/"
Expand Down
16 changes: 10 additions & 6 deletions autopush/tests/test_endpoint.py
Expand Up @@ -216,14 +216,18 @@ def test_init_info(self):
d = self.reg._init_info()
eq_(d["remote_ip"], "local2")

def test_ap_settings_update(self):
def test_settings_crypto_key(self):
fake = 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='
reg = self.reg
reg.ap_settings.update(banana="fruit")
eq_(reg.ap_settings.banana, "fruit")
reg.ap_settings.update(crypto_key=fake)
eq_(reg.ap_settings.fernet._fernets[0]._encryption_key,
settings = AutopushSettings(crypto_key=fake)
eq_(settings.fernet._fernets[0]._encryption_key,
'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')

fake2 = 'BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB='
settings = AutopushSettings(crypto_key=[fake, fake2])
eq_(settings.fernet._fernets[0]._encryption_key,
'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
eq_(settings.fernet._fernets[1]._encryption_key,
'\x10A\x04\x10A\x04\x10A\x04\x10A\x04\x10A\x04\x10')

def test_cors(self):
ch1 = "Access-Control-Allow-Origin"
Expand Down

0 comments on commit 1616d24

Please sign in to comment.