Skip to content

Commit

Permalink
add spam-only export option
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Feb 8, 2023
1 parent 8bad8c3 commit af6885d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
29 changes: 27 additions & 2 deletions backend/export.py
Expand Up @@ -8,7 +8,7 @@
from oasst_backend.models import Message, MessageTreeState
from oasst_backend.models.message_tree_state import State as TreeState
from oasst_backend.utils import tree_export
from sqlmodel import Session
from sqlmodel import Session, not_


def fetch_tree_ids(
Expand Down Expand Up @@ -38,6 +38,7 @@ def fetch_tree_messages(
deleted: bool = None,
prompts_only: bool = False,
lang: Optional[str] = None,
review_result: Optional[bool] = None,
) -> List[Message]:
qry = db.query(Message)

Expand All @@ -51,6 +52,10 @@ def fetch_tree_messages(
qry = qry.filter(Message.parent_id.is_(None))
if lang:
qry = qry.filter(Message.lang == lang)
if review_result is False:
qry = qry.filter(not_(Message.review_result), Message.review_count > 2)
elif review_result is True:
qry = qry.filter(Message.review_result)

return qry.all()

Expand All @@ -64,16 +69,18 @@ def export_trees(
prompts_only: bool = False,
state_filter: Optional[TreeState] = None,
lang: Optional[str] = None,
review_result: Optional[bool] = None,
) -> None:
trees_to_export: List[tree_export.ExportMessageTree] = []

if user_id:
if user_id or review_result is False:
messages = fetch_tree_messages(
db,
user_id=user_id,
deleted=deleted,
prompts_only=prompts_only,
lang=lang,
review_result=review_result,
)
tree_export.write_messages_to_file(export_file, messages, use_compression)
else:
Expand All @@ -86,6 +93,7 @@ def export_trees(
deleted=deleted,
prompts_only=prompts_only,
lang=None,
review_result=review_result,
)
for (tree_id, _) in message_tree_ids
]
Expand Down Expand Up @@ -135,6 +143,16 @@ def parse_args():
action="store_true",
help="Export only deleted messages (implies --include-deleted)",
)
parser.add_argument(
"--include-spam",
action="store_true",
help="Export only messages with negative review result.",
)
parser.add_argument(
"--spam-only",
action="store_true",
help="Export only messages with negative review result (implies --include-spam).",
)
parser.add_argument(
"--user",
type=str,
Expand Down Expand Up @@ -176,6 +194,12 @@ def main():
if args.deleted_only:
deleted = True

review_result: Optional[bool] = True
if args.include_spam:
review_result = None
if args.spam_only:
review_result = False

with Session(engine) as db:
export_trees(
db,
Expand All @@ -186,6 +210,7 @@ def main():
prompts_only=args.prompts_only,
state_filter=state_filter,
lang=args.lang,
review_result=review_result,
)


Expand Down
2 changes: 2 additions & 0 deletions backend/oasst_backend/utils/tree_export.py
Expand Up @@ -21,6 +21,7 @@ class ExportMessageNode(BaseModel):
role: str
lang: str | None
review_count: int | None
review_result: bool | None
rank: int | None
synthetic: bool | None
model_name: str | None
Expand All @@ -36,6 +37,7 @@ def prep_message_export(message: Message) -> ExportMessageNode:
role=message.role,
lang=message.lang,
review_count=message.review_count,
review_result=message.review_result if message.review_result or message.review_count > 2 else None,
synthetic=message.synthetic,
model_name=message.model_name,
emojis=message.emojis,
Expand Down

0 comments on commit af6885d

Please sign in to comment.