Skip to content

Commit

Permalink
[Main] block/unblock users
Browse files Browse the repository at this point in the history
  • Loading branch information
noahkw committed Feb 2, 2024
1 parent 0bf7875 commit ebc5f1d
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 18 deletions.
25 changes: 24 additions & 1 deletion botwbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, config, **kwargs):
self.prefixes = {} # guild.id -> prefix
self.whitelist = set() # guild.id
self.custom_emoji = {} # name -> Emoji instance
self.blocked_users = {} # guild.id -> set(user.id)

self.channel_locker = ChannelLocker()

Expand All @@ -72,6 +73,13 @@ async def on_ready(self):
if guild_settings.whitelisted:
self.whitelist.add(guild_settings._guild)

blocked_users = await db.get_blocked_users(session)
for blocked_user in blocked_users:
blocked_users_in_guild = self.blocked_users.setdefault(
blocked_user._guild, set()
)
blocked_users_in_guild.add(blocked_user._user)

for name, emoji_name in CUSTOM_EMOJI.items():
emoji = discord.utils.find(lambda e: e.name == emoji_name, self.emojis)
if emoji is None:
Expand Down Expand Up @@ -211,10 +219,25 @@ async def get_guilds_for_cog(self, cog: Cog) -> typing.Optional[set[discord.Guil

return guild_objs

def is_author_blocked_in_guild(self, author: discord.Member, guild: discord.Guild):
blocked_users_in_guild = (
self.blocked_users.get(guild.id) if guild is not None else None
)

return (
blocked_users_in_guild is not None
and not author.guild_permissions.administrator
and author.id in blocked_users_in_guild
)

async def process_commands(self, message):
ctx = await self.get_context(message)

if ctx.command is None or message.author.id in self.blacklist:
if (
ctx.command is None
or message.author.id in self.blacklist
or self.is_author_blocked_in_guild(message.author, message.guild)
):
return

# spam control
Expand Down
16 changes: 12 additions & 4 deletions cogs/CustomRoles.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,18 @@ async def create(self, ctx: commands.Context):
async def creation_callback(result: RoleCreatorResult):
async with self.bot.Session() as session:
member = ctx.guild.get_member(result.user_id)
if not member:
if member is None:
return

custom_role_settings = await db.get_custom_role_settings(
session, ctx.guild.id
)
if custom_role_settings is None:
logger.info(
"creation callback in %s (%d) triggered but custom roles not set up",
str(ctx.guild),
ctx.guild.id,
)
return

try:
Expand All @@ -111,9 +122,6 @@ async def creation_callback(result: RoleCreatorResult):
name=result.name,
color=Color.from_str("#" + result.color),
)
custom_role_settings = await db.get_custom_role_settings(
session, ctx.guild.id
)
await ctx.guild.edit_role_positions(
{role: custom_role_settings.role.position},
reason=f"Moving new custom role above {custom_role_settings.role}",
Expand Down
42 changes: 36 additions & 6 deletions cogs/Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from discord import app_commands
from discord.ext import commands

import db
from botwbot import BotwBot
from menu import Confirm
from models import GuildSettings, GuildCog
from models import GuildSettings, GuildCog, BlockedUser
from util import safe_send, safe_mention, detail_mention, ack

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self, bot):

@commands.command(brief="Sets the prefix for the current guild")
@commands.has_permissions(administrator=True)
async def prefix(self, ctx, prefix=None):
async def prefix(self, ctx: commands.Context, prefix=None):
if not prefix:
await ctx.send(
f"The current prefix is `{self.bot.prefixes[ctx.guild.id]}`."
Expand All @@ -77,7 +78,7 @@ async def prefix(self, ctx, prefix=None):
)

@commands.hybrid_command(brief="Request the bot for your guild")
async def invite(self, ctx, guild_id: int):
async def invite(self, ctx: commands.Context, guild_id: int):
await self.bot.get_user(self.bot.CREATOR_ID).send(
f"Request to whitelist guild `{guild_id}` from {detail_mention(ctx.author)}."
)
Expand Down Expand Up @@ -118,7 +119,7 @@ async def whitelist_cog(self, ctx: commands.Context, guild_id: str, cog_name: st

@commands.command(brief="Adds a guild to the whitelist")
@commands.is_owner()
async def whitelist(self, ctx, guild_id: int, requester_id: int):
async def whitelist(self, ctx: commands.Context, guild_id: int, requester_id: int):
requester = self.bot.get_user(requester_id)

confirm = await Confirm(f"Whitelist guild `{guild_id}`?").prompt(ctx)
Expand Down Expand Up @@ -147,7 +148,7 @@ async def whitelist(self, ctx, guild_id: int, requester_id: int):

@commands.command(brief="Removes a guild from the whitelist")
@commands.is_owner()
async def unwhitelist(self, ctx, guild_id: int):
async def unwhitelist(self, ctx: commands.Context, guild_id: int):
guild = self.bot.get_guild(guild_id)

confirm = await Confirm(
Expand Down Expand Up @@ -178,7 +179,7 @@ async def unwhitelist(self, ctx, guild_id: int):
@commands.command(brief="Sets the bot's activity")
@commands.is_owner()
@ack
async def activity(self, ctx, type_, *, message):
async def activity(self, ctx: commands.Context, type_, *, message):
try:
activity_type = discord.ActivityType[type_.lower()]
except KeyError:
Expand All @@ -191,6 +192,35 @@ async def activity(self, ctx, type_, *, message):
activity=discord.Activity(type=activity_type, name=message)
)

@commands.command(brief="Block a user from using the bot")
@commands.has_permissions(administrator=True)
@ack
async def block(self, ctx: commands.Context, member: discord.Member):
if member is None:
raise commands.BadArgument("Need a user.")

async with self.bot.Session() as session:
blocked_user = BlockedUser(_guild=ctx.guild.id, _user=member.id)
await session.merge(blocked_user)
await session.commit()

self.bot.blocked_users.setdefault(ctx.guild.id, set()).add(member.id)

@commands.command(brief="Unblock a user from using the bot")
@commands.has_permissions(administrator=True)
@ack
async def unblock(self, ctx: commands.Context, member: discord.Member):
if member is None:
raise commands.BadArgument("Need a user.")

async with self.bot.Session() as session:
await db.delete_blocked_user(session, ctx.guild.id, member.id)
await session.commit()

blocked_users_in_guild = self.bot.blocked_users.get(ctx.guild.id)
if blocked_users_in_guild is not None:
blocked_users_in_guild.remove(member.id)

@commands.Cog.listener()
async def on_guild_join(self, guild: discord.Guild):
await self.bot.get_user(self.bot.CREATOR_ID).send(
Expand Down
18 changes: 18 additions & 0 deletions db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
GuildCog,
CustomRole,
CustomRoleSettings,
BlockedUser,
)


Expand Down Expand Up @@ -391,6 +392,23 @@ async def get_custom_role_settings(
return result[0] if result else None


async def get_blocked_users(session: AsyncSession) -> list[BlockedUser]:
statement = select(BlockedUser)
result = (await session.execute(statement)).all()

return [r for (r,) in result]


async def delete_blocked_user(
session: AsyncSession, guild_id: int, user_id: int
) -> None:
statement = delete(BlockedUser).where(
(BlockedUser._guild == guild_id) & (BlockedUser._user == user_id)
)

await session.execute(statement)


async def get_greeter(session, guild_id, greeter_type):
statement = select(Greeter).where(
(Greeter._guild == guild_id) & (Greeter.type == greeter_type)
Expand Down
2 changes: 2 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .blocked_users import BlockedUser
from .botw import BotwState, BotwWinner, Nomination, Idol, BotwSettings
from .channel_mirror import ChannelMirror
from .custom_role import CustomRole, CustomRoleSettings
Expand All @@ -11,6 +12,7 @@
from .twitter import TwtSetting, TwtAccount, TwtSorting, TwtFilter

__all__ = (
"BlockedUser",
"BotwState",
"BotwWinner",
"ChannelMirror",
Expand Down
16 changes: 16 additions & 0 deletions models/blocked_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from discord import Member
from sqlalchemy import Column, BigInteger
from sqlalchemy.ext.hybrid import hybrid_property

from models.base import Base
from models.guild_settings import GuildSettingsMixin


class BlockedUser(GuildSettingsMixin, Base):
__tablename__ = "blocked_users"

_user = Column(BigInteger, primary_key=True)

@hybrid_property
def member(self) -> Member:
return self.guild.get_member(self._user)
14 changes: 14 additions & 0 deletions views/base_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from discord import Interaction
from discord.ui import View


class BaseView(View):
"""
Has to be called and checked by extending classes.
I.e., if False, the interaction_check has to fail
"""

async def interaction_check(self, interaction: Interaction, /) -> bool:
return not interaction.client.is_author_blocked_in_guild(
interaction.user, interaction.guild
)
22 changes: 15 additions & 7 deletions views/role_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import discord
from discord import Interaction, Button
from discord.ui import View, TextInput, Modal, RoleSelect
from discord.ui import TextInput, Modal, RoleSelect

import db

from models import CustomRoleSettings
from views.base_view import BaseView


class RoleCreatorResult:
Expand Down Expand Up @@ -46,7 +47,7 @@ async def callback(self, interaction: Interaction) -> typing.Any:
await interaction.response.defer()


class CustomRoleSetup(View):
class CustomRoleSetup(BaseView):
def __init__(self):
super().__init__()

Expand Down Expand Up @@ -88,7 +89,7 @@ async def confirm(self, interaction: Interaction, button: Button):
)


class RoleCreatorView(CallbackView, View):
class RoleCreatorView(CallbackView, BaseView):
@discord.ui.button(label="Click me to start", style=discord.ButtonStyle.blurple)
async def create_role(self, interaction: Interaction, button: Button):
self.result = RoleCreatorResult()
Expand All @@ -97,6 +98,9 @@ async def create_role(self, interaction: Interaction, button: Button):
)

async def interaction_check(self, interaction: Interaction, /) -> bool:
if not await super().interaction_check(interaction):
return False

async with interaction.client.Session() as session:
custom_role = await db.get_user_custom_role_in_guild(
session, interaction.user.id, interaction.guild_id
Expand Down Expand Up @@ -126,7 +130,7 @@ async def interaction_check(self, interaction: Interaction, /) -> bool:
return False


class RoleCreatorNameConfirmationView(CallbackView, View):
class RoleCreatorNameConfirmationView(CallbackView, BaseView):
@discord.ui.button(label="I confirm", style=discord.ButtonStyle.blurple)
async def confirm(self, interaction: Interaction, button: Button):
self.stop()
Expand All @@ -142,7 +146,9 @@ async def retry(self, interaction: Interaction, button: Button):
)


class RoleCreatorNameModal(CallbackView, Modal, title="Choose your role's name"):
class RoleCreatorNameModal(
CallbackView, BaseView, Modal, title="Choose your role's name"
):
name = TextInput(
label="Role name",
placeholder="Your custom role name here...",
Expand All @@ -162,7 +168,7 @@ async def on_submit(self, interaction: Interaction) -> None:
)


class RoleCreatorColorConfirmationView(CallbackView, View):
class RoleCreatorColorConfirmationView(CallbackView, BaseView):
@discord.ui.button(label="I confirm", style=discord.ButtonStyle.blurple)
async def confirm(self, interaction: Interaction, button: Button):
self.stop()
Expand All @@ -181,7 +187,9 @@ async def retry(self, interaction: Interaction, button: Button):
)


class RoleCreatorColorModal(CallbackView, Modal, title="Choose your role's color"):
class RoleCreatorColorModal(
CallbackView, BaseView, Modal, title="Choose your role's color"
):
color = TextInput(
label="Role color",
placeholder="000000",
Expand Down

0 comments on commit ebc5f1d

Please sign in to comment.