Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: 支持 WebSocket 连接同时获取 str 或 bytes #962

Merged
merged 3 commits into from May 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 17 additions & 4 deletions nonebot/drivers/aiohttp.py
Expand Up @@ -132,20 +132,33 @@ async def _receive(self) -> aiohttp.WSMessage:

@overrides(BaseWebSocket)
async def receive(self) -> str:
msg = await self._receive()
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data

@overrides(BaseWebSocket)
async def receive_text(self) -> str:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data

@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
if msg.type != aiohttp.WSMsgType.BINARY:
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data

@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
await self.websocket.send_str(data)

@overrides(BaseWebSocket)
Expand Down
16 changes: 13 additions & 3 deletions nonebot/drivers/fastapi.py
Expand Up @@ -11,7 +11,7 @@

import logging
from functools import wraps
from typing import Any, List, Tuple, Callable, Optional
from typing import Any, List, Tuple, Union, Callable, Optional

import uvicorn
from pydantic import BaseSettings
Expand All @@ -36,6 +36,8 @@ async def decorator(*args, **kwargs):
return await func(*args, **kwargs)
except WebSocketDisconnect as e:
raise WebSocketClosed(e.code)
except KeyError:
raise TypeError("WebSocket received unexpected frame type")

return decorator

Expand Down Expand Up @@ -261,9 +263,17 @@ async def close(
) -> None:
await self.websocket.close(code)

@overrides(BaseWebSocket)
async def receive(self) -> Union[str, bytes]:
# assert self.websocket.application_state == WebSocketState.CONNECTED
msg = await self.websocket.receive()
if msg["type"] == "websocket.disconnect":
raise WebSocketClosed(msg["code"])
return msg["text"] if "text" in msg else msg["bytes"]

@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
async def receive_text(self) -> str:
return await self.websocket.receive_text()

@overrides(BaseWebSocket)
Expand All @@ -272,7 +282,7 @@ async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()

@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data})

@overrides(BaseWebSocket)
Expand Down
10 changes: 5 additions & 5 deletions nonebot/drivers/httpx.py
Expand Up @@ -49,22 +49,22 @@ def type(self) -> str:
async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient(
http2=setup.version == HTTPVersion.H2,
proxies=setup.proxy,
proxies=setup.proxy, # type: ignore
follow_redirects=True,
) as client:
response = await client.request(
setup.method,
str(setup.url),
content=setup.content,
data=setup.data,
content=setup.content, # type: ignore
data=setup.data, # type: ignore
json=setup.json,
files=setup.files,
files=setup.files, # type: ignore
headers=tuple(setup.headers.items()),
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers,
headers=response.headers.multi_items(),
content=response.content,
request=setup,
)
Expand Down
15 changes: 10 additions & 5 deletions nonebot/drivers/quart.py
Expand Up @@ -17,7 +17,7 @@

import asyncio
from functools import wraps
from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine
from typing import List, Tuple, Union, TypeVar, Callable, Optional, Coroutine

import uvicorn
from pydantic import BaseSettings
Expand Down Expand Up @@ -199,7 +199,7 @@ async def _handle_http(self, setup: HTTPServerSetup) -> Response:
http_request = BaseRequest(
request.method,
request.url,
headers=request.headers.items(),
headers=list(request.headers.items()),
cookies=list(request.cookies.items()),
content=await request.get_data(
cache=False, as_text=False, parse_form_data=False
Expand All @@ -224,7 +224,7 @@ async def _handle_ws(self, setup: WebSocketServerSetup) -> None:
http_request = BaseRequest(
websocket.method,
websocket.url,
headers=websocket.headers.items(),
headers=list(websocket.headers.items()),
cookies=list(websocket.cookies.items()),
version=websocket.http_version,
)
Expand Down Expand Up @@ -257,7 +257,12 @@ async def close(self, code: int = 1000, reason: str = ""):

@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
async def receive(self) -> Union[str, bytes]:
return await self.websocket.receive()

@overrides(BaseWebSocket)
@catch_closed
async def receive_text(self) -> str:
msg = await self.websocket.receive()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
Expand All @@ -272,7 +277,7 @@ async def receive_bytes(self) -> bytes:
return msg

@overrides(BaseWebSocket)
async def send(self, data: str):
async def send_text(self, data: str):
await self.websocket.send(data)

@overrides(BaseWebSocket)
Expand Down
16 changes: 11 additions & 5 deletions nonebot/drivers/websockets.py
Expand Up @@ -16,8 +16,8 @@
"""
import logging
from functools import wraps
from typing import Type, AsyncGenerator
from contextlib import asynccontextmanager
from typing import Type, Union, AsyncGenerator

from nonebot.typing import overrides
from nonebot.log import LoguruHandler
Expand Down Expand Up @@ -46,9 +46,9 @@ async def decorator(*args, **kwargs):
return await func(*args, **kwargs)
except ConnectionClosed as e:
if e.rcvd_then_sent:
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
else:
raise WebSocketClosed(e.sent.code, e.sent.reason)
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore

return decorator

Expand Down Expand Up @@ -100,7 +100,13 @@ async def close(self, code: int = 1000, reason: str = ""):

@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
async def receive(self) -> Union[str, bytes]:
msg = await self.websocket.recv()
return msg

@overrides(BaseWebSocket)
@catch_closed
async def receive_text(self) -> str:
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
Expand All @@ -115,7 +121,7 @@ async def receive_bytes(self) -> bytes:
return msg

@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
await self.websocket.send(data)

@overrides(BaseWebSocket)
Expand Down
18 changes: 16 additions & 2 deletions nonebot/internal/driver/model.py
Expand Up @@ -186,7 +186,12 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
raise NotImplementedError

@abc.abstractmethod
async def receive(self) -> str:
async def receive(self) -> Union[str, bytes]:
"""接收一条 WebSocket text/bytes 信息"""
raise NotImplementedError

@abc.abstractmethod
async def receive_text(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError

Expand All @@ -195,8 +200,17 @@ async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError

async def send(self, data: Union[str, bytes]) -> None:
"""发送一条 WebSocket text/bytes 信息"""
if isinstance(data, str):
await self.send_text(data)
elif isinstance(data, bytes):
await self.send_bytes(data)
else:
raise TypeError("WebSocker send method expects str or bytes!")

@abc.abstractmethod
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
"""发送一条 WebSocket text 信息"""
raise NotImplementedError

Expand Down
27 changes: 26 additions & 1 deletion tests/test_driver.py
Expand Up @@ -15,6 +15,7 @@
)
async def test_reverse_driver(app: App):
import nonebot
from nonebot.exception import WebSocketClosed
from nonebot.drivers import (
URL,
Request,
Expand All @@ -36,7 +37,21 @@ async def _handle_ws(ws: WebSocket) -> None:
data = await ws.receive()
assert data == "ping"
await ws.send("pong")
await ws.close()

data = await ws.receive()
assert data == b"ping"
await ws.send(b"pong")

data = await ws.receive_text()
assert data == "ping"
await ws.send("pong")

data = await ws.receive_bytes()
assert data == b"ping"
await ws.send(b"pong")

with pytest.raises(WebSocketClosed):
await ws.receive()

http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
driver.setup_http_server(http_setup)
Expand All @@ -53,3 +68,13 @@ async def _handle_ws(ws: WebSocket) -> None:
async with client.websocket_connect("/ws_test") as ws:
await ws.send_text("ping")
assert await ws.receive_text() == "pong"
await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong"

await ws.send_text("ping")
assert await ws.receive_text() == "pong"

await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong"

await ws.close()