Skip to content
This repository has been archived by the owner on Jan 5, 2022. It is now read-only.

Commit

Permalink
Make DB optional and do some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
norinorin committed Sep 24, 2021
1 parent 99d4643 commit e426033
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 50 deletions.
7 changes: 6 additions & 1 deletion nokari/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
load_dotenv()

if missing := [
var for var in ("DISCORD_BOT_TOKEN", "POSTGRESQL_DSN") if var not in os.environ
var
for var in (
"DISCORD_BOT_TOKEN",
# "POSTGRESQL_DSN"
)
if var not in os.environ
]:
raise RuntimeError(f"missing {', '.join(missing)} env variable{'s'*bool(missing)}")

Expand Down
79 changes: 42 additions & 37 deletions nokari/core/bot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A module that contains a custom command handler class implementation."""
from __future__ import annotations

import asyncio
import datetime
Expand All @@ -8,7 +9,6 @@
import os
import shutil
import sys
import traceback
import typing
from contextlib import suppress

Expand All @@ -32,6 +32,10 @@
from nokari.utils import db, human_timedelta

__all__: typing.Final[typing.List[str]] = ["Nokari"]
_CommandOrPluginT = typing.TypeVar(
"_CommandOrPluginT", bound=typing.Union[lightbulb.Plugin, lightbulb.Command]
)
_LOGGER = logging.getLogger("nokari.core.bot")


def _get_prefixes(bot: lightbulb.Bot, message: hikari.Message) -> typing.List[str]:
Expand Down Expand Up @@ -92,9 +96,6 @@ def __init__(self) -> None:
# Responses cache
self._resp_cache = LRU(1024)

# Setup logger
self.setup_logger()

# Non-modular commands
_ = [
self.add_command(g)
Expand All @@ -103,11 +104,12 @@ def __init__(self) -> None:
]

# Set Launch time
self.launch_time: typing.Optional[datetime.datetime] = None
self.launch_time: datetime.datetime | None = None

# Default prefixes
self.default_prefixes = ["nokari", "n!"]

# pylint: disable=redefined-outer-name
@functools.wraps(lightbulb.Bot._invoke_command)
async def _invoke_command(
self,
Expand Down Expand Up @@ -165,7 +167,7 @@ def loop(self) -> asyncio.AbstractEventLoop:
return asyncio.get_running_loop()

@property
def session(self) -> typing.Optional[aiohttp.ClientSession]:
def session(self) -> aiohttp.ClientSession | None:
"""Returns a ClientSession."""
return self.rest._get_live_attributes().client_session

Expand All @@ -175,30 +177,22 @@ def responses_cache(self) -> LRU:
return self._resp_cache

@property
def pool(self) -> typing.Optional[asyncpg.Pool]:
def pool(self) -> asyncpg.Pool | None:
return getattr(self, "_pool", None)

async def create_pool(self) -> None:
"""Creates a connection pool."""
self._pool = await db.create_pool()
if pool := await db.create_pool():
self._pool = pool

async def _load_prefixes(self) -> None:
self.prefixes = {
record["hash"]: record["prefixes"]
for record in await self._pool.fetch("SELECT * FROM prefixes")
}

def setup_logger(self) -> None:
"""Sets a logger that outputs to a file as well as stdout."""
self.log = logging.getLogger(self.__class__.__name__)

file_handler = logging.handlers.TimedRotatingFileHandler( # type: ignore
"nokari.log", when="D", interval=7
)
file_handler.setLevel(logging.INFO)
self.log.addHandler(file_handler)
if self.pool:
self.prefixes = {
record["hash"]: record["prefixes"]
for record in await self.pool.fetch("SELECT * FROM prefixes")
}

async def _resolve_prefix(self, message: hikari.Message) -> typing.Optional[str]:
async def _resolve_prefix(self, message: hikari.Message) -> str | None:
"""Case-insensitive prefix resolver."""
prefixes = await maybe_await(self.get_prefix, self, message)

Expand Down Expand Up @@ -258,20 +252,11 @@ def load_extensions(self) -> None:
try:
self.load_extension(extension)
except lightbulb.errors.ExtensionMissingLoad:
print(extension, "is missing load function.")
_LOGGER.error("%s is missing load function", extension)
except lightbulb.errors.ExtensionAlreadyLoaded:
pass
except lightbulb.errors.ExtensionError as _e:
print(extension, "failed to load.")
print(
" ".join(
traceback.format_exception(
type(_e or _e.__cause__),
_e or _e.__cause__,
_e.__traceback__,
)
)
)
_LOGGER.error("%s failed to load", exc_info=_e)

# pylint: disable=lost-exception
async def prompt(
Expand Down Expand Up @@ -340,10 +325,22 @@ def predicate(event: InteractionCreateEvent) -> bool:
return confirm

@property
def me(self) -> typing.Optional[hikari.OwnUser]:
def me(self) -> hikari.OwnUser | None:
"""Temp fix until lightbub updates."""
return self.get_me()

def add_plugin(
self, plugin: lightbulb.Plugin | typing.Type[lightbulb.Plugin]
) -> None:
if getattr(plugin, "__requires_db__", False) and self.pool is None:
if (name := getattr(plugin, "name", None)) is None:
name = plugin.__class__.name

_LOGGER.warning("Not loading %s plugin as it requires DB", name)
return None

return super().add_plugin(plugin)


@lightbulb.check(checks.owner_only)
@group(name="reload")
Expand Down Expand Up @@ -378,7 +375,7 @@ async def reload_module(ctx: Context, *, modules: str) -> None:
module = sys.modules[mod]
importlib.reload(module)
except Exception as e: # pylint: disable=broad-except
ctx.bot.log.error("Failed to reload %s", mod, exc_info=e)
_LOGGER.error("Failed to reload %s", mod, exc_info=e)
failed.add((mod, e.__class__.__name__))

for parent in parents:
Expand All @@ -388,8 +385,16 @@ async def reload_module(ctx: Context, *, modules: str) -> None:
module = sys.modules[".".join(parent_split[:idx])]
importlib.reload(module)
except Exception as e: # pylint: disable=broad-except
ctx.bot.log.error("Failed to reload parent %s", parent, exc_info=e)
_LOGGER.error("Failed to reload parent %s", parent, exc_info=e)

loaded = "\n".join(f"+ {i}" for i in modules ^ {x[0] for x in failed})
failed = "\n".join(f"- {m} {e}" for m, e in failed)
await ctx.respond(f"```diff\n{loaded}\n{failed}```")


def requires_db(command_or_plugin: _CommandOrPluginT) -> _CommandOrPluginT:
if isinstance(command_or_plugin, commands.Command):
command_or_plugin.disabled = True
else:
command_or_plugin.__requires_db__ = True
return command_or_plugin
6 changes: 4 additions & 2 deletions nokari/core/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import typing

from lightbulb import commands, context, errors
from lightbulb import commands
from lightbulb import context as context_
from lightbulb import errors

__all__: typing.Final[typing.List[str]] = ["Command", "command", "group"]
_CommandCallbackT = typing.TypeVar(
Expand All @@ -26,7 +28,7 @@ def __init__(
self.usage = usage
"""The custom command signature if specified."""

async def is_runnable(self, context: context.Context) -> bool:
async def is_runnable(self, context: context_.Context) -> bool:
if getattr(self, "disabled", False):
raise errors.CheckFailure("Command is disabled.")

Expand Down
4 changes: 3 additions & 1 deletion nokari/core/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A module that contains a custom Context class implementation."""

import logging
import time
import typing
from types import SimpleNamespace
Expand All @@ -14,6 +15,7 @@
from nokari.utils.perms import has_channel_perms, has_guild_perms

__all__: typing.Final[typing.List[str]] = ["Context"]
_LOGGER = logging.getLogger("nokari.core.context")


class Context(lightbulb.Context):
Expand Down Expand Up @@ -151,7 +153,7 @@ def execute_plugins(
else plugin
)
except Exception as _e: # pylint: disable=broad-except
self.bot.log.error("Failed to reload %s", plugin, exc_info=_e)
_LOGGER.error("Failed to reload %s", plugin, exc_info=_e)
failed.add((plugin, _e.__class__.__name__))

key = lambda s: (len(s), s)
Expand Down
1 change: 1 addition & 0 deletions nokari/plugins/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def convert_prefix(arg: WrappedArg) -> str:
return arg.data.strip().lower()


@core.bot.requires_db
class Config(plugins.Plugin):
"""A plugin that contains config commands."""

Expand Down
4 changes: 3 additions & 1 deletion nokari/plugins/extras/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import typing
from inspect import getmembers

Expand Down Expand Up @@ -25,6 +26,7 @@
typing.Literal[None],
],
)
_LOGGER = logging.getLogger("nokari.plugins.extras.errors")


def handle(
Expand Down Expand Up @@ -90,7 +92,7 @@ async def on_error(self, event: lightbulb.CommandErrorEvent) -> None:
# then it's an empty prefix
return

self.bot.log.error(
_LOGGER.error(
"Ignoring exception in command %s",
event.command and event.command.qualified_name,
exc_info=error,
Expand Down
11 changes: 9 additions & 2 deletions nokari/plugins/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Reminders based on RoboDanny.

import asyncio
import logging
import textwrap
import typing
from datetime import datetime, timedelta, timezone
Expand All @@ -14,6 +15,7 @@
from tabulate import tabulate

from nokari.core import command, group
from nokari.core.bot import requires_db
from nokari.core.context import Context
from nokari.utils import db, plural, timers
from nokari.utils.chunker import chunk, simple_chunk
Expand All @@ -24,6 +26,7 @@

MAX_DAYS: typing.Final[int] = 40
RETRY_IN: typing.Final[int] = 86400
_LOGGER = logging.getLogger("nokari.plugins.utils")


class SERIAL:
Expand All @@ -43,14 +46,18 @@ class Reminders(db.Table):
interval: db.Column[Snowflake] # BIGINT


# todo: move this to commands
# if there were commands that don't require DB.
@requires_db
class Utils(Plugin):
def __init__(self, bot: Bot) -> None:
self.bot = bot
super().__init__()

self.event = asyncio.Event()
self._current_timer: typing.Optional[timers.Timer] = None
self._task: asyncio.Task[None] = asyncio.create_task(self.dispatch_timers())
if bot.pool:
self._task: asyncio.Task[None] = asyncio.create_task(self.dispatch_timers())
self._remind_parser = (
ArgumentParser()
.interval("--interval", "-i", argmax=0, default=False)
Expand Down Expand Up @@ -83,7 +90,7 @@ async def wait_for_active_timers(self) -> timers.Timer:
async def call_timer(self, timer: timers.Timer) -> None:
args = [timer.id]

self.bot.log.debug("Dispatching timer with interval %s", timer.interval)
_LOGGER.debug("Dispatching timer with interval %s", timer.interval)

if timer.interval:
query = "UPDATE reminders SET expires_at = CURRENT_TIMESTAMP + $2 * interval '1 sec' WHERE id=$1"
Expand Down
9 changes: 6 additions & 3 deletions nokari/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

__all__: typing.Final[typing.List[str]] = ["require_env"]
CommandT = typing.TypeVar("CommandT", bound=Command)
_LOGGER = logging.getLogger("nokari.utils.checks")


def require_env(*vars_: str) -> typing.Callable[[Command], Command]:
Expand All @@ -16,9 +17,11 @@ def decorator(cmd: Command) -> Command:
"'require_env' decorator must be above the command decorator."
)

logging.warning(
f"Missing {', '.join(missing)} env variable{'s'*bool(missing)}, "
f"{cmd.name} will be disabled"
_LOGGER.warning(
"Missing %s env variable%s. %s will be disabled",
", ".join(missing),
"s" * bool(missing),
cmd.name,
)

cmd.disabled = True
Expand Down
2 changes: 2 additions & 0 deletions nokari/utils/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __getitem__(self, key: int) -> T:
def __getitem__(self, key: slice) -> "Indexable[T]":
...

# pylint: disable=non-iterator-returned
def __iter__(self) -> Iterator[T]:
...

Expand Down Expand Up @@ -82,6 +83,7 @@ def simple_chunk(
...


# pylint: disable=used-before-assignment
def simple_chunk(text: Any, length: Any, lazy: bool = False) -> Any:
"""A lite version of the chunk function."""
return (
Expand Down
9 changes: 6 additions & 3 deletions nokari/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,12 @@ def create_tables(
return con.execute(" ".join(statements))


def create_pool(
async def create_pool(
min_size: int = 3, max_size: int = 10, max_inactive_connection_lifetime: int = 60
) -> typing.Coroutine[typing.Any, typing.Any, asyncpg.Pool]:
) -> asyncpg.Pool | None:
if not (dsn := os.getenv("POSTGRESQL_DSN")):
return None

def _encode_jsonb(value: dict) -> str:
return json.dumps(value)

Expand All @@ -124,7 +127,7 @@ async def init(con: asyncpg.Connection) -> None:
)

return asyncpg.create_pool(
dsn=os.getenv("POSTGRESQL_DSN"),
dsn=dsn,
init=init,
min_size=min_size,
max_size=max_size,
Expand Down
1 change: 1 addition & 0 deletions nokari/utils/spotify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def wrapper(
spotify_text + metadata.album, self.SMALL_FONT, False
)

# pylint: disable=unsubscriptable-object
spotify_width = sum(
[spotify_album_c_mapping[char][0] for char in spotify_text]
)
Expand Down

0 comments on commit e426033

Please sign in to comment.