Skip to content

Commit

Permalink
Use Python 3.11's start_tls
Browse files Browse the repository at this point in the history
  • Loading branch information
icgood committed Apr 22, 2023
1 parent 93071ac commit ed794ab
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 64 deletions.
48 changes: 14 additions & 34 deletions pymap/imap/__init__.py
Expand Up @@ -7,17 +7,15 @@
import re
import sys
from argparse import ArgumentParser
from asyncio import shield, StreamReader, StreamWriter, WriteTransport, \
AbstractServer, CancelledError, TimeoutError
from asyncio import shield, StreamReader, StreamWriter, AbstractServer, \
CancelledError, TimeoutError
from base64 import b64encode, b64decode
from collections.abc import Awaitable, Iterable
from contextlib import closing, AsyncExitStack
from ssl import SSLError
from typing import TypeVar
from uuid import uuid4

from proxyprotocol.reader import ProxyProtocolReader
from proxyprotocol.result import ProxyResult
from proxyprotocol.sock import SocketInfo
from proxyprotocol.version import ProxyProtocolVersion
from pymap.concurrent import Event
Expand Down Expand Up @@ -78,16 +76,18 @@ async def start(self, stack: AsyncExitStack) -> None:
config = self.config
servers: list[AbstractServer] = []
imap_server = IMAPServer(backend.login, config)
pp_reader = ProxyProtocolReader(config.proxy_protocol)
imap_server_cb = pp_reader.get_callback(imap_server)
if config.args.inherited_sockets:
sockets = InheritedSockets.of(config.args.inherited_sockets).get()
if not sockets:
raise ValueError('No inherited sockets found')
for sock in sockets:
servers.append(await asyncio.start_server(
imap_server, sock=sock))
imap_server_cb, sock=sock))
else:
servers.append(await asyncio.start_server(
imap_server, host=config.host, port=config.port))
imap_server_cb, host=config.host, port=config.port))
for server in servers:
await stack.enter_async_context(server)
task = asyncio.create_task(server.serve_forever())
Expand All @@ -113,9 +113,10 @@ def __init__(self, login: LoginInterface, config: IMAPConfig) -> None:
self._login = login
self._config = config

async def __call__(self, reader: StreamReader,
writer: StreamWriter) -> None:
conn = IMAPConnection(self.commands, self._config, reader, writer)
async def __call__(self, reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
conn = IMAPConnection(self.commands, self._config,
reader, writer, sock_info)
state = ConnectionState(self._login, self._config)
async with AsyncExitStack() as stack:
connection_exit.set(stack)
Expand All @@ -141,27 +142,16 @@ class IMAPConnection:
'reader', 'writer', 'pp_reader', 'pp_result']

def __init__(self, commands: Commands, config: IMAPConfig,
reader: StreamReader,
writer: StreamWriter) -> None:
reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
super().__init__()
self.commands = commands
self.config = config
self.params = config.parsing_params
self.bad_command_limit = config.bad_command_limit
self.pp_reader = ProxyProtocolReader(config.proxy_protocol)
self.pp_result: ProxyResult | None = None
self._reset_streams(reader, writer)

def _reset_streams(self, reader: StreamReader,
writer: StreamWriter) -> None:
self.reader = reader
self.writer = writer
socket_info.set(SocketInfo.get(writer, self.pp_result,
unique_id=uuid4().bytes))

async def _read_proxy_protocol(self) -> None:
self.pp_result = await self.pp_reader.read(self.reader)
self._reset_streams(self.reader, self.writer)
socket_info.set(sock_info)

def close(self) -> None:
self.writer.close()
Expand Down Expand Up @@ -275,17 +265,8 @@ async def write_response(self, resp: Response) -> None:
self._print('%s <--| %s', bytes(resp))

async def start_tls(self) -> None:
loop = asyncio.get_event_loop()
transport = self.writer.transport
protocol = transport.get_protocol()
ssl_context = self.config.ssl_context
new_transport = await loop.start_tls(
transport, protocol, ssl_context, server_side=True)
assert isinstance(new_transport, WriteTransport)
new_protocol = new_transport.get_protocol()
new_writer = StreamWriter(new_transport, new_protocol,
self.reader, loop)
self._reset_streams(self.reader, new_writer)
await self.writer.start_tls(ssl_context)
self._print('%s <->| %s', '<TLS handshake>')

async def send_error_disconnect(self) -> None:
Expand Down Expand Up @@ -350,7 +331,6 @@ async def run(self, state: ConnectionState) -> None:
state: Defines the interaction with the backend plugin.
"""
await self._read_proxy_protocol()
self._print('%s +++| %s', str(socket_info.get()))
try:
await self._run_state(state)
Expand Down
38 changes: 12 additions & 26 deletions pymap/sieve/manage/__init__.py
Expand Up @@ -6,7 +6,7 @@
import logging
import re
from argparse import ArgumentParser
from asyncio import StreamReader, StreamWriter, WriteTransport
from asyncio import StreamReader, StreamWriter
from base64 import b64encode, b64decode
from collections.abc import Mapping
from contextlib import closing, AsyncExitStack
Expand Down Expand Up @@ -62,10 +62,12 @@ async def start(self, stack: AsyncExitStack) -> None:
backend = self.backend
config = self.config
managesieve_server = ManageSieveServer(backend.login, config)
pp_reader = ProxyProtocolReader(config.proxy_protocol)
managesieve_server_cb = pp_reader.get_callback(managesieve_server)
host: str | None = config.args.sieve_host
port: str | int = config.args.sieve_port
server = await asyncio.start_server(
managesieve_server, host=host, port=port)
managesieve_server_cb, host=host, port=port)
await stack.enter_async_context(server)
task = asyncio.create_task(server.serve_forever())
stack.callback(task.cancel)
Expand All @@ -87,9 +89,10 @@ def __init__(self, login: LoginInterface, config: IMAPConfig) -> None:
self._login = login
self._config = config

async def __call__(self, reader: StreamReader,
writer: StreamWriter) -> None:
conn = ManageSieveConnection(self._login, self._config, reader, writer)
async def __call__(self, reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
conn = ManageSieveConnection(self._login, self._config,
reader, writer, sock_info)
async with AsyncExitStack() as stack:
connection_exit.set(stack)
stack.enter_context(closing(writer))
Expand All @@ -113,7 +116,8 @@ class ManageSieveConnection:
_impl = b'pymap managesieve ' + __version__.encode('ascii')

def __init__(self, login: LoginInterface, config: IMAPConfig,
reader: StreamReader, writer: StreamWriter) -> None:
reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
super().__init__()
self.login = login
self.config = config
Expand All @@ -123,17 +127,9 @@ def __init__(self, login: LoginInterface, config: IMAPConfig,
self.pp_result: ProxyResult | None = None
self._offer_starttls = b'STARTTLS' in config.initial_capability
self._state: FilterState | None = None
self._reset_streams(reader, writer)

def _reset_streams(self, reader: StreamReader,
writer: StreamWriter) -> None:
self.reader = reader
self.writer = writer
socket_info.set(SocketInfo.get(writer, self.pp_result))

async def _read_proxy_protocol(self) -> None:
self.pp_result = await self.pp_reader.read(self.reader)
self._reset_streams(self.reader, self.writer)
socket_info.set(sock_info)

def _get_state(self, session: SessionInterface) -> FilterState:
owner = session.owner.encode('utf-8')
Expand Down Expand Up @@ -280,16 +276,7 @@ async def _do_starttls(self) -> Response:
return Response(Condition.NO, text='Bad command.')
resp = Response(Condition.OK)
await self._write_response(resp)
loop = asyncio.get_event_loop()
transport = self.writer.transport
protocol = transport.get_protocol()
new_transport = await loop.start_tls(
transport, protocol, ssl_context, server_side=True)
assert isinstance(new_transport, WriteTransport)
new_protocol = new_transport.get_protocol()
new_writer = StreamWriter(new_transport, new_protocol,
self.reader, loop)
self._reset_streams(self.reader, new_writer)
await self.writer.start_tls(ssl_context)
self._print('%d <->| %s', b'<TLS handshake>')
self._offer_starttls = False
self.auth = self.config.tls_auth
Expand All @@ -300,7 +287,6 @@ async def run(self) -> None:
enter the command/response cycle.
"""
await self._read_proxy_protocol()
self._print('%d +++| %s', str(socket_info.get()))
greeting = await self._do_greeting()
await self._write_response(greeting)
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Expand Up @@ -25,7 +25,7 @@ build-backend = 'hatchling.build'

[project]
name = 'pymap'
version = '0.29.0'
version = '0.30.0'
authors = [
{ name = 'Ian Good', email = 'ian@icgood.net' },
]
Expand All @@ -45,15 +45,15 @@ classifiers = [
]
dependencies = [
'pysasl ~= 1.0',
'proxy-protocol ~= 0.9.1',
'proxy-protocol ~= 0.10.3',
]

[project.optional-dependencies]
admin = ['pymap-admin ~= 0.9.0', 'protobuf', 'googleapis-common-protos']
macaroon = ['pymacaroons']
redis = ['redis ~= 4.2', 'msgpack ~= 1.0']
sieve = ['sievelib']
swim = ['swim-protocol ~= 0.3.10']
swim = ['swim-protocol ~= 0.3.12']
systemd = ['systemd-python']
optional = ['hiredis', 'passlib', 'pid']

Expand Down Expand Up @@ -88,6 +88,9 @@ user = 'pymap.admin.handlers.user:UserHandlers'
[tool.hatch.build]
exclude = ['/tasks', '/doc', '/.github']

[tool.hatch.build.targets.wheel]
packages = ['pymap']

[tool.mypy]
files = ['pymap', 'test']
warn_redundant_casts = true
Expand Down
4 changes: 3 additions & 1 deletion test/server/base.py
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Iterable

import pytest
from proxyprotocol.sock import SocketInfoLocal
from pysasl.hashing import BuiltinHash

from pymap.backend.dict import DictBackend
Expand Down Expand Up @@ -92,7 +93,8 @@ async def _run_transport(self, transport: MockTransport) -> None:
server = transport.server
reader: StreamReader = transport # type: ignore
writer: StreamWriter = transport # type: ignore
return await server(reader, writer)
sock_info = SocketInfoLocal(transport)
return await server(reader, writer, sock_info)

async def run(self, *transports: MockTransport) -> None:
failures = []
Expand Down

0 comments on commit ed794ab

Please sign in to comment.