diff --git a/.test.env b/.test.env
index def9d16..dfc9088 100644
--- a/.test.env
+++ b/.test.env
@@ -32,6 +32,8 @@ CHANNEL_SPOILER=2769521890099371011
CHANNEL_BOT_LOGS=1105517088266788925
# Roles
+ROLE_VERIFIED=1333333333333333337
+
ROLE_BIZCTF2022=7629466241011276950
ROLE_NOAH_GANG=6706800691011276950
ROLE_BUDDY_GANG=6706800681011276950
diff --git a/src/bot.py b/src/bot.py
index f535903..f598844 100644
--- a/src/bot.py
+++ b/src/bot.py
@@ -12,6 +12,7 @@
MissingRequiredArgument, NoPrivateMessage, UserInputError,
)
from sqlalchemy.exc import NoResultFound
+from typing import TypeVar
from src import trace_config
from src.core import constants, settings
@@ -20,6 +21,9 @@
logger = logging.getLogger(__name__)
+BOT_TYPE = TypeVar("BOT_TYPE", "Bot", DiscordBot)
+
+
class Bot(DiscordBot):
"""Base bot class."""
diff --git a/src/cmds/automation/auto_verify.py b/src/cmds/automation/auto_verify.py
index 606ad3e..98c538d 100644
--- a/src/cmds/automation/auto_verify.py
+++ b/src/cmds/automation/auto_verify.py
@@ -2,12 +2,8 @@
from discord import Member, Message, User
from discord.ext import commands
-from sqlalchemy import select
from src.bot import Bot
-from src.database.models import HtbDiscordLink
-from src.database.session import AsyncSessionLocal
-from src.helpers.verification import get_user_details, process_identification
logger = logging.getLogger(__name__)
@@ -19,31 +15,11 @@ def __init__(self, bot: Bot):
self.bot = bot
async def process_reverification(self, member: Member | User) -> None:
- """Re-verifation process for a member."""
- async with AsyncSessionLocal() as session:
- stmt = (
- select(HtbDiscordLink)
- .where(HtbDiscordLink.discord_user_id == member.id)
- .order_by(HtbDiscordLink.id)
- .limit(1)
- )
- result = await session.scalars(stmt)
- htb_discord_link: HtbDiscordLink = result.first()
-
- if not htb_discord_link:
- raise VerificationError(f"HTB Discord link for user {member.name} with ID {member}")
-
- member_token: str = htb_discord_link.account_identifier
-
- if member_token is None:
- raise VerificationError(f"HTB account identifier for user {member.name} with ID {member.id} not found")
-
- logger.debug(f"Processing re-verify of member {member.name} ({member.id}).")
- htb_details = await get_user_details(member_token)
- if htb_details is None:
- raise VerificationError(f"Retrieving user details for user {member.name} with ID {member.id} failed")
-
- await process_identification(htb_details, user=member, bot=self.bot)
+ """Re-verifation process for a member.
+
+ TODO: Reimplement once it's possible to fetch link state from the HTB Account.
+ """
+ raise VerificationError("Not implemented")
@commands.Cog.listener()
@commands.cooldown(1, 60, commands.BucketType.user)
@@ -74,4 +50,5 @@ class VerificationError(Exception):
def setup(bot: Bot) -> None:
"""Load the `MessageHandler` cog."""
- bot.add_cog(MessageHandler(bot))
+ # bot.add_cog(MessageHandler(bot))
+ pass
diff --git a/src/cmds/core/identify.py b/src/cmds/core/identify.py
index ef52ed7..34a378a 100644
--- a/src/cmds/core/identify.py
+++ b/src/cmds/core/identify.py
@@ -1,17 +1,12 @@
import logging
-from typing import Sequence
-import discord
from discord import ApplicationContext, Interaction, WebhookMessage, slash_command
from discord.ext import commands
from discord.ext.commands import cooldown
-from sqlalchemy import select
from src.bot import Bot
from src.core import settings
-from src.database.models import HtbDiscordLink
-from src.database.session import AsyncSessionLocal
-from src.helpers.verification import get_user_details, process_identification
+from src.helpers.verification import send_verification_instructions
logger = logging.getLogger(__name__)
@@ -25,121 +20,15 @@ def __init__(self, bot: Bot):
@slash_command(
guild_ids=settings.guild_ids,
description="Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user "
- "ID.", guild_only=False
+ "ID.",
+ guild_only=False,
)
@cooldown(1, 60, commands.BucketType.user)
- async def identify(self, ctx: ApplicationContext, account_identifier: str) -> Interaction | WebhookMessage:
- """Identify yourself on the HTB Discord server by linking your HTB account ID to your Discord user ID."""
- if len(account_identifier) != 60:
- return await ctx.respond(
- "This Account Identifier does not appear to be the right length (must be 60 characters long).",
- ephemeral=True
- )
-
- await ctx.respond("Identification initiated, please wait...", ephemeral=True)
- htb_user_details = await get_user_details(account_identifier)
- if htb_user_details is None:
- embed = discord.Embed(title="Error: Invalid account identifier.", color=0xFF0000)
- return await ctx.respond(embed=embed, ephemeral=True)
-
- json_htb_user_id = htb_user_details["user_id"]
-
- author = ctx.user
- member = await self.bot.get_or_fetch_user(author.id)
- if not member:
- return await ctx.respond(f"Error getting guild member with id: {author.id}.")
-
- # Step 1: Check if the Account Identifier has already been recorded and if they are the previous owner.
- # Scenario:
- # - I create a new Discord account.
- # - I reuse my previous Account Identifier.
- # - I now have an "alt account" with the same roles.
- async with AsyncSessionLocal() as session:
- stmt = (
- select(HtbDiscordLink)
- .filter(HtbDiscordLink.account_identifier == account_identifier)
- .order_by(HtbDiscordLink.id.desc())
- .limit(1)
- )
- result = await session.scalars(stmt)
- most_recent_rec: HtbDiscordLink = result.first()
-
- if most_recent_rec and most_recent_rec.discord_user_id_as_int != member.id:
- error_desc = (
- f"Verified user {member.mention} tried to identify as another identified user.\n"
- f"Current Discord UID: {member.id}\n"
- f"Other Discord UID: {most_recent_rec.discord_user_id}\n"
- f"Related HTB UID: {most_recent_rec.htb_user_id}"
- )
- embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429)
- await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
-
- return await ctx.respond(
- "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True
- )
-
- # Step 2: Given the htb_user_id from JSON, check if each discord_user_id are different from member.id.
- # Scenario:
- # - I have a Discord account that is linked already to a "Hacker" role.
- # - I create a new HTB account.
- # - I identify with the new account.
- # - `SELECT * FROM htb_discord_link WHERE htb_user_id = %s` will be empty,
- # because the new account has not been verified before. All is good.
- # - I am now "Noob" rank.
- async with AsyncSessionLocal() as session:
- stmt = select(HtbDiscordLink).filter(HtbDiscordLink.htb_user_id == json_htb_user_id)
- result = await session.scalars(stmt)
- user_links: Sequence[HtbDiscordLink] = result.all()
-
- discord_user_ids = {u_link.discord_user_id_as_int for u_link in user_links}
- if discord_user_ids and member.id not in discord_user_ids:
- orig_discord_ids = ", ".join([f"<@{id_}>" for id_ in discord_user_ids])
- error_desc = (f"The HTB account {json_htb_user_id} attempted to be identified by user <@{member.id}>, "
- f"but is tied to another Discord account.\n"
- f"Originally linked to Discord UID {orig_discord_ids}.")
- embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429)
- await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
-
- return await ctx.respond(
- "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True
- )
-
- # Step 3: Check if discord_user_id already linked to an htb_user_id, and if JSON/db HTB IDs are the same.
- # Scenario:
- # - I have a new, unlinked Discord account.
- # - Clubby generates a new token and gives it to me.
- # - `SELECT * FROM htb_discord_link WHERE discord_user_id = %s`
- # will be empty because I have not identified before.
- # - I am now Clubby.
- async with AsyncSessionLocal() as session:
- stmt = select(HtbDiscordLink).filter(HtbDiscordLink.discord_user_id == member.id)
- result = await session.scalars(stmt)
- user_links: Sequence[HtbDiscordLink] = result.all()
-
- user_htb_ids = {u_link.htb_user_id_as_int for u_link in user_links}
- if user_htb_ids and json_htb_user_id not in user_htb_ids:
- error_desc = (f"User {member.mention} ({member.id}) tried to identify with a new HTB account.\n"
- f"Original HTB UIDs: {', '.join([str(i) for i in user_htb_ids])}, new HTB UID: "
- f"{json_htb_user_id}.")
- embed = discord.Embed(title="Identification error", description=error_desc, color=0xFF2429)
- await self.bot.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
-
- return await ctx.respond(
- "Identification error: please contact an online Moderator or Administrator for help.", ephemeral=True
- )
-
- htb_discord_link = HtbDiscordLink(
- account_identifier=account_identifier, discord_user_id=member.id, htb_user_id=json_htb_user_id
- )
- async with AsyncSessionLocal() as session:
- session.add(htb_discord_link)
- await session.commit()
-
- await process_identification(htb_user_details, user=member, bot=self.bot)
-
- return await ctx.respond(
- f"Your Discord user has been successfully identified as HTB user {json_htb_user_id}.", ephemeral=True
- )
+ async def identify(
+ self, ctx: ApplicationContext, account_identifier: str
+ ) -> Interaction | WebhookMessage:
+ """Legacy command. Now sends instructions to identify with HTB account."""
+ await send_verification_instructions(ctx, ctx.author)
def setup(bot: Bot) -> None:
diff --git a/src/cmds/core/verify.py b/src/cmds/core/verify.py
index 23088c1..919f872 100644
--- a/src/cmds/core/verify.py
+++ b/src/cmds/core/verify.py
@@ -1,14 +1,12 @@
import logging
-import discord
from discord import ApplicationContext, Interaction, WebhookMessage, slash_command
-from discord.errors import Forbidden, HTTPException
from discord.ext import commands
from discord.ext.commands import cooldown
from src.bot import Bot
from src.core import settings
-from src.helpers.verification import process_certification
+from src.helpers.verification import process_certification, send_verification_instructions
logger = logging.getLogger(__name__)
@@ -47,57 +45,7 @@ async def verifycertification(self, ctx: ApplicationContext, certid: str, fullna
@cooldown(1, 60, commands.BucketType.user)
async def verify(self, ctx: ApplicationContext) -> Interaction | WebhookMessage:
"""Receive instructions in a DM on how to identify yourself with your HTB account."""
- member = ctx.user
-
- # Step one
- embed_step1 = discord.Embed(color=0x9ACC14)
- embed_step1.add_field(
- name="Step 1: Log in at Hack The Box",
- value="Go to the Hack The Box website at "
- " and navigate to **Login > HTB Labs**. Log in to your HTB Account."
- , inline=False, )
- embed_step1.set_image(
- url="https://media.discordapp.net/attachments/724587782755844098/839871275627315250/unknown.png"
- )
-
- # Step two
- embed_step2 = discord.Embed(color=0x9ACC14)
- embed_step2.add_field(
- name="Step 2: Locate the Account Identifier",
- value='Click on your profile name, then select **My Profile**. '
- 'In the Profile Settings tab, find the field labeled **Account Identifier**. () '
- "Click the green button to copy your secret identifier.", inline=False, )
- embed_step2.set_image(
- url="https://media.discordapp.net/attachments/724587782755844098/839871332963188766/unknown.png"
- )
-
- # Step three
- embed_step3 = discord.Embed(color=0x9ACC14)
- embed_step3.add_field(
- name="Step 3: Identification",
- value="Now type `/identify IDENTIFIER_HERE` in the bot-commands channel.\n\nYour roles will be "
- "applied automatically.", inline=False
- )
- embed_step3.set_image(
- url="https://media.discordapp.net/attachments/709907130102317093/904744444539076618/unknown.png"
- )
-
- try:
- await member.send(embed=embed_step1)
- await member.send(embed=embed_step2)
- await member.send(embed=embed_step3)
- except Forbidden as ex:
- logger.error("Exception during verify call", exc_info=ex)
- return await ctx.respond(
- "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server "
- "members and try again in 1 minute."
- )
- except HTTPException as ex:
- logger.error("Exception during verify call.", exc_info=ex)
- return await ctx.respond(
- "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator."
- )
- return await ctx.respond("Please check your DM for instructions.", ephemeral=True)
+ await send_verification_instructions(ctx, ctx.author)
def setup(bot: Bot) -> None:
diff --git a/src/core/config.py b/src/core/config.py
index 6b3025e..d07de2c 100644
--- a/src/core/config.py
+++ b/src/core/config.py
@@ -91,6 +91,7 @@ class AcademyCertificates(BaseSettings):
class Roles(BaseSettings):
"""The roles settings."""
+ VERIFIED: int
# Moderation
COMMUNITY_MANAGER: int
@@ -330,6 +331,8 @@ def load_settings(env_file: str | None = None):
global_settings.roles.NOOB,
global_settings.roles.VIP,
global_settings.roles.VIP_PLUS,
+ ],
+ "ALL_SEASON_RANKS": [
global_settings.roles.SEASON_HOLO,
global_settings.roles.SEASON_PLATINUM,
global_settings.roles.SEASON_RUBY,
diff --git a/src/helpers/ban.py b/src/helpers/ban.py
index 8d2ddd0..92fe22e 100644
--- a/src/helpers/ban.py
+++ b/src/helpers/ban.py
@@ -1,11 +1,21 @@
"""Helper methods to handle bans, mutes and infractions. Bot or message responses are NOT allowed."""
+
import asyncio
import logging
from datetime import datetime, timezone
+from enum import Enum
-import arrow
import discord
-from discord import Forbidden, Guild, HTTPException, Member, NotFound, User
+from discord import (
+ Forbidden,
+ Guild,
+ HTTPException,
+ Member,
+ NotFound,
+ User,
+ TextChannel,
+ ClientUser,
+)
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
@@ -22,39 +32,233 @@
logger = logging.getLogger(__name__)
-async def _check_member(bot: Bot, guild: Guild, member: Member | User, author: Member = None) -> SimpleResponse | None:
+class BanCodes(Enum):
+ SUCCESS = "SUCCESS"
+ ALREADY_EXISTS = "ALREADY_EXISTS"
+ FAILED = "FAILED"
+
+
+async def _check_member(
+ bot: Bot, guild: Guild, member: Member | User, author: Member | ClientUser | None = None
+) -> SimpleResponse | None:
if isinstance(member, Member):
if member_is_staff(member):
- return SimpleResponse(message="You cannot ban another staff member.", delete_after=None)
+ return SimpleResponse(
+ message="You cannot ban another staff member.", delete_after=None
+ )
elif isinstance(member, User):
- member = await bot.get_member_or_user(guild, member.id)
+ member = await bot.get_member_or_user(guild, member.id) # type: ignore
if member.bot:
- return SimpleResponse(message="You cannot ban a bot.", delete_after=None)
+ return SimpleResponse(
+ message="You cannot ban a bot.", delete_after=None, code=BanCodes.FAILED
+ )
if author and author.id == member.id:
- return SimpleResponse(message="You cannot ban yourself.", delete_after=None)
+ return SimpleResponse(
+ message="You cannot ban yourself.", delete_after=None, code=BanCodes.FAILED
+ )
-async def _get_ban_or_create(member: Member, ban: Ban, infraction: Infraction) -> tuple[int, bool]:
+async def get_ban(member: Member | User) -> Ban | None:
async with AsyncSessionLocal() as session:
- stmt = select(Ban).filter(Ban.user_id == member.id, Ban.unbanned.is_(False)).limit(1)
+ stmt = (
+ select(Ban)
+ .filter(Ban.user_id == member.id, Ban.unbanned.is_(False))
+ .limit(1)
+ )
result = await session.scalars(stmt)
- existing_ban = result.first()
- if existing_ban:
- return existing_ban.id, True
+ return result.first()
+
+
+async def update_ban(ban: Ban) -> None:
+ logger.info(f"Updating ban {ban.id} for user {ban.user_id} with expiration {ban.unban_time}")
+ async with AsyncSessionLocal() as session:
+ session.add(ban)
+ await session.commit()
+
+
+async def _get_ban_or_create(
+ member: Member | User, ban: Ban, infraction: Infraction
+) -> tuple[int, bool]:
+ existing_ban = await get_ban(member)
+ if existing_ban:
+ return existing_ban.id, True
+ async with AsyncSessionLocal() as session:
session.add(ban)
session.add(infraction)
await session.commit()
+
ban_id: int = ban.id
assert ban_id is not None
return ban_id, False
-async def ban_member(
- bot: Bot, guild: Guild, member: Member | User, duration: str, reason: str, evidence: str, author: Member = None,
- needs_approval: bool = True
-) -> SimpleResponse | None:
- """Ban a member from the guild."""
+async def _create_ban_response(
+ member: Member | User, end_date: str, dm_banned_member: bool, needs_approval: bool
+) -> SimpleResponse:
+ """Create a SimpleResponse for ban operations."""
+ if needs_approval:
+ if member:
+ message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)."
+ else:
+ message = f"{member.id} has been banned until {end_date} (UTC)."
+ else:
+ if member:
+ message = f"Member {member.display_name} has been banned permanently."
+ else:
+ message = f"Member {member.id} has been banned permanently."
+
+ if not dm_banned_member:
+ message += "\n Could not DM banned member due to permission error."
+
+ return SimpleResponse(
+ message=message,
+ delete_after=0 if not needs_approval else None,
+ code=BanCodes.SUCCESS,
+ )
+
+
+async def _send_ban_notice(
+ guild: Guild,
+ member: Member,
+ reason: str,
+ author: str,
+ end_date: str,
+ channel: TextChannel | None,
+) -> None:
+ """Send a ban log to the moderator channel."""
+ if not isinstance(channel, TextChannel):
+ channel = guild.get_channel(settings.channels.SR_MOD) # type: ignore
+
+ embed = discord.Embed(
+ title="Ban",
+ description=f"User {member.mention} ({member.id}) was banned on the platform and thus banned here.",
+ color=0xFF2429,
+ )
+ embed.add_field(name="Reason", value=reason)
+ embed.add_field(name="Author", value=author)
+ embed.add_field(name="End Date", value=end_date)
+
+ await channel.send(embed=embed) # type: ignore
+
+
+async def handle_platform_ban_or_update(
+ bot: Bot,
+ guild: Guild,
+ member: Member,
+ expires_timestamp: int,
+ reason: str,
+ evidence: str,
+ author_name: str,
+ expires_at_str: str,
+ log_channel_id: int,
+ logger,
+ extra_log_data: dict | None = None,
+) -> dict:
+ """Handle platform ban by either creating new ban, updating existing ban, or taking no action.
+
+ Args:
+ bot: The Discord bot instance
+ guild: The guild to ban the member from
+ member: The member to ban
+ expires_timestamp: Unix timestamp when the ban should end
+ reason: Reason for the ban
+ evidence: Evidence supporting the ban (notes)
+ author_name: Name of the person who created the ban
+ expires_at_str: Human-readable expiration date string
+ log_channel_id: Channel ID for logging ban actions
+ logger: Logger instance for recording events
+ extra_log_data: Additional data to include in log entries
+
+ Returns:
+ dict with 'action' key indicating what was done: 'unbanned', 'extended', 'no_action', 'updated', 'created'
+ """
+ if extra_log_data is None:
+ extra_log_data = {}
+
+ expires_dt = datetime.fromtimestamp(expires_timestamp)
+
+ existing_ban = await get_ban(member)
+ if not existing_ban:
+ # No existing ban, create new one
+ await ban_member_with_epoch(
+ bot, guild, member, expires_timestamp, reason, evidence, needs_approval=False
+ )
+ await _send_ban_notice(
+ guild, member, reason, author_name, expires_at_str, guild.get_channel(log_channel_id) # type: ignore
+ )
+ logger.info(f"Created new platform ban for user {member.id} until {expires_at_str}", extra=extra_log_data)
+ return {"action": "created"}
+
+ # Existing ban found - determine what to do based on ban type and timing
+ is_platform_ban = existing_ban.reason.startswith("Platform Ban")
+
+ if is_platform_ban:
+ # Platform bans have authority over other platform bans
+ if expires_dt < datetime.now():
+ # Platform ban has expired, unban the user
+ await unban_member(guild, member)
+ msg = f"User {member.mention} ({member.id}) has been unbanned due to platform ban expiration."
+ await guild.get_channel(log_channel_id).send(msg) # type: ignore
+ logger.info(msg, extra=extra_log_data)
+ return {"action": "unbanned"}
+
+ if existing_ban.unban_time < expires_timestamp:
+ # Extend the existing platform ban
+ existing_ban.unban_time = expires_timestamp
+ await update_ban(existing_ban)
+ msg = f"User {member.mention} ({member.id}) has had their ban extended to {expires_at_str}."
+ await guild.get_channel(log_channel_id).send(msg) # type: ignore
+ logger.info(msg, extra=extra_log_data)
+ return {"action": "extended"}
+ else:
+ # Non-platform ban exists
+ if existing_ban.unban_time >= expires_timestamp:
+ # Existing ban is longer than platform ban, no action needed
+ logger.info(
+ f"User {member.mention} ({member.id}) is already banned until {existing_ban.unban_time}, "
+ f"which exceeds or equals the platform ban expiration date {expires_at_str}. No action taken.",
+ extra=extra_log_data,
+ )
+ return {"action": "no_action"}
+ else:
+ # Platform ban is longer, update the existing ban
+ existing_ban.unban_time = expires_timestamp
+ existing_ban.reason = f"Platform Ban: {reason}" # Update reason to indicate platform authority
+ await update_ban(existing_ban)
+ logger.info(f"Updated existing ban for user {member.id} until {expires_at_str}.", extra=extra_log_data)
+ return {"action": "updated"}
+
+ # Default case (shouldn't reach here, but for safety)
+ logger.warning(f"Unexpected case in platform ban handling for user {member.id}", extra=extra_log_data)
+ return {"action": "no_action"}
+
+
+async def ban_member_with_epoch(
+ bot: Bot,
+ guild: Guild,
+ member: Member | User,
+ unban_epoch_time: int,
+ reason: str,
+ evidence: str,
+ author: Member | ClientUser | None = None,
+ needs_approval: bool = True,
+) -> SimpleResponse:
+ """Ban a member from the guild until a specific epoch time.
+
+ Args:
+ bot: The Discord bot instance
+ guild: The guild to ban the member from
+ member: The member or user to ban
+ unban_epoch_time: Unix timestamp when the ban should end
+ reason: Reason for the ban
+ evidence: Evidence supporting the ban
+ author: The member issuing the ban (defaults to bot user)
+ needs_approval: Whether the ban requires approval
+
+ Returns:
+ SimpleResponse with the result of the ban operation, or None if no response needed
+ """
if checked := await _check_member(bot, guild, member, author):
return checked
@@ -65,35 +269,50 @@ async def ban_member(
if not evidence:
evidence = "none provided"
- # Validate duration
- dur, dur_exc = validate_duration(duration)
- # Check if duration is valid,
- # negative values are generally not allowed,
- # so they should be caught here
- if dur <= 0:
- return SimpleResponse(message=dur_exc, delete_after=15)
- else:
- end_date: str = datetime.fromtimestamp(dur, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
+ # Validate epoch time is in the future
+ current_time = datetime.now(tz=timezone.utc).timestamp()
+ if unban_epoch_time <= current_time:
+ return SimpleResponse(
+ message="Unban time must be in the future",
+ delete_after=15
+ )
+
+ end_date: str = datetime.fromtimestamp(unban_epoch_time, tz=timezone.utc).strftime(
+ "%Y-%m-%d %H:%M:%S"
+ )
if author is None:
author = bot.user
+
+ # Author should never be None at this point
+ if author is None:
+ raise ValueError("Author cannot be None")
ban = Ban(
- user_id=member.id, reason=reason, moderator_id=author.id, unban_time=dur,
- approved=False if needs_approval else True
+ user_id=member.id,
+ reason=reason,
+ moderator_id=author.id,
+ unban_time=unban_epoch_time,
+ approved=False if needs_approval else True,
)
infraction = Infraction(
user_id=member.id,
reason=f"Previously banned for: {reason} - Evidence: {evidence}",
weight=0,
moderator_id=author.id,
- date=datetime.now().date()
+ date=datetime.now().date(),
)
+
ban_id, is_existing = await _get_ban_or_create(member, ban, infraction)
if is_existing:
+ try:
+ await guild.ban(member, reason=reason, delete_message_seconds=0)
+ except NotFound:
+ pass
return SimpleResponse(
message=f"A ban with id: {ban_id} already exists for member {member}",
- delete_after=None
+ delete_after=None,
+ code=BanCodes.ALREADY_EXISTS,
)
# DM member, before we ban, else we cannot dm since we do not share a guild
@@ -103,67 +322,120 @@ async def ban_member(
await guild.ban(member, reason=reason, delete_message_seconds=0)
except Forbidden as exc:
logger.warning(
- "Ban failed due to permission error", exc_info=exc,
- extra={"ban_requestor": author.name, "ban_receiver": member.id}
+ "Ban failed due to permission error",
+ exc_info=exc,
+ extra={"ban_requestor": author.name, "ban_receiver": member.id},
)
if author:
- return SimpleResponse(message="You do not have the proper permissions to ban.", delete_after=None)
+ return SimpleResponse(
+ message="You do not have the proper permissions to ban.",
+ delete_after=None,
+ code=BanCodes.FAILED,
+ )
return
except HTTPException as ex:
- logger.warning(f"HTTPException when trying to ban user with ID {member.id}", exc_info=ex)
+ logger.warning(
+ f"HTTPException when trying to ban user with ID {member.id}", exc_info=ex
+ )
if author:
return SimpleResponse(
message="Here's a 400 Bad Request for you. Just like when you tried to ask me out, last week.",
- delete_after=None
+ delete_after=None,
+ code=BanCodes.FAILED,
)
return
# If approval is required, send a message to the moderator channel about the ban
if not needs_approval:
- if member:
- message = f"Member {member.display_name} has been banned permanently."
- else:
- message = f"Member {member.id} has been banned permanently."
-
- if not dm_banned_member:
- message += "\n Could not DM banned member due to permission error."
-
logger.info(
"Member has been banned permanently.",
- extra={"ban_requestor": author.name, "ban_receiver": member.id, "dm_banned_member": dm_banned_member}
+ extra={
+ "ban_requestor": author.name,
+ "ban_receiver": member.id,
+ "dm_banned_member": dm_banned_member,
+ },
)
- unban_task = schedule(unban_member(guild, member), run_at=datetime.fromtimestamp(ban.unban_time))
+ unban_task = schedule(
+ unban_member(guild, member), run_at=datetime.fromtimestamp(ban.unban_time)
+ )
asyncio.create_task(unban_task)
- logger.debug("Unbanned sceduled for ban", extra={"ban_id": ban_id, "unban_time": ban.unban_time})
- return SimpleResponse(message=message, delete_after=0)
+ logger.debug(
+ "Unbanned sceduled for ban",
+ extra={"ban_id": ban_id, "unban_time": ban.unban_time},
+ )
else:
- if member:
- message = f"{member.display_name} ({member.id}) has been banned until {end_date} (UTC)."
- else:
- message = f"{member.id} has been banned until {end_date} (UTC)."
-
- if not dm_banned_member:
- message += " Could not DM banned member due to permission error."
-
member_name = f"{member.display_name} ({member.name})"
embed = discord.Embed(
title=f"Ban request #{ban_id}",
description=f"{author.display_name} ({author.name}) "
- f"would like to ban {member_name} until {end_date} (UTC).\n"
- f"Reason: {reason}\n"
- f"Evidence: {evidence}", )
+ f"would like to ban {member_name} until {end_date} (UTC).\n"
+ f"Reason: {reason}\n"
+ f"Evidence: {evidence}",
+ )
embed.set_thumbnail(url=f"{settings.HTB_URL}/images/logo600.png")
view = BanDecisionView(ban_id, bot, guild, member, end_date, reason)
- await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view)
- return SimpleResponse(message=message)
+ await guild.get_channel(settings.channels.SR_MOD).send(embed=embed, view=view) # type: ignore
+ return await _create_ban_response(
+ member, end_date, dm_banned_member, needs_approval
+ )
-async def _dm_banned_member(end_date: str, guild: Guild, member: Member, reason: str) -> bool:
+
+async def ban_member(
+ bot: Bot,
+ guild: Guild,
+ member: Member | User,
+ duration: str,
+ reason: str,
+ evidence: str,
+ author: Member | None = None,
+ needs_approval: bool = True,
+) -> SimpleResponse:
+ """Ban a member from the guild using a duration.
+
+ Args:
+ bot: The Discord bot instance
+ guild: The guild to ban the member from
+ member: The member or user to ban
+ duration: Duration string (e.g., "1d", "1h") or seconds as int
+ reason: Reason for the ban
+ evidence: Evidence supporting the ban
+ author: The member issuing the ban (defaults to bot user)
+ needs_approval: Whether the ban requires approval
+
+ Returns:
+ SimpleResponse with the result of the ban operation, or None if no response needed
+ """
+ dur, dur_exc = validate_duration(duration)
+
+ # Check if duration is valid,
+ # negative values are generally not allowed,
+ # so they should be caught here
+ if dur <= 0:
+ return SimpleResponse(message=dur_exc, delete_after=15)
+
+ return await ban_member_with_epoch(
+ bot=bot,
+ guild=guild,
+ member=member,
+ unban_epoch_time=dur,
+ reason=reason,
+ evidence=evidence,
+ author=author,
+ needs_approval=needs_approval,
+ )
+
+
+async def _dm_banned_member(
+ end_date: str, guild: Guild, member: Member | User, reason: str
+) -> bool:
"""Send a message to the member about the ban."""
- message = (f"You have been banned from {guild.name} until {end_date} (UTC). "
- f"To appeal the ban, please reach out to an Administrator.\n"
- f"Following is the reason given:\n>>> {reason}\n")
+ message = (
+ f"You have been banned from {guild.name} until {end_date} (UTC). "
+ f"To appeal the ban, please reach out to an Administrator.\n"
+ f"Following is the reason given:\n>>> {reason}\n"
+ )
try:
await member.send(message)
return True
@@ -171,29 +443,43 @@ async def _dm_banned_member(end_date: str, guild: Guild, member: Member, reason:
logger.warning(
f"Could not DM member with id {member.id} due to privacy settings, however will still attempt to ban "
f"them...",
- exc_info=ex
+ exc_info=ex,
)
except HTTPException as ex:
- logger.warning(f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex)
+ logger.warning(
+ f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex
+ )
return False
-async def unban_member(guild: Guild, member: Member) -> Member:
+async def unban_member(guild: Guild, member: Member | User) -> Member | User:
"""Unban a member from the guild."""
try:
await guild.unban(member)
logger.info(f"Unbanned user {member.id}.")
except Forbidden as ex:
- logger.error(f"Permission denied when trying to unban user with ID {member.id}", exc_info=ex)
+ logger.error(
+ f"Permission denied when trying to unban user with ID {member.id}",
+ exc_info=ex,
+ )
except NotFound as ex:
logger.error(
f"NotFound when trying to unban user with ID {member.id}. "
- f"This could indicate that the user is not currently banned.", exc_info=ex, )
+ f"This could indicate that the user is not currently banned.",
+ exc_info=ex,
+ )
except HTTPException as ex:
- logger.error(f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex)
+ logger.error(
+ f"HTTPException when trying to unban user with ID {member.id}", exc_info=ex
+ )
async with AsyncSessionLocal() as session:
- stmt = select(Ban).filter(Ban.user_id == member.id).filter(Ban.unbanned.is_(False)).limit(1)
+ stmt = (
+ select(Ban)
+ .filter(Ban.user_id == member.id)
+ .filter(Ban.unbanned.is_(False))
+ .limit(1)
+ )
result = await session.scalars(stmt)
ban = result.first()
if ban:
@@ -207,7 +493,12 @@ async def unban_member(guild: Guild, member: Member) -> Member:
async def mute_member(
- bot: Bot, guild: Guild, member: Member, duration: str, reason: str, author: Member = None
+ bot: Bot,
+ guild: Guild,
+ member: Member,
+ duration: str,
+ reason: str,
+ author: Member | ClientUser | None = None,
) -> SimpleResponse | None:
"""Mute a member on the guild."""
if checked := await _check_member(bot, guild, member, author):
@@ -225,13 +516,17 @@ async def mute_member(
if author is None:
author = bot.user
+
+ # Author should never be None at this point
+ if author is None:
+ raise ValueError("Author cannot be None")
role = guild.get_role(settings.roles.MUTED)
if member:
# No longer on the server - cleanup, but don't attempt to remove a role
logger.info(f"Add mute from {member.name}:{member.id}.")
- await member.add_roles(role, reason=reason)
+ await member.add_roles(role, reason=reason) # type: ignore
mute = Mute(
user_id=member.id, reason=reason, moderator_id=author.id, unmute_time=dur
@@ -248,7 +543,7 @@ async def unmute_member(guild: Guild, member: Member) -> Member:
if isinstance(member, Member):
# No longer on the server - cleanup, but don't attempt to remove a role
logger.info(f"Remove mute from {member.name}:{member.id}.")
- await member.remove_roles(role)
+ await member.remove_roles(role) # type: ignore
await member.remove_timeout()
async with AsyncSessionLocal() as session:
@@ -272,7 +567,9 @@ async def add_infraction(
if len(reason) == 0:
reason = "No reason given ..."
- infraction = Infraction(user_id=member.id, reason=reason, weight=weight, moderator_id=author.id)
+ infraction = Infraction(
+ user_id=member.id, reason=reason, weight=weight, moderator_id=author.id
+ )
async with AsyncSessionLocal() as session:
session.add(infraction)
await session.commit()
@@ -287,9 +584,15 @@ async def add_infraction(
)
except Forbidden as ex:
message = "Could not DM member due to privacy settings, however the infraction was still added."
- logger.warning(f"Forbidden, when trying to contact user with ID {member.id} about infraction.", exc_info=ex)
+ logger.warning(
+ f"Forbidden, when trying to contact user with ID {member.id} about infraction.",
+ exc_info=ex,
+ )
except HTTPException as ex:
message = "Here's a 400 Bad Request for you. Just like when you tried to ask me out, last week."
- logger.warning(f"HTTPException when trying to add infraction for user with ID {member.id}", exc_info=ex)
+ logger.warning(
+ f"HTTPException when trying to add infraction for user with ID {member.id}",
+ exc_info=ex,
+ )
return SimpleResponse(message=message, delete_after=None)
diff --git a/src/helpers/responses.py b/src/helpers/responses.py
index 6513bfd..e80da1a 100644
--- a/src/helpers/responses.py
+++ b/src/helpers/responses.py
@@ -1,15 +1,16 @@
import json
-
+from typing import Any
class SimpleResponse(object):
"""A simple response object."""
- def __init__(self, message: str, delete_after: int | None = None):
+ def __init__(self, message: str, delete_after: int | None = None, code: str | Any = None):
self.message = message
self.delete_after = delete_after
+ self.code = code
def __str__(self):
- return json.dumps(dict(self), ensure_ascii=False)
-
+ return json.dumps(dict(self), ensure_ascii=False) # type: ignore
+
def __repr__(self):
return self.__str__()
diff --git a/src/helpers/verification.py b/src/helpers/verification.py
index ddb4be9..7e8358d 100644
--- a/src/helpers/verification.py
+++ b/src/helpers/verification.py
@@ -1,52 +1,151 @@
import logging
-from datetime import datetime
-from typing import Dict, List, Optional, cast
+import traceback
+from typing import Dict, List, Optional, Any, TypeVar
import aiohttp
import discord
-from discord import Forbidden, Member, Role, User
+from discord import (
+ ApplicationContext,
+ Forbidden,
+ HTTPException,
+ Member,
+ Role,
+ User,
+ Guild,
+)
from discord.ext.commands import GuildNotFound, MemberNotFound
-from src.bot import Bot
+from src.bot import Bot, BOT_TYPE
+
from src.core import settings
-from src.helpers.ban import ban_member
+from src.helpers.ban import BanCodes, ban_member, _send_ban_notice
logger = logging.getLogger(__name__)
-async def get_user_details(account_identifier: str) -> Optional[Dict]:
+async def send_verification_instructions(
+ ctx: ApplicationContext, member: Member
+) -> discord.Interaction | discord.WebhookMessage:
+ """Send instructions via DM on how to identify with HTB account.
+
+ Args:
+ ctx (ApplicationContext): The context of the command.
+ member (Member): The member to send the instructions to.
+
+ Returns:
+ discord.Interaction | discord.WebhookMessage: The response message.
+ """
+ member = ctx.user
+
+ # Create step-by-step instruction embeds
+ embed_step1 = discord.Embed(color=0x9ACC14)
+ embed_step1.add_field(
+ name="Step 1: Login to your HTB Account",
+ value="Go to and login.",
+ inline=False,
+ )
+ embed_step1.set_image(
+ url="https://media.discordapp.net/attachments/1102700815493378220/1384587341338902579/image.png"
+ )
+
+ embed_step2 = discord.Embed(color=0x9ACC14)
+ embed_step2.add_field(
+ name="Step 2: Navigate to your Security Settings",
+ value="In the navigation bar, click on **Security Settings** and scroll down to the **Discord Account** section. "
+ "()",
+ inline=False,
+ )
+ embed_step2.set_image(
+ url="https://media.discordapp.net/attachments/1102700815493378220/1384587813760270392/image.png"
+ )
+
+ embed_step3 = discord.Embed(color=0x9ACC14)
+ embed_step3.add_field(
+ name="Step 3: Link your Discord Account",
+ value="Click **Connect** and you will be redirected to login to your Discord account via oauth. "
+ "After logging in, you will be redirected back to the HTB Account page. "
+ "Your Discord account will now be linked. Discord may take a few minutes to update. "
+ "If you have any issues, please contact a Moderator.",
+ inline=False,
+ )
+ embed_step3.set_image(
+ url="https://media.discordapp.net/attachments/1102700815493378220/1384586811384402042/image.png"
+ )
+
+ try:
+ await member.send(embed=embed_step1)
+ await member.send(embed=embed_step2)
+ await member.send(embed=embed_step3)
+ except Forbidden as ex:
+ logger.error("Exception during verify call", exc_info=ex)
+ return await ctx.respond(
+ "Whoops! I cannot DM you after all due to your privacy settings. Please allow DMs from other server "
+ "members and try again in 1 minute."
+ )
+ except HTTPException as ex:
+ logger.error("Exception during verify call.", exc_info=ex)
+ return await ctx.respond(
+ "An unexpected error happened (HTTP 400, bad request). Please contact an Administrator."
+ )
+
+ return await ctx.respond("Please check your DM for instructions.", ephemeral=True)
+
+
+
+def get_labs_session() -> aiohttp.ClientSession:
+ """Get a session for the HTB Labs API."""
+ return aiohttp.ClientSession(headers={"Authorization": f"Bearer {settings.HTB_API_KEY}"})
+
+
+async def get_user_details(labs_id: int | str) -> dict:
"""Get user details from HTB."""
- acc_id_url = f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}"
- async with aiohttp.ClientSession() as session:
- async with session.get(acc_id_url) as r:
+ if not labs_id:
+ return {}
+
+ user_profile_api_url = f"{settings.API_V4_URL}/user/profile/basic/{labs_id}"
+ user_content_api_url = f"{settings.API_V4_URL}/user/profile/content/{labs_id}"
+
+ async with get_labs_session() as session:
+ async with session.get(user_profile_api_url) as r:
if r.status == 200:
- response = await r.json()
- elif r.status == 404:
- logger.debug("Account identifier has been regenerated since last identification. Cannot re-verify.")
- response = None
+ profile_response = await r.json()
else:
- logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.")
- response = None
+ logger.error(
+ f"Non-OK HTTP status code returned from user details lookup: {r.status}."
+ )
+ profile_response = {}
- return response
+ async with session.get(user_content_api_url) as r:
+ if r.status == 200:
+ content_response = await r.json()
+ else:
+ logger.error(
+ f"Non-OK HTTP status code returned from user content lookup: {r.status}."
+ )
+ content_response = {}
+
+ profile = profile_response.get("profile", {})
+ profile["content"] = content_response.get("profile", {}).get("content", {})
+ return profile
async def get_season_rank(htb_uid: int) -> str | None:
"""Get season rank from HTB."""
- headers = {"Authorization": f"Bearer {settings.HTB_API_KEY}"}
season_api_url = f"{settings.API_V4_URL}/season/end/{settings.SEASON_ID}/{htb_uid}"
- async with aiohttp.ClientSession() as session:
- async with session.get(season_api_url, headers=headers) as r:
+ async with get_labs_session() as session:
+ async with session.get(season_api_url) as r:
if r.status == 200:
response = await r.json()
elif r.status == 404:
logger.error("Invalid Season ID.")
- response = None
+ response = {}
else:
- logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.")
- response = None
+ logger.error(
+ f"Non-OK HTTP status code returned from identifier lookup: {r.status}."
+ )
+ response = {}
if not response["data"]:
rank = None
@@ -59,26 +158,10 @@ async def get_season_rank(htb_uid: int) -> str | None:
return rank
-async def _check_for_ban(uid: str) -> Optional[Dict]:
- async with aiohttp.ClientSession() as session:
- token_url = f"{settings.API_URL}/discord/{uid}/banned?secret={settings.HTB_API_SECRET}"
- async with session.get(token_url) as r:
- if r.status == 200:
- ban_details = await r.json()
- else:
- logger.error(
- f"Could not fetch ban details for uid {uid}: "
- f"non-OK status code returned ({r.status}). Body: {r.content}"
- )
- ban_details = None
-
- return ban_details
-
-
async def process_certification(certid: str, name: str):
"""Process certifications."""
cert_api_url = f"{settings.API_V4_URL}/certificate/lookup"
- params = {'id': certid, 'name': name}
+ params = {"id": certid, "name": name}
async with aiohttp.ClientSession() as session:
async with session.get(cert_api_url, params=params) as r:
if r.status == 200:
@@ -86,8 +169,10 @@ async def process_certification(certid: str, name: str):
elif r.status == 404:
return False
else:
- logger.error(f"Non-OK HTTP status code returned from identifier lookup: {r.status}.")
- response = None
+ logger.error(
+ f"Non-OK HTTP status code returned from identifier lookup: {r.status}."
+ )
+ response = {}
try:
certRawName = response["certificates"][0]["name"]
except IndexError:
@@ -107,118 +192,300 @@ async def process_certification(certid: str, name: str):
return cert
-async def process_identification(
- htb_user_details: Dict[str, str], user: Optional[Member | User], bot: Bot
+async def _handle_banned_user(member: Member, bot: BOT_TYPE):
+ """Handle banned trait during account linking.
+
+ Args:
+ member (Member): The member to process.
+ bot (Bot): The bot instance.
+ """
+ resp = await ban_member(
+ bot, # type: ignore
+ member.guild,
+ member,
+ "1337w",
+ (
+ "Platform Ban - Ban duration could not be determined. "
+ "Please login to confirm ban details and contact HTB Support to appeal."
+ ),
+ "N/A",
+ None,
+ needs_approval=False,
+ )
+ if resp.code == BanCodes.SUCCESS:
+ await _send_ban_notice(
+ member.guild,
+ member,
+ resp.message,
+ "System",
+ "1337w",
+ member.guild.get_channel(settings.channels.VERIFY_LOGS), # type: ignore
+ )
+
+
+async def _set_nickname(member: Member, nickname: str) -> bool:
+ """Set the nickname of the member.
+
+ Args:
+ member (Member): The member to set the nickname for.
+ nickname (str): The nickname to set.
+
+ Returns:
+ bool: True if the nickname was set, False otherwise.
+ """
+ try:
+ await member.edit(nick=nickname)
+ return True
+ except Forbidden as e:
+ logger.error(f"Exception whe trying to edit the nick-name of the user: {e}")
+ return False
+
+
+async def process_account_identification(
+ member: Member, bot: BOT_TYPE, traits: dict[str, Any]
+) -> None:
+ """Process HTB account identification, to be called during account linking.
+
+ Args:
+ member (Member): The member to process.
+ bot (Bot): The bot instance.
+ traits (dict[str, str] | None): Optional user traits to process.
+ """
+ try:
+ await member.add_roles(member.guild.get_role(settings.roles.VERIFIED), atomic=True) # type: ignore
+ except Exception as e:
+ logger.error(f"Failed to add VERIFIED role to user {member.id}: {e}")
+ # Don't raise - continue with other operations
+
+ nickname_changed = False
+
+ traits = traits or {}
+
+ if traits.get("username") and traits.get("username") != member.name:
+ nickname_changed = await _set_nickname(member, traits.get("username")) # type: ignore
+
+ if not nickname_changed:
+ logger.warning(
+ f"No username provided for {member.name} with ID {member.id} during identification."
+ )
+
+ if traits.get("mp_user_id"):
+ try:
+ logger.debug(f"MP user ID: {traits.get('mp_user_id', None)}")
+ htb_user_details = await get_user_details(traits.get("mp_user_id", None))
+ if htb_user_details:
+ await process_labs_identification(htb_user_details, member, bot) # type: ignore
+
+ if not nickname_changed and htb_user_details.get("username"):
+ logger.debug(
+ f"Falling back on HTB username to set nickname for {member.name} with ID {member.id}."
+ )
+ await _set_nickname(member, htb_user_details["username"])
+ except Exception as e:
+ logger.error(f"Failed to process labs identification for user {member.id}: {e}")
+ # Don't raise - this is not critical
+
+ if traits.get("banned", False) == True: # noqa: E712 - explicit bool only, no truthiness
+ try:
+ logger.debug(f"Handling banned user {member.id}")
+ await _handle_banned_user(member, bot)
+ return
+ except Exception as e:
+ logger.error(f"Failed to handle banned user {member.id}: {e}")
+ logger.exception(traceback.format_exc())
+ # Don't raise - continue processing
+
+
+async def process_labs_identification(
+ htb_user_details: dict, user: Optional[Member | User], bot: Bot
) -> Optional[List[Role]]:
"""Returns roles to assign if identification was successfully processed."""
- htb_uid = htb_user_details["user_id"]
+
+ # Resolve member and guild
+ member, guild = await _resolve_member_and_guild(user, bot)
+
+ # Get roles to remove and assign
+ to_remove = _get_roles_to_remove(member, guild)
+ to_assign = await _process_role_assignments(htb_user_details, guild)
+
+ # Remove roles that will be reassigned
+ to_remove = list(set(to_remove) - set(to_assign))
+
+ # Apply role changes
+ await _apply_role_changes(member, to_remove, to_assign)
+
+ return to_assign
+
+
+async def _resolve_member_and_guild(
+ user: Optional[Member | User], bot: Bot
+) -> tuple[Member, Guild]:
+ """Resolve member and guild from user object."""
if isinstance(user, Member):
- member = user
- guild = member.guild
- # This will only work if the user and the bot share only one guild.
- elif isinstance(user, User) and len(user.mutual_guilds) == 1:
+ return user, user.guild
+
+ if isinstance(user, User) and len(user.mutual_guilds) == 1:
guild = user.mutual_guilds[0]
member = await bot.get_member_or_user(guild, user.id)
if not member:
raise MemberNotFound(str(user.id))
- else:
- raise GuildNotFound(f"Could not identify member {user} in guild.")
- season_rank = await get_season_rank(htb_uid)
- banned_details = await _check_for_ban(htb_uid)
-
- if banned_details is not None and banned_details["banned"]:
- # If user is banned, this field must be a string
- # Strip date e.g. from "2022-01-31T11:00:00.000000Z"
- banned_until: str = cast(str, banned_details["ends_at"])[:10]
- banned_until_dt: datetime = datetime.strptime(banned_until, "%Y-%m-%d")
- ban_duration: str = f"{(banned_until_dt - datetime.now()).days}d"
- reason = "Banned on the HTB Platform. Please contact HTB Support to appeal."
- logger.info(f"Discord user {member.name} ({member.id}) is platform banned. Banning from Discord...")
- await ban_member(bot, guild, member, ban_duration, reason, None, needs_approval=False)
-
- embed = discord.Embed(
- title="Identification error",
- description=f"User {member.mention} ({member.id}) was platform banned HTB and thus also here.",
- color=0xFF2429, )
-
- await guild.get_channel(settings.channels.VERIFY_LOGS).send(embed=embed)
- return None
+ return member, guild # type: ignore
+
+ raise GuildNotFound(f"Could not identify member {user} in guild.")
+
+def _get_roles_to_remove(member: Member, guild: Guild) -> list[Role]:
+ """Get existing roles that should be removed."""
to_remove = []
+ try:
+ all_ranks = settings.role_groups.get("ALL_RANKS", [])
+ all_positions = settings.role_groups.get("ALL_POSITIONS", [])
+ removable_role_ids = all_ranks + all_positions
+
+ for role in member.roles:
+ if role.id in removable_role_ids:
+ guild_role = guild.get_role(role.id)
+ if guild_role:
+ to_remove.append(guild_role)
+ except Exception as e:
+ logger.error(f"Error processing existing roles for user {member.id}: {e}")
+ return to_remove
- for role in member.roles:
- if role.id in settings.role_groups.get("ALL_RANKS") + settings.role_groups.get("ALL_POSITIONS"):
- to_remove.append(guild.get_role(role.id))
+async def _process_role_assignments(
+ htb_user_details: dict, guild: Guild
+) -> list[Role]:
+ """Process role assignments based on HTB user details."""
to_assign = []
- logger.debug(
- "Getting role 'rank':", extra={
- "role_id": settings.get_post_or_rank(htb_user_details["rank"]),
- "role_obj": guild.get_role(settings.get_post_or_rank(htb_user_details["rank"])),
- "htb_rank": htb_user_details["rank"],
- }, )
- if htb_user_details["rank"] not in ["Deleted", "Moderator", "Ambassador", "Admin", "Staff"]:
- to_assign.append(guild.get_role(settings.get_post_or_rank(htb_user_details["rank"])))
- if season_rank:
- to_assign.append(guild.get_role(settings.get_season(season_rank)))
- if htb_user_details["vip"]:
- logger.debug(
- 'Getting role "VIP":', extra={"role_id": settings.roles.VIP, "role_obj": guild.get_role(settings.roles.VIP)}
- )
- to_assign.append(guild.get_role(settings.roles.VIP))
- if htb_user_details["dedivip"]:
- logger.debug(
- 'Getting role "VIP+":',
- extra={"role_id": settings.roles.VIP_PLUS, "role_obj": guild.get_role(settings.roles.VIP_PLUS)}
- )
- to_assign.append(guild.get_role(settings.roles.VIP_PLUS))
- if htb_user_details["hof_position"] != "unranked":
- position = int(htb_user_details["hof_position"])
- pos_top = None
- if position == 1:
- pos_top = "1"
- elif position <= 10:
- pos_top = "10"
- if pos_top:
- logger.debug(f"User is Hall of Fame rank {position}. Assigning role Top-{pos_top}...")
- logger.debug(
- 'Getting role "HoF role":', extra={
- "role_id": settings.get_post_or_rank(pos_top),
- "role_obj": guild.get_role(settings.get_post_or_rank(pos_top)), "hof_val": pos_top,
- }, )
- to_assign.append(guild.get_role(settings.get_post_or_rank(pos_top)))
- else:
- logger.debug(f"User is position {position}. No Hall of Fame roles for them.")
- if htb_user_details["machines"]:
- logger.debug(
- 'Getting role "BOX_CREATOR":',
- extra={"role_id": settings.roles.BOX_CREATOR, "role_obj": guild.get_role(settings.roles.BOX_CREATOR)}, )
- to_assign.append(guild.get_role(settings.roles.BOX_CREATOR))
- if htb_user_details["challenges"]:
- logger.debug(
- 'Getting role "CHALLENGE_CREATOR":', extra={
- "role_id": settings.roles.CHALLENGE_CREATOR,
- "role_obj": guild.get_role(settings.roles.CHALLENGE_CREATOR),
- }, )
- to_assign.append(guild.get_role(settings.roles.CHALLENGE_CREATOR))
-
- if member.nick != htb_user_details["user_name"]:
- try:
- await member.edit(nick=htb_user_details["user_name"])
- except Forbidden as e:
- logger.error(f"Exception whe trying to edit the nick-name of the user: {e}")
+
+ # Process rank roles
+ to_assign.extend(_process_rank_roles(htb_user_details.get("rank", ""), guild))
+
+ # Process season rank roles
+ to_assign.extend(await _process_season_rank_roles(htb_user_details.get("id", ""), guild))
+
+ # Process VIP roles
+ to_assign.extend(_process_vip_roles(htb_user_details, guild))
+
+ # Process HOF position roles
+ to_assign.extend(_process_hof_position_roles(htb_user_details.get("ranking", "unranked"), guild))
+
+ # Process creator roles
+ to_assign.extend(_process_creator_roles(htb_user_details.get("content", {}), guild))
+
+ return to_assign
- logger.debug("All roles to_assign:", extra={"to_assign": to_assign})
- # We don't need to remove any roles that are going to be assigned again
- to_remove = list(set(to_remove) - set(to_assign))
- logger.debug("All roles to_remove:", extra={"to_remove": to_remove})
- if to_remove:
- await member.remove_roles(*to_remove, atomic=True)
- else:
- logger.debug("No roles need to be removed")
- if to_assign:
- await member.add_roles(*to_assign, atomic=True)
- else:
- logger.debug("No roles need to be assigned")
- return to_assign
+def _process_rank_roles(rank: str, guild: Guild) -> list[Role]:
+ """Process rank-based role assignments."""
+ roles = []
+
+ if rank and rank not in ["Deleted", "Moderator", "Ambassador", "Admin", "Staff"]:
+ role_id = settings.get_post_or_rank(rank)
+ if role_id:
+ role = guild.get_role(role_id)
+ if role:
+ roles.append(role)
+
+ return roles
+
+
+async def _process_season_rank_roles(mp_user_id: int, guild: Guild) -> list[Role]:
+ """Process season rank role assignments."""
+ roles = []
+ try:
+ season_rank = await get_season_rank(mp_user_id)
+ if isinstance(season_rank, str):
+ season_role_id = settings.get_season(season_rank)
+ if season_role_id:
+ season_role = guild.get_role(season_role_id)
+ if season_role:
+ roles.append(season_role)
+ except Exception as e:
+ logger.error(f"Error getting season rank for user {mp_user_id}: {e}")
+ return roles
+
+
+def _process_vip_roles(htb_user_details: dict, guild: Guild) -> list[Role]:
+ """Process VIP role assignments."""
+ roles = []
+ try:
+ if htb_user_details.get("isVip", False):
+ vip_role = guild.get_role(settings.roles.VIP)
+ if vip_role:
+ roles.append(vip_role)
+
+ if htb_user_details.get("isDedicatedVip", False):
+ vip_plus_role = guild.get_role(settings.roles.VIP_PLUS)
+ if vip_plus_role:
+ roles.append(vip_plus_role)
+ except Exception as e:
+ logger.error(f"Error processing VIP roles: {e}")
+ return roles
+
+
+def _process_hof_position_roles(htb_user_ranking: str | int, guild: Guild) -> list[Role]:
+ """Process Hall of Fame position role assignments."""
+ roles = []
+ try:
+ hof_position = htb_user_ranking or "unranked"
+ logger.debug(f"HTB user ranking: {hof_position}")
+ if hof_position != "unranked":
+ position = int(hof_position)
+ pos_top = _get_position_tier(position)
+
+ if pos_top:
+ pos_role_id = settings.get_post_or_rank(pos_top)
+ if pos_role_id:
+ pos_role = guild.get_role(pos_role_id)
+ if pos_role:
+ roles.append(pos_role)
+ except (ValueError, TypeError) as e:
+ logger.error(f"Error processing HOF position: {e}")
+ return roles
+
+
+def _get_position_tier(position: int) -> Optional[str]:
+ """Get position tier based on HOF position."""
+ if position == 1:
+ return "1"
+ elif position <= 10:
+ return "10"
+ return None
+
+
+def _process_creator_roles(htb_user_content: dict, guild: Guild) -> list[Role]:
+ """Process creator role assignments."""
+ roles = []
+ try:
+ if htb_user_content.get("machines"):
+ box_creator_role = guild.get_role(settings.roles.BOX_CREATOR)
+ if box_creator_role:
+ logger.debug("Adding box creator role to user.")
+ roles.append(box_creator_role)
+
+ if htb_user_content.get("challenges"):
+ challenge_creator_role = guild.get_role(settings.roles.CHALLENGE_CREATOR)
+ if challenge_creator_role:
+ logger.debug("Adding challenge creator role to user.")
+ roles.append(challenge_creator_role)
+ except Exception as e:
+ logger.error(f"Error processing creator roles: {e}")
+ return roles
+
+
+async def _apply_role_changes(
+ member: Member, to_remove: list[Role], to_assign: list[Role]
+) -> None:
+ """Apply role changes to member."""
+ try:
+ if to_remove:
+ await member.remove_roles(*to_remove, atomic=True)
+ except Exception as e:
+ logger.error(f"Error removing roles from user {member.id}: {e}")
+
+ try:
+ if to_assign:
+ await member.add_roles(*to_assign, atomic=True)
+ except Exception as e:
+ logger.error(f"Error adding roles to user {member.id}: {e}")
diff --git a/src/webhooks/handlers/__init__.py b/src/webhooks/handlers/__init__.py
index 02e0edf..4036462 100644
--- a/src/webhooks/handlers/__init__.py
+++ b/src/webhooks/handlers/__init__.py
@@ -1,16 +1,23 @@
from discord import Bot
+from typing import Any
-from src.webhooks.handlers.academy import handler as academy_handler
+from src.webhooks.handlers.account import AccountHandler
+from src.webhooks.handlers.academy import AcademyHandler
+from src.webhooks.handlers.mp import MPHandler
from src.webhooks.types import Platform, WebhookBody
-handlers = {Platform.ACADEMY: academy_handler}
+handlers = {
+ Platform.ACCOUNT: AccountHandler().handle,
+ Platform.MAIN: MPHandler().handle,
+ Platform.ACADEMY: AcademyHandler().handle,
+}
def can_handle(platform: Platform) -> bool:
return platform in handlers.keys()
-def handle(body: WebhookBody, bot: Bot) -> any:
+def handle(body: WebhookBody, bot: Bot) -> Any:
platform = body.platform
if not can_handle(platform):
diff --git a/src/webhooks/handlers/academy.py b/src/webhooks/handlers/academy.py
index 5aa3568..a57785c 100644
--- a/src/webhooks/handlers/academy.py
+++ b/src/webhooks/handlers/academy.py
@@ -1,76 +1,50 @@
-import logging
-
from discord import Bot
-from discord.errors import NotFound
-from fastapi import HTTPException
from src.core import settings
+from src.webhooks.handlers.base import BaseHandler
from src.webhooks.types import WebhookBody, WebhookEvent
-logger = logging.getLogger(__name__)
-
-
-async def handler(body: WebhookBody, bot: Bot) -> dict:
- """
- Handles incoming webhook events and performs actions accordingly.
-
- This function processes different webhook events related to account linking,
- certificate awarding, and account unlinking. It updates the member's roles
- based on the received event.
-
- Args:
- body (WebhookBody): The data received from the webhook.
- bot (Bot): The instance of the Discord bot.
-
- Returns:
- dict: A dictionary with a "success" key indicating whether the operation was successful.
-
- Raises:
- HTTPException: If an error occurs while processing the webhook event.
- """
- # TODO: Change it here so we pass the guild instead of the bot # noqa: T000
- guild = await bot.fetch_guild(settings.guild_ids[0])
-
- try:
- discord_id = int(body.data["discord_id"])
- member = await guild.fetch_member(discord_id)
- except ValueError as exc:
- logger.debug("Invalid Discord ID", exc_info=exc)
- raise HTTPException(status_code=400, detail="Invalid Discord ID") from exc
- except NotFound as exc:
- logger.debug("User is not in the Discord server", exc_info=exc)
- raise HTTPException(status_code=400, detail="User is not in the Discord server") from exc
-
- if body.event == WebhookEvent.ACCOUNT_LINKED:
- roles_to_add = {settings.roles.ACADEMY_USER}
- roles_to_add.update(settings.get_academy_cert_role(cert["id"]) for cert in body.data["certifications"])
-
- # Filter out invalid role IDs
- role_ids_to_add = {role_id for role_id in roles_to_add if role_id is not None}
- roles_to_add = {guild.get_role(role_id) for role_id in role_ids_to_add}
-
- await member.add_roles(*roles_to_add, atomic=True)
- elif body.event == WebhookEvent.CERTIFICATE_AWARDED:
- cert_id = body.data["certification"]["id"]
-
- role = settings.get_academy_cert_role(cert_id)
- if not role:
- logger.debug(f"Role for certification: {cert_id} does not exist")
- raise HTTPException(status_code=400, detail=f"Role for certification: {cert_id} does not exist")
-
- await member.add_roles(role, atomic=True)
- elif body.event == WebhookEvent.ACCOUNT_UNLINKED:
- current_role_ids = {role.id for role in member.roles}
- cert_role_ids = {settings.get_academy_cert_role(cert_id) for _, cert_id in settings.academy_certificates}
-
- common_role_ids = current_role_ids.intersection(cert_role_ids)
-
- role_ids_to_remove = {settings.roles.ACADEMY_USER}.union(common_role_ids)
- roles_to_remove = {guild.get_role(role_id) for role_id in role_ids_to_remove}
-
- await member.remove_roles(*roles_to_remove, atomic=True)
- else:
- logger.debug(f"Event {body.event} not implemented")
- raise HTTPException(status_code=501, detail=f"Event {body.event} not implemented")
- return {"success": True}
+class AcademyHandler(BaseHandler):
+ async def handle(self, body: WebhookBody, bot: Bot):
+ """
+ Handles incoming webhook events and performs actions accordingly.
+
+ This function processes different webhook events originating from the
+ HTB Account.
+ """
+ if body.event == WebhookEvent.CERTIFICATE_AWARDED:
+ return await self._handle_certificate_awarded(body, bot)
+ else:
+ raise ValueError(f"Invalid event: {body.event}")
+
+ async def _handle_certificate_awarded(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the certificate awarded event.
+ """
+ discord_id = self.validate_discord_id(self.get_property_or_trait(body, "discord_id"))
+ _ = self.validate_account_id(self.get_property_or_trait(body, "account_id"))
+ certificate_id = self.validate_property(
+ self.get_property_or_trait(body, "certificate_id"), "certificate_id"
+ )
+
+ self.logger.info(f"Handling certificate awarded event for {discord_id} with certificate {certificate_id}")
+
+ member = await self.get_guild_member(discord_id, bot)
+ certificate_role_id = settings.get_academy_cert_role(int(certificate_id))
+
+ if not certificate_role_id:
+ self.logger.warning(f"No certificate role found for certificate {certificate_id}")
+ return self.fail()
+
+ if certificate_role_id:
+ self.logger.info(f"Adding certificate role {certificate_role_id} to member {member.id}")
+ try:
+ await member.add_roles(
+ bot.guilds[0].get_role(certificate_role_id), atomic=True # type: ignore
+ ) # type: ignore
+ except Exception as e:
+ self.logger.error(f"Error adding certificate role {certificate_role_id} to member {member.id}: {e}")
+ raise e
+
+ return self.success()
diff --git a/src/webhooks/handlers/account.py b/src/webhooks/handlers/account.py
new file mode 100644
index 0000000..d941f8b
--- /dev/null
+++ b/src/webhooks/handlers/account.py
@@ -0,0 +1,168 @@
+from datetime import datetime
+from discord import Bot
+
+from src.core import settings
+from src.helpers.ban import handle_platform_ban_or_update
+from src.helpers.verification import process_account_identification
+from src.webhooks.handlers.base import BaseHandler
+from src.webhooks.types import WebhookBody, WebhookEvent
+
+
+class AccountHandler(BaseHandler):
+ async def handle(self, body: WebhookBody, bot: Bot):
+ """
+ Handles incoming webhook events and performs actions accordingly.
+
+ This function processes different webhook events originating from the
+ HTB Account.
+ """
+ if body.event == WebhookEvent.ACCOUNT_LINKED:
+ return await self._handle_account_linked(body, bot)
+ elif body.event == WebhookEvent.ACCOUNT_UNLINKED:
+ return await self._handle_account_unlinked(body, bot)
+ elif body.event == WebhookEvent.ACCOUNT_DELETED:
+ return await self._handle_account_deleted(body, bot)
+ elif body.event == WebhookEvent.ACCOUNT_BANNED:
+ return await self._handle_account_banned(body, bot)
+ else:
+ raise ValueError(f"Invalid event: {body.event}")
+
+ async def _handle_account_linked(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the account linked event.
+ """
+ discord_id = self.validate_discord_id(
+ self.get_property_or_trait(body, "discord_id")
+ )
+ account_id = self.validate_account_id(
+ self.get_property_or_trait(body, "account_id")
+ )
+
+ member = await self.get_guild_member(discord_id, bot)
+ await process_account_identification(
+ member,
+ bot, # type: ignore
+ traits=self.merge_properties_and_traits(body.properties, body.traits),
+ )
+
+ # Safely attempt to send verification log
+ try:
+ verify_channel = bot.guilds[0].get_channel(settings.channels.VERIFY_LOGS)
+ if verify_channel:
+ await verify_channel.send( # type: ignore
+ f"Account linked: {account_id} -> ({member.mention} ({member.id})",
+ )
+ else:
+ self.logger.warning(
+ f"Verify logs channel {settings.channels.VERIFY_LOGS} not found"
+ )
+ except Exception as e:
+ self.logger.error(f"Failed to send verification log: {e}")
+ # Don't raise - this is not critical
+
+ self.logger.info(
+ f"Account {account_id} linked to {member.id}",
+ extra={"account_id": account_id, "discord_id": discord_id},
+ )
+
+ return self.success()
+
+ async def _handle_account_unlinked(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the account unlinked event.
+ """
+ discord_id = self.validate_discord_id(
+ self.get_property_or_trait(body, "discord_id")
+ )
+ account_id = self.validate_account_id(
+ self.get_property_or_trait(body, "account_id")
+ )
+
+ member = await self.get_guild_member(discord_id, bot)
+
+ await member.remove_roles(
+ bot.guilds[0].get_role(settings.roles.VERIFIED), atomic=True # type: ignore
+ ) # type: ignore
+
+ return self.success()
+
+ async def _handle_name_change(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the name change event.
+ """
+ discord_id = self.validate_discord_id(body.properties.get("discord_id"))
+ _ = self.validate_account_id(body.properties.get("account_id"))
+ name = self.validate_property(body.properties.get("name"), "name")
+
+ member = await self.get_guild_member(discord_id, bot)
+ await member.edit(nick=name)
+ return self.success()
+
+ async def _handle_account_banned(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the account banned event.
+ """
+ discord_id = self.validate_discord_id(
+ self.get_property_or_trait(body, "discord_id")
+ )
+ account_id = self.validate_account_id(
+ self.get_property_or_trait(body, "account_id")
+ )
+ expires_at = self.validate_property(
+ self.get_property_or_trait(body, "expires_at"), "expires_at"
+ )
+ reason = body.properties.get("reason")
+ notes = body.properties.get("notes")
+ created_by = body.properties.get("created_by")
+
+ expires_ts = int(datetime.fromisoformat(expires_at).timestamp()) # type: ignore
+ extra = {"account_id": account_id, "discord_id": discord_id}
+
+ member = await self.get_guild_member(discord_id, bot)
+ if not member:
+ self.logger.warning(
+ f"Cannot ban user {discord_id}- not found in guild", extra=extra
+ )
+ return self.fail()
+
+ # Use the generic ban helper to handle all the complex logic
+ result = await handle_platform_ban_or_update(
+ bot=bot, # type: ignore
+ guild=bot.guilds[0],
+ member=member,
+ expires_timestamp=expires_ts,
+ reason=f"Platform Ban - {reason}",
+ evidence=notes or "N/A",
+ author_name=created_by or "System",
+ expires_at_str=expires_at, # type: ignore
+ log_channel_id=settings.channels.BOT_LOGS,
+ logger=self.logger,
+ extra_log_data=extra,
+ )
+
+ self.logger.debug(
+ f"Platform ban handling result: {result['action']}", extra=extra
+ )
+
+ return self.success()
+
+ async def _handle_account_deleted(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the account deleted event.
+ """
+ discord_id = self.validate_discord_id(body.properties.get("discord_id"))
+ account_id = self.validate_account_id(body.properties.get("account_id"))
+
+ member = await self.get_guild_member(discord_id, bot)
+ if not member:
+ self.logger.warning(
+ f"Cannot delete account {account_id}- not found in guild",
+ extra={"account_id": account_id, "discord_id": discord_id},
+ )
+ return self.fail()
+
+ await member.remove_roles(
+ bot.guilds[0].get_role(settings.roles.VERIFIED), atomic=True # type: ignore
+ ) # type: ignore
+
+ return self.success()
diff --git a/src/webhooks/handlers/base.py b/src/webhooks/handlers/base.py
new file mode 100644
index 0000000..b604f01
--- /dev/null
+++ b/src/webhooks/handlers/base.py
@@ -0,0 +1,128 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import TypeVar
+
+from discord import Bot, Member
+from discord.errors import NotFound
+from fastapi import HTTPException
+
+from src.core import settings
+from src.webhooks.types import WebhookBody
+
+T = TypeVar("T")
+
+
+class BaseHandler(ABC):
+ ACADEMY_USER_ID = "academy_user_id"
+ MP_USER_ID = "mp_user_id"
+ EP_USER_ID = "ep_user_id"
+ CTF_USER_ID = "ctf_user_id"
+ ACCOUNT_ID = "account_id"
+ DISCORD_ID = "discord_id"
+
+ def __init__(self):
+ self.logger = logging.getLogger(self.__class__.__name__)
+
+ @abstractmethod
+ async def handle(self, body: WebhookBody, bot: Bot) -> dict:
+ pass
+
+ async def get_guild_member(self, discord_id: int | str, bot: Bot) -> Member:
+ """
+ Fetches a guild member from the Discord server.
+
+ Args:
+ discord_id (int): The Discord ID of the user.
+ bot (Bot): The Discord bot instance.
+
+ Returns:
+ Member: The guild member.
+
+ Raises:
+ HTTPException: If the user is not in the Discord server (400)
+ """
+
+ try:
+ guild = await bot.fetch_guild(settings.guild_ids[0])
+ member = await guild.fetch_member(int(discord_id))
+ return member
+
+ except NotFound as exc:
+ self.logger.debug("User is not in the Discord server", exc_info=exc)
+ raise HTTPException(
+ status_code=400, detail="User is not in the Discord server"
+ ) from exc
+
+ def validate_property(self, property: T | None, name: str) -> T:
+ """
+ Validates a property is not None.
+
+ Args:
+ property (T | None): The property to validate.
+ name (str): The name of the property.
+
+ Returns:
+ T: The validated property.
+
+ Raises:
+ HTTPException: If the property is None (400)
+ """
+ if property is None:
+ msg = f"Invalid {name}"
+ self.logger.debug(msg)
+ raise HTTPException(status_code=400, detail=msg)
+
+ return property
+
+ def validate_discord_id(self, discord_id: str | int | None) -> int | str:
+ """
+ Validates the Discord ID. See validate_property function.
+ """
+ return self.validate_property(discord_id, "Discord ID")
+
+ def validate_account_id(self, account_id: str | int | None) -> int | str:
+ """
+ Validates the Account ID. See validate_property function.
+ """
+ return self.validate_property(account_id, "Account ID")
+
+ def get_property_or_trait(self, body: WebhookBody, name: str) -> int | None:
+ """
+ Gets a trait or property from the webhook body.
+ """
+ return body.properties.get(name) or body.traits.get(name)
+
+ def merge_properties_and_traits(
+ self, properties: dict[str, int | None], traits: dict[str, int | None]
+ ) -> dict[str, int | None]:
+ """
+ Merges the properties and traits from the webhook body without duplicates.
+ If a property and trait have the same name but different values, the property value will be used.
+ """
+ return {
+ **properties,
+ **{k: v for k, v in traits.items() if k not in properties},
+ }
+
+ def get_platform_properties(self, body: WebhookBody) -> dict[str, int | None]:
+ """
+ Gets the platform properties from the webhook body.
+ """
+ properties = {
+ self.ACCOUNT_ID: self.get_property_or_trait(body, self.ACCOUNT_ID),
+ self.MP_USER_ID: self.get_property_or_trait(body, self.MP_USER_ID),
+ self.EP_USER_ID: self.get_property_or_trait(body, self.EP_USER_ID),
+ self.CTF_USER_ID: self.get_property_or_trait(body, self.CTF_USER_ID),
+ self.ACADEMY_USER_ID: self.get_property_or_trait(
+ body, self.ACADEMY_USER_ID
+ ),
+ }
+ return properties
+
+ @staticmethod
+ def success():
+ return {"success": True}
+
+ @staticmethod
+ def fail():
+ return {"success": False}
\ No newline at end of file
diff --git a/src/webhooks/handlers/mp.py b/src/webhooks/handlers/mp.py
new file mode 100644
index 0000000..c8b0e2b
--- /dev/null
+++ b/src/webhooks/handlers/mp.py
@@ -0,0 +1,180 @@
+import discord
+
+from datetime import datetime
+from discord import Bot, Member, Role
+
+from typing import Literal
+from sqlalchemy import select
+
+from src.core import settings
+from src.webhooks.handlers.base import BaseHandler
+from src.webhooks.types import WebhookBody, WebhookEvent
+
+
+class MPHandler(BaseHandler):
+ async def handle(self, body: WebhookBody, bot: Bot):
+ """
+ Handles incoming webhook events and performs actions accordingly.
+
+ This function processes different webhook events originating from the
+ HTB Account.
+ """
+ if body.event == WebhookEvent.HOF_CHANGE:
+ return await self._handle_hof_change(body, bot)
+ elif body.event == WebhookEvent.RANK_UP:
+ return await self._handle_rank_up(body, bot)
+ elif body.event == WebhookEvent.SUBSCRIPTION_CHANGE:
+ return await self._handle_subscription_change(body, bot)
+ else:
+ raise ValueError(f"Invalid event: {body.event}")
+
+ async def _handle_subscription_change(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the subscription change event.
+ """
+ discord_id = self.validate_discord_id(body.properties.get("discord_id"))
+ _ = self.validate_account_id(body.properties.get("account_id"))
+ subscription_name = self.validate_property(
+ body.properties.get("subscription_name"), "subscription_name"
+ )
+
+ member = await self.get_guild_member(discord_id, bot)
+
+ role = settings.get_post_or_rank(subscription_name)
+ if not role:
+ raise ValueError(f"Invalid subscription name: {subscription_name}")
+
+ await member.add_roles(bot.guilds[0].get_role(role), atomic=True) # type: ignore
+ return self.success()
+
+ async def _handle_hof_change(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the HOF change event.
+ """
+ self.logger.info("Handling HOF change event.")
+ discord_id = self.validate_discord_id(self.get_property_or_trait(body, "discord_id"))
+ account_id = self.validate_account_id(self.get_property_or_trait(body, "account_id"))
+ hof_tier: Literal["1", "10"] = self.validate_property(
+ self.get_property_or_trait(body, "hof_tier"), "hof_tier" # type: ignore
+ )
+ hof_roles = {
+ "1": bot.guilds[0].get_role(settings.roles.RANK_ONE),
+ "10": bot.guilds[0].get_role(settings.roles.RANK_TEN),
+ }
+
+ member = await self.get_guild_member(discord_id, bot)
+ member_roles = member.roles
+
+ if not member:
+ msg = f"Cannot find member {discord_id}"
+ self.logger.warning(
+ msg, extra={"account_id": account_id, "discord_id": discord_id}
+ )
+ raise ValueError(msg)
+
+ async def _swap_hof_roles(member: Member, role_to_grant: Role | None):
+ """Grants a HOF role to a member and removes the other HOF role"""
+ if not role_to_grant:
+ return
+
+ member_hof_role = next(
+ (r for r in member_roles if r in hof_roles.values()), None
+ )
+ if member_hof_role:
+ await member.remove_roles(member_hof_role, atomic=True)
+ await member.add_roles(role_to_grant, atomic=True)
+
+ if int(hof_tier) == 1:
+ # Find existing top 1 user and make them a top 10
+ if existing_top_one_user := await self._find_user_with_role(
+ bot, hof_roles["1"]
+ ):
+ if existing_top_one_user.id != member.id:
+ await _swap_hof_roles(existing_top_one_user, hof_roles["10"])
+ else:
+ return self.success()
+
+ # Grant top 1 role to member
+ await _swap_hof_roles(member, hof_roles["1"])
+ return self.success()
+
+ # Just grant top 10 role to member
+ elif int(hof_tier) == 10:
+ await _swap_hof_roles(member, hof_roles["10"])
+ return self.success()
+
+ else:
+ err = ValueError(f"Invalid HOF tier: {hof_tier}")
+ self.logger.error(
+ err,
+ extra={
+ "account_id": account_id,
+ "discord_id": discord_id,
+ "hof_tier": hof_tier,
+ },
+ )
+ raise err
+
+ async def _handle_rank_up(self, body: WebhookBody, bot: Bot) -> dict:
+ """
+ Handles the rank up event.
+ """
+ discord_id = self.validate_discord_id(self.get_property_or_trait(body, "discord_id"))
+ account_id = self.validate_account_id(self.get_property_or_trait(body, "account_id"))
+ rank = self.validate_property(self.get_property_or_trait(body, "rank"), "rank")
+
+ member = await self.get_guild_member(discord_id, bot)
+
+ rank_id = settings.get_post_or_rank(rank)
+ if not rank_id:
+ err = ValueError(f"Cannot find role for '{rank}'")
+ self.logger.error(
+ err,
+ extra={
+ "account_id": account_id,
+ "discord_id": discord_id,
+ "rank": rank,
+ },
+ )
+ raise err
+
+ rank_role = bot.guilds[0].get_role(rank_id)
+ rank_roles = [
+ bot.guilds[0].get_role(int(r)) for r in settings.role_groups["ALL_RANKS"]
+ ] # All rank roles
+ new_role = next(
+ (r for r in rank_roles if r and r.id == rank_role.id), None
+ ) # Get passed rank as role from rank roles
+ old_role = next(
+ (r for r in member.roles if r in rank_roles), None
+ ) # Find existing rank role on user
+
+ if old_role:
+ await member.remove_roles(old_role, atomic=True) # Yeet the old role
+
+ if new_role:
+ await member.add_roles(new_role, atomic=True) # Add the new role
+
+ if not new_role:
+ # Why are you passing me BS roles?
+ err = ValueError(f"Cannot find role for '{rank}'")
+ self.logger.error(
+ err,
+ extra={
+ "account_id": account_id,
+ "discord_id": discord_id,
+ "rank": rank,
+ },
+ )
+ raise err
+
+ return self.success()
+
+ async def _find_user_with_role(self, bot: Bot, role: Role | None) -> Member | None:
+ """
+ Finds the user with the given role.
+ """
+ if not role:
+ return None
+
+ return next((m for m in role.members), None)
diff --git a/src/webhooks/server.py b/src/webhooks/server.py
index 92222a7..1c4d9e3 100644
--- a/src/webhooks/server.py
+++ b/src/webhooks/server.py
@@ -1,10 +1,13 @@
+import hashlib
import hmac
import logging
-from typing import Any, Dict, Union
+import json
+from typing import Any, Dict
-from fastapi import FastAPI, Header, HTTPException
+from fastapi import FastAPI, HTTPException, Request
from hypercorn.asyncio import serve as hypercorn_serve
from hypercorn.config import Config as HypercornConfig
+from pydantic import ValidationError
from src.bot import bot
from src.core import settings
@@ -17,20 +20,36 @@
app = FastAPI()
+def verify_signature(body: dict, signature: str, secret: str) -> bool:
+ """
+ HMAC SHA1 signature verification.
+
+ Args:
+ body (dict): The raw body of the webhook request.
+ signature (str): The X-Signature header of the webhook request.
+ secret (str): The webhook secret.
+
+ Returns:
+ bool: True if the signature is valid, False otherwise.
+ """
+ if not signature:
+ return False
+
+ digest = hmac.new(secret.encode(), body, hashlib.sha1).hexdigest() # type: ignore
+ return hmac.compare_digest(signature, digest)
+
+
@app.post("/webhook")
-async def webhook_handler(
- body: WebhookBody, authorization: Union[str, None] = Header(default=None)
-) -> Dict[str, Any]:
+async def webhook_handler(request: Request) -> Dict[str, Any]:
"""
Handles incoming webhook requests and forwards them to the appropriate handler.
- This function first checks the provided authorization token in the request header.
- If the token is valid, it checks if the platform can be handled and then forwards
+ This function first verifies the provided HMAC signature in the request header.
+ If the signature is valid, it checks if the platform can be handled and then forwards
the request to the corresponding handler.
Args:
- body (WebhookBody): The data received from the webhook.
- authorization (Union[str, None]): The authorization header containing the Bearer token.
+ request (Request): The incoming webhook request.
Returns:
Dict[str, Any]: The response from the corresponding handler. The dictionary contains
@@ -39,20 +58,42 @@ async def webhook_handler(
Raises:
HTTPException: If an error occurs while processing the webhook event or if unauthorized.
"""
- if authorization is None or not authorization.strip().startswith("Bearer"):
- logger.warning("Unauthorized webhook request")
- raise HTTPException(status_code=401, detail="Unauthorized")
+ body = await request.body()
+ signature = request.headers.get("X-Signature")
- token = authorization[6:].strip()
- if hmac.compare_digest(token, settings.WEBHOOK_TOKEN):
+ if not verify_signature(body, signature, settings.WEBHOOK_TOKEN): # type: ignore
logger.warning("Unauthorized webhook request")
raise HTTPException(status_code=401, detail="Unauthorized")
+ try:
+ body = WebhookBody.validate(json.loads(body))
+ except ValidationError as e:
+ logger.warning("Invalid webhook request: %s", e.errors())
+ raise HTTPException(status_code=400, detail="Invalid webhook request body")
+
if not handlers.can_handle(body.platform):
- logger.warning("Webhook request not handled by platform")
+ logger.warning("Webhook request not handled by platform: %s", body.platform)
raise HTTPException(status_code=501, detail="Platform not implemented")
- return await handlers.handle(body, bot)
+ try:
+ return await handlers.handle(body, bot)
+ except HTTPException:
+ # Re-raise HTTP exceptions as they already have appropriate status codes
+ raise
+ except Exception as e:
+ # Log the full exception details for debugging
+ logger.error(
+ "Unhandled exception in webhook handler",
+ exc_info=e,
+ extra={
+ "platform": body.platform,
+ "event": body.event,
+ "properties": body.properties,
+ "traits": body.traits,
+ }
+ )
+ # Return a generic 500 error to the client
+ raise HTTPException(status_code=500, detail="Internal server error")
app.mount("/metrics", metrics_app)
diff --git a/src/webhooks/types.py b/src/webhooks/types.py
index 3885dda..2fabb3b 100644
--- a/src/webhooks/types.py
+++ b/src/webhooks/types.py
@@ -1,17 +1,19 @@
from enum import Enum
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict, Extra, Field
class WebhookEvent(Enum):
- ACCOUNT_LINKED = "AccountLinked"
- ACCOUNT_UNLINKED = "AccountUnlinked"
+ ACCOUNT_LINKED = "DiscordAccountLinked"
+ ACCOUNT_UNLINKED = "DiscordAccountUnlinked"
+ ACCOUNT_DELETED = "UserAccountDeleted"
+ ACCOUNT_BANNED = "UserAccountBanned"
CERTIFICATE_AWARDED = "CertificateAwarded"
RANK_UP = "RankUp"
HOF_CHANGE = "HofChange"
SUBSCRIPTION_CHANGE = "SubscriptionChange"
- CONTENT_RELEASED = "ContentReleased"
NAME_CHANGE = "NameChange"
+ SEASON_RANK_CHANGE = "SeasonRankChange"
class Platform(Enum):
@@ -19,9 +21,13 @@ class Platform(Enum):
ACADEMY = "academy"
CTF = "ctf"
ENTERPRISE = "enterprise"
+ ACCOUNT = "account"
class WebhookBody(BaseModel):
+ model_config = ConfigDict(extra=Extra.allow)
+
platform: Platform
event: WebhookEvent
- data: dict
+ properties: dict = Field(default_factory=dict)
+ traits: dict = Field(default_factory=dict)
diff --git a/tests/src/helpers/test_ban.py b/tests/src/helpers/test_ban.py
index 99d3bdc..1771928 100644
--- a/tests/src/helpers/test_ban.py
+++ b/tests/src/helpers/test_ban.py
@@ -1,8 +1,10 @@
+from datetime import datetime, timezone
from unittest import mock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from discord import Forbidden, HTTPException
+from datetime import datetime, timezone
from src.helpers.ban import _check_member, _dm_banned_member, ban_member
from src.helpers.responses import SimpleResponse
@@ -116,11 +118,15 @@ async def test_ban_member_valid_duration(self, bot, guild, member, author):
evidence = "Some evidence"
member.display_name = "Banned Member"
+ # Use a future timestamp instead of a past one
+ future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now
+ expected_date = datetime.fromtimestamp(future_timestamp, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
+
with (
mock.patch("src.helpers.ban._check_member", return_value=None),
mock.patch("src.helpers.ban._dm_banned_member", return_value=True),
mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)),
- mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")),
+ mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")),
):
mock_channel = helpers.MockTextChannel()
mock_channel.send.return_value = MagicMock()
@@ -128,7 +134,7 @@ async def test_ban_member_valid_duration(self, bot, guild, member, author):
result = await ban_member(bot, guild, member, duration, reason, evidence)
assert isinstance(result, SimpleResponse)
- assert result.message == f"{member.display_name} ({member.id}) has been banned until 2023-05-16 22:41:40 " \
+ assert result.message == f"{member.display_name} ({member.id}) has been banned until {expected_date} " \
f"(UTC)."
@pytest.mark.asyncio
@@ -153,12 +159,15 @@ async def test_ban_member_permanently_success(self, bot, guild, member, author):
evidence = "Some evidence"
member.display_name = "Banned Member"
+ # Use a future timestamp instead of a past one
+ future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now
+
# Patching the necessary classes and functions
with (
mock.patch("src.helpers.ban._check_member", return_value=None),
mock.patch("src.helpers.ban._dm_banned_member", return_value=True),
mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)),
- mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")),
+ mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")),
):
response = await ban_member(bot, guild, member, duration, reason, evidence, author, False)
assert isinstance(response, SimpleResponse)
@@ -171,12 +180,15 @@ async def test_ban_member_no_reason_success(self, bot, guild, member, author):
evidence = "Some evidence"
member.display_name = "Banned Member"
+ # Use a future timestamp instead of a past one
+ future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now
+
# Patching the necessary classes and functions
with (
mock.patch("src.helpers.ban._check_member", return_value=None),
mock.patch("src.helpers.ban._dm_banned_member", return_value=True),
mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)),
- mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")),
+ mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")),
):
response = await ban_member(bot, guild, member, duration, reason, evidence, author, False)
assert isinstance(response, SimpleResponse)
@@ -189,11 +201,14 @@ async def test_ban_member_no_author_success(self, bot, guild, member):
evidence = "Some evidence"
member.display_name = "Banned Member"
+ # Use a future timestamp instead of a past one
+ future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now
+
with (
mock.patch("src.helpers.ban._check_member", return_value=None),
mock.patch("src.helpers.ban._dm_banned_member", return_value=True),
mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, False)),
- mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")),
+ mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")),
):
response = await ban_member(bot, guild, member, duration, reason, evidence, None, False)
assert isinstance(response, SimpleResponse)
@@ -206,11 +221,14 @@ async def test_ban_already_exists(self, bot, guild, member, author):
evidence = "Some evidence"
member.display_name = "Banned Member"
+ # Use a future timestamp instead of a past one
+ future_timestamp = int((datetime.now(tz=timezone.utc).timestamp() + 86400)) # 1 day from now
+
with (
mock.patch("src.helpers.ban._check_member", return_value=None),
mock.patch("src.helpers.ban._dm_banned_member", return_value=True),
mock.patch("src.helpers.ban._get_ban_or_create", return_value=(1, True)),
- mock.patch("src.helpers.ban.validate_duration", return_value=(1684276900, "")),
+ mock.patch("src.helpers.ban.validate_duration", return_value=(future_timestamp, "")),
):
response = await ban_member(bot, guild, member, duration, reason, evidence, author)
assert isinstance(response, SimpleResponse)
diff --git a/tests/src/helpers/test_verification.py b/tests/src/helpers/test_verification.py
index 3ebddf0..1c949da 100644
--- a/tests/src/helpers/test_verification.py
+++ b/tests/src/helpers/test_verification.py
@@ -14,37 +14,60 @@ async def test_get_user_details_success(self):
account_identifier = "some_identifier"
with aioresponses.aioresponses() as m:
+ # Mock the profile API call
m.get(
- f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}",
+ f"{settings.API_V4_URL}/user/profile/basic/{account_identifier}",
status=200,
- payload={"some_key": "some_value"},
+ payload={"profile": {"some_key": "some_value"}},
+ )
+ # Mock the content API call
+ m.get(
+ f"{settings.API_V4_URL}/user/profile/content/{account_identifier}",
+ status=200,
+ payload={"profile": {"content": {"content_key": "content_value"}}},
)
result = await get_user_details(account_identifier)
- self.assertEqual(result, {"some_key": "some_value"})
+ expected = {
+ "some_key": "some_value",
+ "content": {"content_key": "content_value"}
+ }
+ self.assertEqual(result, expected)
@pytest.mark.asyncio
async def test_get_user_details_404(self):
account_identifier = "some_identifier"
with aioresponses.aioresponses() as m:
+ # Mock the profile API call with404
m.get(
- f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}",
+ f"{settings.API_V4_URL}/user/profile/basic/{account_identifier}",
+ status=404,
+ )
+ # Mock the content API call with404
+ m.get(
+ f"{settings.API_V4_URL}/user/profile/content/{account_identifier}",
status=404,
)
result = await get_user_details(account_identifier)
- self.assertIsNone(result)
+ self.assertEqual(result, {"content": {}})
@pytest.mark.asyncio
async def test_get_user_details_other_status(self):
account_identifier = "some_identifier"
with aioresponses.aioresponses() as m:
+ # Mock the profile API call with500
+ m.get(
+ f"{settings.API_V4_URL}/user/profile/basic/{account_identifier}",
+ status=500,
+ )
+ # Mock the content API call with500
m.get(
- f"{settings.API_URL}/discord/identifier/{account_identifier}?secret={settings.HTB_API_SECRET}",
+ f"{settings.API_V4_URL}/user/profile/content/{account_identifier}",
status=500,
)
result = await get_user_details(account_identifier)
- self.assertIsNone(result)
+ self.assertEqual(result, {"content": {}})
diff --git a/tests/src/webhooks/handlers/test_academy.py b/tests/src/webhooks/handlers/test_academy.py
new file mode 100644
index 0000000..2c76af4
--- /dev/null
+++ b/tests/src/webhooks/handlers/test_academy.py
@@ -0,0 +1,119 @@
+import pytest
+from unittest.mock import AsyncMock, patch
+from fastapi import HTTPException
+
+from src.webhooks.handlers.academy import AcademyHandler
+from src.webhooks.types import WebhookBody, Platform, WebhookEvent
+from tests import helpers
+
+class TestAcademyHandler:
+ @pytest.mark.asyncio
+ async def test_handle_certificate_awarded_success(self, bot):
+ handler = AcademyHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ certificate_id = 42
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.add_roles = AsyncMock()
+ body = WebhookBody(
+ platform=Platform.ACADEMY,
+ event=WebhookEvent.CERTIFICATE_AWARDED,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "certificate_id": certificate_id,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=certificate_id),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.academy.settings") as mock_settings,
+ patch.object(handler.logger, "info") as mock_log,
+ ):
+ mock_settings.get_academy_cert_role.return_value = 555
+ mock_guild = helpers.MockGuild(id=1)
+ mock_guild.get_role.return_value = 555
+ bot.guilds = [mock_guild]
+ result = await handler._handle_certificate_awarded(body, bot)
+ mock_member.add_roles.assert_awaited()
+ mock_log.assert_called()
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_certificate_awarded_no_role(self, bot):
+ handler = AcademyHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ certificate_id = 42
+ mock_member = helpers.MockMember(id=discord_id)
+ body = WebhookBody(
+ platform=Platform.ACADEMY,
+ event=WebhookEvent.CERTIFICATE_AWARDED,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "certificate_id": certificate_id,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=certificate_id),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.academy.settings") as mock_settings,
+ patch.object(handler.logger, "warning") as mock_log,
+ ):
+ mock_settings.get_academy_cert_role.return_value = None
+ result = await handler._handle_certificate_awarded(body, bot)
+ mock_log.assert_called()
+ assert result == handler.fail()
+
+ @pytest.mark.asyncio
+ async def test_handle_certificate_awarded_add_roles_error(self, bot):
+ handler = AcademyHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ certificate_id = 42
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.add_roles = AsyncMock(side_effect=Exception("add_roles error"))
+ body = WebhookBody(
+ platform=Platform.ACADEMY,
+ event=WebhookEvent.CERTIFICATE_AWARDED,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "certificate_id": certificate_id,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=certificate_id),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.academy.settings") as mock_settings,
+ patch.object(handler.logger, "error") as mock_log,
+ ):
+ mock_settings.get_academy_cert_role.return_value = 555
+ mock_guild = helpers.MockGuild(id=1)
+ mock_guild.get_role.return_value = 555
+ bot.guilds = [mock_guild]
+ with pytest.raises(Exception, match="add_roles error"):
+ await handler._handle_certificate_awarded(body, bot)
+ mock_log.assert_called()
+
+ @pytest.mark.asyncio
+ async def test_handle_invalid_event(self, bot):
+ handler = AcademyHandler()
+ body = WebhookBody(
+ platform=Platform.ACADEMY,
+ event=WebhookEvent.RANK_UP, # Not handled by AcademyHandler
+ properties={},
+ traits={},
+ )
+ with pytest.raises(ValueError, match="Invalid event"):
+ await handler.handle(body, bot)
\ No newline at end of file
diff --git a/tests/src/webhooks/handlers/test_account.py b/tests/src/webhooks/handlers/test_account.py
new file mode 100644
index 0000000..e18cf7e
--- /dev/null
+++ b/tests/src/webhooks/handlers/test_account.py
@@ -0,0 +1,510 @@
+import logging
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from discord import Bot, Member
+from discord.errors import NotFound
+from fastapi import HTTPException
+
+from src.webhooks.handlers.account import AccountHandler
+from src.webhooks.types import WebhookBody, Platform, WebhookEvent
+from tests import helpers
+
+
+class TestAccountHandler:
+ """Test the `AccountHandler` class."""
+
+ def test_initialization(self):
+ """Test that AccountHandler initializes correctly."""
+ handler = AccountHandler()
+
+ assert isinstance(handler.logger, logging.Logger)
+ assert handler.logger.name == "AccountHandler"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_event(self, bot):
+ """Test handle method routes ACCOUNT_LINKED event correctly."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": 123456789, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler, "_handle_account_linked", new_callable=AsyncMock
+ ) as mock_handle:
+ await handler.handle(body, bot)
+ mock_handle.assert_called_once_with(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_event(self, bot):
+ """Test handle method routes ACCOUNT_UNLINKED event correctly."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": 123456789, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler, "_handle_account_unlinked", new_callable=AsyncMock
+ ) as mock_handle:
+ await handler.handle(body, bot)
+ mock_handle.assert_called_once_with(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_account_deleted_event(self, bot):
+ """Test handle method with ACCOUNT_DELETED event."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ mock_member = helpers.MockMember(id=discord_id)
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_DELETED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.account.settings") as mock_settings,
+ ):
+ mock_settings.roles.VERIFIED = helpers.MockRole(id=99999, name="Verified")
+ mock_member.remove_roles = AsyncMock()
+
+ result = await handler.handle(body, bot)
+
+ # Should succeed and return success
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_unknown_event(self, bot):
+ """Test handle method with unknown event raises ValueError."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.CERTIFICATE_AWARDED, # Not handled by AccountHandler
+ properties={"discord_id": 123456789, "account_id": 987654321},
+ traits={},
+ )
+
+ # Should raise ValueError for unknown event
+ with pytest.raises(ValueError, match="Invalid event"):
+ await handler.handle(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_success(self, bot):
+ """Test successful account linking."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ mock_member = helpers.MockMember(id=discord_id, mention="@testuser")
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={"htb_user_id": 555},
+ )
+
+ with (
+ patch.object(
+ handler, "validate_discord_id", return_value=discord_id
+ ) as mock_validate_discord,
+ patch.object(
+ handler, "validate_account_id", return_value=account_id
+ ) as mock_validate_account,
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ return_value=mock_member,
+ ) as mock_get_member,
+ patch.object(
+ handler,
+ "merge_properties_and_traits",
+ return_value={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "htb_user_id": 555,
+ },
+ ) as mock_merge,
+ patch(
+ "src.webhooks.handlers.account.process_account_identification",
+ new_callable=AsyncMock,
+ ) as mock_process,
+ patch("src.webhooks.handlers.account.settings") as mock_settings,
+ patch.object(handler.logger, "info") as mock_log,
+ ):
+ mock_settings.channels.VERIFY_LOGS = 12345
+
+ # Mock the bot's guild structure and channel
+ mock_channel = MagicMock()
+ mock_channel.send = AsyncMock()
+ mock_guild = MagicMock()
+ mock_guild.get_channel.return_value = mock_channel
+ bot.guilds = [mock_guild]
+
+ result = await handler._handle_account_linked(body, bot)
+
+ # Verify all method calls
+ mock_validate_discord.assert_called_once_with(discord_id)
+ mock_validate_account.assert_called_once_with(account_id)
+ mock_get_member.assert_called_once_with(discord_id, bot)
+ mock_merge.assert_called_once_with(body.properties, body.traits)
+ mock_process.assert_called_once_with(
+ mock_member,
+ bot,
+ traits={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "htb_user_id": 555,
+ },
+ )
+ mock_channel.send.assert_called_once_with(
+ f"Account linked: {account_id} -> (@testuser ({discord_id})"
+ )
+ mock_log.assert_called_once_with(
+ f"Account {account_id} linked to {discord_id}",
+ extra={"account_id": account_id, "discord_id": discord_id},
+ )
+
+ # Should return success
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_invalid_discord_id(self, bot):
+ """Test account linking with invalid Discord ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": None, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler,
+ "validate_discord_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler._handle_account_linked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Discord ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_invalid_account_id(self, bot):
+ """Test account linking with invalid Account ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": 123456789, "account_id": None},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=123456789),
+ patch.object(
+ handler,
+ "validate_account_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Account ID"),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler._handle_account_linked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Account ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_linked_user_not_in_guild(self, bot):
+ """Test account linking when user is not in the Discord guild."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ side_effect=HTTPException(
+ status_code=400, detail="User is not in the Discord server"
+ ),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler._handle_account_linked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "User is not in the Discord server"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_success(self, bot):
+ """Test successful account unlinking."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ mock_member = helpers.MockMember(id=discord_id)
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={},
+ )
+
+ with (
+ patch.object(
+ handler, "validate_discord_id", return_value=discord_id
+ ) as mock_validate_discord,
+ patch.object(
+ handler, "validate_account_id", return_value=account_id
+ ) as mock_validate_account,
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ return_value=mock_member,
+ ) as mock_get_member,
+ patch("src.webhooks.handlers.account.settings") as mock_settings,
+ ):
+ # Mock the bot's guild structure and role
+ mock_role = helpers.MockRole(id=99999, name="Verified")
+ mock_guild = MagicMock()
+ mock_guild.get_role.return_value = mock_role
+ bot.guilds = [mock_guild]
+ mock_settings.roles.VERIFIED = 99999
+ mock_member.remove_roles = AsyncMock()
+
+ result = await handler._handle_account_unlinked(body, bot)
+
+ # Verify all method calls
+ mock_validate_discord.assert_called_once_with(discord_id)
+ mock_validate_account.assert_called_once_with(account_id)
+ mock_get_member.assert_called_once_with(discord_id, bot)
+ mock_member.remove_roles.assert_called_once_with(
+ mock_role, atomic=True
+ )
+
+ # Should return success
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_invalid_discord_id(self, bot):
+ """Test account unlinking with invalid Discord ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": None, "account_id": 987654321},
+ traits={},
+ )
+
+ with patch.object(
+ handler,
+ "validate_discord_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler._handle_account_unlinked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Discord ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_invalid_account_id(self, bot):
+ """Test account unlinking with invalid Account ID."""
+ handler = AccountHandler()
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": 123456789, "account_id": None},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=123456789),
+ patch.object(
+ handler,
+ "validate_account_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Account ID"),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler._handle_account_unlinked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Account ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_unlinked_user_not_in_guild(self, bot):
+ """Test account unlinking when user is not in the Discord guild."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_UNLINKED,
+ properties={"discord_id": discord_id, "account_id": account_id},
+ traits={},
+ )
+
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(
+ handler,
+ "get_guild_member",
+ new_callable=AsyncMock,
+ side_effect=HTTPException(
+ status_code=400, detail="User is not in the Discord server"
+ ),
+ ),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler._handle_account_unlinked(body, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "User is not in the Discord server"
+
+ @pytest.mark.asyncio
+ async def test_handle_name_change_success(self, bot):
+ """Test successful name change event."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ new_name = "NewNickname"
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.edit = AsyncMock()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.NAME_CHANGE,
+ properties={"discord_id": discord_id, "account_id": account_id, "name": new_name},
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=new_name),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ ):
+ result = await handler._handle_name_change(body, bot)
+ mock_member.edit.assert_called_once_with(nick=new_name)
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_name_change_invalid_discord_id(self, bot):
+ """Test name change event with invalid Discord ID."""
+ handler = AccountHandler()
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.NAME_CHANGE,
+ properties={"discord_id": None, "account_id": 987654321, "name": "NewNickname"},
+ traits={},
+ )
+ with patch.object(
+ handler,
+ "validate_discord_id",
+ side_effect=HTTPException(status_code=400, detail="Invalid Discord ID"),
+ ):
+ with pytest.raises(HTTPException) as exc_info:
+ await handler._handle_name_change(body, bot)
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Discord ID"
+
+ @pytest.mark.asyncio
+ async def test_handle_account_banned_success(self, bot):
+ """Test successful account banned event."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ expires_at = "2024-12-31T23:59:59"
+ reason = "Violation"
+ notes = "Repeated violations"
+ created_by = "Admin"
+ mock_member = helpers.MockMember(id=discord_id)
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_BANNED,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "expires_at": expires_at,
+ "reason": reason,
+ "notes": notes,
+ "created_by": created_by,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=expires_at),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.account.handle_platform_ban_or_update", new_callable=AsyncMock) as mock_ban,
+ patch("src.webhooks.handlers.account.settings") as mock_settings,
+ patch.object(handler.logger, "debug") as mock_log,
+ ):
+ mock_ban.return_value = {"action": "banned"}
+ mock_settings.channels.BOT_LOGS = 12345
+ mock_settings.channels.VERIFY_LOGS = 54321
+ mock_settings.roles.VERIFIED = 99999
+ mock_settings.guild_ids = [1]
+ bot.guilds = [helpers.MockGuild(id=1)]
+ result = await handler._handle_account_banned(body, bot)
+ mock_ban.assert_awaited()
+ mock_log.assert_called()
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_account_banned_member_not_found(self, bot):
+ """Test account banned event when member is not found in guild."""
+ handler = AccountHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ expires_at = "2024-12-31T23:59:59"
+ body = WebhookBody(
+ platform=Platform.ACCOUNT,
+ event=WebhookEvent.ACCOUNT_BANNED,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "expires_at": expires_at,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=expires_at),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=None),
+ patch.object(handler.logger, "warning") as mock_log,
+ ):
+ result = await handler._handle_account_banned(body, bot)
+ mock_log.assert_called()
+ assert result == handler.fail()
diff --git a/tests/src/webhooks/handlers/test_base.py b/tests/src/webhooks/handlers/test_base.py
new file mode 100644
index 0000000..ba774cf
--- /dev/null
+++ b/tests/src/webhooks/handlers/test_base.py
@@ -0,0 +1,315 @@
+import logging
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from discord import Bot
+from discord.errors import NotFound
+from fastapi import HTTPException
+
+from src.webhooks.handlers.base import BaseHandler
+from src.webhooks.types import WebhookBody, Platform, WebhookEvent
+from tests import helpers
+
+
+class ConcreteHandler(BaseHandler):
+ """Concrete implementation of BaseHandler for testing purposes."""
+
+ async def handle(self, body: WebhookBody, bot: Bot) -> dict:
+ return {"status": "handled"}
+
+
+class TestBaseHandler:
+ """Test the `BaseHandler` class."""
+
+ def test_initialization(self):
+ """Test that BaseHandler initializes correctly."""
+ handler = ConcreteHandler()
+
+ assert isinstance(handler.logger, logging.Logger)
+ assert handler.logger.name == "ConcreteHandler"
+
+ def test_constants(self):
+ """Test that all required constants are defined."""
+ handler = ConcreteHandler()
+
+ assert handler.ACADEMY_USER_ID == "academy_user_id"
+ assert handler.MP_USER_ID == "mp_user_id"
+ assert handler.EP_USER_ID == "ep_user_id"
+ assert handler.CTF_USER_ID == "ctf_user_id"
+ assert handler.ACCOUNT_ID == "account_id"
+ assert handler.DISCORD_ID == "discord_id"
+
+ @pytest.mark.asyncio
+ async def test_get_guild_member_success(self, bot):
+ """Test successful guild member retrieval."""
+ handler = ConcreteHandler()
+ discord_id = 123456789
+ mock_guild = helpers.MockGuild(id=12345)
+ mock_member = helpers.MockMember(id=discord_id)
+
+ bot.fetch_guild = AsyncMock(return_value=mock_guild)
+ mock_guild.fetch_member = AsyncMock(return_value=mock_member)
+
+ with patch("src.webhooks.handlers.base.settings") as mock_settings:
+ mock_settings.guild_ids = [12345]
+
+ result = await handler.get_guild_member(discord_id, bot)
+
+ assert result == mock_member
+ bot.fetch_guild.assert_called_once_with(12345)
+ mock_guild.fetch_member.assert_called_once_with(discord_id)
+
+ @pytest.mark.asyncio
+ async def test_get_guild_member_not_found(self, bot):
+ """Test guild member retrieval when user is not in server."""
+ handler = ConcreteHandler()
+ discord_id = 123456789
+ mock_guild = helpers.MockGuild(id=12345)
+
+ bot.fetch_guild = AsyncMock(return_value=mock_guild)
+ mock_guild.fetch_member = AsyncMock(
+ side_effect=NotFound(MagicMock(), "User not found")
+ )
+
+ with patch("src.webhooks.handlers.base.settings") as mock_settings:
+ mock_settings.guild_ids = [12345]
+
+ with pytest.raises(HTTPException) as exc_info:
+ await handler.get_guild_member(discord_id, bot)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "User is not in the Discord server"
+
+ def test_validate_property_success(self):
+ """Test successful property validation."""
+ handler = ConcreteHandler()
+
+ result = handler.validate_property("valid_value", "test_property")
+ assert result == "valid_value"
+
+ result = handler.validate_property(123, "test_number")
+ assert result == 123
+
+ def test_validate_property_none(self):
+ """Test property validation with None value."""
+ handler = ConcreteHandler()
+
+ with pytest.raises(HTTPException) as exc_info:
+ handler.validate_property(None, "test_property")
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid test_property"
+
+ def test_validate_discord_id_success(self):
+ """Test successful Discord ID validation."""
+ handler = ConcreteHandler()
+
+ result = handler.validate_discord_id(123456789)
+ assert result == 123456789
+
+ result = handler.validate_discord_id("987654321")
+ assert result == "987654321"
+
+ def test_validate_discord_id_none(self):
+ """Test Discord ID validation with None value."""
+ handler = ConcreteHandler()
+
+ with pytest.raises(HTTPException) as exc_info:
+ handler.validate_discord_id(None)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Discord ID"
+
+ def test_validate_account_id_success(self):
+ """Test successful Account ID validation."""
+ handler = ConcreteHandler()
+
+ result = handler.validate_account_id(123456789)
+ assert result == 123456789
+
+ result = handler.validate_account_id("987654321")
+ assert result == "987654321"
+
+ def test_validate_account_id_none(self):
+ """Test Account ID validation with None value."""
+ handler = ConcreteHandler()
+
+ with pytest.raises(HTTPException) as exc_info:
+ handler.validate_account_id(None)
+
+ assert exc_info.value.status_code == 400
+ assert exc_info.value.detail == "Invalid Account ID"
+
+ def test_get_property_or_trait_from_properties(self):
+ """Test getting value from properties."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"test_key": 123},
+ traits={"test_key": 456, "other_key": 789},
+ )
+
+ result = handler.get_property_or_trait(body, "test_key")
+ assert result == 123 # Should prioritize properties over traits
+
+ def test_get_property_or_trait_from_traits(self):
+ """Test getting value from traits when not in properties."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={},
+ traits={"test_key": 456},
+ )
+
+ result = handler.get_property_or_trait(body, "test_key")
+ assert result == 456
+
+ def test_get_property_or_trait_not_found(self):
+ """Test getting value when key is not found."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={},
+ traits={},
+ )
+
+ result = handler.get_property_or_trait(body, "missing_key")
+ assert result is None
+
+ def test_merge_properties_and_traits_no_duplicates(self):
+ """Test merging properties and traits without duplicates."""
+ handler = ConcreteHandler()
+ properties = {"key1": 1, "key2": 2}
+ traits = {"key3": 3, "key4": 4}
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ expected = {"key1": 1, "key2": 2, "key3": 3, "key4": 4}
+ assert result == expected
+
+ def test_merge_properties_and_traits_with_duplicates(self):
+ """Test merging properties and traits with duplicate keys."""
+ handler = ConcreteHandler()
+ properties = {"key1": 1, "key2": 2}
+ traits = {"key2": 99, "key3": 3} # key2 is duplicate
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ expected = {"key1": 1, "key2": 2, "key3": 3} # Properties value should win
+ assert result == expected
+
+ def test_merge_properties_and_traits_empty_properties(self):
+ """Test merging when properties is empty."""
+ handler = ConcreteHandler()
+ properties = {}
+ traits = {"key1": 1, "key2": 2}
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ assert result == traits
+
+ def test_merge_properties_and_traits_empty_traits(self):
+ """Test merging when traits is empty."""
+ handler = ConcreteHandler()
+ properties = {"key1": 1, "key2": 2}
+ traits = {}
+
+ result = handler.merge_properties_and_traits(properties, traits)
+
+ assert result == properties
+
+ def test_get_platform_properties_all_present(self):
+ """Test getting platform properties when all are present."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": 3,
+ "ctf_user_id": 4,
+ "academy_user_id": 5,
+ },
+ traits={},
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": 3,
+ "ctf_user_id": 4,
+ "academy_user_id": 5,
+ }
+ assert result == expected
+
+ def test_get_platform_properties_mixed_sources(self):
+ """Test getting platform properties from both properties and traits."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"account_id": 1, "mp_user_id": 2},
+ traits={"ep_user_id": 3, "ctf_user_id": 4, "academy_user_id": 5},
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": 3,
+ "ctf_user_id": 4,
+ "academy_user_id": 5,
+ }
+ assert result == expected
+
+ def test_get_platform_properties_missing_values(self):
+ """Test getting platform properties when some are missing."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"account_id": 1},
+ traits={"mp_user_id": 2},
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2,
+ "ep_user_id": None,
+ "ctf_user_id": None,
+ "academy_user_id": None,
+ }
+ assert result == expected
+
+ def test_get_platform_properties_properties_override_traits(self):
+ """Test that properties override traits for the same key."""
+ handler = ConcreteHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.ACCOUNT_LINKED,
+ properties={"account_id": 1, "mp_user_id": 2},
+ traits={
+ "mp_user_id": 999, # Should be overridden
+ "ep_user_id": 3,
+ },
+ )
+
+ result = handler.get_platform_properties(body)
+
+ expected = {
+ "account_id": 1,
+ "mp_user_id": 2, # Properties value should win
+ "ep_user_id": 3,
+ "ctf_user_id": None,
+ "academy_user_id": None,
+ }
+ assert result == expected
diff --git a/tests/src/webhooks/handlers/test_mp.py b/tests/src/webhooks/handlers/test_mp.py
new file mode 100644
index 0000000..3454d1b
--- /dev/null
+++ b/tests/src/webhooks/handlers/test_mp.py
@@ -0,0 +1,223 @@
+import pytest
+from unittest.mock import AsyncMock, patch, MagicMock
+from fastapi import HTTPException
+
+from src.webhooks.handlers.mp import MPHandler
+from src.webhooks.types import WebhookBody, Platform, WebhookEvent
+from tests import helpers
+
+class TestMPHandler:
+ @pytest.mark.asyncio
+ async def test_handle_invalid_event(self, bot):
+ handler = MPHandler()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.CERTIFICATE_AWARDED, # Not handled by MPHandler
+ properties={},
+ traits={},
+ )
+ with pytest.raises(ValueError, match="Invalid event"):
+ await handler.handle(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_subscription_change_success(self, bot):
+ handler = MPHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ subscription_name = "VIP"
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.add_roles = AsyncMock()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.SUBSCRIPTION_CHANGE,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "subscription_name": subscription_name,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=subscription_name),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.mp.settings") as mock_settings,
+ ):
+ mock_settings.get_post_or_rank.return_value = 555
+ mock_guild = helpers.MockGuild(id=1)
+ mock_guild.get_role.return_value = 555
+ bot.guilds = [mock_guild]
+ result = await handler._handle_subscription_change(body, bot)
+ mock_member.add_roles.assert_awaited()
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_subscription_change_invalid_role(self, bot):
+ handler = MPHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ subscription_name = "INVALID"
+ mock_member = helpers.MockMember(id=discord_id)
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.SUBSCRIPTION_CHANGE,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "subscription_name": subscription_name,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=subscription_name),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.mp.settings") as mock_settings,
+ ):
+ mock_settings.get_post_or_rank.return_value = None
+ with pytest.raises(ValueError, match="Invalid subscription name"):
+ await handler._handle_subscription_change(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_hof_change_success_top1(self, bot):
+ handler = MPHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ hof_tier = "1"
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.roles = []
+ mock_member.add_roles = AsyncMock()
+ mock_member.remove_roles = AsyncMock()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.HOF_CHANGE,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "hof_tier": hof_tier,
+ },
+ traits={},
+ )
+ mock_role_1 = MagicMock()
+ mock_role_10 = MagicMock()
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=hof_tier),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.mp.settings") as mock_settings,
+ patch.object(handler, "_find_user_with_role", new_callable=AsyncMock, return_value=None),
+ ):
+ mock_settings.roles.RANK_ONE = 1
+ mock_settings.roles.RANK_TEN = 10
+ mock_guild = helpers.MockGuild(id=1)
+ mock_guild.get_role.side_effect = lambda rid: mock_role_1 if rid == 1 else mock_role_10
+ bot.guilds = [mock_guild]
+ result = await handler._handle_hof_change(body, bot)
+ mock_member.add_roles.assert_awaited_with(mock_role_1, atomic=True)
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_hof_change_invalid_tier(self, bot):
+ handler = MPHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ hof_tier = "99"
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.roles = []
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.HOF_CHANGE,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "hof_tier": hof_tier,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=hof_tier),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.mp.settings") as mock_settings,
+ ):
+ mock_settings.roles.RANK_ONE = 1
+ mock_settings.roles.RANK_TEN = 10
+ mock_guild = helpers.MockGuild(id=1)
+ mock_guild.get_role.side_effect = lambda rid: None
+ bot.guilds = [mock_guild]
+ with pytest.raises(ValueError, match="Invalid HOF tier"):
+ await handler._handle_hof_change(body, bot)
+
+ @pytest.mark.asyncio
+ async def test_handle_rank_up_success(self, bot):
+ handler = MPHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ rank = "Elite Hacker"
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.roles = []
+ mock_member.add_roles = AsyncMock()
+ mock_member.remove_roles = AsyncMock()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.RANK_UP,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "rank": rank,
+ },
+ traits={},
+ )
+ mock_role = MagicMock()
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=rank),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.mp.settings") as mock_settings,
+ ):
+ mock_settings.role_groups = {"ALL_RANKS": [555]}
+ mock_guild = helpers.MockGuild(id=1)
+ mock_guild.get_role.return_value = mock_role
+ bot.guilds = [mock_guild]
+ result = await handler._handle_rank_up(body, bot)
+ mock_member.add_roles.assert_awaited_with(mock_role, atomic=True)
+ assert result == handler.success()
+
+ @pytest.mark.asyncio
+ async def test_handle_rank_up_invalid_role(self, bot):
+ handler = MPHandler()
+ discord_id = 123456789
+ account_id = 987654321
+ rank = "Nonexistent"
+ mock_member = helpers.MockMember(id=discord_id)
+ mock_member.roles = []
+ mock_member.add_roles = AsyncMock()
+ mock_member.remove_roles = AsyncMock()
+ body = WebhookBody(
+ platform=Platform.MAIN,
+ event=WebhookEvent.RANK_UP,
+ properties={
+ "discord_id": discord_id,
+ "account_id": account_id,
+ "rank": rank,
+ },
+ traits={},
+ )
+ with (
+ patch.object(handler, "validate_discord_id", return_value=discord_id),
+ patch.object(handler, "validate_account_id", return_value=account_id),
+ patch.object(handler, "validate_property", return_value=rank),
+ patch.object(handler, "get_guild_member", new_callable=AsyncMock, return_value=mock_member),
+ patch("src.webhooks.handlers.mp.settings") as mock_settings,
+ ):
+ mock_settings.role_groups = {"ALL_RANKS": [555]}
+ mock_guild = helpers.MockGuild(id=1)
+ mock_guild.get_role.return_value = None
+ bot.guilds = [mock_guild]
+ with pytest.raises(ValueError, match="Cannot find role for"):
+ await handler._handle_rank_up(body, bot)
\ No newline at end of file
diff --git a/tests/src/webhooks/test_handlers_init.py b/tests/src/webhooks/test_handlers_init.py
new file mode 100644
index 0000000..ce0b436
--- /dev/null
+++ b/tests/src/webhooks/test_handlers_init.py
@@ -0,0 +1,35 @@
+from unittest import mock
+from typing import Callable
+
+import pytest
+
+from src.webhooks.handlers import handlers, can_handle, handle
+from src.webhooks.types import Platform, WebhookBody, WebhookEvent
+from tests.conftest import bot
+
+class TestHandlersInit:
+ def test_handler_init(self):
+ assert handlers is not None
+ assert isinstance(handlers, dict)
+ assert len(handlers) > 0
+ assert all(isinstance(handler, Callable) for handler in handlers.values())
+
+ def test_can_handle_unknown_platform(self):
+ assert not can_handle("UNKNOWN")
+
+ def test_can_handle_success(self):
+ with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: True}):
+ assert can_handle(Platform.MAIN)
+
+ def test_handle_success(self):
+ with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}):
+ assert handle(WebhookBody(platform=Platform.MAIN, event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot) == 1337
+
+ def test_handle_unknown_platform(self):
+ with pytest.raises(ValueError):
+ handle(WebhookBody(platform="UNKNOWN", event=WebhookEvent.ACCOUNT_LINKED, properties={}, traits={}), bot)
+
+ def test_handle_unknown_event(self):
+ with mock.patch("src.webhooks.handlers.handlers", {Platform.MAIN: lambda x, y: 1337}):
+ with pytest.raises(ValueError):
+ handle(WebhookBody(platform=Platform.MAIN, event="UNKNOWN", properties={}, traits={}), bot)
\ No newline at end of file