Skip to content

Commit

Permalink
Use yarl and add generic C-S discovery
Browse files Browse the repository at this point in the history
Fixes #20
Fixes #25
  • Loading branch information
tulir committed Sep 18, 2020
1 parent 351cce7 commit cc43498
Show file tree
Hide file tree
Showing 10 changed files with 234 additions and 138 deletions.
84 changes: 52 additions & 32 deletions mautrix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,25 @@
import logging
import asyncio

from yarl import URL
from aiohttp import ClientSession
from aiohttp.client_exceptions import ContentTypeError, ClientError

from mautrix.errors import make_request_error, MatrixConnectionError
from mautrix.util.logging import TraceLogger

if TYPE_CHECKING:
from mautrix.types import JSON


class APIPath(Enum):
"""The known Matrix API path prefixes."""
CLIENT = "/_matrix/client/r0"
CLIENT_UNSTABLE = "/_matrix/client/unstable"
MEDIA = "/_matrix/media/r0"
IDENTITY = "/_matrix/identity/r0"
"""
The known Matrix API path prefixes.
These don't start with a slash so they can be used nicely with yarl.
"""
CLIENT = "_matrix/client/r0"
CLIENT_UNSTABLE = "_matrix/client/unstable"
MEDIA = "_matrix/media/r0"

def __repr__(self):
return self.value
Expand Down Expand Up @@ -60,7 +64,7 @@ class PathBuilder:
>>> room_id = "!foo:example.com"
>>> event_id = "$bar:example.com"
>>> str(Path.rooms[room_id].event[event_id])
"/_matrix/client/r0/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com"
"_matrix/client/r0/rooms/%21foo%3Aexample.com/event/%24bar%3Aexample.com"
"""

def __init__(self, path: Union[str, APIPath] = "") -> None:
Expand Down Expand Up @@ -105,14 +109,21 @@ def __getitem__(self, append: Union[str, int]) -> 'PathBuilder':
ClientPath = Path
UnstableClientPath = PathBuilder(APIPath.CLIENT_UNSTABLE)
MediaPath = PathBuilder(APIPath.MEDIA)
IdentityPath = PathBuilder(APIPath.IDENTITY)


class HTTPAPI:
"""HTTPAPI is a simple asyncio Matrix API request sender."""

def __init__(self, base_url: str, token: str = "", *, client_session: ClientSession = None,
txn_id: int = 0, log: Optional[logging.Logger] = None,
base_url: URL
token: str
log: TraceLogger
loop: asyncio.AbstractEventLoop
session: ClientSession
txn_id: Optional[int]

def __init__(self, base_url: Union[URL, str], token: str = "", *,
client_session: ClientSession = None,
txn_id: int = 0, log: Optional[TraceLogger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
"""
Args:
Expand All @@ -122,18 +133,18 @@ def __init__(self, base_url: str, token: str = "", *, client_session: ClientSess
txn_id: The outgoing transaction ID to start with.
log: The logging.Logger instance to log requests with.
"""
self.base_url: str = base_url
self.token: str = token
self.log: Optional[logging.Logger] = log or logging.getLogger("mau.http")
self.base_url = URL(base_url)
self.token = token
self.log = log or logging.getLogger("mau.http")
self.loop = loop or asyncio.get_event_loop()
self.session: ClientSession = client_session or ClientSession(loop=self.loop)
self.session = client_session or ClientSession(loop=self.loop)
if txn_id is not None:
self.txn_id: int = txn_id
self.txn_id = txn_id

async def _send(self, method: Method, endpoint: str, content: Union[bytes, str],
async def _send(self, method: Method, url: URL, content: Union[bytes, str],
query_params: Dict[str, str], headers: Dict[str, str]) -> 'JSON':
while True:
request = self.session.request(str(method), endpoint, data=content,
request = self.session.request(str(method), url, data=content,
params=query_params, headers=headers)
async with request as response:
if response.status < 200 or response.status >= 300:
Expand All @@ -150,7 +161,10 @@ async def _send(self, method: Method, endpoint: str, content: Union[bytes, str],

if response.status == 429:
resp = await response.json()
await asyncio.sleep(resp["retry_after_ms"] / 1000, loop=self.loop)
seconds = resp["retry_after_ms"] / 1000
self.log.debug(f"Request to {url} returned 429, "
f"waiting {seconds} seconds and retrying")
await asyncio.sleep(seconds, loop=self.loop)
else:
return await response.json()

Expand All @@ -161,7 +175,7 @@ def _log_request(self, method: Method, path: PathBuilder, content: Union[str, by
log_content = content if not isinstance(content, bytes) else f"<{len(content)} bytes>"
as_user = query_params.get("user_id", None)
level = 1 if path == Path.sync else 5
self.log.log(level, f"{method} {path} {log_content}".strip(" "),
self.log.log(level, f"{method} /{path} {log_content}".strip(" "),
extra={"matrix_http_request": {
"method": str(method),
"path": str(path),
Expand All @@ -170,23 +184,25 @@ def _log_request(self, method: Method, path: PathBuilder, content: Union[str, by
"user": as_user,
}})

async def request(self, method: Method, path: PathBuilder,
content: Optional[Union['JSON', bytes, str]] = None,
async def request(self, method: Method, path: Union[PathBuilder, str],
content: Optional[Union[dict, list, bytes, str]] = None,
headers: Optional[Dict[str, str]] = None,
query_params: Optional[Dict[str, str]] = None) -> 'JSON':
"""
Make a raw HTTP request.
Make a raw Matrix API request.
Args:
method: The HTTP method to use.
path: The API endpoint to call.
Does not include the base path (e.g. /_matrix/client/r0).
content: The content to post as a dict (json) or bytes/str (raw).
headers: The dict of HTTP headers to send.
query_params: The dict of query parameters to send.
path: The full API endpoint to call (including the _matrix/... prefix)
content: The content to post as a dict/list (will be serialized as JSON)
or bytes/str (will be sent as-is).
headers: A dict of HTTP headers to send.
If the headers don't contain ``Content-Type``, it'll be set to ``application/json``.
The ``Authorization`` header is always overridden if :attr:`token` is set.
query_params: A dict of query parameters to send.
Returns:
The response as a dict.
The parsed response JSON.
"""
content = content or {}
headers = headers or {}
Expand All @@ -203,18 +219,22 @@ async def request(self, method: Method, path: PathBuilder,

self._log_request(method, path, content, orig_content, query_params)

endpoint = self.base_url + str(path)
path = str(path)
if path and path[0] == "/":
path = path[1:]

try:
return await self._send(method, endpoint, content, query_params, headers or {})
return await self._send(method, self.base_url / path,
content, query_params, headers or {})
except ClientError as e:
raise MatrixConnectionError(str(e)) from e

def get_txn_id(self) -> str:
"""Get a new unique transaction ID."""
self.txn_id += 1
return str(self.txn_id) + str(int(time() * 1000))
return f"mautrix-python_R{self.txn_id}@T{int(time() * 1000)}"

def get_download_url(self, mxc_uri: str, download_type: str = "download") -> str:
def get_download_url(self, mxc_uri: str, download_type: str = "download") -> URL:
"""
Get the full HTTP URL to download a mxc:// URI.
Expand All @@ -234,6 +254,6 @@ def get_download_url(self, mxc_uri: str, download_type: str = "download") -> str
"https://matrix.org/_matrix/media/r0/download/matrix.org/pqjkOuKZ1ZKRULWXgz2IVZV6"
"""
if mxc_uri.startswith("mxc://"):
return f"{self.base_url}{APIPath.MEDIA}/{download_type}/{mxc_uri[6:]}"
return self.base_url / str(APIPath.MEDIA) / download_type / mxc_uri[6:]
else:
raise ValueError("MXC URI did not begin with `mxc://`")
6 changes: 4 additions & 2 deletions mautrix/appservice/api/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import asyncio

from aiohttp import ClientSession
from yarl import URL

from mautrix.types import UserID
from mautrix.api import HTTPAPI, Method, PathBuilder
Expand Down Expand Up @@ -39,7 +40,7 @@ class AppServiceAPI(HTTPAPI):

_bot_intent: Optional[IntentAPI]

def __init__(self, base_url: str, bot_mxid: UserID = None, token: str = None,
def __init__(self, base_url: Union[URL, str], bot_mxid: UserID = None, token: str = None,
identity: Optional[UserID] = None, log: TraceLogger = None,
state_store: 'ASStateStore' = None, client_session: ClientSession = None,
child: bool = False, real_user: bool = False,
Expand Down Expand Up @@ -96,7 +97,8 @@ def user(self, user: UserID) -> 'ChildAppServiceAPI':
self.children[user] = child
return child

def real_user(self, mxid: UserID, token: str, base_url: Optional[str] = None) -> 'AppServiceAPI':
def real_user(self, mxid: UserID, token: str, base_url: Optional[URL] = None
) -> 'AppServiceAPI':
"""
Get the AppServiceAPI for a real (non-appservice-managed) Matrix user.
Expand Down

0 comments on commit cc43498

Please sign in to comment.