Skip to content

Commit

Permalink
Fix the incorrect fetch partial parsing
Browse files Browse the repository at this point in the history
Fix #131

The original implementation interpreted the partial fetch arguments as
a start and end octet, but the RFC states they are a start octet and a
length. This also causes incorrect parsing failures because there was a
check that the first number was less than the second number, which does
not make sense with corrected assumptions.
  • Loading branch information
icgood committed May 7, 2022
1 parent 73c52f8 commit 3cf851e
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 34 deletions.
2 changes: 1 addition & 1 deletion pymap/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from grpclib.events import listen, RecvRequest
from grpclib.health.check import ServiceStatus
from grpclib.health.service import Health, OVERALL
from grpclib.server import Server
from pymap.context import cluster_metadata
from pymap.interfaces.backend import ServiceInterface
from pymapadmin import is_compatible, __version__ as server_version
from pymapadmin.local import socket_file, token_file

from .errors import get_incompatible_version_error
from .handlers import handlers
from .server import Server
from .typing import Handler

__all__ = ['AdminService']
Expand Down
32 changes: 32 additions & 0 deletions pymap/admin/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

from __future__ import annotations

from asyncio import AbstractEventLoop

from grpclib.server import Server as _Server

__all__ = ['Server']


class Server(_Server):
"""The :class:`~grpclib.server.Server` class included with :mod:`grpclib`
has typing issues due to some missing methods. These methods do not seem to
be part of its public API, so they should be implemented to simply raise
exceptions.
Note:
If this is fixed upstream, this class should be removed.
"""

def get_loop(self) -> AbstractEventLoop:
raise NotImplementedError()

def is_serving(self) -> bool:
raise NotImplementedError()

async def serve_forever(self) -> None:
raise NotImplementedError()

async def start_serving(self) -> None:
raise NotImplementedError()
7 changes: 4 additions & 3 deletions pymap/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import asynccontextmanager, AbstractAsyncContextManager
from contextvars import copy_context, Context
from functools import partial
from threading import local, Event as _threading_Event, Lock as _threading_Lock
from typing import cast, TypeVar, TypeAlias
from typing import TypeVar, TypeAlias
from weakref import WeakSet

__all__ = ['Subsystem', 'Event', 'ReadWriteLock', 'FileLock', 'EventT', 'RetT']
Expand Down Expand Up @@ -135,8 +136,8 @@ async def execute(self, future: Awaitable[RetT]) -> RetT:

def _run_in_thread(self, future: Awaitable[RetT], ctx: Context) -> RetT:
loop = self._local.event_loop
ret = ctx.run(loop.run_until_complete, future)
return cast(RetT, ret)
foo: partial[RetT] = partial(loop.run_until_complete, future)
return ctx.run(foo)

def new_rwlock(self) -> _ThreadingReadWriteLock:
return _ThreadingReadWriteLock()
Expand Down
21 changes: 12 additions & 9 deletions pymap/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@
from abc import abstractmethod, ABCMeta
from collections.abc import Iterator, Mapping, Sequence, AsyncIterator
from contextlib import contextmanager, asynccontextmanager
from typing import TypeAlias, ClassVar, Final, Protocol, Any, Union
from typing import ClassVar, Final, Protocol, Any

from .bytes import BytesFormat, MaybeBytes, Writeable
from .interfaces.message import MessageInterface, LoadedMessageInterface
from .parsing.primitives import Nil, Number, List, LiteralString
from .parsing.specials import DateTime, FetchRequirement, FetchAttribute, \
FetchValue
from .parsing.specials import DateTime
from .parsing.specials.fetchattr import FetchPartial, FetchRequirement, \
FetchAttribute, FetchValue
from .selected import SelectedMailbox

__all__ = ['LoadedMessageProvider', 'DynamicFetchValue',
'DynamicLoadedFetchValue', 'MessageAttributes']

_Partial: TypeAlias = Union[tuple[int, int | None], None]


class LoadedMessageProvider(Protocol):
"""Generic protocol that provides access to a message's loaded contents
Expand Down Expand Up @@ -101,7 +100,8 @@ def __bytes__(self) -> bytes:
self.attribute.for_response, value)

@classmethod
def _get_data(cls, section: FetchAttribute.Section, partial: _Partial,
def _get_data(cls, section: FetchAttribute.Section,
partial: FetchPartial | None,
loaded_msg: LoadedMessageInterface, *,
binary: bool = False) -> Writeable:
specifier = section.specifier
Expand All @@ -124,13 +124,16 @@ def _get_data(cls, section: FetchAttribute.Section, partial: _Partial,
return cls._get_partial(data, partial)

@classmethod
def _get_partial(cls, data: Writeable, partial: _Partial) -> Writeable:
def _get_partial(cls, data: Writeable,
partial: FetchPartial | None) -> Writeable:
if partial is None:
return data
full = bytes(data)
start, end = partial
if end is None:
start, length = (partial.start, partial.length)
if length is None:
end = len(full)
else:
end = start + length
return Writeable.wrap(full[start:end])


Expand Down
5 changes: 3 additions & 2 deletions pymap/imap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import re
import sys
from argparse import ArgumentParser
from asyncio import shield, StreamReader, StreamWriter, AbstractServer, \
CancelledError, TimeoutError
from asyncio import shield, StreamReader, StreamWriter, WriteTransport, \
AbstractServer, CancelledError, TimeoutError
from base64 import b64encode, b64decode
from collections.abc import Awaitable, Iterable
from contextlib import closing, AsyncExitStack
Expand Down Expand Up @@ -279,6 +279,7 @@ async def start_tls(self) -> None:
ssl_context = self.config.ssl_context
new_transport = await loop.start_tls(
transport, protocol, ssl_context, server_side=True)
assert isinstance(new_transport, WriteTransport)
new_protocol = new_transport.get_protocol()
new_writer = StreamWriter(new_transport, new_protocol,
self.reader, loop)
Expand Down
40 changes: 30 additions & 10 deletions pymap/parsing/specials/fetchattr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from abc import ABCMeta
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from functools import total_ordering, reduce
from typing import Final, Any

Expand All @@ -14,7 +15,22 @@
from ..primitives import Atom, List
from ...bytes import BytesFormat, MaybeBytes, Writeable

__all__ = ['FetchRequirement', 'FetchAttribute', 'FetchValue']
__all__ = ['FetchPartial', 'FetchRequirement', 'FetchAttribute', 'FetchValue']


@dataclass(frozen=True)
class FetchPartial:
"""Used to indicate that only a substring of the desired fetch is being
requested.
Args:
start: The first octet of the requested substring.
length: The maximum length of the requested substring, or ``None``.
"""

start: int
length: int | None = None


class FetchRequirement(enum.Flag):
Expand Down Expand Up @@ -121,7 +137,7 @@ def __hash__(self) -> int:

def __init__(self, attribute: bytes,
section: FetchAttribute.Section | None = None,
partial: tuple[int, int | None] | None = None) -> None:
partial: FetchPartial | None = None) -> None:
super().__init__()
self.attribute = attribute.upper()
self.section = section
Expand All @@ -137,11 +153,12 @@ def value(self) -> bytes:
@property
def for_response(self) -> FetchAttribute:
if self._for_response is None:
if self.partial is None or len(self.partial) < 2:
if self.partial is None or self.partial.length is None:
self._for_response = self
else:
new_partial = FetchPartial(self.partial.start, None)
self._for_response = FetchAttribute(
self.value, self.section, (self.partial[0], None))
self.value, self.section, new_partial)
return self._for_response

@property
Expand Down Expand Up @@ -199,9 +216,11 @@ def raw(self) -> bytes:
parts.append(bytes(List(headers, sort=True)))
parts.append(b']')
if self.partial:
partial = BytesFormat(b'.').join(
[b'%i' % num for num in self.partial if num is not None])
parts += [b'<', partial, b'>']
start, length = (self.partial.start, self.partial.length)
if length is None:
parts.append(b'<%i>' % start)
else:
parts.append(b'<%i.%i>' % (start, length))
self._raw = raw = b''.join(parts)
return raw

Expand Down Expand Up @@ -294,10 +313,11 @@ def parse(cls, buf: memoryview, params: Params) \
if match:
if attr == b'BINARY.SIZE':
raise NotParseable(buf)
from_, to = int(match.group(1)), int(match.group(2))
if from_ < 0 or to <= 0 or from_ > to:
start, length = int(match.group(1)), int(match.group(2))
if start < 0 or length <= 0:
raise NotParseable(buf)
return cls(attr, section, (from_, to)), buf[match.end(0):]
partial = FetchPartial(start, length)
return cls(attr, section, partial), buf[match.end(0):]
return cls(attr, section), buf

def __bytes__(self) -> bytes:
Expand Down
3 changes: 2 additions & 1 deletion pymap/sieve/manage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import re
from argparse import ArgumentParser
from asyncio import StreamReader, StreamWriter
from asyncio import StreamReader, StreamWriter, WriteTransport
from base64 import b64encode, b64decode
from collections.abc import Mapping
from contextlib import closing, AsyncExitStack
Expand Down Expand Up @@ -283,6 +283,7 @@ async def _do_starttls(self) -> Response:
protocol = transport.get_protocol()
new_transport = await loop.start_tls(
transport, protocol, ssl_context, server_side=True)
assert isinstance(new_transport, WriteTransport)
new_protocol = new_transport.get_protocol()
new_writer = StreamWriter(new_transport, new_protocol,
self.reader, loop)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
license = f.read()

setup(name='pymap',
version='0.27.0',
version='0.27.1',
author='Ian Good',
author_email='ian@icgood.net',
description='Lightweight, asynchronous IMAP serving in Python.',
Expand Down
2 changes: 1 addition & 1 deletion test/server/test_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ async def test_fetch_binary_section_partial(self, imap_server):
transport.push_login()
transport.push_select(b'Sent')
transport.push_readline(
b'fetch1 FETCH 2 (BINARY[1]<5.22>)\r\n')
b'fetch1 FETCH 2 (BINARY[1]<5.17>)\r\n')
transport.push_write(
b'* 2 FETCH (BINARY[1]<5> ~{17}\r\n'
+ 'ⅬⅬՕ ᎳоᏒ'.encode('utf-8') +
Expand Down
6 changes: 6 additions & 0 deletions test/test_parsing_command_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def test_parse_list(self):
FetchAttribute(b'ENVELOPE')], ret.attributes)
self.assertEqual(b' ', buf)

def test_parse_uid_list(self):
ret, buf = UidFetchCommand.parse(b' 265 (BODY.PEEK[1]<208.78> '
b'BODY.PEEK[2]<2027.77>)\n ',
Params())
self.assertEqual(b' ', buf)

def test_parse_macro_all(self):
ret, buf = FetchCommand.parse(b' 1,2,3 ALL\n ', Params())
self.assertEqual([1, 2, 3], ret.sequence_set.value)
Expand Down
12 changes: 6 additions & 6 deletions test/test_parsing_specials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from pymap.parsing.exceptions import NotParseable, UnexpectedType, \
InvalidContent
from pymap.parsing.specials import AString, Tag, Mailbox, DateTime, Flag, \
StatusAttribute, SequenceSet, FetchAttribute, SearchKey, ObjectId, \
ExtensionOptions
StatusAttribute, SequenceSet, SearchKey, ObjectId, ExtensionOptions
from pymap.parsing.state import ParsingState
from pymap.parsing.specials.fetchattr import FetchPartial, FetchAttribute
from pymap.parsing.specials.sequenceset import MaxValue


Expand Down Expand Up @@ -223,12 +223,12 @@ def test_hash(self):

def test_parse(self):
ret, buf = FetchAttribute.parse(
b'body.peek[1.2.HEADER.FIELDS (A B)]<4.5> ', Params())
b'body.peek[1.2.HEADER.FIELDS (A B)]<4.2> ', Params())
self.assertEqual(b'BODY.PEEK', ret.value)
self.assertEqual([1, 2], ret.section.parts)
self.assertEqual(b'HEADER.FIELDS', ret.section.specifier)
self.assertEqual({b'A', b'B'}, ret.section.headers)
self.assertEqual((4, 5), ret.partial)
self.assertEqual(FetchPartial(4, 2), ret.partial)
self.assertEqual(b' ', buf)

def test_parse_simple(self):
Expand Down Expand Up @@ -283,8 +283,8 @@ def test_bytes(self):
self.assertEqual(b'ENVELOPE', bytes(attr1))
section = FetchAttribute.Section((1, 2), b'STUFF',
frozenset({b'A', b'B'}))
attr2 = FetchAttribute(b'BODY', section, (4, 5))
self.assertEqual(b'BODY[1.2.STUFF (A B)]<4.5>', bytes(attr2))
attr2 = FetchAttribute(b'BODY', section, FetchPartial(4, 2))
self.assertEqual(b'BODY[1.2.STUFF (A B)]<4.2>', bytes(attr2))
self.assertEqual(b'BODY[1.2.STUFF (A B)]<4>',
bytes(attr2.for_response))

Expand Down

0 comments on commit 3cf851e

Please sign in to comment.