Skip to content

Commit

Permalink
Exclude prompts of disabled users from prompt lottery (#1748)
Browse files Browse the repository at this point in the history
Co-authored-by: --show <--show>
  • Loading branch information
andreaskoepf committed Feb 19, 2023
1 parent 7f8163b commit 52945f1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
3 changes: 2 additions & 1 deletion backend/main.py
Expand Up @@ -21,7 +21,7 @@
from oasst_backend.models import message_tree_state
from oasst_backend.prompt_repository import PromptRepository, UserRepository
from oasst_backend.task_repository import TaskRepository, delete_expired_tasks
from oasst_backend.tree_manager import TreeManager
from oasst_backend.tree_manager import TreeManager, halt_prompts_of_disabled_users
from oasst_backend.user_repository import User
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
Expand Down Expand Up @@ -333,6 +333,7 @@ def update_user_streak(session: Session) -> None:
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def cronjob_delete_expired_tasks(session: Session) -> None:
delete_expired_tasks(session)
halt_prompts_of_disabled_users(session)


@app.on_event("startup")
Expand Down
33 changes: 33 additions & 0 deletions backend/oasst_backend/tree_manager.py
Expand Up @@ -120,6 +120,33 @@ class TreeManagerStats(pydantic.BaseModel):
message_counts: list[TreeMessageCountStats]


def halt_prompts_of_disabled_users(db: Session):
_sql_halt_prompts_of_disabled_users = """
-- remove prompts of disabled & deleted users from prompt lottery
WITH cte AS (
SELECT mts.message_tree_id
FROM message_tree_state mts
JOIN message m ON mts.message_tree_id = m.id
JOIN "user" u ON m.user_id = u.id
WHERE state = :prompt_lottery_waiting_state AND (NOT u.enabled OR u.deleted)
)
UPDATE message_tree_state mts2
SET active=false, state=:halted_by_moderator_state
FROM cte
WHERE mts2.message_tree_id = cte.message_tree_id;
"""

r = db.execute(
text(_sql_halt_prompts_of_disabled_users),
{
"prompt_lottery_waiting_state": message_tree_state.State.PROMPT_LOTTERY_WAITING,
"halted_by_moderator_state": message_tree_state.State.HALTED_BY_MODERATOR,
},
)
if r.rowcount > 0:
logger.info(f"Halted {r.rowcount} prompts of disabled users.")


class TreeManager:
def __init__(
self,
Expand Down Expand Up @@ -240,16 +267,20 @@ def _prompt_lottery(self, lang: str, max_activate: int = 1) -> int:

@managed_tx_function(CommitMode.COMMIT)
def activate_one(db: Session) -> int:

# select among distinct users
authors_qry = (
db.query(Message.user_id)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.id)
.join(User, Message.user_id == User.id)
.filter(
MessageTreeState.state == message_tree_state.State.PROMPT_LOTTERY_WAITING,
Message.lang == lang,
not_(Message.deleted),
Message.review_result,
User.enabled,
not_(User.deleted),
)
.distinct(Message.user_id)
)
Expand Down Expand Up @@ -1309,6 +1340,8 @@ def ensure_tree_states(self) -> None:
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
self._insert_default_state(id, state=state)

halt_prompts_of_disabled_users(self.db)

# check tree state transitions (maybe variables haves changes): prompt review -> growing -> ranking -> scoring
prompt_review_trees: list[MessageTreeState] = (
self.db.query(MessageTreeState)
Expand Down

0 comments on commit 52945f1

Please sign in to comment.