Skip to content

Commit

Permalink
Fix type hints due to changes in Twisted 21.2.0. (#178)
Browse files Browse the repository at this point in the history
Twisted 21.2.0 added type hints which showed flaws in the
type hints of Sygnal. Those are fixed to be consistent with
Twisted.
  • Loading branch information
clokep authored Mar 16, 2021
1 parent ab87ab8 commit d99867b
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 28 deletions.
1 change: 1 addition & 0 deletions changelog.d/178.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix type hints due to Twisted upgrade.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[mypy]
plugins = mypy_zope:plugin
check_untyped_defs = True
show_error_codes = True
show_traceback = True
Expand Down
25 changes: 12 additions & 13 deletions sygnal/gcmpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,10 @@ async def create(cls, name, sygnal, config):
an instance of this Pushkin
"""
logger.debug("About to set up CanonicalRegId Store")
canonical_reg_id_store = CanonicalRegIdStore()
await canonical_reg_id_store.setup(sygnal.database, sygnal.database_engine)
canonical_reg_id_store = CanonicalRegIdStore(
sygnal.database, sygnal.database_engine
)
await canonical_reg_id_store.setup()
logger.debug("Finished setting up CanonicalRegId Store")

return cls(name, sygnal, config, canonical_reg_id_store)
Expand Down Expand Up @@ -468,26 +470,23 @@ class CanonicalRegIdStore(object):
);
"""

def __init__(self):
self.db: ConnectionPool = None
self.engine = None

async def setup(self, db, engine):
def __init__(self, db: ConnectionPool, engine: str):
"""
Prepares, if necessary, the database for storing canonical registration IDs.
Separate method from the constructor because we wait for an async request
to complete, so it must be an `async def` method.
Args:
db (adbapi.ConnectionPool): database to prepare
engine (str):
Database engine to use. Shoud be either "sqlite" or "postgresql".
"""
self.db = db
self.engine = engine

async def setup(self):
"""
Prepares, if necessary, the database for storing canonical registration IDs.
Separate method from the constructor because we wait for an async request
to complete, so it must be an `async def` method.
"""
await self.db.runOperation(self.TABLE_CREATE_QUERY)

async def set_canonical_id(self, reg_id, canonical_reg_id):
Expand Down
27 changes: 14 additions & 13 deletions sygnal/helper/proxy/connectproxyclient_twisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from twisted.internet import defer, protocol
from twisted.internet.base import ReactorBase
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IProtocol, IStreamClientEndpoint
from twisted.internet.protocol import connectionDone
from twisted.internet.interfaces import IProtocolFactory, IStreamClientEndpoint
from twisted.internet.protocol import Protocol, connectionDone
from twisted.web import http
from zope.interface import implementer

Expand Down Expand Up @@ -71,7 +71,8 @@ def __init__(
def __repr__(self):
return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)

def connect(self, protocolFactory: protocol.ClientFactory):
def connect(self, protocolFactory: IProtocolFactory):
assert isinstance(protocolFactory, protocol.ClientFactory)
f = HTTPProxiedClientFactory(
self._host, self._port, self._proxy_auth, protocolFactory
)
Expand All @@ -90,11 +91,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
connection.
Args:
dst_host (bytes): hostname that we want to CONNECT to
dst_port (int): port that we want to connect to
proxy_auth (tuple): None or tuple of (username, pasword) for HTTP basic proxy
dst_host: hostname that we want to CONNECT to
dst_port: port that we want to connect to
proxy_auth: None or tuple of (username, pasword) for HTTP basic proxy
authentication
wrapped_factory (protocol.ClientFactory): The original Factory
wrapped_factory: The original Factory
"""

def __init__(
Expand Down Expand Up @@ -141,18 +142,18 @@ class HTTPConnectProtocol(protocol.Protocol):
"""Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
Args:
host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
host: The original HTTP(s) hostname or IPv4 or IPv6 address literal
to put in the CONNECT request
port (int): The original HTTP(s) port to put in the CONNECT request
port: The original HTTP(s) port to put in the CONNECT request
proxy_auth (tuple): None or tuple of (username, pasword) for HTTP basic proxy
proxy_auth: None or tuple of (username, pasword) for HTTP basic proxy
authentication
wrapped_protocol (interfaces.IProtocol): the original protocol (probably
wrapped_protocol: the original protocol (probably
HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
connected_deferred (Deferred): a Deferred which will be callbacked with
connected_deferred: a Deferred which will be callbacked with
wrapped_protocol when the CONNECT completes
"""

Expand All @@ -161,7 +162,7 @@ def __init__(
host: bytes,
port: int,
proxy_auth: Optional[Tuple[str, str]],
wrapped_protocol: IProtocol,
wrapped_protocol: Protocol,
connected_deferred: Deferred,
):
self.host = host
Expand Down
5 changes: 3 additions & 2 deletions sygnal/helper/proxy/proxyagent_twisted.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.python.failure import Failure
from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
from twisted.web.error import SchemeNotSupported
Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(

self.proxy_endpoint = HostnameEndpoint(
reactor, parsed_url.hostname, parsed_url.port, **self._endpoint_kwargs
)
) # type: Optional[HostnameEndpoint]
else:
self.proxy_endpoint = None

Expand Down Expand Up @@ -127,7 +128,7 @@ def request(self, method, uri, headers=None, bodyProducer=None):
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.proxy_endpoint)
endpoint = self.proxy_endpoint
endpoint = self.proxy_endpoint # type: IStreamClientEndpoint
request_path = uri
elif parsed_uri.scheme == b"https" and self.proxy_endpoint:
endpoint = HTTPConnectProxyEndpoint(
Expand Down
4 changes: 4 additions & 0 deletions tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def getHostByName(name, timeout=None):

self.nameResolver = SimpleResolverComplexifier(FakeResolver())

def installNameResolver(self, resolver):
# It is not expected that this gets called.
raise RuntimeError(resolver)

def callFromThread(self, function, *args):
self.callLater(0, function, *args)

Expand Down

0 comments on commit d99867b

Please sign in to comment.