Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/kernelbot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,32 @@ async def _stream_submission_response(
pass


@app.post("/admin/ban/{user_id}")
async def admin_ban_user(
user_id: str,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
with db_context as db:
found = db.ban_user(user_id)
if not found:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return {"status": "ok", "user_id": user_id, "banned": True}


@app.delete("/admin/ban/{user_id}")
async def admin_unban_user(
user_id: str,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
with db_context as db:
found = db.unban_user(user_id)
if not found:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return {"status": "ok", "user_id": user_id, "banned": False}


@app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}")
async def run_submission( # noqa: C901
leaderboard_name: str,
Expand Down
46 changes: 46 additions & 0 deletions src/kernelbot/cogs/admin_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ def __init__(self, bot: "ClusterBot"):
name="set-forum-ids", description="Sets forum IDs"
)(self.set_forum_ids)

self.ban_user = bot.admin_group.command(
name="ban", description="Ban a user from making submissions"
)(self.ban_user)

self.unban_user = bot.admin_group.command(
name="unban", description="Unban a user"
)(self.unban_user)

self.export_to_hf = bot.admin_group.command(
name="export-hf", description="Export competition data to Hugging Face dataset"
)(self.export_to_hf)
Expand Down Expand Up @@ -154,6 +162,44 @@ async def is_creator_check(
return True
return False

@discord.app_commands.describe(user_id="Discord user ID to ban")
@with_error_handling
async def ban_user(self, interaction: discord.Interaction, user_id: str):
if not await self.admin_check(interaction):
await send_discord_message(
interaction, "You need to have Admin permissions to run this command", ephemeral=True
)
return

with self.bot.leaderboard_db as db:
if db.ban_user(user_id):
await send_discord_message(
interaction, f"User `{user_id}` has been banned.", ephemeral=True
)
else:
await send_discord_message(
interaction, f"User `{user_id}` not found.", ephemeral=True
)

@discord.app_commands.describe(user_id="Discord user ID to unban")
@with_error_handling
async def unban_user(self, interaction: discord.Interaction, user_id: str):
if not await self.admin_check(interaction):
await send_discord_message(
interaction, "You need to have Admin permissions to run this command", ephemeral=True
)
return

with self.bot.leaderboard_db as db:
if db.unban_user(user_id):
await send_discord_message(
interaction, f"User `{user_id}` has been unbanned.", ephemeral=True
)
else:
await send_discord_message(
interaction, f"User `{user_id}` not found.", ephemeral=True
)

@discord.app_commands.describe(
directory="Directory of the kernel definition. Also used as the leaderboard's name",
gpu="The GPU to submit to. Leave empty for interactive selection/multiple GPUs",
Expand Down
53 changes: 53 additions & 0 deletions src/libkernelbot/leaderboard_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,59 @@ def validate_cli_id(self, cli_id: str) -> Optional[dict[str, str]]:
raise KernelBotError("Error validating CLI ID") from e


def ban_user(self, user_id: str) -> bool:
"""Ban a user by their ID. Returns True if the user was found and banned."""
try:
self.cursor.execute(
"""
UPDATE leaderboard.user_info
SET is_banned = TRUE
WHERE id = %s
""",
(str(user_id),),
)
self.connection.commit()
return self.cursor.rowcount > 0
except psycopg2.Error as e:
self.connection.rollback()
logger.exception("Error banning user %s", user_id, exc_info=e)
raise KernelBotError("Error banning user") from e

def unban_user(self, user_id: str) -> bool:
"""Unban a user by their ID. Returns True if the user was found and unbanned."""
try:
self.cursor.execute(
"""
UPDATE leaderboard.user_info
SET is_banned = FALSE
WHERE id = %s
""",
(str(user_id),),
)
self.connection.commit()
return self.cursor.rowcount > 0
except psycopg2.Error as e:
self.connection.rollback()
logger.exception("Error unbanning user %s", user_id, exc_info=e)
raise KernelBotError("Error unbanning user") from e

def is_user_banned(self, user_id: str) -> bool:
"""Check if a user is banned."""
try:
self.cursor.execute(
"""
SELECT is_banned FROM leaderboard.user_info
WHERE id = %s
""",
(str(user_id),),
)
row = self.cursor.fetchone()
return row[0] if row else False
except psycopg2.Error as e:
self.connection.rollback()
logger.exception("Error checking ban status for user %s", user_id, exc_info=e)
raise KernelBotError("Error checking ban status") from e

def set_rate_limit(self, leaderboard_name: str, mode_category: str, max_per_hour: int) -> RateLimitItem:
try:
self.cursor.execute(
Expand Down
4 changes: 4 additions & 0 deletions src/libkernelbot/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def prepare_submission( # noqa: C901
"The bot is currently not accepting any new submissions, please try again later."
)

with backend.db as db:
if db.is_user_banned(str(req.user_id)):
raise KernelBotError("You are banned from making submissions.")

if profanity.contains_profanity(req.file_name):
raise KernelBotError("Please provide a non-rude filename")

Expand Down
22 changes: 22 additions & 0 deletions src/migrations/20260318_01_ban-user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
add_is_banned_to_user_info
"""

from yoyo import step

__depends__ = {'20260317_01_rate-limits'}

steps = [
step(
# forward
"""
ALTER TABLE leaderboard.user_info
ADD COLUMN is_banned BOOLEAN NOT NULL DEFAULT FALSE
""",
# backward
"""
ALTER TABLE leaderboard.user_info
DROP COLUMN is_banned;
"""
)
]
1 change: 1 addition & 0 deletions tests/test_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def mock_backend():
"name": "test_board",
}
db_context.get_leaderboard_gpu_types.return_value = ["A100", "V100"]
db_context.is_user_banned.return_value = False

return backend

Expand Down
Loading