Skip to content

Commit

Permalink
Add not_rankable and reveal_synthetic flags (#1934)
Browse files Browse the repository at this point in the history
Extend protocol:
- add `synthetic` flag to Conversation message
- for ranking tasks add `reveal_synthetic` bool that indicates whether
synthetic status should be revealed
- ranking interaction (response) now has a new `not_rankable` bool that
can be set to indicate that all shown messages are flawed, factually
incorrect or unacceptable
  • Loading branch information
andreaskoepf committed Mar 2, 2023
1 parent 0b6865b commit 6244ef6
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 11 deletions.
6 changes: 4 additions & 2 deletions backend/main.py
Expand Up @@ -200,7 +200,7 @@ class DummyMessage(BaseModel):
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(
msg.text,
msg.lang,
msg.lang or "en",
msg.task_message_id,
msg.user_message_id,
review_count=5,
Expand All @@ -210,7 +210,9 @@ class DummyMessage(BaseModel):
)
if message.parent_id is None:
tm._insert_default_state(
root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING
root_message_id=message.id,
lang=message.lang,
state=msg.tree_state or message_tree_state.State.GROWING,
)
session.flush()

Expand Down
1 change: 1 addition & 0 deletions backend/oasst_backend/api/v1/utils.py
Expand Up @@ -44,6 +44,7 @@ def prepare_conversation_message(message: Message) -> protocol.ConversationMessa
emojis=message.emojis or {},
user_emojis=message.user_emojis or [],
user_is_author=message.user_is_author,
synthetic=message.synthetic,
)


Expand Down
2 changes: 2 additions & 0 deletions backend/oasst_backend/models/db_payload.py
Expand Up @@ -67,6 +67,7 @@ class RankingReactionPayload(ReactionPayload):
ranked_message_ids: list[UUID]
ranking_parent_id: Optional[UUID]
message_tree_id: Optional[UUID]
not_rankable: Optional[bool] # all options flawed, factually incorrect or unacceptable


@payload_type
Expand All @@ -75,6 +76,7 @@ class RankConversationRepliesPayload(TaskPayload):
reply_messages: list[protocol_schema.ConversationMessage]
ranking_parent_id: Optional[UUID]
message_tree_id: Optional[UUID]
reveal_synthetic: Optional[bool]


@payload_type
Expand Down
1 change: 1 addition & 0 deletions backend/oasst_backend/prompt_repository.py
Expand Up @@ -382,6 +382,7 @@ def store_ranking(self, ranking: protocol_schema.MessageRanking) -> tuple[Messag
ranked_message_ids=ranked_message_ids,
ranking_parent_id=task_payload.ranking_parent_id,
message_tree_id=task_payload.message_tree_id,
not_rankable=ranking.not_rankable,
)
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=parent_msg.id)
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
Expand Down
2 changes: 2 additions & 0 deletions backend/oasst_backend/task_repository.py
Expand Up @@ -83,6 +83,7 @@ def store_task(
reply_messages=task.reply_messages,
ranking_parent_id=task.ranking_parent_id,
message_tree_id=task.message_tree_id,
reveal_synthetic=task.reveal_synthetic,
)

case protocol_schema.RankAssistantRepliesTask:
Expand All @@ -92,6 +93,7 @@ def store_task(
reply_messages=task.reply_messages,
ranking_parent_id=task.ranking_parent_id,
message_tree_id=task.message_tree_id,
reveal_synthetic=task.reveal_synthetic,
)

case protocol_schema.LabelInitialPromptTask:
Expand Down
12 changes: 10 additions & 2 deletions backend/oasst_backend/tree_manager.py
Expand Up @@ -510,8 +510,14 @@ def next_task(
assert len(replies) > 1
random.shuffle(replies) # hand out replies in random order
reply_messages = prepare_conversation_message_list(replies)
replies = [p.text for p in replies]
if any(not m.synthetic for m in reply_messages):
reveal_synthetic = False
for rm in reply_messages:
rm.synthetic = None
else:
reveal_synthetic = True

replies = [p.text for p in replies]
if messages[-1].role == "assistant":
logger.info("Generating a RankPrompterRepliesTask.")
task = protocol_schema.RankPrompterRepliesTask(
Expand All @@ -520,6 +526,7 @@ def next_task(
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
reveal_synthetic=reveal_synthetic,
)
else:
logger.info("Generating a RankAssistantRepliesTask.")
Expand All @@ -529,6 +536,7 @@ def next_task(
reply_messages=reply_messages,
ranking_parent_id=ranking_parent.id,
message_tree_id=ranking_parent.message_tree_id,
reveal_synthetic=reveal_synthetic,
)

parent_message_id = ranking_parent_id
Expand Down Expand Up @@ -718,7 +726,7 @@ async def handle_interaction(self, interaction: protocol_schema.AnyInteraction)
logger.info(
f"TreeManager: Inserting new tree state for initial prompt {message.id=} [{message.lang}]"
)
self._insert_default_state(message.id, message.lang)
self._insert_default_state(message.id, lang=message.lang)

if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
Expand Down
16 changes: 9 additions & 7 deletions oasst-shared/oasst_shared/schemas/protocol.py
Expand Up @@ -68,15 +68,16 @@ class FrontEndUserPage(PageResult):
class ConversationMessage(BaseModel):
"""Represents a message in a conversation between the user and the assistant."""

id: Optional[UUID] = None
id: Optional[UUID]
user_id: Optional[UUID]
frontend_message_id: Optional[str] = None
frontend_message_id: Optional[str]
text: str
lang: Optional[str] # BCP 47
is_assistant: bool
emojis: Optional[dict[str, int]] = None
user_emojis: Optional[list[str]] = None
user_is_author: Optional[bool] = None
emojis: Optional[dict[str, int]]
user_emojis: Optional[list[str]]
user_is_author: Optional[bool]
synthetic: Optional[bool]


class Conversation(BaseModel):
Expand All @@ -103,7 +104,6 @@ class Message(ConversationMessage):
review_result: Optional[bool]
review_count: Optional[int]
deleted: Optional[bool]
synthetic: Optional[bool]
model_name: Optional[str]
message_tree_id: Optional[UUID]
ranking_count: Optional[int]
Expand Down Expand Up @@ -229,6 +229,7 @@ class RankConversationRepliesTask(Task):
reply_messages: list[ConversationMessage]
message_tree_id: UUID
ranking_parent_id: UUID
reveal_synthetic: bool


class RankPrompterRepliesTask(RankConversationRepliesTask):
Expand Down Expand Up @@ -354,8 +355,9 @@ class MessageRanking(Interaction):
"""A user has given a ranking for a message."""

type: Literal["message_ranking"] = "message_ranking"
message_id: str
message_id: str # parent message of replies that were ranked
ranking: conlist(item_type=int, min_items=1)
not_rankable: Optional[bool] # all options flawed, factually incorrect or unacceptable


class LabelWidget(str, enum.Enum):
Expand Down

0 comments on commit 6244ef6

Please sign in to comment.