Skip to content

Commit

Permalink
do not use insecure mode in tests, tests for server-side subs (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia committed Feb 6, 2024
1 parent ddd1d2e commit 319d506
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
uses: crazy-max/ghaction-setup-docker@v3

- name: Start Centrifugo
run: docker run -d -p 8000:8000 -e CENTRIFUGO_PRESENCE=true -e CENTRIFUGO_JOIN_LEAVE=true -e CENTRIFUGO_FORCE_PUSH_JOIN_LEAVE=true -e CENTRIFUGO_HISTORY_TTL=300s -e CENTRIFUGO_HISTORY_SIZE=100 -e CENTRIFUGO_FORCE_RECOVERY=true centrifugo/centrifugo:v5 centrifugo --client_insecure
run: docker run -d -p 8000:8000 -e CENTRIFUGO_TOKEN_HMAC_SECRET_KEY="secret" -e CENTRIFUGO_PRESENCE=1 -e CENTRIFUGO_JOIN_LEAVE=true -e CENTRIFUGO_FORCE_PUSH_JOIN_LEAVE=true -e CENTRIFUGO_HISTORY_TTL=300s -e CENTRIFUGO_HISTORY_SIZE=100 -e CENTRIFUGO_FORCE_RECOVERY=true -e CENTRIFUGO_USER_SUBSCRIBE_TO_PERSONAL=true -e CENTRIFUGO_ALLOW_PUBLISH_FOR_SUBSCRIBER=true -e CENTRIFUGO_ALLOW_PRESENCE_FOR_SUBSCRIBER=true -e CENTRIFUGO_ALLOW_HISTORY_FOR_SUBSCRIBER=true centrifugo/centrifugo:v5 centrifugo

- name: Install dependencies
run: |
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: proto test lint lint-ci
.PHONY: proto test lint lint-fix lint-ci

dev:
pip install -e ".[dev]"
Expand All @@ -12,5 +12,8 @@ test:
lint:
ruff .

lint-fix:
ruff . --fix

lint-ci:
ruff . --output-format=github
4 changes: 0 additions & 4 deletions centrifuge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .contexts import (
ConnectedContext,
ConnectingContext,
ClientTokenContext,
DisconnectedContext,
ErrorContext,
JoinContext,
Expand All @@ -19,7 +18,6 @@
SubscribedContext,
SubscribingContext,
SubscriptionErrorContext,
SubscriptionTokenContext,
UnsubscribedContext,
)
from .exceptions import (
Expand Down Expand Up @@ -50,7 +48,6 @@
"ClientEventHandler",
"ClientInfo",
"ClientState",
"ClientTokenContext",
"ConnectedContext",
"ConnectingContext",
"DisconnectedContext",
Expand Down Expand Up @@ -80,7 +77,6 @@
"SubscriptionErrorContext",
"SubscriptionEventHandler",
"SubscriptionState",
"SubscriptionTokenContext",
"SubscriptionUnsubscribedError",
"UnauthorizedError",
"UnsubscribedContext",
Expand Down
16 changes: 7 additions & 9 deletions centrifuge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from centrifuge.contexts import (
ConnectedContext,
ConnectingContext,
ClientTokenContext,
DisconnectedContext,
ErrorContext,
JoinContext,
Expand All @@ -41,7 +40,6 @@
SubscribedContext,
SubscribingContext,
SubscriptionErrorContext,
SubscriptionTokenContext,
UnsubscribedContext,
ServerSubscribedContext,
ServerPublicationContext,
Expand Down Expand Up @@ -131,7 +129,7 @@ def __init__(
address: str,
events: Optional[ClientEventHandler] = None,
token: str = "",
get_token: Optional[Callable[[ClientTokenContext], Awaitable[str]]] = None,
get_token: Optional[Callable[[], Awaitable[str]]] = None,
use_protobuf: bool = False,
timeout: float = 5.0,
max_server_ping_delay: float = 10.0,
Expand Down Expand Up @@ -196,7 +194,7 @@ def new_subscription(
channel: str,
events: Optional[SubscriptionEventHandler] = None,
token: str = "",
get_token: Optional[Callable[[SubscriptionTokenContext], Awaitable[str]]] = None,
get_token: Optional[Callable[[str], Awaitable[str]]] = None,
data: Optional[Any] = None,
min_resubscribe_delay=0.1,
max_resubscribe_delay=10.0,
Expand Down Expand Up @@ -308,7 +306,7 @@ async def _create_connection(self) -> bool:

if not self._token and self._get_token:
try:
token = await self._get_token(ClientTokenContext())
token = await self._get_token()
except Exception as e:
if isinstance(e, UnauthorizedError):
code = _DisconnectedCode.UNAUTHORIZED
Expand Down Expand Up @@ -551,7 +549,7 @@ async def _refresh(self) -> None:
cmd_id = self._next_command_id()

try:
token = await self._get_token(ClientTokenContext())
token = await self._get_token()
except Exception as e:
if isinstance(e, UnauthorizedError):
code = _DisconnectedCode.UNAUTHORIZED
Expand Down Expand Up @@ -608,7 +606,7 @@ async def _sub_refresh(self, channel: str):
return

try:
token = await sub._get_token(SubscriptionTokenContext(channel=channel))
token = await sub._get_token(channel)
except Exception as e:
if isinstance(e, UnauthorizedError):
code = _UnsubscribedCode.UNAUTHORIZED
Expand Down Expand Up @@ -690,7 +688,7 @@ async def _subscribe(self, channel):

if not sub._token and sub._get_token:
try:
token = await sub._get_token(SubscriptionTokenContext(channel=channel))
token = await sub._get_token(channel)
except Exception as e:
if isinstance(e, UnauthorizedError):
code = _UnsubscribedCode.UNAUTHORIZED
Expand Down Expand Up @@ -1347,7 +1345,7 @@ def _initialize(
channel: str,
events: Optional[SubscriptionEventHandler] = None,
token: str = "",
get_token: Optional[Callable[[SubscriptionTokenContext], Awaitable[str]]] = None,
get_token: Optional[Callable[[str], Awaitable[str]]] = None,
data: Optional[Any] = None,
min_resubscribe_delay: float = 0.1,
max_resubscribe_delay: float = 10.0,
Expand Down
12 changes: 0 additions & 12 deletions centrifuge/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,3 @@ class SubscriptionErrorContext:

code: int
error: Exception


@dataclass
class ClientTokenContext:
"""ClientTokenContext is a context passed to get_token callback of connection."""


@dataclass
class SubscriptionTokenContext:
"""SubscriptionTokenContext is a context passed to get_token callback of subscription."""

channel: str
19 changes: 12 additions & 7 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
ClientEventHandler,
ConnectedContext,
ConnectingContext,
ClientTokenContext,
DisconnectedContext,
ErrorContext,
JoinContext,
Expand All @@ -17,7 +16,6 @@
SubscribedContext,
SubscribingContext,
SubscriptionErrorContext,
SubscriptionTokenContext,
UnsubscribedContext,
SubscriptionEventHandler,
ServerSubscribedContext,
Expand All @@ -37,11 +35,11 @@
cf_logger.setLevel(logging.DEBUG)


async def get_token(ctx: ClientTokenContext) -> str:
async def get_client_token() -> str:
# To reject connection raise centrifuge.UnauthorizedError() exception:
# raise centrifuge.UnauthorizedError()

logging.info("get connection token called: %s", ctx)
logging.info("get client token called")

# REPLACE with your own logic to get token from the backend!
example_token = (
Expand All @@ -51,11 +49,11 @@ async def get_token(ctx: ClientTokenContext) -> str:
return example_token


async def get_subscription_token(ctx: SubscriptionTokenContext) -> str:
async def get_subscription_token(channel: str) -> str:
# To reject subscription raise centrifuge.UnauthorizedError() exception:
# raise centrifuge.UnauthorizedError()

logging.info("get subscription token called: %s", ctx)
logging.info("get subscription token called for channel %s", channel)

# REPLACE with your own logic to get token from the backend!
example_token = (
Expand Down Expand Up @@ -129,7 +127,7 @@ def run_example():
client = Client(
"ws://localhost:8000/connection/websocket",
events=ClientEventLoggerHandler(),
get_token=get_token,
get_token=get_client_token,
use_protobuf=False,
)

Expand Down Expand Up @@ -178,6 +176,13 @@ async def run():
async def shutdown(received_signal):
logging.info("received exit signal %s...", received_signal.name)
await client.disconnect()

tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()

logging.info("Cancelling outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True)
loop.stop()

signals = (signal.SIGTERM, signal.SIGINT)
Expand Down
Loading

0 comments on commit 319d506

Please sign in to comment.