Skip to content

Commit

Permalink
Merge pull request #628 from i-dot-ai/feature/persist-route-in-django…
Browse files Browse the repository at this point in the history
…-chat-message

Persist route in Django chat message.
  • Loading branch information
brunns committed Jun 24, 2024
2 parents 1ccca60 + e640dc1 commit 5c86875
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 17 deletions.
4 changes: 2 additions & 2 deletions django_app/redbox_app/redbox_core/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class FileAdmin(admin.ModelAdmin):


class ChatMessageAdmin(admin.ModelAdmin):
list_display = ["chat_history", "get_user", "text", "role", "created_at"]
list_filter = ["role", "chat_history__users"]
list_display = ["chat_history", "get_user", "text", "role", "route", "created_at"]
list_filter = ["role", "route", "chat_history__users"]
date_hierarchy = "created_at"

@admin.display(ordering="chat_history__users", description="User")
Expand Down
22 changes: 15 additions & 7 deletions django_app/redbox_app/redbox_core/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclasses_json import Undefined, dataclass_json
from django.conf import settings
from django.utils import timezone
from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, File, User
from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, ChatRoute, File, User
from websockets import WebSocketClientProtocol
from websockets.client import connect
from yarl import URL
Expand All @@ -29,7 +29,7 @@ class CoreChatResponseDoc:
@dataclass_json(undefined=Undefined.EXCLUDE)
@dataclass(frozen=True)
class CoreChatResponse:
resource_type: Literal["text", "documents", "end"]
resource_type: Literal["text", "documents", "route_name", "end"]
data: list[CoreChatResponseDoc] | str | None = None


Expand Down Expand Up @@ -64,26 +64,29 @@ async def llm_conversation(self, selected_files: Sequence[File], session: ChatHi
}
await self.send_to_server(core_websocket, message)
await self.send_to_client({"type": "session-id", "data": str(session.id)})
reply, source_files = await self.receive_llm_responses(user, core_websocket)
await self.save_message(session, reply, ChatRoleEnum.ai, source_files=source_files)
reply, source_files, route = await self.receive_llm_responses(user, core_websocket)
await self.save_message(session, reply, ChatRoleEnum.ai, source_files=source_files, route=route)

for file in source_files:
file.last_referenced = timezone.now()
await self.file_save(file)

async def receive_llm_responses(
self, user: User, core_websocket: WebSocketClientProtocol
) -> tuple[str, Sequence[File]]:
) -> tuple[str, Sequence[File], ChatRoute]:
full_reply: MutableSequence[str] = []
source_files: MutableSequence[File] = []
route: ChatRoute | None = None
async for raw_message in core_websocket:
message = CoreChatResponse.schema().loads(raw_message)
logger.debug("received %s from core-api", message)
if message.resource_type == "text":
full_reply.append(await self.handle_text(message))
elif message.resource_type == "documents":
source_files += await self.handle_documents(message, user)
return "".join(full_reply), source_files
elif message.resource_type == "route_name":
route = await self.handle_route(message)
return "".join(full_reply), source_files, route

async def handle_documents(self, message: CoreChatResponse, user: User) -> Sequence[File]:
doc_uuids: Sequence[UUID] = [UUID(doc.file_uuid) for doc in message.data]
Expand All @@ -101,6 +104,10 @@ async def handle_text(self, message: CoreChatResponse) -> str:
await self.send_to_client({"type": "text", "data": message.data})
return message.data

async def handle_route(self, message: CoreChatResponse) -> ChatRoute:
await self.send_to_client({"type": "route", "data": message.data})
return ChatRoute[message.data]

async def send_to_client(self, data):
logger.debug("sending %s to browser", data)
await self.send(json.dumps(data, default=str))
Expand Down Expand Up @@ -134,8 +141,9 @@ def save_message(
role: ChatRoleEnum,
source_files: OptFileSeq = None,
selected_files: OptFileSeq = None,
route: ChatRoute | None = None,
) -> ChatMessage:
chat_message = ChatMessage(chat_history=session, text=user_message_text, role=role)
chat_message = ChatMessage(chat_history=session, text=user_message_text, role=role, route=route)
chat_message.save()
if source_files:
chat_message.source_files.set(source_files)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.0.6 on 2024-06-21 07:08

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('redbox_core', '0017_alter_user_business_unit_alter_user_grade_and_more'),
]

operations = [
migrations.AddField(
model_name='chatmessage',
name='route',
field=models.CharField(blank=True, choices=[('info', 'Info'), ('ability', 'Ability'), ('coach', 'Coach'), ('gratitude', 'Gratitude'), ('retrieval', 'Retrieval'), ('summarisation', 'Summarisation'), ('extract', 'Extract'), ('vanilla', 'Vanilla')], null=True),
),
]
12 changes: 12 additions & 0 deletions django_app/redbox_app/redbox_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,22 @@ class ChatRoleEnum(models.TextChoices):
system = "system"


class ChatRoute(models.TextChoices):
info = "info"
ability = "ability"
coach = "coach"
gratitude = "gratitude"
retrieval = "retrieval"
summarisation = "summarisation"
extract = "extract"
vanilla = "vanilla"


class ChatMessage(UUIDPrimaryKeyBase, TimeStampedModel):
chat_history = models.ForeignKey(ChatHistory, on_delete=models.CASCADE)
text = models.TextField(max_length=32768, null=False, blank=False)
role = models.CharField(choices=ChatRoleEnum.choices, null=False, blank=False)
route = models.CharField(choices=ChatRoute.choices, null=True, blank=True)
source_files = models.ManyToManyField(
File,
related_name="chat_messages",
Expand Down
16 changes: 13 additions & 3 deletions django_app/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from channels.testing import WebsocketCommunicator
from django.db.models import Model
from redbox_app.redbox_core.consumers import ChatConsumer
from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, File, User
from redbox_app.redbox_core.models import ChatHistory, ChatMessage, ChatRoleEnum, ChatRoute, File, User
from websockets import WebSocketClientProtocol
from websockets.legacy.client import Connect

Expand All @@ -33,20 +33,24 @@ async def test_chat_consumer_with_new_session(alice: User, uploaded_file: File,
response2 = await communicator.receive_json_from(timeout=5)
response3 = await communicator.receive_json_from(timeout=5)
response4 = await communicator.receive_json_from(timeout=5)
response5 = await communicator.receive_json_from(timeout=5)

# Then
assert response1["type"] == "session-id"
assert response2["type"] == "text"
assert response2["data"] == "Good afternoon, "
assert response3["type"] == "text"
assert response3["data"] == "Mr. Amor."
assert response4["type"] == "source"
assert response4["data"]["original_file_name"] == uploaded_file.original_file_name
assert response4["type"] == "route"
assert response4["data"] == "gratitude"
assert response5["type"] == "source"
assert response5["data"]["original_file_name"] == uploaded_file.original_file_name
# Close
await communicator.disconnect()

assert await get_chat_message_text(alice, ChatRoleEnum.user) == ["Hello Hal."]
assert await get_chat_message_text(alice, ChatRoleEnum.ai) == ["Good afternoon, Mr. Amor."]
assert await get_chat_message_route(alice, ChatRoleEnum.ai) == [ChatRoute.gratitude]
await refresh_from_db(uploaded_file)
assert uploaded_file.last_referenced.date() == datetime.now(tz=UTC).date()

Expand Down Expand Up @@ -82,6 +86,11 @@ def get_chat_message_text(user: User, role: ChatRoleEnum) -> Sequence[str]:
return [m.text for m in ChatMessage.objects.filter(chat_history__users=user, role=role)]


@database_sync_to_async
def get_chat_message_route(user: User, role: ChatRoleEnum) -> Sequence[ChatRoute]:
return [m.route for m in ChatMessage.objects.filter(chat_history__users=user, role=role)]


@pytest.mark.xfail()
@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio()
Expand Down Expand Up @@ -163,6 +172,7 @@ def mocked_connect(uploaded_file: File) -> Connect:
mocked_websocket.__aiter__.return_value = [
json.dumps({"resource_type": "text", "data": "Good afternoon, "}),
json.dumps({"resource_type": "text", "data": "Mr. Amor."}),
json.dumps({"resource_type": "route_name", "data": "gratitude"}),
json.dumps({"resource_type": "documents", "data": [{"file_uuid": str(uploaded_file.core_file_uuid)}]}),
json.dumps({"resource_type": "end"}),
]
Expand Down
5 changes: 1 addition & 4 deletions django_app/tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

import pytest
from django.core.files.uploadedfile import SimpleUploadedFile
from django_test_migrations.migrator import Migrator


@pytest.mark.django_db()
def test_0012_alter_file_status():
migrator = Migrator(database="default")

def test_0012_alter_file_status(migrator):
old_state = migrator.apply_initial_migration(("redbox_core", "0012_alter_file_status"))

original_file = SimpleUploadedFile("original_file.txt", b"Lorem Ipsum.")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ select = [
"YTT",
]
exclude = ["django_app/redbox_app/redbox_core/migrations/*.py", "out/*"]
ignore = ["COM812", "DJ001", "RET505", "RET508"]
ignore = ["COM812", "DJ001", "RET505", "RET508", "PLR0913"]

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101", "S106", "PLR0913", "PLR2004", "TD003", "S311"]
Expand Down

0 comments on commit 5c86875

Please sign in to comment.