Skip to content

Commit

Permalink
Improve type hints for queue.get()
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed May 10, 2023
1 parent 1ccc35d commit 835438e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
24 changes: 18 additions & 6 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import asyncio
import dataclasses
import sys
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from enum import Enum, IntEnum, unique
from functools import singledispatch
from types import TracebackType
from typing import (
Any, AsyncContextManager, AsyncIterable, Awaitable, Callable, Dict,
Generator, Iterator, Optional, Type, TypeVar, Union,
Generator, Iterator, Optional, Type, TypeVar, Union, overload,
)


try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
if sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
from typing_extensions import Literal, TypedDict

import aiormq.abc
from aiormq.abc import ExceptionType
Expand Down Expand Up @@ -324,6 +324,18 @@ async def cancel(
) -> aiormq.spec.Basic.CancelOk:
raise NotImplementedError

@overload
async def get(
self, *, no_ack: bool = False,
fail: Literal[True] = ..., timeout: TimeoutType = ...,
) -> AbstractIncomingMessage:
...
@overload
async def get(
self, *, no_ack: bool = False,
fail: Literal[False] = ..., timeout: TimeoutType = ...,
) -> Optional[AbstractIncomingMessage]:
...
@abstractmethod
async def get(
self, *, no_ack: bool = False,
Expand Down
21 changes: 20 additions & 1 deletion aio_pika/queue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import sys
from functools import partial
from types import TracebackType
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Optional, Type, overload

import aiormq
from aiormq.abc import DeliveredMessage
Expand All @@ -18,6 +19,12 @@
from .tools import CallbackCollection, create_task


if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


log = get_logger(__name__)


Expand Down Expand Up @@ -256,6 +263,18 @@ async def cancel(
consumer_tag=consumer_tag, nowait=nowait, timeout=timeout,
)

@overload
async def get(
self, *, no_ack: bool = False,
fail: Literal[True] = ..., timeout: TimeoutType = ...,
) -> IncomingMessage:
...
@overload
async def get(
self, *, no_ack: bool = False,
fail: Literal[False] = ..., timeout: TimeoutType = ...,
) -> Optional[IncomingMessage]:
...
async def get(
self, *, no_ack: bool = False,
fail: bool = True, timeout: TimeoutType = 5,
Expand Down

0 comments on commit 835438e

Please sign in to comment.