Skip to content

Commit

Permalink
Merge pull request #6 from gunyu1019/feature/2-annotated
Browse files Browse the repository at this point in the history
[Feat] Support Annotated
  • Loading branch information
gunyu1019 committed Jan 31, 2024
2 parents e5b8dec + 8241f1b commit 6f8ddeb
Show file tree
Hide file tree
Showing 20 changed files with 279 additions and 124 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ jobs:
python -m pip install --upgrade pip
pip install .[lint]
- name: Analysing the code with pycodestyle
run: python -m pycodestyle async_client_decorator --max-line-length=120
run: python -m pycodestyle async_client --max-line-length=120
- name: Checking the code for formatted in black
run: python -m black --check async_client_decorator example
run: python -m black --check async_client example
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ pip install async-client-decorator
## Quick Example

An example is the API provided by the [BUS API](https://github.com/gunyu1019/trafficAPI).

```python
import asyncio
import aiohttp
from async_client_decorator import request, Session, Query
from async_client import request, Session, Query

loop = asyncio.get_event_loop()

Expand All @@ -30,11 +31,11 @@ class BusAPI(Session):

@request("GET", "/bus/station")
async def station_search_with_query(
self,
response: aiohttp.ClientResponse,
name: Query | str
self,
response: aiohttp.ClientResponse,
name: Query | str
):
return await response.json()
return await response.json()


async def main():
Expand All @@ -43,5 +44,6 @@ async def main():
data = await response.json()
print(len)


loop.run_until_complete(main())
```
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
__author__ = "gunyu1019"
__license__ = "MIT"
__copyright__ = "Copyright 2023-present gunyu1019"
__version__ = "0.1.1" # version_info.to_string()
__version__ = "0.2.0" # version_info.to_string()


class VersionInfo(NamedTuple):
Expand All @@ -54,5 +54,5 @@ def to_string(self) -> str:


version_info: VersionInfo = VersionInfo(
major=0, minor=1, micro=1, release_level=None, serial=0
major=0, minor=2, micro=0, release_level=None, serial=0
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
23 changes: 15 additions & 8 deletions async_client_decorator/component.py → async_client/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
SOFTWARE.
"""

import aiohttp
import dataclasses
import inspect
from collections.abc import Collection
from typing import Any, Optional, Literal

import aiohttp
from .utils import *


@dataclasses.dataclass
Expand Down Expand Up @@ -60,23 +62,28 @@ def fill_keyword_argument(
return kwargs.get(key) if key in kwargs.keys() else parameter.default

def set_body(self, data: inspect.Parameter | dict | list | aiohttp.FormData):
body_annotations = (
body_annotation = (
data.annotation if isinstance(data, inspect.Parameter) else type(data)
)
argument = (
body_annotation.__args__
if is_annotated_parameter(body_annotation)
else body_annotation
)
separated_argument = separate_union_type(argument)
origin_argument = [
get_origin_for_generic(x) for x in make_collection(separated_argument)
]

if not (
issubclass(dict, body_annotations)
or issubclass(list, body_annotations)
or issubclass(aiohttp.FormData, body_annotations)
):
if not is_subclass_safe(origin_argument, (dict, list, aiohttp.FormData)):
raise TypeError(
"Body parameter can only have aiohttp.FormData or dict, list."
)

if self.body is not None:
raise ValueError("Only one Body Parameter is allowed.")

if issubclass(aiohttp.FormData, body_annotations):
if is_subclass_safe(separated_argument, aiohttp.FormData):
self.body_type = "data"
else:
self.body_type = "json"
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
39 changes: 29 additions & 10 deletions async_client_decorator/request.py → async_client/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .path import Path
from .query import Query
from .session import Session
from .utils import *

T = TypeVar("T")

Expand Down Expand Up @@ -117,23 +118,41 @@ def decorator(func: RequestFunction):
components.path.update(getattr(func, Path.DEFAULT_KEY, dict()))

for parameter in func_parameters.values():
if hasattr(parameter.annotation, "__args__"):
annotation = parameter.annotation.__args__
else:
annotation = (parameter.annotation,)
annotation = parameter.annotation
metadata = (
annotation.__metadata__
if is_annotated_parameter(annotation)
else annotation
)
separated_annotation = separate_union_type(metadata)

if issubclass(Header, annotation) or parameter.name in header_parameter:
if (
is_subclass_safe(separated_annotation, Header)
or parameter.name in header_parameter
):
components.header[parameter.name] = parameter
elif issubclass(Query, annotation) or parameter.name in query_parameter:
elif (
is_subclass_safe(separated_annotation, Query)
or parameter.name in query_parameter
):
components.query[parameter.name] = parameter
elif issubclass(Path, annotation) or parameter.name in path_parameter:
elif (
is_subclass_safe(separated_annotation, Path)
or parameter.name in path_parameter
):
components.path[parameter.name] = parameter
elif issubclass(Form, annotation) or parameter.name in form_parameter:
elif (
is_subclass_safe(separated_annotation, Form)
or parameter.name in form_parameter
):
components.add_form(parameter.name, parameter)
elif issubclass(Body, annotation) or parameter.name == body_parameter:
elif (
is_subclass_safe(separated_annotation, Body)
or parameter.name == body_parameter
):
components.set_body(parameter)
elif (
issubclass(aiohttp.ClientResponse, annotation)
is_subclass_safe(separated_annotation, aiohttp.ClientResponse)
or parameter.name in response_parameter
):
components.response.append(parameter.name)
Expand Down
File renamed without changes.
45 changes: 45 additions & 0 deletions async_client/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections.abc import Collection
from types import UnionType, GenericAlias
from typing import Annotated, get_origin


def is_subclass_safe(_class, _class_info) -> bool:
"""
Same functionality as `issubclass` method
However, _class parameter can be list of type.
"""
_class = make_collection(_class)
return any([issubclass(t, _class_info) for t in _class if isinstance(t, type)])


def separate_union_type(t):
"""
If type is Union, return list of type
else return t
"""
if isinstance(t, UnionType):
return t.__args__
return t


def is_annotated_parameter(t) -> bool:
"""
Return `True` if type is Annotated
"""
return get_origin(t) is Annotated


def get_origin_for_generic(t):
"""
If type is Generic, return origin of generic type
else return t
"""
if isinstance(t, GenericAlias):
return t.__origin__
return t


def make_collection(t):
if not isinstance(t, Collection):
return (t,)
return t
28 changes: 28 additions & 0 deletions example/annotated_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio
import aiohttp

from async_client import request, Session, Query
from typing import Annotated

loop = asyncio.get_event_loop()


class BusAPI(Session):
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__("https://api.yhs.kr", loop)

@request("GET", "/bus/station")
async def station_search_with_query(
self, response: aiohttp.ClientResponse, name: Annotated[str, Query]
):
return await response.json()


async def main():
async with BusAPI(loop) as client:
response = await client.station_search_with_query(name="bus-station-name")
data = await response.json()
print(len)


loop.run_until_complete(main())
2 changes: 1 addition & 1 deletion example/flask_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import aiohttp

from flask import Flask
from async_client_decorator import request, Session, Query
from async_client import request, Session, Query

app = Flask(__name__)
loop = asyncio.get_event_loop()
Expand Down
2 changes: 1 addition & 1 deletion example/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import aiohttp

from async_client_decorator import request, Session, Query
from async_client import request, Session, Query

loop = asyncio.get_event_loop()

Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import setup

version = ""
with open("async_client_decorator/__init__.py") as f:
with open("async_client/__init__.py") as f:
version = re.search(
r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE
).group(1)
Expand All @@ -18,9 +18,9 @@
}

setup(
name="async_client_decorator",
name="async_client",
version=version,
packages=["async_client_decorator"],
packages=["async_client"],
url="https://github.com/gunyu1019/async-client-decorator",
license="MIT",
author="gunyu1019",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_component.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from async_client_decorator import *
from async_client_decorator.component import Component
from async_client import *
from async_client.component import Component


def test_duplicated_body_type():
Expand Down
Loading

0 comments on commit 6f8ddeb

Please sign in to comment.