Skip to content

Commit

Permalink
Store worker compliance check results, add initial scoring algorithm (#…
Browse files Browse the repository at this point in the history
…1894)

Closes #1892,

This also refactors `models.py` into several model definition files by
category.
  • Loading branch information
olliestanley committed Feb 28, 2023
1 parent d8e1027 commit 5c6907f
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 41 deletions.
@@ -0,0 +1,60 @@
"""Add worker compliance checks to db
Revision ID: a2865908a537
Revises: e7a4bffb424c
Create Date: 2023-02-26 13:28:25.181340
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op

# revision identifiers, used by Alembic.
revision = "a2865908a537"
down_revision = "e7a4bffb424c"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"worker_compliance_check",
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("worker_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("compare_worker_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("start_time", sa.DateTime(), nullable=False),
sa.Column("end_time", sa.DateTime(), nullable=True),
sa.Column("responded", sa.Boolean(), nullable=False),
sa.Column("error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("passed", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["worker_id"],
["worker.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_worker_compliance_check_worker_id"), "worker_compliance_check", ["worker_id"], unique=False
)
op.create_index(
op.f("ix_worker_compliance_check_compare_worker_id"),
"worker_compliance_check",
["compare_worker_id"],
unique=False,
)
op.alter_column("chat", "user_id", existing_type=sa.VARCHAR(), nullable=False)
op.create_index(op.f("ix_report_message_id"), "report", ["message_id"], unique=False)
op.create_index(op.f("ix_vote_message_id"), "vote", ["message_id"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_vote_message_id"), table_name="vote")
op.drop_index(op.f("ix_report_message_id"), table_name="report")
op.alter_column("chat", "user_id", existing_type=sa.VARCHAR(), nullable=True)
op.drop_index(op.f("ix_worker_compliance_check_worker_id"), table_name="worker_compliance_check")
op.drop_index(op.f("ix_worker_compliance_check_compare_worker_id"), table_name="worker_compliance_check")
op.drop_table("worker_compliance_check")
# ### end Alembic commands ###
15 changes: 15 additions & 0 deletions inference/server/oasst_inference_server/models/__init__.py
@@ -0,0 +1,15 @@
from .chat import DbChat, DbMessage, DbReport, DbVote
from .user import DbUser
from .worker import DbWorker, DbWorkerComplianceCheck, DbWorkerEvent, WorkerEventType

__all__ = [
"DbChat",
"DbMessage",
"DbReport",
"DbVote",
"DbUser",
"DbWorker",
"DbWorkerComplianceCheck",
"DbWorkerEvent",
"WorkerEventType",
]
@@ -1,12 +1,11 @@
import datetime
import enum
from uuid import uuid4

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_inference_server.schemas import chat as chat_schema
from oasst_shared.schemas import inference
from sqlmodel import Field, Index, Relationship, SQLModel
from sqlmodel import Field, Relationship, SQLModel


class DbMessage(SQLModel, table=True):
Expand Down Expand Up @@ -86,42 +85,3 @@ class DbReport(SQLModel, table=True):

def to_read(self) -> inference.Report:
return inference.Report(id=self.id, report_type=self.report_type, reason=self.reason)


class WorkerEventType(str, enum.Enum):
connect = "connect"


class DbWorkerEvent(SQLModel, table=True):
__tablename__ = "worker_event"

id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
worker_id: str = Field(foreign_key="worker.id", index=True)
worker: "DbWorker" = Relationship(back_populates="events")
time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
event_type: WorkerEventType
worker_config: inference.WorkerConfig | None = Field(None, sa_column=sa.Column(pg.JSONB))


class DbWorker(SQLModel, table=True):
__tablename__ = "worker"

id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
api_key: str = Field(default_factory=lambda: str(uuid4()), index=True)
name: str

in_compliance_check: bool = Field(default=False, sa_column=sa.Column(sa.Boolean, server_default=sa.text("false")))
next_compliance_check: datetime.datetime | None = Field(None)
events: list[DbWorkerEvent] = Relationship(back_populates="worker")


class DbUser(SQLModel, table=True):
__tablename__ = "user"
__table_args__ = (Index("provider", "provider_account_id", unique=True),)

id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)

provider: str = Field(index=True)
provider_account_id: str = Field(index=True)

display_name: str = Field(nullable=False, max_length=256)
15 changes: 15 additions & 0 deletions inference/server/oasst_inference_server/models/user.py
@@ -0,0 +1,15 @@
from uuid import uuid4

from sqlmodel import Field, Index, SQLModel


class DbUser(SQLModel, table=True):
__tablename__ = "user"
__table_args__ = (Index("provider", "provider_account_id", unique=True),)

id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)

provider: str = Field(index=True)
provider_account_id: str = Field(index=True)

display_name: str = Field(nullable=False, max_length=256)
51 changes: 51 additions & 0 deletions inference/server/oasst_inference_server/models/worker.py
@@ -0,0 +1,51 @@
import datetime
import enum
from uuid import uuid4

import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_shared.schemas import inference
from sqlmodel import Field, Relationship, SQLModel


class WorkerEventType(str, enum.Enum):
connect = "connect"


class DbWorkerComplianceCheck(SQLModel, table=True):
__tablename__ = "worker_compliance_check"

id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
worker_id: str = Field(foreign_key="worker.id", index=True)
worker: "DbWorker" = Relationship(back_populates="compliance_checks")
compare_worker_id: str = Field(foreign_key="worker.id", index=True)

start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime | None = Field(None, nullable=True)
responded: bool = Field(default=False, nullable=False)
error: str | None = Field(None, nullable=True)
passed: bool = Field(default=False, nullable=False)


class DbWorkerEvent(SQLModel, table=True):
__tablename__ = "worker_event"

id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
worker_id: str = Field(foreign_key="worker.id", index=True)
worker: "DbWorker" = Relationship(back_populates="events")
time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
event_type: WorkerEventType
worker_config: inference.WorkerConfig | None = Field(None, sa_column=sa.Column(pg.JSONB))


class DbWorker(SQLModel, table=True):
__tablename__ = "worker"

id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
api_key: str = Field(default_factory=lambda: str(uuid4()), index=True)
name: str

compliance_checks: list[DbWorkerComplianceCheck] = Relationship(back_populates="worker")
in_compliance_check: bool = Field(default=False, sa_column=sa.Column(sa.Boolean, server_default=sa.text("false")))
next_compliance_check: datetime.datetime | None = Field(None)
events: list[DbWorkerEvent] = Relationship(back_populates="worker")
43 changes: 43 additions & 0 deletions inference/server/oasst_inference_server/worker_handler.py
Expand Up @@ -12,6 +12,7 @@
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from sqlalchemy.sql.functions import random as sql_random
from sqlmodel import not_, or_

WSException = (
websockets.exceptions.WebSocketException,
Expand Down Expand Up @@ -171,6 +172,8 @@ async def run_compliance_check(websocket: fastapi.WebSocket, worker_id: str, wor
logger.info(f"Running compliance check for worker {worker_id}")

with deps.manual_create_session() as session:
compliance_check = models.DbWorkerComplianceCheck(worker_id=worker_id)

try:
message = find_compliance_work_request_message(session, worker_config, worker_id)
if message is None:
Expand All @@ -179,6 +182,7 @@ async def run_compliance_check(websocket: fastapi.WebSocket, worker_id: str, wor
)
return

compliance_check.compare_worker_id = message.worker_id
compliance_work_request = worker_handler.build_work_request(message)

logger.info(f"Found work request for compliance check for worker {worker_id}: {compliance_work_request}")
Expand All @@ -187,17 +191,23 @@ async def run_compliance_check(websocket: fastapi.WebSocket, worker_id: str, wor
while True:
response = await receive_work_response_packet(websocket)
if response.error is not None:
compliance_check.responded = True
compliance_check.error = response.error
logger.warning(f"Worker {worker_id} errored during compliance check: {response.error}")
return
if response.is_end:
break
if response is None:
logger.warning(f"Worker {worker_id} did not respond to compliance check")
return
compliance_check.responded = True
passes = response.generated_text.text == message.content
compliance_check.passed = passes
logger.info(f"Worker {worker_id} passed compliance check: {passes}")

finally:
compliance_check.end_time = datetime.datetime.utcnow()
session.add(compliance_check)
worker = get_worker(worker_id, session, with_for_update=True)
worker.next_compliance_check = datetime.datetime.utcnow() + datetime.timedelta(
seconds=settings.compliance_check_interval
Expand Down Expand Up @@ -405,3 +415,36 @@ async def perform_work(
logger.exception(f"Error handling {message_id=}")
cr.abort_work(message_id, reason=str(e))
raise WorkerError("Error handling chat", did_work=True)


async def compute_worker_compliance_score(worker_id: str) -> float:
"""
Compute a float between 0 and 1 (inclusive) representing the compliance score of the worker.
Workers are rewarded for passing compliance checks, and penalised for failing to respond to a check, erroring during a check, or failing a check.
In-progress checks are ignored.
"""
with deps.manual_create_session() as session:
worker_checks: list[models.DbWorkerComplianceCheck] = session.exec(
sqlmodel.select(models.DbWorkerComplianceCheck).where(
or_(
models.DbWorkerComplianceCheck.worker_id == worker_id,
models.DbWorkerComplianceCheck.compare_worker_id == worker_id,
),
not_(models.DbWorkerComplianceCheck.end_time.is_(None)),
)
).all()

# Rudimentary scoring algorithm, we may want to add weightings or other factors
total_count = len(worker_checks)

checked = [c for c in worker_checks if c.worker_id == worker_id]
compared = [c for c in worker_checks if c.compare_worker_id == worker_id]

pass_count = sum(1 for _ in filter(lambda c: c.passed, checked))
error_count = sum(1 for _ in filter(lambda c: c.error is not None, checked))
no_response_count = sum(1 for _ in filter(lambda c: not c.responded, checked))

compare_fail_count = sum(1 for _ in filter(lambda c: not c.passed, compared))
fail_count = len(checked) - pass_count - error_count - no_response_count

return (fail_count + compare_fail_count) / total_count

0 comments on commit 5c6907f

Please sign in to comment.