Skip to content

Commit

Permalink
Merge 25192c5 into 6ab7918
Browse files Browse the repository at this point in the history
  • Loading branch information
onerandomusername authored Nov 27, 2021
2 parents 6ab7918 + 25192c5 commit 15270e0
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 8 deletions.
8 changes: 8 additions & 0 deletions modmail/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import asyncio
import logging
import logging.handlers
import os
from pathlib import Path

import coloredlogs

from modmail.log import ModmailLogger


# On windows aiodns's asyncio support relies on APIs like add_reader (which aiodns uses)
# are not guaranteed to be available, and in particular are not available when using the
# ProactorEventLoop on Windows, this method is only supported with Windows SelectorEventLoop
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

logging.TRACE = 5
logging.NOTICE = 25
logging.addLevelName(logging.TRACE, "TRACE")
Expand Down
41 changes: 34 additions & 7 deletions modmail/bot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
import logging
import signal
import socket
import typing as t
from typing import Any

import aiohttp
import arrow
import discord
from aiohttp import ClientSession
from discord import Activity, AllowedMentions, Intents
from discord.client import _cleanup_loop
from discord.ext import commands
Expand Down Expand Up @@ -41,9 +41,12 @@ class ModmailBot(commands.Bot):
def __init__(self, **kwargs):
self.config = CONFIG
self.start_time: t.Optional[arrow.Arrow] = None # arrow.utcnow()
self.http_session: t.Optional[ClientSession] = None
self.http_session: t.Optional[aiohttp.ClientSession] = None
self.dispatcher = Dispatcher()

self._connector = None
self._resolver = None

status = discord.Status.online
activity = Activity(type=discord.ActivityType.listening, name="users dming me!")
# listen to messages mentioning the bot or matching the prefix
Expand All @@ -65,6 +68,24 @@ def __init__(self, **kwargs):
**kwargs,
)

async def create_connectors(self, *args, **kwargs) -> None:
"""Re-create the connector and set up sessions before logging into Discord."""
# Use asyncio for DNS resolution instead of threads so threads aren't spammed.
self._resolver = aiohttp.AsyncResolver()

# Use AF_INET as its socket family to prevent HTTPS related problems both locally
# and in production.
self._connector = aiohttp.TCPConnector(
resolver=self._resolver,
family=socket.AF_INET,
)

# Client.login() will call HTTPClient.static_login() which will create a session using
# this connector attribute.
self.http.connector = self._connector

self.http_session = aiohttp.ClientSession(connector=self._connector)

async def start(self, token: str, reconnect: bool = True) -> None:
"""
Start the bot.
Expand All @@ -74,8 +95,8 @@ async def start(self, token: str, reconnect: bool = True) -> None:
"""
try:
# create the aiohttp session
self.http_session = ClientSession(loop=self.loop)
self.logger.trace("Created ClientSession.")
await self.create_connectors()
self.logger.trace("Created aiohttp.ClientSession.")
# set start time to when we started the bot.
# This is now, since we're about to connect to the gateway.
# This should also be before we load any extensions, since if they have a load time, it should
Expand Down Expand Up @@ -122,7 +143,7 @@ def run(self, *args, **kwargs) -> None:
except NotImplementedError:
pass

def stop_loop_on_completion(f: Any) -> None:
def stop_loop_on_completion(f: t.Any) -> None:
loop.stop()

future = asyncio.ensure_future(self.start(*args, **kwargs), loop=loop)
Expand Down Expand Up @@ -164,10 +185,16 @@ async def close(self) -> None:
except Exception:
self.logger.error(f"Exception occured while removing cog {cog.name}", exc_info=True)

await super().close()

if self.http_session:
await self.http_session.close()

await super().close()
if self._connector:
await self._connector.close()

if self._resolver:
await self._resolver.close()

def load_extensions(self) -> None:
"""Load all enabled extensions."""
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ flake8-todo = "~=0.7"
isort = "^5.9.2"
pep8-naming = "~=0.11"
# testing
aioresponses = "^0.7.2"
coverage = { extras = ["toml"], version = "^6.0.2" }
coveralls = "^3.3.1"
pytest = "^6.2.4"
Expand Down
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +1,29 @@
import aiohttp
import aioresponses
import pytest


@pytest.fixture
def aioresponse():
"""Fixture to mock aiohttp responses."""
with aioresponses.aioresponses() as aioresponse:
yield aioresponse


@pytest.fixture
@pytest.mark.asyncio
async def http_session(aioresponse) -> aiohttp.ClientSession:
"""
Fixture function for a aiohttp.ClientSession.
Requests fixture aioresponse to ensure that all client sessions do not make actual requests.
"""
resolver = aiohttp.AsyncResolver()
connector = aiohttp.TCPConnector(resolver=resolver)
client_session = aiohttp.ClientSession(connector=connector)

yield client_session

await client_session.close()
await connector.close()
await resolver.close()
42 changes: 42 additions & 0 deletions tests/test_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import aiohttp
import pytest


if TYPE_CHECKING:
import aioresponses


class TestSessionFixture:
"""Grouping for aiohttp.ClientSession fixture tests."""

@pytest.mark.asyncio
async def test_session_fixture_no_requests(self, http_session: aiohttp.ClientSession):
"""
Test all requests fail.
This means that aioresponses is being requested by the http_session fixture.
"""
url = "https://github.com/"

with pytest.raises(aiohttp.ClientConnectionError):
await http_session.get(url)

@pytest.mark.asyncio
async def test_session_fixture_mock_requests(
self, aioresponse: aioresponses.aioresponses, http_session: aiohttp.ClientSession
):
"""
Test all requests fail.
This means that aioresponses is being requested by the http_session fixture.
"""
url = "https://github.com/"
status = 200
aioresponse.get(url, status=status)

async with http_session.get(url) as resp:
assert status == resp.status

0 comments on commit 15270e0

Please sign in to comment.