Skip to content

Commit

Permalink
make stream-attacher work properly with async/await
Browse files Browse the repository at this point in the history
  • Loading branch information
meejah committed Apr 9, 2017
1 parent 1ada8ed commit 7418b66
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 7 deletions.
91 changes: 91 additions & 0 deletions test/py3_torstate.py
@@ -0,0 +1,91 @@
from twisted.trial import unittest
from twisted.test import proto_helpers
from twisted.internet import defer
from twisted.internet.interfaces import IReactorCore

from zope.interface import implementer

from txtorcon import TorControlProtocol
from txtorcon import TorState
from txtorcon import Circuit
from txtorcon.interface import IStreamAttacher


@implementer(IReactorCore)
class FakeReactor:

def __init__(self, test):
self.test = test

def addSystemEventTrigger(self, *args):
self.test.assertEqual(args[0], 'before')
self.test.assertEqual(args[1], 'shutdown')
self.test.assertEqual(args[2], self.test.state.undo_attacher)
return 1

def removeSystemEventTrigger(self, id):
self.test.assertEqual(id, 1)

def connectTCP(self, *args, **kw):
"""for testing build_tor_connection"""
raise RuntimeError('connectTCP: ' + str(args))

def connectUNIX(self, *args, **kw):
"""for testing build_tor_connection"""
raise RuntimeError('connectUNIX: ' + str(args))


class FakeCircuit(Circuit):

def __init__(self, id=-999):
self.streams = []
self.id = id
self.state = 'BOGUS'


class TorStatePy3Tests(unittest.TestCase):

def setUp(self):
self.protocol = TorControlProtocol()
self.state = TorState(self.protocol)
# avoid spew in trial logs; state prints this by default
self.state._attacher_error = lambda f: f
self.protocol.connectionMade = lambda: None
self.transport = proto_helpers.StringTransport()
self.protocol.makeConnection(self.transport)

def send(self, line):
self.protocol.dataReceived(line.strip() + b"\r\n")

def test_attacher_coroutine(self):
@implementer(IStreamAttacher)
class MyAttacher(object):

def __init__(self, answer):
self.streams = []
self.answer = answer

async def attach_stream(self, stream, circuits):
self.streams.append(stream)
x = await defer.succeed(self.answer)
return x

self.state.circuits[1] = FakeCircuit(1)
self.state.circuits[1].state = 'BUILT'
attacher = MyAttacher(self.state.circuits[1])
self.state.set_attacher(attacher, FakeReactor(self))

# boilerplate to finish enough set-up in the protocol so it
# works
events = 'GUARD STREAM CIRC NS NEWCONSENSUS ORCONN NEWDESC ADDRMAP STATUS_GENERAL'
self.protocol._set_valid_events(events)
self.state._add_events()
for ignored in self.state.event_map.items():
self.send(b"250 OK")

self.send(b"650 STREAM 1 NEW 0 ca.yahoo.com:80 SOURCE_ADDR=127.0.0.1:54327 PURPOSE=USER")
self.send(b"650 STREAM 1 REMAP 0 87.248.112.181:80 SOURCE=CACHE")
self.assertEqual(len(attacher.streams), 1)
self.assertEqual(attacher.streams[0].id, 1)
self.assertEqual(len(self.protocol.commands), 1)
self.assertEqual(self.protocol.commands[0][1], b'ATTACHSTREAM 1 1')
1 change: 0 additions & 1 deletion test/test_controller.py
Expand Up @@ -20,7 +20,6 @@
from txtorcon import launch
from txtorcon import connect
from txtorcon.controller import _is_non_public_numeric_address
from txtorcon.util import delete_file_or_tree
from .util import TempDir

from zope.interface import implementer, directlyProvides
Expand Down
4 changes: 4 additions & 0 deletions test/test_torstate.py
Expand Up @@ -7,6 +7,7 @@
from twisted.internet.interfaces import IStreamClientEndpoint, IReactorCore

import tempfile
import six

from ipaddress import IPv4Address

Expand All @@ -29,6 +30,9 @@
from txtorcon.torstate import _extract_reason
from txtorcon.circuit import _get_circuit_attacher

if six.PY3:
from .py3_torstate import TorStatePy3Tests # noqa


@implementer(ICircuitListener)
class CircuitListener(object):
Expand Down
3 changes: 2 additions & 1 deletion tox.ini
Expand Up @@ -11,7 +11,7 @@
# stretch: 16.3.0-1

[tox]
envlist = flake8,py27-tx15,py27-tx16,py27-tx17,pypy-tx15,pypy-tx16,pypy-tx17,py35-tx15,py35-tx16,py35-tx16
envlist = flake8,py27-tx15,py27-tx16,py27-tx17,pypy-tx15,pypy-tx16,pypy-tx17,py35-tx15,py35-tx16,py35-tx17
# if you're not using detox, you can use this list to get
# "all environments" coverage stats:
# tox -e 'clean,flake8,py27-tx16,pypy-tx16,py35-tx16,stats'
Expand Down Expand Up @@ -131,6 +131,7 @@ commands=
cuv graph

[testenv:flake8]
basepython=python3.5
deps=
flake8
commands=
Expand Down
8 changes: 4 additions & 4 deletions txtorcon/socks.py
Expand Up @@ -503,7 +503,7 @@ class _TorSocksProtocol(Protocol):
def __init__(self, host, port, socks_method, factory):
self._machine = _SocksMachine(
req_type=socks_method,
host=host, # unicode() on py3, py2? we want idna, actually?
host=host, # noqa unicode() on py3, py2? we want idna, actually?
port=port,
on_disconnect=self._on_disconnect,
on_data=self._on_data,
Expand Down Expand Up @@ -645,7 +645,7 @@ def resolve(tor_endpoint, hostname):
:param hostname: the hostname to look up.
"""
if six.PY2 and isinstance(hostname, str):
hostname = unicode(hostname)
hostname = unicode(hostname) # noqa
factory = _TorSocksFactory(
hostname, 0, 'RESOLVE', None,
)
Expand All @@ -657,7 +657,7 @@ def resolve(tor_endpoint, hostname):
@inlineCallbacks
def resolve_ptr(tor_endpoint, hostname):
if six.PY2 and isinstance(hostname, str):
hostname = unicode(hostname)
hostname = unicode(hostname) # noqa
factory = _TorSocksFactory(
hostname, 0, 'RESOLVE_PTR', None,
)
Expand All @@ -679,7 +679,7 @@ class TorSocksEndpoint(object):
def __init__(self, socks_endpoint, host, port, tls=False):
self._proxy_ep = socks_endpoint # can be Deferred
if six.PY2 and isinstance(host, str):
host = unicode(host)
host = unicode(host) # noqa
self._host = host
self._port = port
self._tls = tls
Expand Down
2 changes: 2 additions & 0 deletions txtorcon/torcontrolprotocol.py
Expand Up @@ -20,6 +20,7 @@

from txtorcon.interface import ITorControlProtocol
from .spaghetti import FSM, State, Transition
from .util import maybe_coroutine

import os
import re
Expand Down Expand Up @@ -757,6 +758,7 @@ def _do_authenticate(self, protoinfo):

if self.password_function and 'HASHEDPASSWORD' in methods:
d = defer.maybeDeferred(self.password_function)
d.addCallback(maybe_coroutine)
d.addCallback(self._do_password_authentication)
d.addErrback(self._auth_failed)
return
Expand Down
2 changes: 2 additions & 0 deletions txtorcon/torstate.py
Expand Up @@ -34,6 +34,7 @@
from txtorcon.interface import IStreamAttacher
from ._microdesc_parser import MicrodescriptorParser
from .router import hexIdFromHash
from .util import maybe_coroutine


def _build_state(proto):
Expand Down Expand Up @@ -630,6 +631,7 @@ def _maybe_attach(self, stream):
self._attacher.attach_stream,
stream, self.circuits,
)
circ_d.addCallback(maybe_coroutine)

# actually do the attachment logic; .attach() can return 3 things:
# 1. None: let Tor do whatever it wants
Expand Down
19 changes: 18 additions & 1 deletion txtorcon/util.py
Expand Up @@ -24,6 +24,9 @@
from zope.interface import implementer
from zope.interface import Interface

if six.PY3:
import asyncio

try:
import GeoIP as _GeoIP
GeoIP = _GeoIP
Expand Down Expand Up @@ -146,7 +149,7 @@ def maybe_ip_addr(addr):
"""

if six.PY2 and isinstance(addr, str):
addr = unicode(addr)
addr = unicode(addr) # noqa
try:
return ipaddress.ip_address(addr)
except ValueError:
Expand Down Expand Up @@ -379,6 +382,19 @@ def notify(*args, **kw):
"""


def maybe_coroutine(obj):
"""
If 'obj' is a coroutine and we're using Python3, wrap it in
ensureDeferred. Otherwise return the original object.
(This is to insert in all callback chains from user code, in case
that user code is Python3 and used 'async def')
"""
if six.PY3 and asyncio.iscoroutine(obj):
return defer.ensureDeferred(obj)
return obj


@implementer(IListener)
class _Listener(object):
"""
Expand Down Expand Up @@ -417,6 +433,7 @@ def failed(fail):

for cb in self._listeners:
d = defer.maybeDeferred(cb, *args, **kw)
d.addCallback(maybe_coroutine)
d.addErrback(failed)
calls.append(d)
return defer.DeferredList(calls)
Expand Down

0 comments on commit 7418b66

Please sign in to comment.