Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/kernelbot/api/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ async def to_submit_info(
user_name=user_name,
gpus=[gpu_type],
leaderboard=leaderboard_name,
identity_type=user_info.get("id_type"),
)
except UnicodeDecodeError:
raise HTTPException(
Expand Down
171 changes: 158 additions & 13 deletions src/kernelbot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ async def validate_cli_header(
if user_info is None:
raise HTTPException(status_code=401, detail="Invalid or unauthorized X-Popcorn-Cli-Id")

user_info["id_type"] = "cli"
return user_info


Expand Down Expand Up @@ -178,6 +179,38 @@ async def validate_user_header(
return user_info


async def optional_user_header(
x_web_auth_id: Optional[str] = Header(None, alias="X-Web-Auth-Id"),
x_popcorn_cli_id: Optional[str] = Header(None, alias="X-Popcorn-Cli-Id"),
db_context: LeaderboardDB = Depends(get_db),
) -> Optional[Any]:
"""Like validate_user_header but returns None instead of raising when no auth header is present."""
token = x_web_auth_id or x_popcorn_cli_id
if not token:
return None

if x_web_auth_id:
id_type = IdentityType.WEB
else:
id_type = IdentityType.CLI

try:
with db_context as db:
user_info = db.validate_identity(token, id_type)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Database error during validation: {e}",
) from e

if not user_info:
raise HTTPException(
status_code=401,
detail="Invalid or unauthorized auth header",
)
return user_info


def require_admin(
authorization: Optional[str] = Header(None, alias="Authorization"),
) -> None:
Expand All @@ -188,6 +221,16 @@ def require_admin(
raise HTTPException(status_code=401, detail="Invalid admin token")


def enforce_leaderboard_access(db, leaderboard_name: str, user_info: Optional[dict]) -> None:
"""Raise 401/403 if the leaderboard is closed and the user lacks access."""
lb = db.get_leaderboard(leaderboard_name)
if lb.get("visibility") == "closed":
if user_info is None:
raise HTTPException(status_code=401, detail="Authentication required for closed leaderboard")
if not db.check_leaderboard_access(leaderboard_name, user_info["user_id"]):
raise HTTPException(status_code=403, detail="You do not have access to this leaderboard")


@app.get("/auth/init")
async def auth_init(provider: str, db_context=Depends(get_db)) -> dict:
if provider not in ["discord", "github"]:
Expand Down Expand Up @@ -576,13 +619,18 @@ async def create_dev_leaderboard(
except Exception:
pass # Leaderboard doesn't exist, that's fine

visibility = payload.get("visibility", "public")
if visibility not in ("public", "closed"):
raise HTTPException(status_code=400, detail="visibility must be 'public' or 'closed'")

db.create_leaderboard(
name=leaderboard_name,
deadline=deadline_value,
definition=definition,
creator_id=0,
forum_id=-1,
gpu_types=definition.gpus,
visibility=visibility,
)
return {"status": "ok", "leaderboard": leaderboard_name}

Expand Down Expand Up @@ -652,6 +700,9 @@ async def admin_update_problems(
problem_set = payload.get("problem_set")
branch = payload.get("branch", "main")
force = payload.get("force", False)
visibility = payload.get("visibility", "public")
if visibility not in ("public", "closed"):
raise HTTPException(status_code=400, detail="visibility must be 'public' or 'closed'")

try:
result = sync_problems(
Expand All @@ -662,6 +713,7 @@ async def admin_update_problems(
force=force,
creator_id=0, # API-created
forum_id=-1, # No Discord forum
visibility=visibility,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
Expand Down Expand Up @@ -740,21 +792,19 @@ async def get_leaderboards(db_context=Depends(get_db)):


@app.get("/gpus/{leaderboard_name}")
async def get_gpus(leaderboard_name: str, db_context=Depends(get_db)) -> list[str]:
"""An endpoint that returns all GPU types that are available for a given leaderboard and runner.

Args:
leaderboard_name (str): The name of the leaderboard to get the GPU types for.
runner_name (str): The name of the runner to get the GPU types for.

Returns:
list[str]: A list of GPU types that are available for the given leaderboard and runner.
"""
async def get_gpus(
leaderboard_name: str,
user_info: Annotated[Optional[Any], Depends(optional_user_header)] = None,
db_context=Depends(get_db),
) -> list[str]:
"""An endpoint that returns all GPU types that are available for a given leaderboard and runner."""
await simple_rate_limit()
try:
with db_context as db:
enforce_leaderboard_access(db, leaderboard_name, user_info)
return db.get_leaderboard_gpu_types(leaderboard_name)

except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching GPU data: {e}") from e

Expand All @@ -765,29 +815,39 @@ async def get_submissions(
gpu_name: str,
limit: int = None,
offset: int = 0,
user_info: Annotated[Optional[Any], Depends(optional_user_header)] = None,
db_context=Depends(get_db),
) -> list[LeaderboardRankedEntry]:
await simple_rate_limit()
try:
with db_context as db:
# Add validation for leaderboard and GPU? Might be redundant if DB handles it.
enforce_leaderboard_access(db, leaderboard_name, user_info)
return db.get_leaderboard_submissions(
leaderboard_name, gpu_name, limit=limit, offset=offset
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching submissions: {e}") from e


@app.get("/submission_count/{leaderboard_name}/{gpu_name}")
async def get_submission_count(
leaderboard_name: str, gpu_name: str, user_id: str = None, db_context=Depends(get_db)
leaderboard_name: str,
gpu_name: str,
user_id: str = None,
user_info: Annotated[Optional[Any], Depends(optional_user_header)] = None,
db_context=Depends(get_db),
) -> dict:
"""Get the total count of submissions for pagination"""
await simple_rate_limit()
try:
with db_context as db:
enforce_leaderboard_access(db, leaderboard_name, user_info)
count = db.get_leaderboard_submission_count(leaderboard_name, gpu_name, user_id)
return {"count": count}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching submission count: {e}") from e

Expand Down Expand Up @@ -912,3 +972,88 @@ async def delete_user_submission(
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting submission: {e}") from e


@app.post("/admin/invites")
async def admin_generate_invites(
payload: dict,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
"""Generate invite codes covering one or more leaderboards.

Accepts either:
{"leaderboards": ["lb1", "lb2"], "count": 10}
{"leaderboard": "lb1", "count": 10} (single leaderboard shorthand)
"""
count = payload.get("count")
if not isinstance(count, int) or count < 1 or count > 10000:
raise HTTPException(status_code=400, detail="count must be an integer between 1 and 10000")
leaderboards = payload.get("leaderboards") or []
if not leaderboards:
single = payload.get("leaderboard")
if single:
leaderboards = [single]
if not leaderboards or not isinstance(leaderboards, list):
raise HTTPException(status_code=400, detail="Must provide 'leaderboards' list or 'leaderboard' string")
with db_context as db:
codes = db.generate_invite_codes(leaderboards, count)
return {"status": "ok", "leaderboards": leaderboards, "codes": codes}


@app.get("/admin/leaderboards/{leaderboard_name}/invites")
async def admin_list_invites(
leaderboard_name: str,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
"""List all invite codes for a leaderboard with claim status."""
with db_context as db:
invites = db.get_invite_codes(leaderboard_name)
return {"status": "ok", "leaderboard": leaderboard_name, "invites": invites}


@app.delete("/admin/invites/{code}")
async def admin_revoke_invite(
code: str,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
"""Revoke an invite code, removing it from the pool."""
with db_context as db:
result = db.revoke_invite_code(code)
return {"status": "ok", **result}


@app.post("/admin/leaderboards/{leaderboard_name}/visibility")
async def admin_set_visibility(
leaderboard_name: str,
payload: dict,
_: Annotated[None, Depends(require_admin)],
db_context=Depends(get_db),
) -> dict:
"""Change the visibility of an existing leaderboard."""
visibility = payload.get("visibility")
if visibility not in ("public", "closed"):
raise HTTPException(status_code=400, detail="visibility must be 'public' or 'closed'")
with db_context as db:
db.set_leaderboard_visibility(leaderboard_name, visibility)
return {"status": "ok", "leaderboard": leaderboard_name, "visibility": visibility}


@app.post("/user/join")
async def user_join_leaderboard(
payload: dict,
user_info: Annotated[dict, Depends(validate_cli_header)],
db_context=Depends(get_db),
) -> dict:
"""Claim an invite code to join a closed leaderboard. CLI only."""
code = payload.get("code")
if not code:
raise HTTPException(status_code=400, detail="Missing required field: code")
try:
with db_context as db:
result = db.claim_invite_code(code, user_info["user_id"])
except KernelBotError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return {"status": "ok", "leaderboards": result["leaderboards"]}
16 changes: 12 additions & 4 deletions src/kernelbot/cogs/admin_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async def leaderboard_create_local(
interaction: discord.Interaction,
directory: str,
gpu: Optional[app_commands.Choice[str]],
closed: bool = False,
):
is_admin = await self.admin_check(interaction)
if not is_admin:
Expand Down Expand Up @@ -218,6 +219,7 @@ async def leaderboard_create_local(
definition=definition,
forum_id=forum_id,
gpu=gpu.value if gpu else None,
visibility="closed" if closed else "public",
):
await send_discord_message(
interaction,
Expand All @@ -241,6 +243,7 @@ async def leaderboard_create_impl( # noqa: C901
deadline: str,
definition: LeaderboardDefinition,
gpus: Optional[str | list[str]],
visibility: str = "public",
):
if len(leaderboard_name) > 95:
await send_discord_message(
Expand Down Expand Up @@ -282,7 +285,8 @@ async def leaderboard_create_impl( # noqa: C901
)

success = await self.create_leaderboard_in_db(
interaction, leaderboard_name, date_value, definition, forum_thread.thread.id, gpus
interaction, leaderboard_name, date_value, definition, forum_thread.thread.id, gpus,
visibility=visibility,
)
if not success:
await forum_thread.delete()
Expand Down Expand Up @@ -331,6 +335,7 @@ async def create_leaderboard_in_db(
definition: LeaderboardDefinition,
forum_id: int,
gpu: Optional[str | list[str]] = None,
visibility: str = "public",
) -> bool:
if gpu is None:
# Ask the user to select GPUs
Expand Down Expand Up @@ -361,6 +366,7 @@ async def create_leaderboard_in_db(
gpu_types=selected_gpus,
creator_id=interaction.user.id,
forum_id=forum_id,
visibility=visibility,
)
except KernelBotError as e:
await send_discord_message(
Expand Down Expand Up @@ -521,6 +527,7 @@ async def update_problems(
problem_set: Optional[str] = None,
branch: Optional[str] = "main",
force: bool = False,
closed: bool = False,
):
is_admin = await self.admin_check(interaction)
if not is_admin:
Expand Down Expand Up @@ -579,7 +586,7 @@ async def update_problems(
)
return
for competition in problem_dir.glob("*.yaml"):
await self.update_competition(interaction, competition)
await self.update_competition(interaction, competition, closed=closed)
else:
problem_set = problem_dir / f"{problem_set}.yaml"
if not problem_set.exists():
Expand All @@ -592,7 +599,7 @@ async def update_problems(
ephemeral=True,
)
return
await self.update_competition(interaction, problem_set, force)
await self.update_competition(interaction, problem_set, force, closed=closed)

async def _create_update_plan( # noqa: C901
self,
Expand Down Expand Up @@ -699,7 +706,7 @@ async def _create_update_plan( # noqa: C901
return update_list, create_list

async def update_competition(
self, interaction: discord.Interaction, spec_file: Path, force: bool = False
self, interaction: discord.Interaction, spec_file: Path, force: bool = False, closed: bool = False
):
try:
root = spec_file.parent
Expand Down Expand Up @@ -738,6 +745,7 @@ async def update_competition(
entry["deadline"],
make_task_definition(root / entry["directory"]),
entry["gpus"],
visibility="closed" if closed else "public",
)
steps += "done\n"

Expand Down
1 change: 1 addition & 0 deletions src/kernelbot/cogs/leaderboard_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ async def submit(
user_name=interaction.user.global_name or interaction.user.name,
gpus=gpu,
leaderboard=leaderboard_name,
identity_type="discord",
)
req = prepare_submission(req, self.bot.backend)

Expand Down
1 change: 1 addition & 0 deletions src/libkernelbot/db_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class LeaderboardItem(TypedDict):
gpu_types: List[str]
forum_id: int
secret_seed: NotRequired[int]
visibility: str


class LeaderboardRankedEntry(TypedDict):
Expand Down
Loading
Loading