diff --git a/.test.env b/.test.env index def9d16..dfc9088 100644 --- a/.test.env +++ b/.test.env @@ -32,6 +32,8 @@ CHANNEL_SPOILER=2769521890099371011 CHANNEL_BOT_LOGS=1105517088266788925 # Roles +ROLE_VERIFIED=1333333333333333337 + ROLE_BIZCTF2022=7629466241011276950 ROLE_NOAH_GANG=6706800691011276950 ROLE_BUDDY_GANG=6706800681011276950 diff --git a/src/bot.py b/src/bot.py index f535903..f598844 100644 --- a/src/bot.py +++ b/src/bot.py @@ -12,6 +12,7 @@ MissingRequiredArgument, NoPrivateMessage, UserInputError, ) from sqlalchemy.exc import NoResultFound +from typing import TypeVar from src import trace_config from src.core import constants, settings @@ -20,6 +21,9 @@ logger = logging.getLogger(__name__) +BOT_TYPE = TypeVar("BOT_TYPE", "Bot", DiscordBot) + + class Bot(DiscordBot): """Base bot class.""" diff --git a/src/cmds/automation/auto_verify.py b/src/cmds/automation/auto_verify.py index 606ad3e..98c538d 100644 --- a/src/cmds/automation/auto_verify.py +++ b/src/cmds/automation/auto_verify.py @@ -2,12 +2,8 @@ from discord import Member, Message, User from discord.ext import commands -from sqlalchemy import select from src.bot import Bot -from src.database.models import HtbDiscordLink -from src.database.session import AsyncSessionLocal -from src.helpers.verification import get_user_details, process_identification logger = logging.getLogger(__name__) @@ -19,31 +15,11 @@ def __init__(self, bot: Bot): self.bot = bot async def process_reverification(self, member: Member | User) -> None: - """Re-verifation process for a member.""" - async with AsyncSessionLocal() as session: - stmt = ( - select(HtbDiscordLink) - .where(HtbDiscordLink.discord_user_id == member.id) - .order_by(HtbDiscordLink.id) - .limit(1) - ) - result = await session.scalars(stmt) - htb_discord_link: HtbDiscordLink = result.first() - - if not htb_discord_link: - raise VerificationError(f"HTB Discord link for user {member.name} with ID {member}") - - member_token: str = htb_discord_link.account_identifier - - if member_token is None: - raise VerificationError(f"HTB account identifier for user {member.name} with ID {member.id} not found") - - logger.debug(f"Processing re-verify of member {member.name} ({member.id}).") - htb_details = await get_user_details(member_token) - if htb_details is None: - raise VerificationError(f"Retrieving user details for user {member.name} with ID {member.id} failed") - - await process_identification(htb_details, user=member, bot=self.bot) + """Re-verifation process for a member. + + TODO: Reimplement once it's possible to fetch link state from the HTB Account. + """ + raise VerificationError("Not implemented") @commands.Cog.listener() @commands.cooldown(1, 60, commands.BucketType.user) @@ -74,4 +50,5 @@ class VerificationError(Exception): def setup(bot: Bot) -> None: """Load the `MessageHandler` cog.""" - bot.add_cog(MessageHandler(bot)) + # bot.add_cog(MessageHandler(bot)) + pass diff --git a/src/cmds/core/identify.py b/src/cmds/core/identify.py index ef52ed7..34a378a 100644 --- a/src/cmds/core/identify.py +++ b/src/cmds/core/identify.py @@ -1,17 +1,12 @@ import logging -from typing import Sequence -import discord from discord import ApplicationContext, Interaction, WebhookMessage, slash_command from discord.ext import commands from discord.ext.commands import cooldown -from sqlalchemy import select from src.bot import Bot from src.core import settings -from src.database.models import HtbDiscordLink -from src.database.session import AsyncSessionLocal -from src.helpers.verification import get_user_details, process_identification +from src.helpers.verification import send_verification_instructions logger = logging.getLogger(__name__) @@ -25,121 +20,15 @@ def __init__(self, bot: Bot): @slash_command( guild_ids=settings.guild_ids, description="Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user " - "ID.", guild_only=False + "ID.", + guild_only=False, ) @cooldown(1, 60, commands.BucketType.user) - async def identify(self, ctx: ApplicationContext, account_identifier: str) -> Interaction | WebhookMessage: - """Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user ID.""" - if len(account_identifier) != 60: - return await ctx.respond( - "This Account Identifier does not appear to be the right length (must be 60 characters long).", - ephemeral=True - ) - - await ctx.respond("Identification initiated, please wait...", ephemeral=True) - htb_user_details = await get_user_details(account_identifier) - if htb_user_details is None: - embed = discord.Embed(title="Error: Invalid account identifier.", color=0xFF0000) - return await ctx.respond(embed=embed, ephemeral=True) - - json_htb_user_id = htb_user_details["user_id"] - - author = ctx.user - member = await self.bot.get_or_fetch_user(author.id) - if not member: - return await ctx.respond(f"Error getting guild member with id: {author.id}.") - - # Step 1: Check if the Account Identifier has already been recorded and if they are the previous owner. - # Scenario: - # - I create a new Discord account. - # - I reuse my previous Account Identifier. - # - I now have an "alt account" with the same roles. - async with AsyncSessionLocal() as session: - stmt = ( - select(HtbDiscordLink) - .filter(HtbDiscordLink.account_identifier == account_identifier) - .order_by(HtbDiscordLink.id.desc()) - .limit(1) - ) - result = await session.scalars(stmt) - most_recent_rec: HtbDiscordLink = result.first() - - if most_recent_rec and most_recent_rec.discord_user_id_as_int != member.id: - error_desc = ( - f"Verified user {member.mention} tried to identify as another identified user.\n" - f"Current Discord UID: {member.id}\n" - f"Other Discord UID: {most_recent_rec.discord_user_id}\n" - f"Related HTB UID: {most_recent_rec.htb_user_id}" - ) - embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429) - await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - - return await ctx.respond( - "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True - ) - - # Step 2: Given the htb_user_id from JSON, check if each discord_user_id are different from member.id. - # Scenario: - # - I have a Discord account that is linked already to a "Hacker" role. - # - I create a new HTB account. - # - I identify with the new account. - # - `SELECT * FROM htb_discord_link WHERE htb_user_id = %s` will be empty, - # because the new account has not been verified before. All is good. - # - I am now "Noob" rank. - async with AsyncSessionLocal() as session: - stmt = select(HtbDiscordLink).filter(HtbDiscordLink.htb_user_id == json_htb_user_id) - result = await session.scalars(stmt) - user_links: Sequence[HtbDiscordLink] = result.all() - - discord_user_ids = {u_link.discord_user_id_as_int for u_link in user_links} - if discord_user_ids and member.id not in discord_user_ids: - orig_discord_ids = ", ".join([f"<@{id_}>" for id_ in discord_user_ids]) - error_desc = (f"The HTB account {json_htb_user_id} attempted to be identified by user <@{member.id}>, " - f"but is tied to another Discord account.\n" - f"Originally linked to Discord UID {orig_discord_ids}.") - embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429) - await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - - return await ctx.respond( - "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True - ) - - # Step 3: Check if discord_user_id already linked to an htb_user_id, and if JSON/db HTB IDs are the same. - # Scenario: - # - I have a new, unlinked Discord account. - # - Clubby generates a new token and gives it to me. - # - `SELECT * FROM htb_discord_link WHERE discord_user_id = %s` - # will be empty because I have not identified before. - # - I am now Clubby. - async with AsyncSessionLocal() as session: - stmt = select(HtbDiscordLink).filter(HtbDiscordLink.discord_user_id == member.id) - result = await session.scalars(stmt) - user_links: Sequence[HtbDiscordLink] = result.all() - - user_htb_ids = {u_link.htb_user_id_as_int for u_link in user_links} - if user_htb_ids and json_htb_user_id not in user_htb_ids: - error_desc = (f"User {member.mention} ({member.id}) tried to identify with a new HTB account.\n" - f"Original HTB UIDs: {', '.join([str(i) for i in user_htb_ids])}, new HTB UID: " - f"{json_htb_user_id}.") - embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429) - await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - - return await ctx.respond( - "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True - ) - - htb_discord_link = HtbDiscordLink( - account_identifier=account_identifier, discord_user_id=member.id, htb_user_id=json_htb_user_id - ) - async with AsyncSessionLocal() as session: - session.add(htb_discord_link) - await session.commit() - - await process_identification(htb_user_details, user=member, bot=self.bot) - - return await ctx.respond( - f"Your Discord user has been successfully identified as HTB user {json_htb_user_id}.", ephemeral=True - ) + async def identify( + self, ctx: ApplicationContext, account_identifier: str + ) -> Interaction | WebhookMessage: + """Legacy command. Now sends instructions to identify with HTB account.""" + await send_verification_instructions(ctx, ctx.author) def setup(bot: Bot) -> None: diff --git a/src/cmds/core/verify.py b/src/cmds/core/verify.py index 23088c1..919f872 100644 --- a/src/cmds/core/verify.py +++ b/src/cmds/core/verify.py @@ -1,14 +1,12 @@ import logging -import discord from discord import ApplicationContext, Interaction, WebhookMessage, slash_command -from discord.errors import Forbidden, HTTPException from discord.ext import commands from discord.ext.commands import cooldown from src.bot import Bot from src.core import settings -from src.helpers.verification import process_certification +from src.helpers.verification import process_certification, send_verification_instructions logger = logging.getLogger(__name__) @@ -47,57 +45,7 @@ async def verifycertification(self, ctx: ApplicationContext, certid: str, fullna @cooldown(1, 60, commands.BucketType.user) async def verify(self, ctx: ApplicationContext) -> Interaction | WebhookMessage: """Receive instructions in a DM on how to identify yourself with your HTB account.""" - member = ctx.user - - # Step one - embed_step1 = discord.Embed(color=0x9ACC14) - embed_step1.add_field( - name="Step 1: Log in at Hack The Box", - value="Go to the Hack The Box website at " - " and navigate to **Login > HTB Labs**. Log in to your HTB Account." - , inline=False, ) - embed_step1.set_image( - url="https://media.discordapp.net/attachments/724587782755844098/839871275627315250/unknown.png" - ) - - # Step two - embed_step2 = discord.Embed(color=0x9ACC14) - embed_step2.add_field( - name="Step 2: Locate the Account Identifier", - value='Click on your profile name, then select **My Profile**. ' - 'In the Profile Settings tab, find the field labeled **Account Identifier**. () ' - "Click the green button to copy your secret identifier.", inline=False, ) - embed_step2.set_image( - url="https://media.discordapp.net/attachments/724587782755844098/839871332963188766/unknown.png" - ) - - # Step three - embed_step3 = discord.Embed(color=0x9ACC14) - embed_step3.add_field( - name="Step 3: Identification", - value="Now type `/identify IDENTIFIER_HERE` in the bot-commands channel.\n\nYour roles will be " - "applied automatically.", inline=False - ) - embed_step3.set_image( - url="https://media.discordapp.net/attachments/709907130102317093/904744444539076618/unknown.png" - ) - - try: - await member.send(embed=embed_step1) - await member.send(embed=embed_step2) - await member.send(embed=embed_step3) - except Forbidden as ex: - logger.error("Exception during verify call", exc_info=ex) - return await ctx.respond( - "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server " - "members and try again in 1 minute." - ) - except HTTPException as ex: - logger.error("Exception during verify call.", exc_info=ex) - return await ctx.respond( - "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator." - ) - return await ctx.respond("Please check your DM for instructions.", ephemeral=True) + await send_verification_instructions(ctx, ctx.author) def setup(bot: Bot) -> None: diff --git a/src/core/config.py b/src/core/config.py index 6b3025e..d07de2c 100644 --- a/src/core/config.py +++ b/src/core/config.py @@ -91,6 +91,7 @@ class AcademyCertificates(BaseSettings): class Roles(BaseSettings): """The roles settings.""" + VERIFIED: int # Moderation COMMUNITY_MANAGER: int @@ -330,6 +331,8 @@ def load_settings(env_file: str | None = None): global_settings.roles.NOOB, global_settings.roles.VIP, global_settings.roles.VIP_PLUS, + ], + "ALL_SEASON_RANKS": [ global_settings.roles.SEASON_HOLO, global_settings.roles.SEASON_PLATINUM, global_settings.roles.SEASON_RUBY, diff --git a/src/helpers/ban.py b/src/helpers/ban.py index 8d2ddd0..92fe22e 100644 --- a/src/helpers/ban.py +++ b/src/helpers/ban.py @@ -1,11 +1,21 @@ """Helper methods to handle bans, mutes and infractions. Bot or message responses are NOT allowed.""" + import asyncio import logging from datetime import datetime, timezone +from enum import Enum -import arrow import discord -from discord import Forbidden, Guild, HTTPException, Member, NotFound, User +from discord import ( + Forbidden, + Guild, + HTTPException, + Member, + NotFound, + User, + TextChannel, + ClientUser, +) from sqlalchemy import select from sqlalchemy.exc import NoResultFound @@ -22,39 +32,233 @@ logger = logging.getLogger(__name__) -async def _check_member(bot: Bot, guild: Guild, member: Member | User, author: Member = None) -> SimpleResponse | None: +class BanCodes(Enum): + SUCCESS = "SUCCESS" + ALREADY_EXISTS = "ALREADY_EXISTS" + FAILED = "FAILED" + + +async def _check_member( + bot: Bot, guild: Guild, member: Member | User, author: Member | ClientUser | None = None +) -> SimpleResponse | None: if isinstance(member, Member): if member_is_staff(member): - return SimpleResponse(message="You cannot ban another staff member.", delete_after=None) + return SimpleResponse( + message="You cannot ban another staff member.", delete_after=None + ) elif isinstance(member, User): - member = await bot.get_member_or_user(guild, member.id) + member = await bot.get_member_or_user(guild, member.id) # type: ignore if member.bot: - return SimpleResponse(message="You cannot ban a bot.", delete_after=None) + return SimpleResponse( + message="You cannot ban a bot.", delete_after=None, code=BanCodes.FAILED + ) if author and author.id == member.id: - return SimpleResponse(message="You cannot ban yourself.", delete_after=None) + return SimpleResponse( + message="You cannot ban yourself.", delete_after=None, code=BanCodes.FAILED + ) -async def _get_ban_or_create(member: Member, ban: Ban, infraction: Infraction) -> tuple[int, bool]: +async def get_ban(member: Member | User) -> Ban | None: async with AsyncSessionLocal() as session: - stmt = select(Ban).filter(Ban.user_id == member.id, Ban.unbanned.is_(False)).limit(1) + stmt = ( + select(Ban) + .filter(Ban.user_id == member.id, Ban.unbanned.is_(False)) + .limit(1) + ) result = await session.scalars(stmt) - existing_ban = result.first() - if existing_ban: - return existing_ban.id, True + return result.first() + + +async def update_ban(ban: Ban) -> None: + logger.info(f"Updating ban {ban.id} for user {ban.user_id} with expiration {ban.unban_time}") + async with AsyncSessionLocal() as session: + session.add(ban) + await session.commit() + + +async def _get_ban_or_create( + member: Member | User, ban: Ban, infraction: Infraction +) -> tuple[int, bool]: + existing_ban = await get_ban(member) + if existing_ban: + return existing_ban.id, True + async with AsyncSessionLocal() as session: session.add(ban) session.add(infraction) await session.commit() + ban_id: int = ban.id assert ban_id is not None return ban_id, False -async def ban_member( - bot: Bot, guild: Guild, member: Member | User, duration: str, reason: str, evidence: str, author: Member = None, - needs_approval: bool = True -) -> SimpleResponse | None: - """Ban a member from the guild.""" +async def _create_ban_response( + member: Member | User, end_date: str, dm_banned_member: bool, needs_approval: bool +) -> SimpleResponse: + """Create a SimpleResponse for ban operations.""" + if needs_approval: + if member: + message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)." + else: + message = f"{member.id} has been banned until {end_date} (UTC)." + else: + if member: + message = f"Member {member.display_name} has been banned permanently." + else: + message = f"Member {member.id} has been banned permanently." + + if not dm_banned_member: + message += "\n Could not DM banned member due to permission error." + + return SimpleResponse( + message=message, + delete_after=0 if not needs_approval else None, + code=BanCodes.SUCCESS, + ) + + +async def _send_ban_notice( + guild: Guild, + member: Member, + reason: str, + author: str, + end_date: str, + channel: TextChannel | None, +) -> None: + """Send a ban log to the moderator channel.""" + if not isinstance(channel, TextChannel): + channel = guild.get_channel(settings.channels.SR_MOD) # type: ignore + + embed = discord.Embed( + title="Ban", + description=f"User {member.mention} ({member.id}) was banned on the platform and thus banned here.", + color=0xFF2429, + ) + embed.add_field(name="Reason", value=reason) + embed.add_field(name="Author", value=author) + embed.add_field(name="End Date", value=end_date) + + await channel.send(embed=embed) # type: ignore + + +async def handle_platform_ban_or_update( + bot: Bot, + guild: Guild, + member: Member, + expires_timestamp: int, + reason: str, + evidence: str, + author_name: str, + expires_at_str: str, + log_channel_id: int, + logger, + extra_log_data: dict | None = None, +) -> dict: + """Handle platform ban by either creating new ban, updating existing ban, or taking no action. + + Args: + bot: The Discord bot instance + guild: The guild to ban the member from + member: The member to ban + expires_timestamp: Unix timestamp when the ban should end + reason: Reason for the ban + evidence: Evidence supporting the ban (notes) + author_name: Name of the person who created the ban + expires_at_str: Human-readable expiration date string + log_channel_id: Channel ID for logging ban actions + logger: Logger instance for recording events + extra_log_data: Additional data to include in log entries + + Returns: + dict with 'action' key indicating what was done: 'unbanned', 'extended', 'no_action', 'updated', 'created' + """ + if extra_log_data is None: + extra_log_data = {} + + expires_dt = datetime.fromtimestamp(expires_timestamp) + + existing_ban = await get_ban(member) + if not existing_ban: + # No existing ban, create new one + await ban_member_with_epoch( + bot, guild, member, expires_timestamp, reason, evidence, needs_approval=False + ) + await _send_ban_notice( + guild, member, reason, author_name, expires_at_str, guild.get_channel(log_channel_id) # type: ignore + ) + logger.info(f"Created new platform ban for user {member.id} until {expires_at_str}", extra=extra_log_data) + return {"action": "created"} + + # Existing ban found - determine what to do based on ban type and timing + is_platform_ban = existing_ban.reason.startswith("Platform Ban") + + if is_platform_ban: + # Platform bans have authority over other platform bans + if expires_dt < datetime.now(): + # Platform ban has expired, unban the user + await unban_member(guild, member) + msg = f"User {member.mention} ({member.id}) has been unbanned due to platform ban expiration." + await guild.get_channel(log_channel_id).send(msg) # type: ignore + logger.info(msg, extra=extra_log_data) + return {"action": "unbanned"} + + if existing_ban.unban_time < expires_timestamp: + # Extend the existing platform ban + existing_ban.unban_time = expires_timestamp + await update_ban(existing_ban) + msg = f"User {member.mention} ({member.id}) has had their ban extended to {expires_at_str}." + await guild.get_channel(log_channel_id).send(msg) # type: ignore + logger.info(msg, extra=extra_log_data) + return {"action": "extended"} + else: + # Non-platform ban exists + if existing_ban.unban_time >= expires_timestamp: + # Existing ban is longer than platform ban, no action needed + logger.info( + f"User {member.mention} ({member.id}) is already banned until {existing_ban.unban_time}, " + f"which exceeds or equals the platform ban expiration date {expires_at_str}. No action taken.", + extra=extra_log_data, + ) + return {"action": "no_action"} + else: + # Platform ban is longer, update the existing ban + existing_ban.unban_time = expires_timestamp + existing_ban.reason = f"Platform Ban: {reason}" # Update reason to indicate platform authority + await update_ban(existing_ban) + logger.info(f"Updated existing ban for user {member.id} until {expires_at_str}.", extra=extra_log_data) + return {"action": "updated"} + + # Default case (shouldn't reach here, but for safety) + logger.warning(f"Unexpected case in platform ban handling for user {member.id}", extra=extra_log_data) + return {"action": "no_action"} + + +async def ban_member_with_epoch( + bot: Bot, + guild: Guild, + member: Member | User, + unban_epoch_time: int, + reason: str, + evidence: str, + author: Member | ClientUser | None = None, + needs_approval: bool = True, +) -> SimpleResponse: + """Ban a member from the guild until a specific epoch time. + + Args: + bot: The Discord bot instance + guild: The guild to ban the member from + member: The member or user to ban + unban_epoch_time: Unix timestamp when the ban should end + reason: Reason for the ban + evidence: Evidence supporting the ban + author: The member issuing the ban (defaults to bot user) + needs_approval: Whether the ban requires approval + + Returns: + SimpleResponse with the result of the ban operation, or None if no response needed + """ if checked := await _check_member(bot, guild, member, author): return checked @@ -65,35 +269,50 @@ async def ban_member( if not evidence: evidence = "none provided" - # Validate duration - dur, dur_exc = validate_duration(duration) - # Check if duration is valid, - # negative values are generally not allowed, - # so they should be caught here - if dur <= 0: - return SimpleResponse(message=dur_exc, delete_after=15) - else: - end_date: str = datetime.fromtimestamp(dur, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + # Validate epoch time is in the future + current_time = datetime.now(tz=timezone.utc).timestamp() + if unban_epoch_time <= current_time: + return SimpleResponse( + message="Unban time must be in the future", + delete_after=15 + ) + + end_date: str = datetime.fromtimestamp(unban_epoch_time, tz=timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S" + ) if author is None: author = bot.user + + # Author should never be None at this point + if author is None: + raise ValueError("Author cannot be None") ban = Ban( - user_id=member.id, reason=reason, moderator_id=author.id, unban_time=dur, - approved=False if needs_approval else True + user_id=member.id, + reason=reason, + moderator_id=author.id, + unban_time=unban_epoch_time, + approved=False if needs_approval else True, ) infraction = Infraction( user_id=member.id, reason=f"Previously banned for: {reason} - Evidence: {evidence}", weight=0, moderator_id=author.id, - date=datetime.now().date() + date=datetime.now().date(), ) + ban_id, is_existing = await _get_ban_or_create(member, ban, infraction) if is_existing: + try: + await guild.ban(member, reason=reason, delete_message_seconds=0) + except NotFound: + pass return SimpleResponse( message=f"A ban with id: {ban_id} already exists for member {member}", - delete_after=None + delete_after=None, + code=BanCodes.ALREADY_EXISTS, ) # DM member, before we ban, else we cannot dm since we do not share a guild @@ -103,67 +322,120 @@ async def ban_member( await guild.ban(member, reason=reason, delete_message_seconds=0) except Forbidden as exc: logger.warning( - "Ban failed due to permission error", exc_info=exc, - extra={"ban_requestor": author.name, "ban_receiver": member.id} + "Ban failed due to permission error", + exc_info=exc, + extra={"ban_requestor": author.name, "ban_receiver": member.id}, ) if author: - return SimpleResponse(message="You do not have the proper permissions to ban.", delete_after=None) + return SimpleResponse( + message="You do not have the proper permissions to ban.", + delete_after=None, + code=BanCodes.FAILED, + ) return except HTTPException as ex: - logger.warning(f"HTTPException when trying to ban user with ID {member.id}", exc_info=ex) + logger.warning( + f"HTTPException when trying to ban user with ID {member.id}", exc_info=ex + ) if author: return SimpleResponse( message="Here's a 400 Bad Request for you. Just like when you tried to ask me out, last week.", - delete_after=None + delete_after=None, + code=BanCodes.FAILED, ) return # If approval is required, send a message to the moderator channel about the ban if not needs_approval: - if member: - message = f"Member {member.display_name} has been banned permanently." - else: - message = f"Member {member.id} has been banned permanently." - - if not dm_banned_member: - message += "\n Could not DM banned member due to permission error." - logger.info( "Member has been banned permanently.", - extra={"ban_requestor": author.name, "ban_receiver": member.id, "dm_banned_member": dm_banned_member} + extra={ + "ban_requestor": author.name, + "ban_receiver": member.id, + "dm_banned_member": dm_banned_member, + }, ) - unban_task = schedule(unban_member(guild, member), run_at=datetime.fromtimestamp(ban.unban_time)) + unban_task = schedule( + unban_member(guild, member), run_at=datetime.fromtimestamp(ban.unban_time) + ) asyncio.create_task(unban_task) - logger.debug("Unbanned sceduled for ban", extra={"ban_id": ban_id, "unban_time": ban.unban_time}) - return SimpleResponse(message=message, delete_after=0) + logger.debug( + "Unbanned sceduled for ban", + extra={"ban_id": ban_id, "unban_time": ban.unban_time}, + ) else: - if member: - message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)." - else: - message = f"{member.id} has been banned until {end_date} (UTC)." - - if not dm_banned_member: - message += " Could not DM banned member due to permission error." - member_name = f"{member.display_name} ({member.name})" embed = discord.Embed( title=f"Ban request #{ban_id}", description=f"{author.display_name} ({author.name}) " - f"would like to ban {member_name} until {end_date} (UTC).\n" - f"Reason: {reason}\n" - f"Evidence: {evidence}", ) + f"would like to ban {member_name} until {end_date} (UTC).\n" + f"Reason: {reason}\n" + f"Evidence: {evidence}", + ) embed.set_thumbnail(url=f"{settings.HTB_URL}/images/logo600.png") view = BanDecisionView(ban_id, bot, guild, member, end_date, reason) - await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view) - return SimpleResponse(message=message) + await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view) # type: ignore + return await _create_ban_response( + member, end_date, dm_banned_member, needs_approval + ) -async def _dm_banned_member(end_date: str, guild: Guild, member: Member, reason: str) -> bool: + +async def ban_member( + bot: Bot, + guild: Guild, + member: Member | User, + duration: str, + reason: str, + evidence: str, + author: Member | None = None, + needs_approval: bool = True, +) -> SimpleResponse: + """Ban a member from the guild using a duration. + + Args: + bot: The Discord bot instance + guild: The guild to ban the member from + member: The member or user to ban + duration: Duration string (e.g., "1d", "1h") or seconds as int + reason: Reason for the ban + evidence: Evidence supporting the ban + author: The member issuing the ban (defaults to bot user) + needs_approval: Whether the ban requires approval + + Returns: + SimpleResponse with the result of the ban operation, or None if no response needed + """ + dur, dur_exc = validate_duration(duration) + + # Check if duration is valid, + # negative values are generally not allowed, + # so they should be caught here + if dur <= 0: + return SimpleResponse(message=dur_exc, delete_after=15) + + return await ban_member_with_epoch( + bot=bot, + guild=guild, + member=member, + unban_epoch_time=dur, + reason=reason, + evidence=evidence, + author=author, + needs_approval=needs_approval, + ) + + +async def _dm_banned_member( + end_date: str, guild: Guild, member: Member | User, reason: str +) -> bool: """Send a message to the member about the ban.""" - message = (f"You have been banned from {guild.name} until {end_date} (UTC). " - f"To appeal the ban, please reach out to an Administrator.\n" - f"Following is the reason given:\n>>> {reason}\n") + message = ( + f"You have been banned from {guild.name} until {end_date} (UTC). " + f"To appeal the ban, please reach out to an Administrator.\n" + f"Following is the reason given:\n>>> {reason}\n" + ) try: await member.send(message) return True @@ -171,29 +443,43 @@ async def _dm_banned_member(end_date: str, guild: Guild, member: Member, reason: logger.warning( f"Could not DM member with id {member.id} due to privacy settings, however will still attempt to ban " f"them...", - exc_info=ex + exc_info=ex, ) except HTTPException as ex: - logger.warning(f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex) + logger.warning( + f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex + ) return False -async def unban_member(guild: Guild, member: Member) -> Member: +async def unban_member(guild: Guild, member: Member | User) -> Member | User: """Unban a member from the guild.""" try: await guild.unban(member) logger.info(f"Unbanned user {member.id}.") except Forbidden as ex: - logger.error(f"Permission denied when trying to unban user with ID {member.id}", exc_info=ex) + logger.error( + f"Permission denied when trying to unban user with ID {member.id}", + exc_info=ex, + ) except NotFound as ex: logger.error( f"NotFound when trying to unban user with ID {member.id}. " - f"This could indicate that the user is not currently banned.", exc_info=ex, ) + f"This could indicate that the user is not currently banned.", + exc_info=ex, + ) except HTTPException as ex: - logger.error(f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex) + logger.error( + f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex + ) async with AsyncSessionLocal() as session: - stmt = select(Ban).filter(Ban.user_id == member.id).filter(Ban.unbanned.is_(False)).limit(1) + stmt = ( + select(Ban) + .filter(Ban.user_id == member.id) + .filter(Ban.unbanned.is_(False)) + .limit(1) + ) result = await session.scalars(stmt) ban = result.first() if ban: @@ -207,7 +493,12 @@ async def unban_member(guild: Guild, member: Member) -> Member: async def mute_member( - bot: Bot, guild: Guild, member: Member, duration: str, reason: str, author: Member = None + bot: Bot, + guild: Guild, + member: Member, + duration: str, + reason: str, + author: Member | ClientUser | None = None, ) -> SimpleResponse | None: """Mute a member on the guild.""" if checked := await _check_member(bot, guild, member, author): @@ -225,13 +516,17 @@ async def mute_member( if author is None: author = bot.user + + # Author should never be None at this point + if author is None: + raise ValueError("Author cannot be None") role = guild.get_role(settings.roles.MUTED) if member: # No longer on the server - cleanup, but don't attempt to remove a role logger.info(f"Add mute from {member.name}:{member.id}.") - await member.add_roles(role, reason=reason) + await member.add_roles(role, reason=reason) # type: ignore mute = Mute( user_id=member.id, reason=reason, moderator_id=author.id, unmute_time=dur @@ -248,7 +543,7 @@ async def unmute_member(guild: Guild, member: Member) -> Member: if isinstance(member, Member): # No longer on the server - cleanup, but don't attempt to remove a role logger.info(f"Remove mute from {member.name}:{member.id}.") - await member.remove_roles(role) + await member.remove_roles(role) # type: ignore await member.remove_timeout() async with AsyncSessionLocal() as session: @@ -272,7 +567,9 @@ async def add_infraction( if len(reason) == 0: reason = "No reason given ..." - infraction = Infraction(user_id=member.id, reason=reason, weight=weight, moderator_id=author.id) + infraction = Infraction( + user_id=member.id, reason=reason, weight=weight, moderator_id=author.id + ) async with AsyncSessionLocal() as session: session.add(infraction) await session.commit() @@ -287,9 +584,15 @@ async def add_infraction( ) except Forbidden as ex: message = "Could not DM member due to privacy settings, however the infraction was still added." - logger.warning(f"Forbidden, when trying to contact user with ID {member.id} about infraction.", exc_info=ex) + logger.warning( + f"Forbidden, when trying to contact user with ID {member.id} about infraction.", + exc_info=ex, + ) except HTTPException as ex: message = "Here's a 400 Bad Request for you. Just like when you tried to ask me out, last week." - logger.warning(f"HTTPException when trying to add infraction for user with ID {member.id}", exc_info=ex) + logger.warning( + f"HTTPException when trying to add infraction for user with ID {member.id}", + exc_info=ex, + ) return SimpleResponse(message=message, delete_after=None) diff --git a/src/helpers/responses.py b/src/helpers/responses.py index 6513bfd..e80da1a 100644 --- a/src/helpers/responses.py +++ b/src/helpers/responses.py @@ -1,15 +1,16 @@ import json - +from typing import Any class SimpleResponse(object): """A simple response object.""" - def __init__(self, message: str, delete_after: int | None = None): + def __init__(self, message: str, delete_after: int | None = None, code: str | Any = None): self.message = message self.delete_after = delete_after + self.code = code def __str__(self): - return json.dumps(dict(self), ensure_ascii=False) - + return json.dumps(dict(self), ensure_ascii=False) # type: ignore + def __repr__(self): return self.__str__() diff --git a/src/helpers/verification.py b/src/helpers/verification.py index ddb4be9..7e8358d 100644 --- a/src/helpers/verification.py +++ b/src/helpers/verification.py @@ -1,52 +1,151 @@ import logging -from datetime import datetime -from typing import Dict, List, Optional, cast +import traceback +from typing import Dict, List, Optional, Any, TypeVar import aiohttp import discord -from discord import Forbidden, Member, Role, User +from discord import ( + ApplicationContext, + Forbidden, + HTTPException, + Member, + Role, + User, + Guild, +) from discord.ext.commands import GuildNotFound, MemberNotFound -from src.bot import Bot +from src.bot import Bot, BOT_TYPE + from src.core import settings -from src.helpers.ban import ban_member +from src.helpers.ban import BanCodes, ban_member, _send_ban_notice logger = logging.getLogger(__name__) -async def get_user_details(account_identifier: str) -> Optional[Dict]: +async def send_verification_instructions( + ctx: ApplicationContext, member: Member +) -> discord.Interaction | discord.WebhookMessage: + """Send instructions via DM on how to identify with HTB account. + + Args: + ctx (ApplicationContext): The context of the command. + member (Member): The member to send the instructions to. + + Returns: + discord.Interaction | discord.WebhookMessage: The response message. + """ + member = ctx.user + + # Create step-by-step instruction embeds + embed_step1 = discord.Embed(color=0x9ACC14) + embed_step1.add_field( + name="Step 1: Login to your HTB Account", + value="Go to and login.", + inline=False, + ) + embed_step1.set_image( + url="https://media.discordapp.net/attachments/1102700815493378220/1384587341338902579/image.png" + ) + + embed_step2 = discord.Embed(color=0x9ACC14) + embed_step2.add_field( + name="Step 2: Navigate to your Security Settings", + value="In the navigation bar, click on **Security Settings** and scroll down to the **Discord Account** section. " + "()", + inline=False, + ) + embed_step2.set_image( + url="https://media.discordapp.net/attachments/1102700815493378220/1384587813760270392/image.png" + ) + + embed_step3 = discord.Embed(color=0x9ACC14) + embed_step3.add_field( + name="Step 3: Link your Discord Account", + value="Click **Connect** and you will be redirected to login to your Discord account via oauth. " + "After logging in, you will be redirected back to the HTB Account page. " + "Your Discord account will now be linked. Discord may take a few minutes to update. " + "If you have any issues, please contact a Moderator.", + inline=False, + ) + embed_step3.set_image( + url="https://media.discordapp.net/attachments/1102700815493378220/1384586811384402042/image.png" + ) + + try: + await member.send(embed=embed_step1) + await member.send(embed=embed_step2) + await member.send(embed=embed_step3) + except Forbidden as ex: + logger.error("Exception during verify call", exc_info=ex) + return await ctx.respond( + "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server " + "members and try again in 1 minute." + ) + except HTTPException as ex: + logger.error("Exception during verify call.", exc_info=ex) + return await ctx.respond( + "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator." + ) + + return await ctx.respond("Please check your DM for instructions.", ephemeral=True) + + + +def get_labs_session() -> aiohttp.ClientSession: + """Get a session for the HTB Labs API.""" + return aiohttp.ClientSession(headers={"Authorization": f"Bearer {settings.HTB_API_KEY}"}) + + +async def get_user_details(labs_id: int | str) -> dict: """Get user details from HTB.""" - acc_id_url = f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}" - async with aiohttp.ClientSession() as session: - async with session.get(acc_id_url) as r: + if not labs_id: + return {} + + user_profile_api_url = f"{settings.API_V4_URL}/user/profile/basic/{labs_id}" + user_content_api_url = f"{settings.API_V4_URL}/user/profile/content/{labs_id}" + + async with get_labs_session() as session: + async with session.get(user_profile_api_url) as r: if r.status == 200: - response = await r.json() - elif r.status == 404: - logger.debug("Account identifier has been regenerated since last identification. Cannot re-verify.") - response = None + profile_response = await r.json() else: - logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.") - response = None + logger.error( + f"Non-OK HTTP status code returned from user details lookup: {r.status}." + ) + profile_response = {} - return response + async with session.get(user_content_api_url) as r: + if r.status == 200: + content_response = await r.json() + else: + logger.error( + f"Non-OK HTTP status code returned from user content lookup: {r.status}." + ) + content_response = {} + + profile = profile_response.get("profile", {}) + profile["content"] = content_response.get("profile", {}).get("content", {}) + return profile async def get_season_rank(htb_uid: int) -> str | None: """Get season rank from HTB.""" - headers = {"Authorization": f"Bearer {settings.HTB_API_KEY}"} season_api_url = f"{settings.API_V4_URL}/season/end/{settings.SEASON_ID}/{htb_uid}" - async with aiohttp.ClientSession() as session: - async with session.get(season_api_url, headers=headers) as r: + async with get_labs_session() as session: + async with session.get(season_api_url) as r: if r.status == 200: response = await r.json() elif r.status == 404: logger.error("Invalid Season ID.") - response = None + response = {} else: - logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.") - response = None + logger.error( + f"Non-OK HTTP status code returned from identifier lookup: {r.status}." + ) + response = {} if not response["data"]: rank = None @@ -59,26 +158,10 @@ async def get_season_rank(htb_uid: int) -> str | None: return rank -async def _check_for_ban(uid: str) -> Optional[Dict]: - async with aiohttp.ClientSession() as session: - token_url = f"{settings.API_URL}/discord/{uid}/banned?secret={settings.HTB_API_SECRET}" - async with session.get(token_url) as r: - if r.status == 200: - ban_details = await r.json() - else: - logger.error( - f"Could not fetch ban details for uid {uid}: " - f"non-OK status code returned ({r.status}). Body: {r.content}" - ) - ban_details = None - - return ban_details - - async def process_certification(certid: str, name: str): """Process certifications.""" cert_api_url = f"{settings.API_V4_URL}/certificate/lookup" - params = {'id': certid, 'name': name} + params = {"id": certid, "name": name} async with aiohttp.ClientSession() as session: async with session.get(cert_api_url, params=params) as r: if r.status == 200: @@ -86,8 +169,10 @@ async def process_certification(certid: str, name: str): elif r.status == 404: return False else: - logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.") - response = None + logger.error( + f"Non-OK HTTP status code returned from identifier lookup: {r.status}." + ) + response = {} try: certRawName = response["certificates"][0]["name"] except IndexError: @@ -107,118 +192,300 @@ async def process_certification(certid: str, name: str): return cert -async def process_identification( - htb_user_details: Dict[str, str], user: Optional[Member | User], bot: Bot +async def _handle_banned_user(member: Member, bot: BOT_TYPE): + """Handle banned trait during account linking. + + Args: + member (Member): The member to process. + bot (Bot): The bot instance. + """ + resp = await ban_member( + bot, # type: ignore + member.guild, + member, + "1337w", + ( + "Platform Ban - Ban duration could not be determined. " + "Please login to confirm ban details and contact HTB Support to appeal." + ), + "N/A", + None, + needs_approval=False, + ) + if resp.code == BanCodes.SUCCESS: + await _send_ban_notice( + member.guild, + member, + resp.message, + "System", + "1337w", + member.guild.get_channel(settings.channels.VERIFY_LOGS), # type: ignore + ) + + +async def _set_nickname(member: Member, nickname: str) -> bool: + """Set the nickname of the member. + + Args: + member (Member): The member to set the nickname for. + nickname (str): The nickname to set. + + Returns: + bool: True if the nickname was set, False otherwise. + """ + try: + await member.edit(nick=nickname) + return True + except Forbidden as e: + logger.error(f"Exception whe trying to edit the nick-name of the user: {e}") + return False + + +async def process_account_identification( + member: Member, bot: BOT_TYPE, traits: dict[str, Any] +) -> None: + """Process HTB account identification, to be called during account linking. + + Args: + member (Member): The member to process. + bot (Bot): The bot instance. + traits (dict[str, str] | None): Optional user traits to process. + """ + try: + await member.add_roles(member.guild.get_role(settings.roles.VERIFIED), atomic=True) # type: ignore + except Exception as e: + logger.error(f"Failed to add VERIFIED role to user {member.id}: {e}") + # Don't raise - continue with other operations + + nickname_changed = False + + traits = traits or {} + + if traits.get("username") and traits.get("username") != member.name: + nickname_changed = await _set_nickname(member, traits.get("username")) # type: ignore + + if not nickname_changed: + logger.warning( + f"No username provided for {member.name} with ID {member.id} during identification." + ) + + if traits.get("mp_user_id"): + try: + logger.debug(f"MP user ID: {traits.get('mp_user_id', None)}") + htb_user_details = await get_user_details(traits.get("mp_user_id", None)) + if htb_user_details: + await process_labs_identification(htb_user_details, member, bot) # type: ignore + + if not nickname_changed and htb_user_details.get("username"): + logger.debug( + f"Falling back on HTB username to set nickname for {member.name} with ID {member.id}." + ) + await _set_nickname(member, htb_user_details["username"]) + except Exception as e: + logger.error(f"Failed to process labs identification for user {member.id}: {e}") + # Don't raise - this is not critical + + if traits.get("banned", False) == True: # noqa: E712 - explicit bool only, no truthiness + try: + logger.debug(f"Handling banned user {member.id}") + await _handle_banned_user(member, bot) + return + except Exception as e: + logger.error(f"Failed to handle banned user {member.id}: {e}") + logger.exception(traceback.format_exc()) + # Don't raise - continue processing + + +async def process_labs_identification( + htb_user_details: dict, user: Optional[Member | User], bot: Bot ) -> Optional[List[Role]]: """Returns roles to assign if identification was successfully processed.""" - htb_uid = htb_user_details["user_id"] + + # Resolve member and guild + member, guild = await _resolve_member_and_guild(user, bot) + + # Get roles to remove and assign + to_remove = _get_roles_to_remove(member, guild) + to_assign = await _process_role_assignments(htb_user_details, guild) + + # Remove roles that will be reassigned + to_remove = list(set(to_remove) - set(to_assign)) + + # Apply role changes + await _apply_role_changes(member, to_remove, to_assign) + + return to_assign + + +async def _resolve_member_and_guild( + user: Optional[Member | User], bot: Bot +) -> tuple[Member, Guild]: + """Resolve member and guild from user object.""" if isinstance(user, Member): - member = user - guild = member.guild - # This will only work if the user and the bot share only one guild. - elif isinstance(user, User) and len(user.mutual_guilds) == 1: + return user, user.guild + + if isinstance(user, User) and len(user.mutual_guilds) == 1: guild = user.mutual_guilds[0] member = await bot.get_member_or_user(guild, user.id) if not member: raise MemberNotFound(str(user.id)) - else: - raise GuildNotFound(f"Could not identify member {user} in guild.") - season_rank = await get_season_rank(htb_uid) - banned_details = await _check_for_ban(htb_uid) - - if banned_details is not None and banned_details["banned"]: - # If user is banned, this field must be a string - # Strip date e.g. from "2022-01-31T11:00:00.000000Z" - banned_until: str = cast(str, banned_details["ends_at"])[:10] - banned_until_dt: datetime = datetime.strptime(banned_until, "%Y-%m-%d") - ban_duration: str = f"{(banned_until_dt - datetime.now()).days}d" - reason = "Banned on the HTB Platform. Please contact HTB Support to appeal." - logger.info(f"Discord user {member.name} ({member.id}) is platform banned. Banning from Discord...") - await ban_member(bot, guild, member, ban_duration, reason, None, needs_approval=False) - - embed = discord.Embed( - title="Identification error", - description=f"User {member.mention} ({member.id}) was platform banned HTB and thus also here.", - color=0xFF2429, ) - - await guild.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed) - return None + return member, guild # type: ignore + + raise GuildNotFound(f"Could not identify member {user} in guild.") + +def _get_roles_to_remove(member: Member, guild: Guild) -> list[Role]: + """Get existing roles that should be removed.""" to_remove = [] + try: + all_ranks = settings.role_groups.get("ALL_RANKS", []) + all_positions = settings.role_groups.get("ALL_POSITIONS", []) + removable_role_ids = all_ranks + all_positions + + for role in member.roles: + if role.id in removable_role_ids: + guild_role = guild.get_role(role.id) + if guild_role: + to_remove.append(guild_role) + except Exception as e: + logger.error(f"Error processing existing roles for user {member.id}: {e}") + return to_remove - for role in member.roles: - if role.id in settings.role_groups.get("ALL_RANKS") + settings.role_groups.get("ALL_POSITIONS"): - to_remove.append(guild.get_role(role.id)) +async def _process_role_assignments( + htb_user_details: dict, guild: Guild +) -> list[Role]: + """Process role assignments based on HTB user details.""" to_assign = [] - logger.debug( - "Getting role 'rank':", extra={ - "role_id": settings.get_post_or_rank(htb_user_details["rank"]), - "role_obj": guild.get_role(settings.get_post_or_rank(htb_user_details["rank"])), - "htb_rank": htb_user_details["rank"], - }, ) - if htb_user_details["rank"] not in ["Deleted", "Moderator", "Ambassador", "Admin", "Staff"]: - to_assign.append(guild.get_role(settings.get_post_or_rank(htb_user_details["rank"]))) - if season_rank: - to_assign.append(guild.get_role(settings.get_season(season_rank))) - if htb_user_details["vip"]: - logger.debug( - 'Getting role "VIP":', extra={"role_id": settings.roles.VIP, "role_obj": guild.get_role(settings.roles.VIP)} - ) - to_assign.append(guild.get_role(settings.roles.VIP)) - if htb_user_details["dedivip"]: - logger.debug( - 'Getting role "VIP+":', - extra={"role_id": settings.roles.VIP_PLUS, "role_obj": guild.get_role(settings.roles.VIP_PLUS)} - ) - to_assign.append(guild.get_role(settings.roles.VIP_PLUS)) - if htb_user_details["hof_position"] != "unranked": - position = int(htb_user_details["hof_position"]) - pos_top = None - if position == 1: - pos_top = "1" - elif position <= 10: - pos_top = "10" - if pos_top: - logger.debug(f"User is Hall of Fame rank {position}. Assigning role Top-{pos_top}...") - logger.debug( - 'Getting role "HoF role":', extra={ - "role_id": settings.get_post_or_rank(pos_top), - "role_obj": guild.get_role(settings.get_post_or_rank(pos_top)), "hof_val": pos_top, - }, ) - to_assign.append(guild.get_role(settings.get_post_or_rank(pos_top))) - else: - logger.debug(f"User is position {position}. No Hall of Fame roles for them.") - if htb_user_details["machines"]: - logger.debug( - 'Getting role "BOX_CREATOR":', - extra={"role_id": settings.roles.BOX_CREATOR, "role_obj": guild.get_role(settings.roles.BOX_CREATOR)}, ) - to_assign.append(guild.get_role(settings.roles.BOX_CREATOR)) - if htb_user_details["challenges"]: - logger.debug( - 'Getting role "CHALLENGE_CREATOR":', extra={ - "role_id": settings.roles.CHALLENGE_CREATOR, - "role_obj": guild.get_role(settings.roles.CHALLENGE_CREATOR), - }, ) - to_assign.append(guild.get_role(settings.roles.CHALLENGE_CREATOR)) - - if member.nick != htb_user_details["user_name"]: - try: - await member.edit(nick=htb_user_details["user_name"]) - except Forbidden as e: - logger.error(f"Exception whe trying to edit the nick-name of the user: {e}") + + # Process rank roles + to_assign.extend(_process_rank_roles(htb_user_details.get("rank", ""), guild)) + + # Process season rank roles + to_assign.extend(await _process_season_rank_roles(htb_user_details.get("id", ""), guild)) + + # Process VIP roles + to_assign.extend(_process_vip_roles(htb_user_details, guild)) + + # Process HOF position roles + to_assign.extend(_process_hof_position_roles(htb_user_details.get("ranking", "unranked"), guild)) + + # Process creator roles + to_assign.extend(_process_creator_roles(htb_user_details.get("content", {}), guild)) + + return to_assign - logger.debug("All roles to_assign:", extra={"to_assign": to_assign}) - # We don't need to remove any roles that are going to be assigned again - to_remove = list(set(to_remove) - set(to_assign)) - logger.debug("All roles to_remove:", extra={"to_remove": to_remove}) - if to_remove: - await member.remove_roles(*to_remove, atomic=True) - else: - logger.debug("No roles need to be removed") - if to_assign: - await member.add_roles(*to_assign, atomic=True) - else: - logger.debug("No roles need to be assigned") - return to_assign +def _process_rank_roles(rank: str, guild: Guild) -> list[Role]: + """Process rank-based role assignments.""" + roles = [] + + if rank and rank not in ["Deleted", "Moderator", "Ambassador", "Admin", "Staff"]: + role_id = settings.get_post_or_rank(rank) + if role_id: + role = guild.get_role(role_id) + if role: + roles.append(role) + + return roles + + +async def _process_season_rank_roles(mp_user_id: int, guild: Guild) -> list[Role]: + """Process season rank role assignments.""" + roles = [] + try: + season_rank = await get_season_rank(mp_user_id) + if isinstance(season_rank, str): + season_role_id = settings.get_season(season_rank) + if season_role_id: + season_role = guild.get_role(season_role_id) + if season_role: + roles.append(season_role) + except Exception as e: + logger.error(f"Error getting season rank for user {mp_user_id}: {e}") + return roles + + +def _process_vip_roles(htb_user_details: dict, guild: Guild) -> list[Role]: + """Process VIP role assignments.""" + roles = [] + try: + if htb_user_details.get("isVip", False): + vip_role = guild.get_role(settings.roles.VIP) + if vip_role: + roles.append(vip_role) + + if htb_user_details.get("isDedicatedVip", False): + vip_plus_role = guild.get_role(settings.roles.VIP_PLUS) + if vip_plus_role: + roles.append(vip_plus_role) + except Exception as e: + logger.error(f"Error processing VIP roles: {e}") + return roles + + +def _process_hof_position_roles(htb_user_ranking: str | int, guild: Guild) -> list[Role]: + """Process Hall of Fame position role assignments.""" + roles = [] + try: + hof_position = htb_user_ranking or "unranked" + logger.debug(f"HTB user ranking: {hof_position}") + if hof_position != "unranked": + position = int(hof_position) + pos_top = _get_position_tier(position) + + if pos_top: + pos_role_id = settings.get_post_or_rank(pos_top) + if pos_role_id: + pos_role = guild.get_role(pos_role_id) + if pos_role: + roles.append(pos_role) + except (ValueError, TypeError) as e: + logger.error(f"Error processing HOF position: {e}") + return roles + + +def _get_position_tier(position: int) -> Optional[str]: + """Get position tier based on HOF position.""" + if position == 1: + return "1" + elif position <= 10: + return "10" + return None + + +def _process_creator_roles(htb_user_content: dict, guild: Guild) -> list[Role]: + """Process creator role assignments.""" + roles = [] + try: + if htb_user_content.get("machines"): + box_creator_role = guild.get_role(settings.roles.BOX_CREATOR) + if box_creator_role: + logger.debug("Adding box creator role to user.") + roles.append(box_creator_role) + + if htb_user_content.get("challenges"): + challenge_creator_role = guild.get_role(settings.roles.CHALLENGE_CREATOR) + if challenge_creator_role: + logger.debug("Adding challenge creator role to user.") + roles.append(challenge_creator_role) + except Exception as e: + logger.error(f"Error processing creator roles: {e}") + return roles + + +async def _apply_role_changes( + member: Member, to_remove: list[Role], to_assign: list[Role] +) -> None: + """Apply role changes to member.""" + try: + if to_remove: + await member.remove_roles(*to_remove, atomic=True) + except Exception as e: + logger.error(f"Error removing roles from user {member.id}: {e}") + + try: + if to_assign: + await member.add_roles(*to_assign, atomic=True) + except Exception as e: + logger.error(f"Error adding roles to user {member.id}: {e}") diff --git a/src/webhooks/handlers/__init__.py b/src/webhooks/handlers/__init__.py index 02e0edf..4036462 100644 --- a/src/webhooks/handlers/__init__.py +++ b/src/webhooks/handlers/__init__.py @@ -1,16 +1,23 @@ from discord import Bot +from typing import Any -from src.webhooks.handlers.academy import handler as academy_handler +from src.webhooks.handlers.account import AccountHandler +from src.webhooks.handlers.academy import AcademyHandler +from src.webhooks.handlers.mp import MPHandler from src.webhooks.types import Platform, WebhookBody -handlers = {Platform.ACADEMY: academy_handler} +handlers = { + Platform.ACCOUNT: AccountHandler().handle, + Platform.MAIN: MPHandler().handle, + Platform.ACADEMY: AcademyHandler().handle, +} def can_handle(platform: Platform) -> bool: return platform in handlers.keys() -def handle(body: WebhookBody, bot: Bot) -> any: +def handle(body: WebhookBody, bot: Bot) -> Any: platform = body.platform if not can_handle(platform): diff --git a/src/webhooks/handlers/academy.py b/src/webhooks/handlers/academy.py index 5aa3568..a57785c 100644 --- a/src/webhooks/handlers/academy.py +++ b/src/webhooks/handlers/academy.py @@ -1,76 +1,50 @@ -import logging - from discord import Bot -from discord.errors import NotFound -from fastapi import HTTPException from src.core import settings +from src.webhooks.handlers.base import BaseHandler from src.webhooks.types import WebhookBody, WebhookEvent -logger = logging.getLogger(__name__) - - -async def handler(body: WebhookBody, bot: Bot) -> dict: - """ - Handles incoming webhook events and performs actions accordingly. - - This function processes different webhook events related to account linking, - certificate awarding, and account unlinking. It updates the member's roles - based on the received event. - - Args: - body (WebhookBody): The data received from the webhook. - bot (Bot): The instance of the Discord bot. - - Returns: - dict: A dictionary with a "success" key indicating whether the operation was successful. - - Raises: - HTTPException: If an error occurs while processing the webhook event. - """ - # TODO: Change it here so we pass the guild instead of the bot # noqa: T000 - guild = await bot.fetch_guild(settings.guild_ids[0]) - - try: - discord_id = int(body.data["discord_id"]) - member = await guild.fetch_member(discord_id) - except ValueError as exc: - logger.debug("Invalid Discord ID", exc_info=exc) - raise HTTPException(status_code=400, detail="Invalid Discord ID") from exc - except NotFound as exc: - logger.debug("User is not in the Discord server", exc_info=exc) - raise HTTPException(status_code=400, detail="User is not in the Discord server") from exc - - if body.event == WebhookEvent.ACCOUNT_LINKED: - roles_to_add = {settings.roles.ACADEMY_USER} - roles_to_add.update(settings.get_academy_cert_role(cert["id"]) for cert in body.data["certifications"]) - - # Filter out invalid role IDs - role_ids_to_add = {role_id for role_id in roles_to_add if role_id is not None} - roles_to_add = {guild.get_role(role_id) for role_id in role_ids_to_add} - - await member.add_roles(*roles_to_add, atomic=True) - elif body.event == WebhookEvent.CERTIFICATE_AWARDED: - cert_id = body.data["certification"]["id"] - - role = settings.get_academy_cert_role(cert_id) - if not role: - logger.debug(f"Role for certification: {cert_id} does not exist") - raise HTTPException(status_code=400, detail=f"Role for certification: {cert_id} does not exist") - - await member.add_roles(role, atomic=True) - elif body.event == WebhookEvent.ACCOUNT_UNLINKED: - current_role_ids = {role.id for role in member.roles} - cert_role_ids = {settings.get_academy_cert_role(cert_id) for _, cert_id in settings.academy_certificates} - - common_role_ids = current_role_ids.intersection(cert_role_ids) - - role_ids_to_remove = {settings.roles.ACADEMY_USER}.union(common_role_ids) - roles_to_remove = {guild.get_role(role_id) for role_id in role_ids_to_remove} - - await member.remove_roles(*roles_to_remove, atomic=True) - else: - logger.debug(f"Event {body.event} not implemented") - raise HTTPException(status_code=501, detail=f"Event {body.event} not implemented") - return {"success": True} +class AcademyHandler(BaseHandler): + async def handle(self, body: WebhookBody, bot: Bot): + """ + Handles incoming webhook events and performs actions accordingly. + + This function processes different webhook events originating from the + HTB Account. + """ + if body.event == WebhookEvent.CERTIFICATE_AWARDED: + return await self._handle_certificate_awarded(body, bot) + else: + raise ValueError(f"Invalid event: {body.event}") + + async def _handle_certificate_awarded(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the certificate awarded event. + """ + discord_id = self.validate_discord_id(self.get_property_or_trait(body, "discord_id")) + _ = self.validate_account_id(self.get_property_or_trait(body, "account_id")) + certificate_id = self.validate_property( + self.get_property_or_trait(body, "certificate_id"), "certificate_id" + ) + + self.logger.info(f"Handling certificate awarded event for {discord_id} with certificate {certificate_id}") + + member = await self.get_guild_member(discord_id, bot) + certificate_role_id = settings.get_academy_cert_role(int(certificate_id)) + + if not certificate_role_id: + self.logger.warning(f"No certificate role found for certificate {certificate_id}") + return self.fail() + + if certificate_role_id: + self.logger.info(f"Adding certificate role {certificate_role_id} to member {member.id}") + try: + await member.add_roles( + bot.guilds[0].get_role(certificate_role_id), atomic=True # type: ignore + ) # type: ignore + except Exception as e: + self.logger.error(f"Error adding certificate role {certificate_role_id} to member {member.id}: {e}") + raise e + + return self.success() diff --git a/src/webhooks/handlers/account.py b/src/webhooks/handlers/account.py new file mode 100644 index 0000000..d941f8b --- /dev/null +++ b/src/webhooks/handlers/account.py @@ -0,0 +1,168 @@ +from datetime import datetime +from discord import Bot + +from src.core import settings +from src.helpers.ban import handle_platform_ban_or_update +from src.helpers.verification import process_account_identification +from src.webhooks.handlers.base import BaseHandler +from src.webhooks.types import WebhookBody, WebhookEvent + + +class AccountHandler(BaseHandler): + async def handle(self, body: WebhookBody, bot: Bot): + """ + Handles incoming webhook events and performs actions accordingly. + + This function processes different webhook events originating from the + HTB Account. + """ + if body.event == WebhookEvent.ACCOUNT_LINKED: + return await self._handle_account_linked(body, bot) + elif body.event == WebhookEvent.ACCOUNT_UNLINKED: + return await self._handle_account_unlinked(body, bot) + elif body.event == WebhookEvent.ACCOUNT_DELETED: + return await self._handle_account_deleted(body, bot) + elif body.event == WebhookEvent.ACCOUNT_BANNED: + return await self._handle_account_banned(body, bot) + else: + raise ValueError(f"Invalid event: {body.event}") + + async def _handle_account_linked(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the account linked event. + """ + discord_id = self.validate_discord_id( + self.get_property_or_trait(body, "discord_id") + ) + account_id = self.validate_account_id( + self.get_property_or_trait(body, "account_id") + ) + + member = await self.get_guild_member(discord_id, bot) + await process_account_identification( + member, + bot, # type: ignore + traits=self.merge_properties_and_traits(body.properties, body.traits), + ) + + # Safely attempt to send verification log + try: + verify_channel = bot.guilds[0].get_channel(settings.channels.VERIFY_LOGS) + if verify_channel: + await verify_channel.send( # type: ignore + f"Account linked: {account_id} -> ({member.mention} ({member.id})", + ) + else: + self.logger.warning( + f"Verify logs channel {settings.channels.VERIFY_LOGS} not found" + ) + except Exception as e: + self.logger.error(f"Failed to send verification log: {e}") + # Don't raise - this is not critical + + self.logger.info( + f"Account {account_id} linked to {member.id}", + extra={"account_id": account_id, "discord_id": discord_id}, + ) + + return self.success() + + async def _handle_account_unlinked(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the account unlinked event. + """ + discord_id = self.validate_discord_id( + self.get_property_or_trait(body, "discord_id") + ) + account_id = self.validate_account_id( + self.get_property_or_trait(body, "account_id") + ) + + member = await self.get_guild_member(discord_id, bot) + + await member.remove_roles( + bot.guilds[0].get_role(settings.roles.VERIFIED), atomic=True # type: ignore + ) # type: ignore + + return self.success() + + async def _handle_name_change(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the name change event. + """ + discord_id = self.validate_discord_id(body.properties.get("discord_id")) + _ = self.validate_account_id(body.properties.get("account_id")) + name = self.validate_property(body.properties.get("name"), "name") + + member = await self.get_guild_member(discord_id, bot) + await member.edit(nick=name) + return self.success() + + async def _handle_account_banned(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the account banned event. + """ + discord_id = self.validate_discord_id( + self.get_property_or_trait(body, "discord_id") + ) + account_id = self.validate_account_id( + self.get_property_or_trait(body, "account_id") + ) + expires_at = self.validate_property( + self.get_property_or_trait(body, "expires_at"), "expires_at" + ) + reason = body.properties.get("reason") + notes = body.properties.get("notes") + created_by = body.properties.get("created_by") + + expires_ts = int(datetime.fromisoformat(expires_at).timestamp()) # type: ignore + extra = {"account_id": account_id, "discord_id": discord_id} + + member = await self.get_guild_member(discord_id, bot) + if not member: + self.logger.warning( + f"Cannot ban user {discord_id}- not found in guild", extra=extra + ) + return self.fail() + + # Use the generic ban helper to handle all the complex logic + result = await handle_platform_ban_or_update( + bot=bot, # type: ignore + guild=bot.guilds[0], + member=member, + expires_timestamp=expires_ts, + reason=f"Platform Ban - {reason}", + evidence=notes or "N/A", + author_name=created_by or "System", + expires_at_str=expires_at, # type: ignore + log_channel_id=settings.channels.BOT_LOGS, + logger=self.logger, + extra_log_data=extra, + ) + + self.logger.debug( + f"Platform ban handling result: {result['action']}", extra=extra + ) + + return self.success() + + async def _handle_account_deleted(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the account deleted event. + """ + discord_id = self.validate_discord_id(body.properties.get("discord_id")) + account_id = self.validate_account_id(body.properties.get("account_id")) + + member = await self.get_guild_member(discord_id, bot) + if not member: + self.logger.warning( + f"Cannot delete account {account_id}- not found in guild", + extra={"account_id": account_id, "discord_id": discord_id}, + ) + return self.fail() + + await member.remove_roles( + bot.guilds[0].get_role(settings.roles.VERIFIED), atomic=True # type: ignore + ) # type: ignore + + return self.success() diff --git a/src/webhooks/handlers/base.py b/src/webhooks/handlers/base.py new file mode 100644 index 0000000..b604f01 --- /dev/null +++ b/src/webhooks/handlers/base.py @@ -0,0 +1,128 @@ +import logging +from abc import ABC, abstractmethod +from typing import TypeVar + +from discord import Bot, Member +from discord.errors import NotFound +from fastapi import HTTPException + +from src.core import settings +from src.webhooks.types import WebhookBody + +T = TypeVar("T") + + +class BaseHandler(ABC): + ACADEMY_USER_ID = "academy_user_id" + MP_USER_ID = "mp_user_id" + EP_USER_ID = "ep_user_id" + CTF_USER_ID = "ctf_user_id" + ACCOUNT_ID = "account_id" + DISCORD_ID = "discord_id" + + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + + @abstractmethod + async def handle(self, body: WebhookBody, bot: Bot) -> dict: + pass + + async def get_guild_member(self, discord_id: int | str, bot: Bot) -> Member: + """ + Fetches a guild member from the Discord server. + + Args: + discord_id (int): The Discord ID of the user. + bot (Bot): The Discord bot instance. + + Returns: + Member: The guild member. + + Raises: + HTTPException: If the user is not in the Discord server (400) + """ + + try: + guild = await bot.fetch_guild(settings.guild_ids[0]) + member = await guild.fetch_member(int(discord_id)) + return member + + except NotFound as exc: + self.logger.debug("User is not in the Discord server", exc_info=exc) + raise HTTPException( + status_code=400, detail="User is not in the Discord server" + ) from exc + + def validate_property(self, property: T | None, name: str) -> T: + """ + Validates a property is not None. + + Args: + property (T | None): The property to validate. + name (str): The name of the property. + + Returns: + T: The validated property. + + Raises: + HTTPException: If the property is None (400) + """ + if property is None: + msg = f"Invalid {name}" + self.logger.debug(msg) + raise HTTPException(status_code=400, detail=msg) + + return property + + def validate_discord_id(self, discord_id: str | int | None) -> int | str: + """ + Validates the Discord ID. See validate_property function. + """ + return self.validate_property(discord_id, "Discord ID") + + def validate_account_id(self, account_id: str | int | None) -> int | str: + """ + Validates the Account ID. See validate_property function. + """ + return self.validate_property(account_id, "Account ID") + + def get_property_or_trait(self, body: WebhookBody, name: str) -> int | None: + """ + Gets a trait or property from the webhook body. + """ + return body.properties.get(name) or body.traits.get(name) + + def merge_properties_and_traits( + self, properties: dict[str, int | None], traits: dict[str, int | None] + ) -> dict[str, int | None]: + """ + Merges the properties and traits from the webhook body without duplicates. + If a property and trait have the same name but different values, the property value will be used. + """ + return { + **properties, + **{k: v for k, v in traits.items() if k not in properties}, + } + + def get_platform_properties(self, body: WebhookBody) -> dict[str, int | None]: + """ + Gets the platform properties from the webhook body. + """ + properties = { + self.ACCOUNT_ID: self.get_property_or_trait(body, self.ACCOUNT_ID), + self.MP_USER_ID: self.get_property_or_trait(body, self.MP_USER_ID), + self.EP_USER_ID: self.get_property_or_trait(body, self.EP_USER_ID), + self.CTF_USER_ID: self.get_property_or_trait(body, self.CTF_USER_ID), + self.ACADEMY_USER_ID: self.get_property_or_trait( + body, self.ACADEMY_USER_ID + ), + } + return properties + + @staticmethod + def success(): + return {"success": True} + + @staticmethod + def fail(): + return {"success": False} \ No newline at end of file diff --git a/src/webhooks/handlers/mp.py b/src/webhooks/handlers/mp.py new file mode 100644 index 0000000..c8b0e2b --- /dev/null +++ b/src/webhooks/handlers/mp.py @@ -0,0 +1,180 @@ +import discord + +from datetime import datetime +from discord import Bot, Member, Role + +from typing import Literal +from sqlalchemy import select + +from src.core import settings +from src.webhooks.handlers.base import BaseHandler +from src.webhooks.types import WebhookBody, WebhookEvent + + +class MPHandler(BaseHandler): + async def handle(self, body: WebhookBody, bot: Bot): + """ + Handles incoming webhook events and performs actions accordingly. + + This function processes different webhook events originating from the + HTB Account. + """ + if body.event == WebhookEvent.HOF_CHANGE: + return await self._handle_hof_change(body, bot) + elif body.event == WebhookEvent.RANK_UP: + return await self._handle_rank_up(body, bot) + elif body.event == WebhookEvent.SUBSCRIPTION_CHANGE: + return await self._handle_subscription_change(body, bot) + else: + raise ValueError(f"Invalid event: {body.event}") + + async def _handle_subscription_change(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the subscription change event. + """ + discord_id = self.validate_discord_id(body.properties.get("discord_id")) + _ = self.validate_account_id(body.properties.get("account_id")) + subscription_name = self.validate_property( + body.properties.get("subscription_name"), "subscription_name" + ) + + member = await self.get_guild_member(discord_id, bot) + + role = settings.get_post_or_rank(subscription_name) + if not role: + raise ValueError(f"Invalid subscription name: {subscription_name}") + + await member.add_roles(bot.guilds[0].get_role(role), atomic=True) # type: ignore + return self.success() + + async def _handle_hof_change(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the HOF change event. + """ + self.logger.info("Handling HOF change event.") + discord_id = self.validate_discord_id(self.get_property_or_trait(body, "discord_id")) + account_id = self.validate_account_id(self.get_property_or_trait(body, "account_id")) + hof_tier: Literal["1", "10"] = self.validate_property( + self.get_property_or_trait(body, "hof_tier"), "hof_tier" # type: ignore + ) + hof_roles = { + "1": bot.guilds[0].get_role(settings.roles.RANK_ONE), + "10": bot.guilds[0].get_role(settings.roles.RANK_TEN), + } + + member = await self.get_guild_member(discord_id, bot) + member_roles = member.roles + + if not member: + msg = f"Cannot find member {discord_id}" + self.logger.warning( + msg, extra={"account_id": account_id, "discord_id": discord_id} + ) + raise ValueError(msg) + + async def _swap_hof_roles(member: Member, role_to_grant: Role | None): + """Grants a HOF role to a member and removes the other HOF role""" + if not role_to_grant: + return + + member_hof_role = next( + (r for r in member_roles if r in hof_roles.values()), None + ) + if member_hof_role: + await member.remove_roles(member_hof_role, atomic=True) + await member.add_roles(role_to_grant, atomic=True) + + if int(hof_tier) == 1: + # Find existing top 1 user and make them a top 10 + if existing_top_one_user := await self._find_user_with_role( + bot, hof_roles["1"] + ): + if existing_top_one_user.id != member.id: + await _swap_hof_roles(existing_top_one_user, hof_roles["10"]) + else: + return self.success() + + # Grant top 1 role to member + await _swap_hof_roles(member, hof_roles["1"]) + return self.success() + + # Just grant top 10 role to member + elif int(hof_tier) == 10: + await _swap_hof_roles(member, hof_roles["10"]) + return self.success() + + else: + err = ValueError(f"Invalid HOF tier: {hof_tier}") + self.logger.error( + err, + extra={ + "account_id": account_id, + "discord_id": discord_id, + "hof_tier": hof_tier, + }, + ) + raise err + + async def _handle_rank_up(self, body: WebhookBody, bot: Bot) -> dict: + """ + Handles the rank up event. + """ + discord_id = self.validate_discord_id(self.get_property_or_trait(body, "discord_id")) + account_id = self.validate_account_id(self.get_property_or_trait(body, "account_id")) + rank = self.validate_property(self.get_property_or_trait(body, "rank"), "rank") + + member = await self.get_guild_member(discord_id, bot) + + rank_id = settings.get_post_or_rank(rank) + if not rank_id: + err = ValueError(f"Cannot find role for '{rank}'") + self.logger.error( + err, + extra={ + "account_id": account_id, + "discord_id": discord_id, + "rank": rank, + }, + ) + raise err + + rank_role = bot.guilds[0].get_role(rank_id) + rank_roles = [ + bot.guilds[0].get_role(int(r)) for r in settings.role_groups["ALL_RANKS"] + ] # All rank roles + new_role = next( + (r for r in rank_roles if r and r.id == rank_role.id), None + ) # Get passed rank as role from rank roles + old_role = next( + (r for r in member.roles if r in rank_roles), None + ) # Find existing rank role on user + + if old_role: + await member.remove_roles(old_role, atomic=True) # Yeet the old role + + if new_role: + await member.add_roles(new_role, atomic=True) # Add the new role + + if not new_role: + # Why are you passing me BS roles? + err = ValueError(f"Cannot find role for '{rank}'") + self.logger.error( + err, + extra={ + "account_id": account_id, + "discord_id": discord_id, + "rank": rank, + }, + ) + raise err + + return self.success() + + async def _find_user_with_role(self, bot: Bot, role: Role | None) -> Member | None: + """ + Finds the user with the given role. + """ + if not role: + return None + + return next((m for m in role.members), None) diff --git a/src/webhooks/server.py b/src/webhooks/server.py index 92222a7..1c4d9e3 100644 --- a/src/webhooks/server.py +++ b/src/webhooks/server.py @@ -1,10 +1,13 @@ +import hashlib import hmac import logging -from typing import Any, Dict, Union +import json +from typing import Any, Dict -from fastapi import FastAPI, Header, HTTPException +from fastapi import FastAPI, HTTPException, Request from hypercorn.asyncio import serve as hypercorn_serve from hypercorn.config import Config as HypercornConfig +from pydantic import ValidationError from src.bot import bot from src.core import settings @@ -17,20 +20,36 @@ app = FastAPI() +def verify_signature(body: dict, signature: str, secret: str) -> bool: + """ + HMAC SHA1 signature verification. + + Args: + body (dict): The raw body of the webhook request. + signature (str): The X-Signature header of the webhook request. + secret (str): The webhook secret. + + Returns: + bool: True if the signature is valid, False otherwise. + """ + if not signature: + return False + + digest = hmac.new(secret.encode(), body, hashlib.sha1).hexdigest() # type: ignore + return hmac.compare_digest(signature, digest) + + @app.post("/webhook") -async def webhook_handler( - body: WebhookBody, authorization: Union[str, None] = Header(default=None) -) -> Dict[str, Any]: +async def webhook_handler(request: Request) -> Dict[str, Any]: """ Handles incoming webhook requests and forwards them to the appropriate handler. - This function first checks the provided authorization token in the request header. - If the token is valid, it checks if the platform can be handled and then forwards + This function first verifies the provided HMAC signature in the request header. + If the signature is valid, it checks if the platform can be handled and then forwards the request to the corresponding handler. Args: - body (WebhookBody): The data received from the webhook. - authorization (Union[str, None]): The authorization header containing the Bearer token. + request (Request): The incoming webhook request. Returns: Dict[str, Any]: The response from the corresponding handler. The dictionary contains @@ -39,20 +58,42 @@ async def webhook_handler( Raises: HTTPException: If an error occurs while processing the webhook event or if unauthorized. """ - if authorization is None or not authorization.strip().startswith("Bearer"): - logger.warning("Unauthorized webhook request") - raise HTTPException(status_code=401, detail="Unauthorized") + body = await request.body() + signature = request.headers.get("X-Signature") - token = authorization[6:].strip() - if hmac.compare_digest(token, settings.WEBHOOK_TOKEN): + if not verify_signature(body, signature, settings.WEBHOOK_TOKEN): # type: ignore logger.warning("Unauthorized webhook request") raise HTTPException(status_code=401, detail="Unauthorized") + try: + body = WebhookBody.validate(json.loads(body)) + except ValidationError as e: + logger.warning("Invalid webhook request: %s", e.errors()) + raise HTTPException(status_code=400, detail="Invalid webhook request body") + if not handlers.can_handle(body.platform): - logger.warning("Webhook request not handled by platform") + logger.warning("Webhook request not handled by platform: %s", body.platform) raise HTTPException(status_code=501, detail="Platform not implemented") - return await handlers.handle(body, bot) + try: + return await handlers.handle(body, bot) + except HTTPException: + # Re-raise HTTP exceptions as they already have appropriate status codes + raise + except Exception as e: + # Log the full exception details for debugging + logger.error( + "Unhandled exception in webhook handler", + exc_info=e, + extra={ + "platform": body.platform, + "event": body.event, + "properties": body.properties, + "traits": body.traits, + } + ) + # Return a generic 500 error to the client + raise HTTPException(status_code=500, detail="Internal server error") app.mount("/metrics", metrics_app) diff --git a/src/webhooks/types.py b/src/webhooks/types.py index 3885dda..2fabb3b 100644 --- a/src/webhooks/types.py +++ b/src/webhooks/types.py @@ -1,17 +1,19 @@ from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Extra, Field class WebhookEvent(Enum): - ACCOUNT_LINKED = "AccountLinked" - ACCOUNT_UNLINKED = "AccountUnlinked" + ACCOUNT_LINKED = "DiscordAccountLinked" + ACCOUNT_UNLINKED = "DiscordAccountUnlinked" + ACCOUNT_DELETED = "UserAccountDeleted" + ACCOUNT_BANNED = "UserAccountBanned" CERTIFICATE_AWARDED = "CertificateAwarded" RANK_UP = "RankUp" HOF_CHANGE = "HofChange" SUBSCRIPTION_CHANGE = "SubscriptionChange" - CONTENT_RELEASED = "ContentReleased" NAME_CHANGE = "NameChange" + SEASON_RANK_CHANGE = "SeasonRankChange" class Platform(Enum): @@ -19,9 +21,13 @@ class Platform(Enum): ACADEMY = "academy" CTF = "ctf" ENTERPRISE = "enterprise" + ACCOUNT = "account" class WebhookBody(BaseModel): + model_config = ConfigDict(extra=Extra.allow) + platform: Platform event: WebhookEvent - data: dict + properties: dict = Field(default_factory=dict) + traits: dict = Field(default_factory=dict) diff --git a/tests/src/helpers/test_ban.py b/tests/src/helpers/test_ban.py index 99d3bdc..1771928 100644 --- a/tests/src/helpers/test_ban.py +++ b/tests/src/helpers/test_ban.py @@ -1,8 +1,10 @@ +from datetime import datetime, timezone from unittest import mock from unittest.mock import AsyncMock, MagicMock, patch import pytest from discord import Forbidden, HTTPException +from datetime import datetime, timezone from src.helpers.ban import _check_member, _dm_banned_member, ban_member from src.helpers.responses import SimpleResponse @@ -116,11 +118,15 @@ async def test_ban_member_valid_duration(self, bot, guild, member, author): evidence = "Some evidence" member.display_name = "Banned Member" + # Use a future timestamp instead of a past one + future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now + expected_date = datetime.fromtimestamp(future_timestamp, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + with ( mock.patch("src.helpers.ban._check_member", return_value=None), mock.patch("src.helpers.ban._dm_banned_member", return_value=True), mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)), - mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")), + mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")), ): mock_channel = helpers.MockTextChannel() mock_channel.send.return_value = MagicMock() @@ -128,7 +134,7 @@ async def test_ban_member_valid_duration(self, bot, guild, member, author): result = await ban_member(bot, guild, member, duration, reason, evidence) assert isinstance(result, SimpleResponse) - assert result.message == f"{member.display_name} ({member.id}) has been banned until 2023-05-16 22:41:40 " \ + assert result.message == f"{member.display_name} ({member.id}) has been banned until {expected_date} " \ f"(UTC)." @pytest.mark.asyncio @@ -153,12 +159,15 @@ async def test_ban_member_permanently_success(self, bot, guild, member, author): evidence = "Some evidence" member.display_name = "Banned Member" + # Use a future timestamp instead of a past one + future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now + # Patching the necessary classes and functions with ( mock.patch("src.helpers.ban._check_member", return_value=None), mock.patch("src.helpers.ban._dm_banned_member", return_value=True), mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)), - mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")), + mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")), ): response = await ban_member(bot, guild, member, duration, reason, evidence, author, False) assert isinstance(response, SimpleResponse) @@ -171,12 +180,15 @@ async def test_ban_member_no_reason_success(self, bot, guild, member, author): evidence = "Some evidence" member.display_name = "Banned Member" + # Use a future timestamp instead of a past one + future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now + # Patching the necessary classes and functions with ( mock.patch("src.helpers.ban._check_member", return_value=None), mock.patch("src.helpers.ban._dm_banned_member", return_value=True), mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)), - mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")), + mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")), ): response = await ban_member(bot, guild, member, duration, reason, evidence, author, False) assert isinstance(response, SimpleResponse) @@ -189,11 +201,14 @@ async def test_ban_member_no_author_success(self, bot, guild, member): evidence = "Some evidence" member.display_name = "Banned Member" + # Use a future timestamp instead of a past one + future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now + with ( mock.patch("src.helpers.ban._check_member", return_value=None), mock.patch("src.helpers.ban._dm_banned_member", return_value=True), mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)), - mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")), + mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")), ): response = await ban_member(bot, guild, member, duration, reason, evidence, None, False) assert isinstance(response, SimpleResponse) @@ -206,11 +221,14 @@ async def test_ban_already_exists(self, bot, guild, member, author): evidence = "Some evidence" member.display_name = "Banned Member" + # Use a future timestamp instead of a past one + future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now + with ( mock.patch("src.helpers.ban._check_member", return_value=None), mock.patch("src.helpers.ban._dm_banned_member", return_value=True), mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, True)), - mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")), + mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")), ): response = await ban_member(bot, guild, member, duration, reason, evidence, author) assert isinstance(response, SimpleResponse) diff --git a/tests/src/helpers/test_verification.py b/tests/src/helpers/test_verification.py index 3ebddf0..1c949da 100644 --- a/tests/src/helpers/test_verification.py +++ b/tests/src/helpers/test_verification.py @@ -14,37 +14,60 @@ async def test_get_user_details_success(self): account_identifier = "some_identifier" with aioresponses.aioresponses() as m: + # Mock the profile API call m.get( - f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}", + f"{settings.API_V4_URL}/user/profile/basic/{account_identifier}", status=200, - payload={"some_key": "some_value"}, + payload={"profile": {"some_key": "some_value"}}, + ) + # Mock the content API call + m.get( + f"{settings.API_V4_URL}/user/profile/content/{account_identifier}", + status=200, + payload={"profile": {"content": {"content_key": "content_value"}}}, ) result = await get_user_details(account_identifier) - self.assertEqual(result, {"some_key": "some_value"}) + expected = { + "some_key": "some_value", + "content": {"content_key": "content_value"} + } + self.assertEqual(result, expected) @pytest.mark.asyncio async def test_get_user_details_404(self): account_identifier = "some_identifier" with aioresponses.aioresponses() as m: + # Mock the profile API call with404 m.get( - f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}", + f"{settings.API_V4_URL}/user/profile/basic/{account_identifier}", + status=404, + ) + # Mock the content API call with404 + m.get( + f"{settings.API_V4_URL}/user/profile/content/{account_identifier}", status=404, ) result = await get_user_details(account_identifier) - self.assertIsNone(result) + self.assertEqual(result, {"content": {}}) @pytest.mark.asyncio async def test_get_user_details_other_status(self): account_identifier = "some_identifier" with aioresponses.aioresponses() as m: + # Mock the profile API call with500 + m.get( + f"{settings.API_V4_URL}/user/profile/basic/{account_identifier}", + status=500, + ) + # Mock the content API call with500 m.get( - f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}", + f"{settings.API_V4_URL}/user/profile/content/{account_identifier}", status=500, ) result = await get_user_details(account_identifier) - self.assertIsNone(result) + self.assertEqual(result, {"content": {}}) diff --git a/tests/src/webhooks/handlers/test_academy.py b/tests/src/webhooks/handlers/test_academy.py new file mode 100644 index 0000000..2c76af4 --- /dev/null +++ b/tests/src/webhooks/handlers/test_academy.py @@ -0,0 +1,119 @@ +import pytest +from unittest.mock import AsyncMock, patch +from fastapi import HTTPException + +from src.webhooks.handlers.academy import AcademyHandler +from src.webhooks.types import WebhookBody, Platform, WebhookEvent +from tests import helpers + +class TestAcademyHandler: + @pytest.mark.asyncio + async def test_handle_certificate_awarded_success(self, bot): + handler = AcademyHandler() + discord_id = 123456789 + account_id = 987654321 + certificate_id = 42 + mock_member = helpers.MockMember(id=discord_id) + mock_member.add_roles = AsyncMock() + body = WebhookBody( + platform=Platform.ACADEMY, + event=WebhookEvent.CERTIFICATE_AWARDED, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "certificate_id": certificate_id, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=certificate_id), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.academy.settings") as mock_settings, + patch.object(handler.logger, "info") as mock_log, + ): + mock_settings.get_academy_cert_role.return_value = 555 + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = 555 + bot.guilds = [mock_guild] + result = await handler._handle_certificate_awarded(body, bot) + mock_member.add_roles.assert_awaited() + mock_log.assert_called() + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_certificate_awarded_no_role(self, bot): + handler = AcademyHandler() + discord_id = 123456789 + account_id = 987654321 + certificate_id = 42 + mock_member = helpers.MockMember(id=discord_id) + body = WebhookBody( + platform=Platform.ACADEMY, + event=WebhookEvent.CERTIFICATE_AWARDED, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "certificate_id": certificate_id, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=certificate_id), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.academy.settings") as mock_settings, + patch.object(handler.logger, "warning") as mock_log, + ): + mock_settings.get_academy_cert_role.return_value = None + result = await handler._handle_certificate_awarded(body, bot) + mock_log.assert_called() + assert result == handler.fail() + + @pytest.mark.asyncio + async def test_handle_certificate_awarded_add_roles_error(self, bot): + handler = AcademyHandler() + discord_id = 123456789 + account_id = 987654321 + certificate_id = 42 + mock_member = helpers.MockMember(id=discord_id) + mock_member.add_roles = AsyncMock(side_effect=Exception("add_roles error")) + body = WebhookBody( + platform=Platform.ACADEMY, + event=WebhookEvent.CERTIFICATE_AWARDED, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "certificate_id": certificate_id, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=certificate_id), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.academy.settings") as mock_settings, + patch.object(handler.logger, "error") as mock_log, + ): + mock_settings.get_academy_cert_role.return_value = 555 + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = 555 + bot.guilds = [mock_guild] + with pytest.raises(Exception, match="add_roles error"): + await handler._handle_certificate_awarded(body, bot) + mock_log.assert_called() + + @pytest.mark.asyncio + async def test_handle_invalid_event(self, bot): + handler = AcademyHandler() + body = WebhookBody( + platform=Platform.ACADEMY, + event=WebhookEvent.RANK_UP, # Not handled by AcademyHandler + properties={}, + traits={}, + ) + with pytest.raises(ValueError, match="Invalid event"): + await handler.handle(body, bot) \ No newline at end of file diff --git a/tests/src/webhooks/handlers/test_account.py b/tests/src/webhooks/handlers/test_account.py new file mode 100644 index 0000000..e18cf7e --- /dev/null +++ b/tests/src/webhooks/handlers/test_account.py @@ -0,0 +1,510 @@ +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from discord import Bot, Member +from discord.errors import NotFound +from fastapi import HTTPException + +from src.webhooks.handlers.account import AccountHandler +from src.webhooks.types import WebhookBody, Platform, WebhookEvent +from tests import helpers + + +class TestAccountHandler: + """Test the `AccountHandler` class.""" + + def test_initialization(self): + """Test that AccountHandler initializes correctly.""" + handler = AccountHandler() + + assert isinstance(handler.logger, logging.Logger) + assert handler.logger.name == "AccountHandler" + + @pytest.mark.asyncio + async def test_handle_account_linked_event(self, bot): + """Test handle method routes ACCOUNT_LINKED event correctly.""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": 123456789, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, "_handle_account_linked", new_callable=AsyncMock + ) as mock_handle: + await handler.handle(body, bot) + mock_handle.assert_called_once_with(body, bot) + + @pytest.mark.asyncio + async def test_handle_account_unlinked_event(self, bot): + """Test handle method routes ACCOUNT_UNLINKED event correctly.""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": 123456789, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, "_handle_account_unlinked", new_callable=AsyncMock + ) as mock_handle: + await handler.handle(body, bot) + mock_handle.assert_called_once_with(body, bot) + + @pytest.mark.asyncio + async def test_handle_account_deleted_event(self, bot): + """Test handle method with ACCOUNT_DELETED event.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + mock_member = helpers.MockMember(id=discord_id) + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_DELETED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.account.settings") as mock_settings, + ): + mock_settings.roles.VERIFIED = helpers.MockRole(id=99999, name="Verified") + mock_member.remove_roles = AsyncMock() + + result = await handler.handle(body, bot) + + # Should succeed and return success + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_unknown_event(self, bot): + """Test handle method with unknown event raises ValueError.""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.CERTIFICATE_AWARDED, # Not handled by AccountHandler + properties={"discord_id": 123456789, "account_id": 987654321}, + traits={}, + ) + + # Should raise ValueError for unknown event + with pytest.raises(ValueError, match="Invalid event"): + await handler.handle(body, bot) + + @pytest.mark.asyncio + async def test_handle_account_linked_success(self, bot): + """Test successful account linking.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + mock_member = helpers.MockMember(id=discord_id, mention="@testuser") + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={"htb_user_id": 555}, + ) + + with ( + patch.object( + handler, "validate_discord_id", return_value=discord_id + ) as mock_validate_discord, + patch.object( + handler, "validate_account_id", return_value=account_id + ) as mock_validate_account, + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + return_value=mock_member, + ) as mock_get_member, + patch.object( + handler, + "merge_properties_and_traits", + return_value={ + "discord_id": discord_id, + "account_id": account_id, + "htb_user_id": 555, + }, + ) as mock_merge, + patch( + "src.webhooks.handlers.account.process_account_identification", + new_callable=AsyncMock, + ) as mock_process, + patch("src.webhooks.handlers.account.settings") as mock_settings, + patch.object(handler.logger, "info") as mock_log, + ): + mock_settings.channels.VERIFY_LOGS = 12345 + + # Mock the bot's guild structure and channel + mock_channel = MagicMock() + mock_channel.send = AsyncMock() + mock_guild = MagicMock() + mock_guild.get_channel.return_value = mock_channel + bot.guilds = [mock_guild] + + result = await handler._handle_account_linked(body, bot) + + # Verify all method calls + mock_validate_discord.assert_called_once_with(discord_id) + mock_validate_account.assert_called_once_with(account_id) + mock_get_member.assert_called_once_with(discord_id, bot) + mock_merge.assert_called_once_with(body.properties, body.traits) + mock_process.assert_called_once_with( + mock_member, + bot, + traits={ + "discord_id": discord_id, + "account_id": account_id, + "htb_user_id": 555, + }, + ) + mock_channel.send.assert_called_once_with( + f"Account linked: {account_id} -> (@testuser ({discord_id})" + ) + mock_log.assert_called_once_with( + f"Account {account_id} linked to {discord_id}", + extra={"account_id": account_id, "discord_id": discord_id}, + ) + + # Should return success + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_account_linked_invalid_discord_id(self, bot): + """Test account linking with invalid Discord ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": None, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, + "validate_discord_id", + side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"), + ): + with pytest.raises(HTTPException) as exc_info: + await handler._handle_account_linked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Discord ID" + + @pytest.mark.asyncio + async def test_handle_account_linked_invalid_account_id(self, bot): + """Test account linking with invalid Account ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": 123456789, "account_id": None}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=123456789), + patch.object( + handler, + "validate_account_id", + side_effect=HTTPException(status_code=400, detail="Invalid Account ID"), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler._handle_account_linked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Account ID" + + @pytest.mark.asyncio + async def test_handle_account_linked_user_not_in_guild(self, bot): + """Test account linking when user is not in the Discord guild.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + side_effect=HTTPException( + status_code=400, detail="User is not in the Discord server" + ), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler._handle_account_linked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "User is not in the Discord server" + + @pytest.mark.asyncio + async def test_handle_account_unlinked_success(self, bot): + """Test successful account unlinking.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + mock_member = helpers.MockMember(id=discord_id) + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={}, + ) + + with ( + patch.object( + handler, "validate_discord_id", return_value=discord_id + ) as mock_validate_discord, + patch.object( + handler, "validate_account_id", return_value=account_id + ) as mock_validate_account, + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + return_value=mock_member, + ) as mock_get_member, + patch("src.webhooks.handlers.account.settings") as mock_settings, + ): + # Mock the bot's guild structure and role + mock_role = helpers.MockRole(id=99999, name="Verified") + mock_guild = MagicMock() + mock_guild.get_role.return_value = mock_role + bot.guilds = [mock_guild] + mock_settings.roles.VERIFIED = 99999 + mock_member.remove_roles = AsyncMock() + + result = await handler._handle_account_unlinked(body, bot) + + # Verify all method calls + mock_validate_discord.assert_called_once_with(discord_id) + mock_validate_account.assert_called_once_with(account_id) + mock_get_member.assert_called_once_with(discord_id, bot) + mock_member.remove_roles.assert_called_once_with( + mock_role, atomic=True + ) + + # Should return success + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_account_unlinked_invalid_discord_id(self, bot): + """Test account unlinking with invalid Discord ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": None, "account_id": 987654321}, + traits={}, + ) + + with patch.object( + handler, + "validate_discord_id", + side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"), + ): + with pytest.raises(HTTPException) as exc_info: + await handler._handle_account_unlinked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Discord ID" + + @pytest.mark.asyncio + async def test_handle_account_unlinked_invalid_account_id(self, bot): + """Test account unlinking with invalid Account ID.""" + handler = AccountHandler() + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": 123456789, "account_id": None}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=123456789), + patch.object( + handler, + "validate_account_id", + side_effect=HTTPException(status_code=400, detail="Invalid Account ID"), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler._handle_account_unlinked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Account ID" + + @pytest.mark.asyncio + async def test_handle_account_unlinked_user_not_in_guild(self, bot): + """Test account unlinking when user is not in the Discord guild.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_UNLINKED, + properties={"discord_id": discord_id, "account_id": account_id}, + traits={}, + ) + + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object( + handler, + "get_guild_member", + new_callable=AsyncMock, + side_effect=HTTPException( + status_code=400, detail="User is not in the Discord server" + ), + ), + ): + with pytest.raises(HTTPException) as exc_info: + await handler._handle_account_unlinked(body, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "User is not in the Discord server" + + @pytest.mark.asyncio + async def test_handle_name_change_success(self, bot): + """Test successful name change event.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + new_name = "NewNickname" + mock_member = helpers.MockMember(id=discord_id) + mock_member.edit = AsyncMock() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.NAME_CHANGE, + properties={"discord_id": discord_id, "account_id": account_id, "name": new_name}, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=new_name), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + ): + result = await handler._handle_name_change(body, bot) + mock_member.edit.assert_called_once_with(nick=new_name) + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_name_change_invalid_discord_id(self, bot): + """Test name change event with invalid Discord ID.""" + handler = AccountHandler() + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.NAME_CHANGE, + properties={"discord_id": None, "account_id": 987654321, "name": "NewNickname"}, + traits={}, + ) + with patch.object( + handler, + "validate_discord_id", + side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"), + ): + with pytest.raises(HTTPException) as exc_info: + await handler._handle_name_change(body, bot) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Discord ID" + + @pytest.mark.asyncio + async def test_handle_account_banned_success(self, bot): + """Test successful account banned event.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + expires_at = "2024-12-31T23:59:59" + reason = "Violation" + notes = "Repeated violations" + created_by = "Admin" + mock_member = helpers.MockMember(id=discord_id) + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_BANNED, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "expires_at": expires_at, + "reason": reason, + "notes": notes, + "created_by": created_by, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=expires_at), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.account.handle_platform_ban_or_update", new_callable=AsyncMock) as mock_ban, + patch("src.webhooks.handlers.account.settings") as mock_settings, + patch.object(handler.logger, "debug") as mock_log, + ): + mock_ban.return_value = {"action": "banned"} + mock_settings.channels.BOT_LOGS = 12345 + mock_settings.channels.VERIFY_LOGS = 54321 + mock_settings.roles.VERIFIED = 99999 + mock_settings.guild_ids = [1] + bot.guilds = [helpers.MockGuild(id=1)] + result = await handler._handle_account_banned(body, bot) + mock_ban.assert_awaited() + mock_log.assert_called() + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_account_banned_member_not_found(self, bot): + """Test account banned event when member is not found in guild.""" + handler = AccountHandler() + discord_id = 123456789 + account_id = 987654321 + expires_at = "2024-12-31T23:59:59" + body = WebhookBody( + platform=Platform.ACCOUNT, + event=WebhookEvent.ACCOUNT_BANNED, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "expires_at": expires_at, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=expires_at), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=None), + patch.object(handler.logger, "warning") as mock_log, + ): + result = await handler._handle_account_banned(body, bot) + mock_log.assert_called() + assert result == handler.fail() diff --git a/tests/src/webhooks/handlers/test_base.py b/tests/src/webhooks/handlers/test_base.py new file mode 100644 index 0000000..ba774cf --- /dev/null +++ b/tests/src/webhooks/handlers/test_base.py @@ -0,0 +1,315 @@ +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from discord import Bot +from discord.errors import NotFound +from fastapi import HTTPException + +from src.webhooks.handlers.base import BaseHandler +from src.webhooks.types import WebhookBody, Platform, WebhookEvent +from tests import helpers + + +class ConcreteHandler(BaseHandler): + """Concrete implementation of BaseHandler for testing purposes.""" + + async def handle(self, body: WebhookBody, bot: Bot) -> dict: + return {"status": "handled"} + + +class TestBaseHandler: + """Test the `BaseHandler` class.""" + + def test_initialization(self): + """Test that BaseHandler initializes correctly.""" + handler = ConcreteHandler() + + assert isinstance(handler.logger, logging.Logger) + assert handler.logger.name == "ConcreteHandler" + + def test_constants(self): + """Test that all required constants are defined.""" + handler = ConcreteHandler() + + assert handler.ACADEMY_USER_ID == "academy_user_id" + assert handler.MP_USER_ID == "mp_user_id" + assert handler.EP_USER_ID == "ep_user_id" + assert handler.CTF_USER_ID == "ctf_user_id" + assert handler.ACCOUNT_ID == "account_id" + assert handler.DISCORD_ID == "discord_id" + + @pytest.mark.asyncio + async def test_get_guild_member_success(self, bot): + """Test successful guild member retrieval.""" + handler = ConcreteHandler() + discord_id = 123456789 + mock_guild = helpers.MockGuild(id=12345) + mock_member = helpers.MockMember(id=discord_id) + + bot.fetch_guild = AsyncMock(return_value=mock_guild) + mock_guild.fetch_member = AsyncMock(return_value=mock_member) + + with patch("src.webhooks.handlers.base.settings") as mock_settings: + mock_settings.guild_ids = [12345] + + result = await handler.get_guild_member(discord_id, bot) + + assert result == mock_member + bot.fetch_guild.assert_called_once_with(12345) + mock_guild.fetch_member.assert_called_once_with(discord_id) + + @pytest.mark.asyncio + async def test_get_guild_member_not_found(self, bot): + """Test guild member retrieval when user is not in server.""" + handler = ConcreteHandler() + discord_id = 123456789 + mock_guild = helpers.MockGuild(id=12345) + + bot.fetch_guild = AsyncMock(return_value=mock_guild) + mock_guild.fetch_member = AsyncMock( + side_effect=NotFound(MagicMock(), "User not found") + ) + + with patch("src.webhooks.handlers.base.settings") as mock_settings: + mock_settings.guild_ids = [12345] + + with pytest.raises(HTTPException) as exc_info: + await handler.get_guild_member(discord_id, bot) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "User is not in the Discord server" + + def test_validate_property_success(self): + """Test successful property validation.""" + handler = ConcreteHandler() + + result = handler.validate_property("valid_value", "test_property") + assert result == "valid_value" + + result = handler.validate_property(123, "test_number") + assert result == 123 + + def test_validate_property_none(self): + """Test property validation with None value.""" + handler = ConcreteHandler() + + with pytest.raises(HTTPException) as exc_info: + handler.validate_property(None, "test_property") + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid test_property" + + def test_validate_discord_id_success(self): + """Test successful Discord ID validation.""" + handler = ConcreteHandler() + + result = handler.validate_discord_id(123456789) + assert result == 123456789 + + result = handler.validate_discord_id("987654321") + assert result == "987654321" + + def test_validate_discord_id_none(self): + """Test Discord ID validation with None value.""" + handler = ConcreteHandler() + + with pytest.raises(HTTPException) as exc_info: + handler.validate_discord_id(None) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Discord ID" + + def test_validate_account_id_success(self): + """Test successful Account ID validation.""" + handler = ConcreteHandler() + + result = handler.validate_account_id(123456789) + assert result == 123456789 + + result = handler.validate_account_id("987654321") + assert result == "987654321" + + def test_validate_account_id_none(self): + """Test Account ID validation with None value.""" + handler = ConcreteHandler() + + with pytest.raises(HTTPException) as exc_info: + handler.validate_account_id(None) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid Account ID" + + def test_get_property_or_trait_from_properties(self): + """Test getting value from properties.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"test_key": 123}, + traits={"test_key": 456, "other_key": 789}, + ) + + result = handler.get_property_or_trait(body, "test_key") + assert result == 123 # Should prioritize properties over traits + + def test_get_property_or_trait_from_traits(self): + """Test getting value from traits when not in properties.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={}, + traits={"test_key": 456}, + ) + + result = handler.get_property_or_trait(body, "test_key") + assert result == 456 + + def test_get_property_or_trait_not_found(self): + """Test getting value when key is not found.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={}, + traits={}, + ) + + result = handler.get_property_or_trait(body, "missing_key") + assert result is None + + def test_merge_properties_and_traits_no_duplicates(self): + """Test merging properties and traits without duplicates.""" + handler = ConcreteHandler() + properties = {"key1": 1, "key2": 2} + traits = {"key3": 3, "key4": 4} + + result = handler.merge_properties_and_traits(properties, traits) + + expected = {"key1": 1, "key2": 2, "key3": 3, "key4": 4} + assert result == expected + + def test_merge_properties_and_traits_with_duplicates(self): + """Test merging properties and traits with duplicate keys.""" + handler = ConcreteHandler() + properties = {"key1": 1, "key2": 2} + traits = {"key2": 99, "key3": 3} # key2 is duplicate + + result = handler.merge_properties_and_traits(properties, traits) + + expected = {"key1": 1, "key2": 2, "key3": 3} # Properties value should win + assert result == expected + + def test_merge_properties_and_traits_empty_properties(self): + """Test merging when properties is empty.""" + handler = ConcreteHandler() + properties = {} + traits = {"key1": 1, "key2": 2} + + result = handler.merge_properties_and_traits(properties, traits) + + assert result == traits + + def test_merge_properties_and_traits_empty_traits(self): + """Test merging when traits is empty.""" + handler = ConcreteHandler() + properties = {"key1": 1, "key2": 2} + traits = {} + + result = handler.merge_properties_and_traits(properties, traits) + + assert result == properties + + def test_get_platform_properties_all_present(self): + """Test getting platform properties when all are present.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={ + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": 3, + "ctf_user_id": 4, + "academy_user_id": 5, + }, + traits={}, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": 3, + "ctf_user_id": 4, + "academy_user_id": 5, + } + assert result == expected + + def test_get_platform_properties_mixed_sources(self): + """Test getting platform properties from both properties and traits.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"account_id": 1, "mp_user_id": 2}, + traits={"ep_user_id": 3, "ctf_user_id": 4, "academy_user_id": 5}, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": 3, + "ctf_user_id": 4, + "academy_user_id": 5, + } + assert result == expected + + def test_get_platform_properties_missing_values(self): + """Test getting platform properties when some are missing.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"account_id": 1}, + traits={"mp_user_id": 2}, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, + "ep_user_id": None, + "ctf_user_id": None, + "academy_user_id": None, + } + assert result == expected + + def test_get_platform_properties_properties_override_traits(self): + """Test that properties override traits for the same key.""" + handler = ConcreteHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.ACCOUNT_LINKED, + properties={"account_id": 1, "mp_user_id": 2}, + traits={ + "mp_user_id": 999, # Should be overridden + "ep_user_id": 3, + }, + ) + + result = handler.get_platform_properties(body) + + expected = { + "account_id": 1, + "mp_user_id": 2, # Properties value should win + "ep_user_id": 3, + "ctf_user_id": None, + "academy_user_id": None, + } + assert result == expected diff --git a/tests/src/webhooks/handlers/test_mp.py b/tests/src/webhooks/handlers/test_mp.py new file mode 100644 index 0000000..3454d1b --- /dev/null +++ b/tests/src/webhooks/handlers/test_mp.py @@ -0,0 +1,223 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from fastapi import HTTPException + +from src.webhooks.handlers.mp import MPHandler +from src.webhooks.types import WebhookBody, Platform, WebhookEvent +from tests import helpers + +class TestMPHandler: + @pytest.mark.asyncio + async def test_handle_invalid_event(self, bot): + handler = MPHandler() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.CERTIFICATE_AWARDED, # Not handled by MPHandler + properties={}, + traits={}, + ) + with pytest.raises(ValueError, match="Invalid event"): + await handler.handle(body, bot) + + @pytest.mark.asyncio + async def test_handle_subscription_change_success(self, bot): + handler = MPHandler() + discord_id = 123456789 + account_id = 987654321 + subscription_name = "VIP" + mock_member = helpers.MockMember(id=discord_id) + mock_member.add_roles = AsyncMock() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.SUBSCRIPTION_CHANGE, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "subscription_name": subscription_name, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=subscription_name), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.mp.settings") as mock_settings, + ): + mock_settings.get_post_or_rank.return_value = 555 + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = 555 + bot.guilds = [mock_guild] + result = await handler._handle_subscription_change(body, bot) + mock_member.add_roles.assert_awaited() + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_subscription_change_invalid_role(self, bot): + handler = MPHandler() + discord_id = 123456789 + account_id = 987654321 + subscription_name = "INVALID" + mock_member = helpers.MockMember(id=discord_id) + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.SUBSCRIPTION_CHANGE, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "subscription_name": subscription_name, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=subscription_name), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.mp.settings") as mock_settings, + ): + mock_settings.get_post_or_rank.return_value = None + with pytest.raises(ValueError, match="Invalid subscription name"): + await handler._handle_subscription_change(body, bot) + + @pytest.mark.asyncio + async def test_handle_hof_change_success_top1(self, bot): + handler = MPHandler() + discord_id = 123456789 + account_id = 987654321 + hof_tier = "1" + mock_member = helpers.MockMember(id=discord_id) + mock_member.roles = [] + mock_member.add_roles = AsyncMock() + mock_member.remove_roles = AsyncMock() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.HOF_CHANGE, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "hof_tier": hof_tier, + }, + traits={}, + ) + mock_role_1 = MagicMock() + mock_role_10 = MagicMock() + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=hof_tier), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.mp.settings") as mock_settings, + patch.object(handler, "_find_user_with_role", new_callable=AsyncMock, return_value=None), + ): + mock_settings.roles.RANK_ONE = 1 + mock_settings.roles.RANK_TEN = 10 + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.side_effect = lambda rid: mock_role_1 if rid == 1 else mock_role_10 + bot.guilds = [mock_guild] + result = await handler._handle_hof_change(body, bot) + mock_member.add_roles.assert_awaited_with(mock_role_1, atomic=True) + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_hof_change_invalid_tier(self, bot): + handler = MPHandler() + discord_id = 123456789 + account_id = 987654321 + hof_tier = "99" + mock_member = helpers.MockMember(id=discord_id) + mock_member.roles = [] + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.HOF_CHANGE, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "hof_tier": hof_tier, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=hof_tier), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.mp.settings") as mock_settings, + ): + mock_settings.roles.RANK_ONE = 1 + mock_settings.roles.RANK_TEN = 10 + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.side_effect = lambda rid: None + bot.guilds = [mock_guild] + with pytest.raises(ValueError, match="Invalid HOF tier"): + await handler._handle_hof_change(body, bot) + + @pytest.mark.asyncio + async def test_handle_rank_up_success(self, bot): + handler = MPHandler() + discord_id = 123456789 + account_id = 987654321 + rank = "Elite Hacker" + mock_member = helpers.MockMember(id=discord_id) + mock_member.roles = [] + mock_member.add_roles = AsyncMock() + mock_member.remove_roles = AsyncMock() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.RANK_UP, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "rank": rank, + }, + traits={}, + ) + mock_role = MagicMock() + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=rank), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.mp.settings") as mock_settings, + ): + mock_settings.role_groups = {"ALL_RANKS": [555]} + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = mock_role + bot.guilds = [mock_guild] + result = await handler._handle_rank_up(body, bot) + mock_member.add_roles.assert_awaited_with(mock_role, atomic=True) + assert result == handler.success() + + @pytest.mark.asyncio + async def test_handle_rank_up_invalid_role(self, bot): + handler = MPHandler() + discord_id = 123456789 + account_id = 987654321 + rank = "Nonexistent" + mock_member = helpers.MockMember(id=discord_id) + mock_member.roles = [] + mock_member.add_roles = AsyncMock() + mock_member.remove_roles = AsyncMock() + body = WebhookBody( + platform=Platform.MAIN, + event=WebhookEvent.RANK_UP, + properties={ + "discord_id": discord_id, + "account_id": account_id, + "rank": rank, + }, + traits={}, + ) + with ( + patch.object(handler, "validate_discord_id", return_value=discord_id), + patch.object(handler, "validate_account_id", return_value=account_id), + patch.object(handler, "validate_property", return_value=rank), + patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member), + patch("src.webhooks.handlers.mp.settings") as mock_settings, + ): + mock_settings.role_groups = {"ALL_RANKS": [555]} + mock_guild = helpers.MockGuild(id=1) + mock_guild.get_role.return_value = None + bot.guilds = [mock_guild] + with pytest.raises(ValueError, match="Cannot find role for"): + await handler._handle_rank_up(body, bot) \ No newline at end of file diff --git a/tests/src/webhooks/test_handlers_init.py b/tests/src/webhooks/test_handlers_init.py new file mode 100644 index 0000000..ce0b436 --- /dev/null +++ b/tests/src/webhooks/test_handlers_init.py @@ -0,0 +1,35 @@ +from unittest import mock +from typing import Callable + +import pytest + +from src.webhooks.handlers import handlers, can_handle, handle +from src.webhooks.types import Platform, WebhookBody, WebhookEvent +from tests.conftest import bot + +class TestHandlersInit: + def test_handler_init(self): + assert handlers is not None + assert isinstance(handlers, dict) + assert len(handlers) > 0 + assert all(isinstance(handler, Callable) for handler in handlers.values()) + + def test_can_handle_unknown_platform(self): + assert not can_handle("UNKNOWN") + + def test_can_handle_success(self): + with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: True}): + assert can_handle(Platform.MAIN) + + def test_handle_success(self): + with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}): + assert handle(WebhookBody(platform=Platform.MAIN, event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot) == 1337 + + def test_handle_unknown_platform(self): + with pytest.raises(ValueError): + handle(WebhookBody(platform="UNKNOWN", event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot) + + def test_handle_unknown_event(self): + with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}): + with pytest.raises(ValueError): + handle(WebhookBody(platform=Platform.MAIN, event="UNKNOWN", properties={}, traits={}), bot) \ No newline at end of file