Skip to content

Commit

Permalink
fix a bug where connections would not be fully closed (mitmproxy#6543)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Dec 12, 2023
1 parent 1fcd033 commit 0a3e016
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 78 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -11,6 +11,8 @@
([#6548](https://github.com/mitmproxy/mitmproxy/pull/6548), @zanieb)
* Improved handling for `--allow-hosts`/`--ignore-hosts` options in WireGuard mode (#5930).
([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper)
* Fix a bug where TCP connections were not closed properly.
([#6543](https://github.com/mitmproxy/mitmproxy/pull/6543), @mhils)
* DNS resolution is now exempted from `--ignore-hosts` in WireGuard Mode.
([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper)
* Fix a bug where logging was stopped prematurely during shutdown.
Expand Down
6 changes: 5 additions & 1 deletion mitmproxy/addons/clientplayback.py
Expand Up @@ -161,9 +161,13 @@ def running(self):
)
self.options = ctx.options

def done(self):
async def done(self):
if self.playback_task:
self.playback_task.cancel()
try:
await self.playback_task
except asyncio.CancelledError:
pass

async def playback(self):
while True:
Expand Down
2 changes: 1 addition & 1 deletion mitmproxy/optmanager.py
Expand Up @@ -523,7 +523,7 @@ def parse(text):
if not text:
return {}
try:
yaml = ruamel.yaml.YAML(typ="unsafe", pure=True)
yaml = ruamel.yaml.YAML(typ="safe", pure=True)
data = yaml.load(text)
except ruamel.yaml.error.YAMLError as v:
if hasattr(v, "problem_mark"):
Expand Down
1 change: 1 addition & 0 deletions mitmproxy/proxy/mode_servers.py
Expand Up @@ -199,6 +199,7 @@ async def handle_tcp_connection(
original_dst = platform.original_addr(s)
except Exception as e:
logger.error(f"Transparent mode failure: {e!r}")
writer.close()
return
else:
handler.layer.context.client.sockname = original_dst
Expand Down
18 changes: 13 additions & 5 deletions mitmproxy/proxy/server.py
Expand Up @@ -308,7 +308,10 @@ async def handle_connection(self, connection: Connection) -> None:
# we may still use this connection to *send* stuff,
# even though the remote has closed their side of the connection.
# to make this work we keep this task running and wait for cancellation.
await asyncio.Event().wait()
try:
await asyncio.Event().wait()
except asyncio.CancelledError as e:
cancelled = e

try:
writer = self.transports[connection].writer
Expand Down Expand Up @@ -336,10 +339,15 @@ async def drain_writers(self):
transport.handler.cancel(f"Error sending data: {e}")

async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
handler = self.transports[self.client].handler
assert handler
handler.cancel("timeout")
try:
handler = self.transports[self.client].handler
except KeyError: # pragma: no cover
# there is a super short window between connection close and watchdog cancellation
pass
else:
self.log(f"Closing connection due to inactivity: {self.client}")
assert handler
handler.cancel("timeout")

async def hook_task(self, hook: commands.StartHook) -> None:
await self.handle_hook(hook)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Expand Up @@ -130,6 +130,8 @@ testpaths = "test"
addopts = "--capture=no --color=yes"
filterwarnings = [
"ignore::DeprecationWarning:tornado.*:",
"error::RuntimeWarning",
"error::pytest.PytestUnraisableExceptionWarning",
]

[tool.mypy]
Expand Down
73 changes: 43 additions & 30 deletions test/mitmproxy/addons/test_asgiapp.py
Expand Up @@ -57,57 +57,70 @@ async def test_asgi_full(caplog):
assert await ps.setup_servers()
proxy_addr = ("127.0.0.1", ps.listen_addrs()[0][1])

reader, writer = await asyncio.open_connection(*proxy_addr)
# We parallelize connection establishment/closure because those operations tend to be slow.
[
(r1, w1),
(r2, w2),
(r3, w3),
(r4, w4),
(r5, w5),
] = await asyncio.gather(
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
)

req = f"GET http://testapp:80/ HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w1.write(req.encode())
header = await r1.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 200 OK")
body = await reader.readuntil(b"testapp")
body = await r1.readuntil(b"testapp")
assert body == b"testapp"
writer.close()
await writer.wait_closed()

reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"GET http://testapp:80/parameters?param1=1&param2=2 HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w2.write(req.encode())
header = await r2.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 200 OK")
body = await reader.readuntil(b"}")
body = await r2.readuntil(b"}")
assert body == b'{"param1": "1", "param2": "2"}'
writer.close()
await writer.wait_closed()

reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"POST http://testapp:80/requestbody HTTP/1.1\r\nContent-Length: 6\r\n\r\nHello!"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w3.write(req.encode())
header = await r3.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 200 OK")
body = await reader.readuntil(b"}")
body = await r3.readuntil(b"}")
assert body == b'{"body": "Hello!"}'
writer.close()
await writer.wait_closed()

reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"GET http://errapp:80/?foo=bar HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w4.write(req.encode())
header = await r4.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 500")
body = await reader.readuntil(b"ASGI Error")
body = await r4.readuntil(b"ASGI Error")
assert body == b"ASGI Error"
writer.close()
await writer.wait_closed()
assert "ValueError" in caplog.text

reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"GET http://noresponseapp:80/ HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w5.write(req.encode())
header = await r5.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 500")
body = await reader.readuntil(b"ASGI Error")
body = await r5.readuntil(b"ASGI Error")
assert body == b"ASGI Error"
writer.close()
await writer.wait_closed()
assert "no response sent" in caplog.text

w1.close()
w2.close()
w3.close()
w4.close()
w5.close()
await asyncio.gather(
w1.wait_closed(),
w2.wait_closed(),
w3.wait_closed(),
w4.wait_closed(),
w5.wait_closed(),
)

tctx.configure(ps, server=False)
assert await ps.setup_servers()
75 changes: 49 additions & 26 deletions test/mitmproxy/addons/test_clientplayback.py
Expand Up @@ -17,23 +17,48 @@

@asynccontextmanager
async def tcp_server(handle_conn, **server_args) -> Address:
server = await asyncio.start_server(handle_conn, "127.0.0.1", 0, **server_args)
"""TCP server context manager that...
1. Exits only after all handlers have returned.
2. Ensures that all handlers are closed properly. If we don't do that,
we get ghost errors in others tests from StreamWriter.__del__.
Spawning a TCP server is relatively slow. Consider using in-memory networking for faster tests.
"""
if not hasattr(asyncio, "TaskGroup"):
pytest.skip("Skipped because asyncio.TaskGroup is unavailable.")

tasks = asyncio.TaskGroup()

async def handle_conn_wrapper(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
try:
await handle_conn(reader, writer)
except Exception as e:
print(f"!!! TCP handler failed: {e}")
raise
finally:
if not writer.is_closing():
writer.close()
await writer.wait_closed()

async def _handle(r, w):
tasks.create_task(handle_conn_wrapper(r, w))

server = await asyncio.start_server(_handle, "127.0.0.1", 0, **server_args)
await server.start_serving()
try:
yield server.sockets[0].getsockname()
finally:
server.close()
async with server:
async with tasks:
yield server.sockets[0].getsockname()


@pytest.mark.parametrize("mode", ["http", "https", "upstream", "err"])
@pytest.mark.parametrize("concurrency", [-1, 1])
async def test_playback(tdata, mode, concurrency):
handler_ok = asyncio.Event()

async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
if mode == "err":
writer.close()
handler_ok.set()
return
req = await reader.readline()
if mode == "upstream":
Expand All @@ -49,7 +74,6 @@ async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
writer.write(b"HTTP/1.1 204 No Content\r\n\r\n")
await writer.drain()
assert not await reader.read()
handler_ok.set()

cp = ClientPlayback()
ps = Proxyserver()
Expand Down Expand Up @@ -92,22 +116,20 @@ async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
cp.start_replay([flow])
assert cp.count() == 1
await asyncio.wait_for(cp.queue.join(), 5)
await asyncio.wait_for(handler_ok.wait(), 5)
cp.done()
if mode != "err":
assert flow.response.status_code == 204
while cp.replay_tasks:
await asyncio.sleep(0.001)
if mode != "err":
assert flow.response.status_code == 204
await cp.done()


async def test_playback_https_upstream():
handler_ok = asyncio.Event()

async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
conn_req = await reader.readuntil(b"\r\n\r\n")
assert conn_req == b"CONNECT address:22 HTTP/1.1\r\n\r\n"
writer.write(b"HTTP/1.1 502 Bad Gateway\r\n\r\n")
await writer.drain()
assert not await reader.read()
handler_ok.set()

cp = ClientPlayback()
ps = Proxyserver()
Expand All @@ -122,17 +144,17 @@ async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
cp.start_replay([flow])
assert cp.count() == 1
await asyncio.wait_for(cp.queue.join(), 5)
await asyncio.wait_for(handler_ok.wait(), 5)
cp.done()
assert flow.response is None
assert (
str(flow.error)
== f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway"
)

assert flow.response is None
assert (
str(flow.error)
== f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway"
)
await cp.done()


async def test_playback_crash(monkeypatch, caplog_async):
async def raise_err():
async def raise_err(*_, **__):
raise ValueError("oops")

monkeypatch.setattr(ReplayHandler, "replay", raise_err)
Expand All @@ -141,8 +163,9 @@ async def raise_err():
cp.running()
cp.start_replay([tflow.tflow(live=False)])
await caplog_async.await_log("Client replay has crashed!")
assert "oops" in caplog_async.caplog.text
assert cp.count() == 0
cp.done()
await cp.done()


def test_check():
Expand Down
26 changes: 15 additions & 11 deletions test/mitmproxy/addons/test_proxyserver.py
Expand Up @@ -23,6 +23,7 @@
from aioquic.quic.connection import QuicConnection
from aioquic.quic.connection import QuicConnectionError

from .test_clientplayback import tcp_server
import mitmproxy.platform
from mitmproxy import dns
from mitmproxy import exceptions
Expand Down Expand Up @@ -55,16 +56,6 @@ def tcp_start(self, f):
self.flows.append(f)


@asynccontextmanager
async def tcp_server(handle_conn) -> Address:
server = await asyncio.start_server(handle_conn, "127.0.0.1", 0)
await server.start_serving()
try:
yield server.sockets[0].getsockname()
finally:
server.close()


async def test_start_stop(caplog_async):
caplog_async.set_level("INFO")

Expand All @@ -74,7 +65,6 @@ async def server_handler(
assert await reader.readuntil(b"\r\n\r\n") == b"GET /hello HTTP/1.1\r\n\r\n"
writer.write(b"HTTP/1.1 204 No Content\r\n\r\n")
await writer.drain()
writer.close()

ps = Proxyserver()
nl = NextLayer()
Expand Down Expand Up @@ -160,6 +150,9 @@ async def server_handler(
ps.inject_tcp(state.flows[0], True, b"c")
assert await reader.read(1) == b"c"

writer.close()
await writer.wait_closed()


async def test_inject_fail(caplog) -> None:
ps = Proxyserver()
Expand Down Expand Up @@ -311,6 +304,9 @@ async def test_dns(caplog_async) -> None:
tctx.configure(ps, server=False)
await caplog_async.await_log("stopped")

w.close()
await w.wait_closed()


def test_validation_no_transparent(monkeypatch):
monkeypatch.setattr(mitmproxy.platform, "original_addr", None)
Expand Down Expand Up @@ -373,6 +369,9 @@ def server_handler(
tctx.configure(ps, server=False)
await caplog_async.await_log("stopped")

w.close()
await w.wait_closed()


class H3EchoServer(QuicConnectionProtocol):
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -779,6 +778,11 @@ async def test_reverse_http3_and_quic_stream(
await _test_echo(client, strict=scheme == "http3")
assert len(ps.connections) == 1

# dirty hack: forcibly close all connections so that there are no unexpected asyncio tasks
# that may cause test failures because they have not been run.
for conn in ps.servers[mode].manager.connections.values():
await conn.on_timeout()

tctx.configure(ps, server=False)
await caplog_async.await_log(f"stopped")

Expand Down

0 comments on commit 0a3e016

Please sign in to comment.