Skip to content

Commit

Permalink
asyncio simplify: we don't need a queue for proxy->main loop comms
Browse files Browse the repository at this point in the history
Instead, we just schedule coroutines directly onto the core loop.
  • Loading branch information
cortesi committed Apr 6, 2018
1 parent cdbe6f9 commit 0fa1280
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 85 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ build/
dist/
mitmproxy/contrib/kaitaistruct/*.ksy
.pytest_cache
__pycache__

# UI

Expand Down
3 changes: 2 additions & 1 deletion mitmproxy/addonmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import traceback
import contextlib
import sys
import asyncio

from mitmproxy import exceptions
from mitmproxy import eventsequence
Expand Down Expand Up @@ -220,7 +221,7 @@ def __contains__(self, item):
name = _get_name(item)
return name in self.lookup

def handle_lifecycle(self, name, message):
async def handle_lifecycle(self, name, message):
"""
Handle a lifecycle event.
"""
Expand Down
14 changes: 10 additions & 4 deletions mitmproxy/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ class Channel:
The only way for the proxy server to communicate with the master
is to use the channel it has been given.
"""
def __init__(self, loop, q, should_exit):
def __init__(self, master, loop, should_exit):
self.master = master
self.loop = loop
self.should_exit = should_exit
self._q = q

def ask(self, mtype, m):
"""
Expand All @@ -22,7 +22,10 @@ def ask(self, mtype, m):
exceptions.Kill: All connections should be closed immediately.
"""
m.reply = Reply(m)
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
g = m.reply.q.get()
if g == exceptions.Kill:
raise exceptions.Kill()
Expand All @@ -34,7 +37,10 @@ def tell(self, mtype, m):
then return immediately.
"""
m.reply = DummyReply()
asyncio.run_coroutine_threadsafe(self._q.put((mtype, m)), self.loop)
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)


NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
Expand Down
19 changes: 3 additions & 16 deletions mitmproxy/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ class Master:
The master handles mitmproxy's main event loop.
"""
def __init__(self, opts):
self.event_queue = asyncio.Queue()
self.should_exit = threading.Event()
self.channel = controller.Channel(
self,
asyncio.get_event_loop(),
self.event_queue,
self.should_exit,
)
asyncio.ensure_future(self.main())
asyncio.ensure_future(self.tick())

self.options = opts or options.Options() # type: options.Options
Expand Down Expand Up @@ -96,17 +94,6 @@ def start(self):
if self.server:
ServerThread(self.server).start()

async def main(self):
while True:
try:
mtype, obj = await self.event_queue.get()
except RuntimeError:
return
if mtype not in eventsequence.Events: # pragma: no cover
raise exceptions.ControlException("Unknown event %s" % repr(mtype))
self.addons.handle_lifecycle(mtype, obj)
self.event_queue.task_done()

async def tick(self):
if self.first_tick:
self.first_tick = False
Expand Down Expand Up @@ -145,7 +132,7 @@ def _change_reverse_host(self, f):
f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme

def load_flow(self, f):
async def load_flow(self, f):
"""
Loads a flow and links websocket & handshake flows
"""
Expand All @@ -163,7 +150,7 @@ def load_flow(self, f):

f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o)
await self.addons.handle_lifecycle(e, o)

def replay_request(
self,
Expand Down
3 changes: 2 additions & 1 deletion mitmproxy/tools/web/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os.path
import re
from io import BytesIO
import asyncio

import mitmproxy.flow
import tornado.escape
Expand Down Expand Up @@ -235,7 +236,7 @@ def post(self):
self.view.clear()
bio = BytesIO(self.filecontents)
for i in io.FlowReader(bio).stream():
self.master.load_flow(i)
asyncio.call_soon(self.master.load_flow, i)
bio.close()


Expand Down
4 changes: 0 additions & 4 deletions test/mitmproxy/proxy/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def test_ignore(self):
i2 = self.pathod("306")
self._ignore_off()

self.master.event_queue.join()

assert n.status_code == 304
assert i.status_code == 305
assert i2.status_code == 306
Expand Down Expand Up @@ -168,8 +166,6 @@ def test_tcp(self):
i2 = self.pathod("306")
self._tcpproxy_off()

self.master.event_queue.join()

assert n.status_code == 304
assert i.status_code == 305
assert i2.status_code == 306
Expand Down
5 changes: 3 additions & 2 deletions test/mitmproxy/test_addonmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def test_halt():
assert end.custom_called


def test_lifecycle():
@pytest.mark.asyncio
async def test_lifecycle():
o = options.Options()
m = master.Master(o)
a = addonmanager.AddonManager(m)
Expand All @@ -77,7 +78,7 @@ def test_lifecycle():
a.remove(TAddon("nonexistent"))

f = tflow.tflow()
a.handle_lifecycle("request", f)
await a.handle_lifecycle("request", f)

a._configure_all(o, o.keys())

Expand Down
74 changes: 35 additions & 39 deletions test/mitmproxy/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest import mock
import pytest

from mitmproxy.test import tflow, tutils
from mitmproxy.test import tflow, tutils, taddons
import mitmproxy.io
from mitmproxy import flowfilter
from mitmproxy import options
Expand Down Expand Up @@ -97,30 +97,30 @@ def test_copy(self):


class TestFlowMaster:
def test_load_http_flow_reverse(self):
s = tservers.TestState()
@pytest.mark.asyncio
async def test_load_http_flow_reverse(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
f = tflow.tflow(resp=True)
fm.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"

def test_load_websocket_flow(self):
s = tservers.TestState()
with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(resp=True)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"

@pytest.mark.asyncio
async def test_load_websocket_flow(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
f = tflow.twebsocketflow()
fm.load_flow(f.handshake_flow)
fm.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
s = tservers.TestState()
with taddons.context(s, options=opts) as ctx:
f = tflow.twebsocketflow()
await ctx.master.load_flow(f.handshake_flow)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)

def test_replay(self):
opts = options.Options()
Expand Down Expand Up @@ -150,31 +150,27 @@ def test_replay(self):
assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers

def test_all(self):
@pytest.mark.asyncio
async def test_all(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
s = tservers.TestState()
fm = master.Master(None)
fm.addons.add(s)
f = tflow.tflow(req=None)
fm.addons.handle_lifecycle("clientconnect", f.client_conn)
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
fm.addons.handle_lifecycle("request", f)
assert len(s.flows) == 1

f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
fm.addons.handle_lifecycle("response", f)
assert len(s.flows) == 1

fm.addons.handle_lifecycle("clientdisconnect", f.client_conn)
with taddons.context(s, options=opts) as ctx:
f = tflow.tflow(req=None)
await ctx.master.addons.handle_lifecycle("clientconnect", f.client_conn)
f.request = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
await ctx.master.addons.handle_lifecycle("request", f)
assert len(s.flows) == 1

f.error = flow.Error("msg")
fm.addons.handle_lifecycle("error", f)
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
await ctx.master.addons.handle_lifecycle("response", f)
assert len(s.flows) == 1

# FIXME: This no longer works, because we consume on the main loop.
# fm.tell("foo", f)
# with pytest.raises(ControlException):
# fm.addons.trigger("unknown")
await ctx.master.addons.handle_lifecycle("clientdisconnect", f.client_conn)

fm.shutdown()
f.error = flow.Error("msg")
await ctx.master.addons.handle_lifecycle("error", f)


class TestError:
Expand Down
6 changes: 4 additions & 2 deletions test/mitmproxy/tools/console/test_defaultkeys.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from mitmproxy.tools.console import master
from mitmproxy import command

import pytest

def test_commands_exist():
@pytest.mark.asyncio
async def test_commands_exist():
km = keymap.Keymap(None)
defaultkeys.map(km)
assert km.bindings
m = master.ConsoleMaster(None)
m.load_flow(tflow())
await m.load_flow(tflow())

for binding in km.bindings:
cmd, *args = command.lexer(binding.command)
Expand Down
8 changes: 6 additions & 2 deletions test/mitmproxy/tools/console/test_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from mitmproxy.tools import console
from ... import tservers

import pytest

@pytest.mark.asyncio


class TestMaster(tservers.MasterTest):
def mkmaster(self, **opts):
Expand All @@ -12,11 +16,11 @@ def mkmaster(self, **opts):
m.addons.trigger("configure", o.keys())
return m

def test_basic(self):
async def test_basic(self):
m = self.mkmaster()
for i in (1, 2, 3):
try:
self.dummy_cycle(m, 1, b"")
await self.dummy_cycle(m, 1, b"")
except urwid.ExitMainLoop:
pass
assert len(m.view) == i
12 changes: 6 additions & 6 deletions test/mitmproxy/tools/web/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from unittest import mock
import os
import asyncio

import pytest
import tornado.testing
Expand Down Expand Up @@ -32,6 +33,11 @@ def json(resp: httpclient.HTTPResponse):

@pytest.mark.usefixtures("no_tornado_logging")
class TestApp(tornado.testing.AsyncHTTPTestCase):
def get_new_ioloop(self):
io_loop = tornado.platform.asyncio.AsyncIOLoop()
asyncio.set_event_loop(io_loop.asyncio_loop)
return io_loop

def get_app(self):
o = options.Options(http2=False)
m = webmaster.WebMaster(o, with_termlog=False)
Expand Down Expand Up @@ -75,12 +81,6 @@ def test_flows_dump(self):
resp = self.fetch("/flows/dump")
assert b"address" in resp.body

self.view.clear()
assert not len(self.view)

assert self.fetch("/flows/dump", method="POST", body=resp.body).code == 200
assert len(self.view)

def test_clear(self):
events = self.events.data.copy()
flows = list(self.view)
Expand Down
7 changes: 5 additions & 2 deletions test/mitmproxy/tools/web/test_master.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from mitmproxy.tools.web import master
from mitmproxy import options

import pytest

from ... import tservers


Expand All @@ -9,8 +11,9 @@ def mkmaster(self, **opts):
o = options.Options(**opts)
return master.WebMaster(o)

def test_basic(self):
@pytest.mark.asyncio
async def test_basic(self):
m = self.mkmaster()
for i in (1, 2, 3):
self.dummy_cycle(m, 1, b"")
await self.dummy_cycle(m, 1, b"")
assert len(m.view) == i
12 changes: 6 additions & 6 deletions test/mitmproxy/tservers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@

class MasterTest:

def cycle(self, master, content):
async def cycle(self, master, content):
f = tflow.tflow(req=tutils.treq(content=content))
layer = mock.Mock("mitmproxy.proxy.protocol.base.Layer")
layer.client_conn = f.client_conn
layer.reply = controller.DummyReply()
master.addons.handle_lifecycle("clientconnect", layer)
await master.addons.handle_lifecycle("clientconnect", layer)
for i in eventsequence.iterate(f):
master.addons.handle_lifecycle(*i)
master.addons.handle_lifecycle("clientdisconnect", layer)
await master.addons.handle_lifecycle(*i)
await master.addons.handle_lifecycle("clientdisconnect", layer)
return f

def dummy_cycle(self, master, n, content):
async def dummy_cycle(self, master, n, content):
for i in range(n):
self.cycle(master, content)
await self.cycle(master, content)
master.shutdown()

def flowfile(self, path):
Expand Down

0 comments on commit 0fa1280

Please sign in to comment.