Skip to content

Commit

Permalink
Allow punishment to be optional for basic AutoMod rules
Browse files Browse the repository at this point in the history
  • Loading branch information
LightSage committed Apr 28, 2023
1 parent 891abe0 commit d736c71
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions sanctum/routers/automod.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncpg
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from pydantic import BaseModel, root_validator

from ..app import Request
from ..errors import NotFound
Expand Down Expand Up @@ -61,11 +61,19 @@ class AutoModPunishmentModel(BaseModel):

class AutoModEventModel(BaseModel):
guild_id: int
type: Literal['message-spam', 'mass-mentions', 'url-spam', 'invite-spam', 'message-content-spam']
type: Literal['message-spam', 'mass-mentions', 'url-spam', 'invite-spam',
'message-content-spam', 'auto-dehoist', 'auto-normalize']
count: int
seconds: int
ignores: Optional[List[int]] = []
punishment: AutoModPunishmentModel
punishment: Optional[AutoModPunishmentModel]

@root_validator
def check_punishment(cls, values):
_type, punishment = values.get("type"), values.get('punishment')
if _type not in ("auto-dehoist", "auto-normalize") and punishment is None:
raise ValueError(f'{_type} requires a punishment')
return values

class Config:
schema_extra = {
Expand All @@ -90,7 +98,7 @@ def from_record(cls, record: asyncpg.Record) -> Self:
async def get_guild_automod_rules(guild_id: int, request: Request) -> List[AutoModEventDBModel]:
"""Gets a guild's automod rule configuration"""
query = """SELECT events.*, punishment.duration AS punishment_duration, punishment.type AS punishment_type FROM guild_automod_rules events
INNER JOIN guild_automod_punishment AS punishment ON events.id = punishment.id
LEFT OUTER JOIN guild_automod_punishment AS punishment ON events.id = punishment.id
WHERE events.guild_id=$1;"""
records = await request.app.pool.fetch(query, guild_id)
if not records:
Expand All @@ -116,10 +124,12 @@ async def add_new_automod_rule(guild_id: int, event: AutoModEventModel, request:
VALUES ($1, $2, $3, $4, $5)
RETURNING id;"""
rnum = await conn.fetchval(query, guild_id, event.type, event.count, event.seconds, event.ignores)
query = """INSERT INTO guild_automod_punishment (id, type, duration)
VALUES ($1, $2, $3);"""
await conn.execute(query, rnum, event.punishment.type,
event.punishment.duration)

if event.punishment:
query = """INSERT INTO guild_automod_punishment (id, type, duration)
VALUES ($1, $2, $3);"""
await conn.execute(query, rnum, event.punishment.type,
event.punishment.duration)

return {"id": rnum}

Expand Down

0 comments on commit d736c71

Please sign in to comment.