Skip to content

Commit

Permalink
Proper event loop usage
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNaif2018 committed Oct 14, 2021
1 parent d6984b0 commit 4784303
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 20 deletions.
5 changes: 2 additions & 3 deletions api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def init_logging():
global logger
logger = get_logger(__name__)
sys.excepthook = excepthook_handler(sys.excepthook)
loop.set_exception_handler(handle_exception)
asyncio.get_running_loop().set_exception_handler(handle_exception)


# TODO: refactor it all into OOP style, i.e. class Settings
Expand Down Expand Up @@ -151,7 +151,6 @@ def get_coin(coin, xpub=None):
notifiers[notifier.name] = {"properties": properties, "required": required}

# initialize redis pool
loop = asyncio.get_event_loop()
redis_pool = None


Expand All @@ -164,7 +163,7 @@ async def init_redis():
async def init_db():
from api import db

await db.db.set_bind(db.CONNECTION_STR, min_size=1, loop=asyncio.get_event_loop())
await db.db.set_bind(db.CONNECTION_STR, min_size=1, loop=asyncio.get_running_loop())


def excepthook_handler(excepthook):
Expand Down
4 changes: 2 additions & 2 deletions api/views/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from starlette.endpoints import WebSocketEndpoint
from starlette.status import WS_1008_POLICY_VIOLATION

from api import models, settings, utils
from api import models, utils
from api.invoices import InvoiceStatus

router = APIRouter()
Expand Down Expand Up @@ -45,7 +45,7 @@ async def on_connect(self, websocket, **kwargs):
if await self.maybe_exit_early(websocket):
return
self.subscriber = await utils.redis.make_subscriber(f"{self.NAME}:{self.object_id}")
utils.tasks.create_task(self.poll_subs(websocket), loop=settings.loop)
utils.tasks.create_task(self.poll_subs(websocket))

async def poll_subs(self, websocket):
async for message in utils.redis.listen_channel(self.subscriber):
Expand Down
2 changes: 1 addition & 1 deletion daemons/bch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def create_daemon(self):
return daemon

async def shutdown_daemon(self):
if self.daemon:
if self.daemon and self.loop:
self.daemon.stop()
await self.loop.run_in_executor(None, self.daemon.join)

Expand Down
5 changes: 3 additions & 2 deletions daemons/btc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self):
self.wallets = {}
self.wallets_updates = {}
# initialize not yet created network
self.loop = asyncio.get_event_loop()
self.loop = None
self.network = None
self.fx = None
self.daemon = None
Expand Down Expand Up @@ -148,14 +148,15 @@ def register_callbacks(self, callback_function):

async def on_startup(self, app):
await super().on_startup(app)
self.loop = asyncio.get_running_loop()
self.daemon = self.create_daemon()
self.network = self.daemon.network
callback_function = self._process_events if self.ASYNC_CLIENT else self._process_events_sync
self.register_callbacks(callback_function)
self.fx = self.daemon.fx

async def shutdown_daemon(self):
if self.daemon:
if self.daemon and self.loop:
await self.loop.run_in_executor(None, self.daemon.on_stop)

async def on_shutdown(self, app):
Expand Down
6 changes: 3 additions & 3 deletions requirements/deterministic/web.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ bcrypt==3.2.0 \
# via
# paramiko
# passlib
bitcart==1.5.1.1 \
--hash=sha256:5303e09c787053c1e67aaa9c1b12780bf543912820b720c0c4731914baa3c00b \
--hash=sha256:6b5aafe02d82ebadb452d7475749ed691121a91e35caf0f6853917898bf6a65e
bitcart==1.6.0.2 \
--hash=sha256:46627da45edab07f4cd74236dece26b8c714c13ff19f2e07f1f5742d089a20a8 \
--hash=sha256:7510038edeb0c18388a34c8015d12840d2885002af9d9149a0458873dcb63300
# via -r requirements/web.txt
certifi==2021.10.8 \
--hash=sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872 \
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import os
import shutil

import pytest
from async_asgi_testclient import TestClient as AsyncClient
from starlette.testclient import TestClient

from api import models, settings
from api import models
from api.db import db
from main import app

Expand All @@ -29,9 +30,9 @@ async def cleanup_db():
await conn.status(table.delete())


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="session")
def event_loop():
yield settings.loop
yield asyncio.get_event_loop_policy().get_event_loop()


@pytest.fixture(scope="session", autouse=True)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def test_make_subscriber():
sub = await utils.redis.make_subscriber("test")
assert isinstance(sub, PubSub)
await sub.subscribe("channel:test")
utils.tasks.create_task(reader(sub), loop=settings.loop)
utils.tasks.create_task(reader(sub))
assert await utils.redis.publish_message("test", {"hello": "world"}) == 1


Expand Down Expand Up @@ -197,12 +197,11 @@ async def test_custom_create_task(caplog):
async def task():
raise Exception(err_msg)

loop = asyncio.get_event_loop()
utils.tasks.create_task(task(), loop=loop)
utils.tasks.create_task(task())
await asyncio.sleep(1)
assert err_msg in caplog.text
caplog.clear()
utils.tasks.create_task(task(), loop=loop).cancel()
utils.tasks.create_task(task()).cancel()
await asyncio.sleep(1)
assert err_msg not in caplog.text

Expand Down
10 changes: 9 additions & 1 deletion tests/test_views/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,19 @@ class DummyInstance:
coin_name = "BTC"


def is_event_loop_running():
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False


def get_future_return_value(return_val):
future = asyncio.Future()
future.set_result(return_val)
minor_ver = int(platform.python_version_tuple()[1])
return future if minor_ver < 8 or asyncio.get_event_loop().is_running() else return_val
return future if minor_ver < 8 or is_event_loop_running() else return_val


def test_docs_root(client: TestClient):
Expand Down
2 changes: 1 addition & 1 deletion worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ def handler(signum, frame):
process.start()
wait_for_port()
signal.signal(signal.SIGINT, handler)
asyncio.get_event_loop().run_until_complete(main())
asyncio.run(main())

0 comments on commit 4784303

Please sign in to comment.