Skip to content

Commit

Permalink
added chat modified, created, and title fields
Browse files Browse the repository at this point in the history
  • Loading branch information
yk committed Mar 12, 2023
1 parent 8fcd982 commit 31faa39
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 4 deletions.
@@ -0,0 +1,32 @@
"""added created and title to chat
Revision ID: af8668441ff0
Revises: 486fe9e7fb84
Create Date: 2023-03-12 17:35:27.837252
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op

# revision identifiers, used by Alembic.
revision = "af8668441ff0"
down_revision = "486fe9e7fb84"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat", sa.Column("created_at", sa.DateTime(), nullable=False))
op.add_column("chat", sa.Column("title", sqlmodel.sql.sqltypes.AutoString(), nullable=True))
op.create_index(op.f("ix_chat_created_at"), "chat", ["created_at"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_chat_created_at"), table_name="chat")
op.drop_column("chat", "title")
op.drop_column("chat", "created_at")
# ### end Alembic commands ###
@@ -0,0 +1,29 @@
"""added modified to chat
Revision ID: b247f202e522
Revises: af8668441ff0
Create Date: 2023-03-12 17:36:53.061631
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "b247f202e522"
down_revision = "af8668441ff0"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("chat", sa.Column("modified_at", sa.DateTime(), nullable=False))
op.create_index(op.f("ix_chat_modified_at"), "chat", ["modified_at"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_chat_modified_at"), table_name="chat")
op.drop_column("chat", "modified_at")
# ### end Alembic commands ###
16 changes: 14 additions & 2 deletions inference/server/oasst_inference_server/models/chat.py
Expand Up @@ -64,14 +64,26 @@ class DbChat(SQLModel, table=True):
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)

user_id: str = Field(foreign_key="user.id", index=True)
created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow, index=True)
modified_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow, index=True)
title: str | None = Field(None)

messages: list[DbMessage] = Relationship(back_populates="chat")

def to_list_read(self) -> chat_schema.ChatListRead:
return chat_schema.ChatListRead(id=self.id)
return chat_schema.ChatListRead(
id=self.id,
created_at=self.created_at,
title=self.title,
)

def to_read(self) -> chat_schema.ChatRead:
return chat_schema.ChatRead(id=self.id, messages=[m.to_read() for m in self.messages])
return chat_schema.ChatRead(
id=self.id,
created_at=self.created_at,
title=self.title,
messages=[m.to_read() for m in self.messages],
)

def get_msg_dict(self) -> dict[str, DbMessage]:
return {m.id: m for m in self.messages}
Expand Down
1 change: 1 addition & 0 deletions inference/server/oasst_inference_server/routes/chats.py
Expand Up @@ -55,6 +55,7 @@ async def create_message(
"""Allows the client to stream the results of a request."""

try:
ucr: UserChatRepository
async with deps.manual_user_chat_repository(user_id) as ucr:
prompter_message = await ucr.add_prompter_message(
chat_id=chat_id, parent_id=request.parent_id, content=request.content
Expand Down
6 changes: 4 additions & 2 deletions inference/server/oasst_inference_server/schemas/chat.py
@@ -1,3 +1,4 @@
import datetime
from typing import Annotated, Literal, Union

import pydantic
Expand Down Expand Up @@ -58,10 +59,11 @@ class CreateChatRequest(pydantic.BaseModel):

class ChatListRead(pydantic.BaseModel):
id: str
created_at: datetime.datetime
title: str | None


class ChatRead(pydantic.BaseModel):
id: str
class ChatRead(ChatListRead):
messages: list[inference.MessageRead]


Expand Down
Expand Up @@ -72,6 +72,7 @@ async def add_prompter_message(self, chat_id: str, parent_id: str | None, conten
if parent_id is None:
if len(chat.messages) > 0:
raise fastapi.HTTPException(status_code=400, detail="Trying to add first message to non-empty chat")
chat.title = content
else:
msg_dict = chat.get_msg_dict()
if parent_id not in msg_dict:
Expand All @@ -87,6 +88,7 @@ async def add_prompter_message(self, chat_id: str, parent_id: str | None, conten
content=content,
)
self.session.add(message)
chat.modified_at = message.created_at

await self.session.commit()
logger.debug(f"Added prompter message {len(content)=} to chat {chat_id}")
Expand Down

0 comments on commit 31faa39

Please sign in to comment.