Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Python 3.11's start_tls #143

Merged
merged 1 commit into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 14 additions & 34 deletions pymap/imap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@
import re
import sys
from argparse import ArgumentParser
from asyncio import shield, StreamReader, StreamWriter, WriteTransport, \
AbstractServer, CancelledError, TimeoutError
from asyncio import shield, StreamReader, StreamWriter, AbstractServer, \
CancelledError, TimeoutError
from base64 import b64encode, b64decode
from collections.abc import Awaitable, Iterable
from contextlib import closing, AsyncExitStack
from ssl import SSLError
from typing import TypeVar
from uuid import uuid4

from proxyprotocol.reader import ProxyProtocolReader
from proxyprotocol.result import ProxyResult
from proxyprotocol.sock import SocketInfo
from proxyprotocol.version import ProxyProtocolVersion
from pymap.concurrent import Event
Expand Down Expand Up @@ -78,16 +76,18 @@ async def start(self, stack: AsyncExitStack) -> None:
config = self.config
servers: list[AbstractServer] = []
imap_server = IMAPServer(backend.login, config)
pp_reader = ProxyProtocolReader(config.proxy_protocol)
imap_server_cb = pp_reader.get_callback(imap_server)
if config.args.inherited_sockets:
sockets = InheritedSockets.of(config.args.inherited_sockets).get()
if not sockets:
raise ValueError('No inherited sockets found')
for sock in sockets:
servers.append(await asyncio.start_server(
imap_server, sock=sock))
imap_server_cb, sock=sock))
else:
servers.append(await asyncio.start_server(
imap_server, host=config.host, port=config.port))
imap_server_cb, host=config.host, port=config.port))
for server in servers:
await stack.enter_async_context(server)
task = asyncio.create_task(server.serve_forever())
Expand All @@ -113,9 +113,10 @@ def __init__(self, login: LoginInterface, config: IMAPConfig) -> None:
self._login = login
self._config = config

async def __call__(self, reader: StreamReader,
writer: StreamWriter) -> None:
conn = IMAPConnection(self.commands, self._config, reader, writer)
async def __call__(self, reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
conn = IMAPConnection(self.commands, self._config,
reader, writer, sock_info)
state = ConnectionState(self._login, self._config)
async with AsyncExitStack() as stack:
connection_exit.set(stack)
Expand All @@ -141,27 +142,16 @@ class IMAPConnection:
'reader', 'writer', 'pp_reader', 'pp_result']

def __init__(self, commands: Commands, config: IMAPConfig,
reader: StreamReader,
writer: StreamWriter) -> None:
reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
super().__init__()
self.commands = commands
self.config = config
self.params = config.parsing_params
self.bad_command_limit = config.bad_command_limit
self.pp_reader = ProxyProtocolReader(config.proxy_protocol)
self.pp_result: ProxyResult | None = None
self._reset_streams(reader, writer)

def _reset_streams(self, reader: StreamReader,
writer: StreamWriter) -> None:
self.reader = reader
self.writer = writer
socket_info.set(SocketInfo.get(writer, self.pp_result,
unique_id=uuid4().bytes))

async def _read_proxy_protocol(self) -> None:
self.pp_result = await self.pp_reader.read(self.reader)
self._reset_streams(self.reader, self.writer)
socket_info.set(sock_info)

def close(self) -> None:
self.writer.close()
Expand Down Expand Up @@ -275,17 +265,8 @@ async def write_response(self, resp: Response) -> None:
self._print('%s <--| %s', bytes(resp))

async def start_tls(self) -> None:
loop = asyncio.get_event_loop()
transport = self.writer.transport
protocol = transport.get_protocol()
ssl_context = self.config.ssl_context
new_transport = await loop.start_tls(
transport, protocol, ssl_context, server_side=True)
assert isinstance(new_transport, WriteTransport)
new_protocol = new_transport.get_protocol()
new_writer = StreamWriter(new_transport, new_protocol,
self.reader, loop)
self._reset_streams(self.reader, new_writer)
await self.writer.start_tls(ssl_context)
self._print('%s <->| %s', '<TLS handshake>')

async def send_error_disconnect(self) -> None:
Expand Down Expand Up @@ -350,7 +331,6 @@ async def run(self, state: ConnectionState) -> None:
state: Defines the interaction with the backend plugin.

"""
await self._read_proxy_protocol()
self._print('%s +++| %s', str(socket_info.get()))
try:
await self._run_state(state)
Expand Down
38 changes: 12 additions & 26 deletions pymap/sieve/manage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import re
from argparse import ArgumentParser
from asyncio import StreamReader, StreamWriter, WriteTransport
from asyncio import StreamReader, StreamWriter
from base64 import b64encode, b64decode
from collections.abc import Mapping
from contextlib import closing, AsyncExitStack
Expand Down Expand Up @@ -62,10 +62,12 @@ async def start(self, stack: AsyncExitStack) -> None:
backend = self.backend
config = self.config
managesieve_server = ManageSieveServer(backend.login, config)
pp_reader = ProxyProtocolReader(config.proxy_protocol)
managesieve_server_cb = pp_reader.get_callback(managesieve_server)
host: str | None = config.args.sieve_host
port: str | int = config.args.sieve_port
server = await asyncio.start_server(
managesieve_server, host=host, port=port)
managesieve_server_cb, host=host, port=port)
await stack.enter_async_context(server)
task = asyncio.create_task(server.serve_forever())
stack.callback(task.cancel)
Expand All @@ -87,9 +89,10 @@ def __init__(self, login: LoginInterface, config: IMAPConfig) -> None:
self._login = login
self._config = config

async def __call__(self, reader: StreamReader,
writer: StreamWriter) -> None:
conn = ManageSieveConnection(self._login, self._config, reader, writer)
async def __call__(self, reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
conn = ManageSieveConnection(self._login, self._config,
reader, writer, sock_info)
async with AsyncExitStack() as stack:
connection_exit.set(stack)
stack.enter_context(closing(writer))
Expand All @@ -113,7 +116,8 @@ class ManageSieveConnection:
_impl = b'pymap managesieve ' + __version__.encode('ascii')

def __init__(self, login: LoginInterface, config: IMAPConfig,
reader: StreamReader, writer: StreamWriter) -> None:
reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
super().__init__()
self.login = login
self.config = config
Expand All @@ -123,17 +127,9 @@ def __init__(self, login: LoginInterface, config: IMAPConfig,
self.pp_result: ProxyResult | None = None
self._offer_starttls = b'STARTTLS' in config.initial_capability
self._state: FilterState | None = None
self._reset_streams(reader, writer)

def _reset_streams(self, reader: StreamReader,
writer: StreamWriter) -> None:
self.reader = reader
self.writer = writer
socket_info.set(SocketInfo.get(writer, self.pp_result))

async def _read_proxy_protocol(self) -> None:
self.pp_result = await self.pp_reader.read(self.reader)
self._reset_streams(self.reader, self.writer)
socket_info.set(sock_info)

def _get_state(self, session: SessionInterface) -> FilterState:
owner = session.owner.encode('utf-8')
Expand Down Expand Up @@ -280,16 +276,7 @@ async def _do_starttls(self) -> Response:
return Response(Condition.NO, text='Bad command.')
resp = Response(Condition.OK)
await self._write_response(resp)
loop = asyncio.get_event_loop()
transport = self.writer.transport
protocol = transport.get_protocol()
new_transport = await loop.start_tls(
transport, protocol, ssl_context, server_side=True)
assert isinstance(new_transport, WriteTransport)
new_protocol = new_transport.get_protocol()
new_writer = StreamWriter(new_transport, new_protocol,
self.reader, loop)
self._reset_streams(self.reader, new_writer)
await self.writer.start_tls(ssl_context)
self._print('%d <->| %s', b'<TLS handshake>')
self._offer_starttls = False
self.auth = self.config.tls_auth
Expand All @@ -300,7 +287,6 @@ async def run(self) -> None:
enter the command/response cycle.

"""
await self._read_proxy_protocol()
self._print('%d +++| %s', str(socket_info.get()))
greeting = await self._do_greeting()
await self._write_response(greeting)
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ build-backend = 'hatchling.build'

[project]
name = 'pymap'
version = '0.29.0'
version = '0.30.0'
authors = [
{ name = 'Ian Good', email = 'ian@icgood.net' },
]
Expand All @@ -45,15 +45,15 @@ classifiers = [
]
dependencies = [
'pysasl ~= 1.0',
'proxy-protocol ~= 0.9.1',
'proxy-protocol ~= 0.10.3',
]

[project.optional-dependencies]
admin = ['pymap-admin ~= 0.9.0', 'protobuf', 'googleapis-common-protos']
macaroon = ['pymacaroons']
redis = ['redis ~= 4.2', 'msgpack ~= 1.0']
sieve = ['sievelib']
swim = ['swim-protocol ~= 0.3.10']
swim = ['swim-protocol ~= 0.3.12']
systemd = ['systemd-python']
optional = ['hiredis', 'passlib', 'pid']

Expand Down Expand Up @@ -88,6 +88,9 @@ user = 'pymap.admin.handlers.user:UserHandlers'
[tool.hatch.build]
exclude = ['/tasks', '/doc', '/.github']

[tool.hatch.build.targets.wheel]
packages = ['pymap']

[tool.mypy]
files = ['pymap', 'test']
warn_redundant_casts = true
Expand Down
4 changes: 3 additions & 1 deletion test/server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Iterable

import pytest
from proxyprotocol.sock import SocketInfoLocal
from pysasl.hashing import BuiltinHash

from pymap.backend.dict import DictBackend
Expand Down Expand Up @@ -92,7 +93,8 @@ async def _run_transport(self, transport: MockTransport) -> None:
server = transport.server
reader: StreamReader = transport # type: ignore
writer: StreamWriter = transport # type: ignore
return await server(reader, writer)
sock_info = SocketInfoLocal(transport)
return await server(reader, writer, sock_info)

async def run(self, *transports: MockTransport) -> None:
failures = []
Expand Down