-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
640 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from tortoise import BaseDBAsyncClient | ||
|
||
|
||
async def upgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
CREATE TABLE IF NOT EXISTS "conversation_elements" ( | ||
"id" SERIAL NOT NULL PRIMARY KEY, | ||
"last_modified" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
"first_created" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
"conversation_id" UUID NOT NULL, | ||
"role" VARCHAR(9) NOT NULL, | ||
"content" TEXT NOT NULL | ||
); | ||
COMMENT ON COLUMN "conversation_elements"."role" IS 'USER: user\nSYSTEM: system\nASSISTANT: assistant'; | ||
COMMENT ON TABLE "conversation_elements" IS 'Model for a snippet of a Conversation.';""" | ||
|
||
|
||
async def downgrade(db: BaseDBAsyncClient) -> str: | ||
return """ | ||
DROP TABLE IF EXISTS "conversation_elements";""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from enum import Enum | ||
|
||
from tortoise import fields | ||
|
||
from fia_api.db.models.fia_base_model import FiaBaseModel | ||
|
||
|
||
class ConversationElementRole(Enum): | ||
"""Enum for role the message content is from.""" | ||
|
||
USER = "user" | ||
SYSTEM = "system" | ||
ASSISTANT = "assistant" | ||
|
||
|
||
class ConversationElementModel(FiaBaseModel): | ||
"""Model for a snippet of a Conversation.""" | ||
|
||
conversation_id = fields.data.UUIDField(null=False, required=True) | ||
role = fields.data.CharEnumField(ConversationElementRole, required=True) | ||
content = fields.data.TextField(null=False, required=True) | ||
|
||
def __str__(self) -> str: | ||
return f"ConversationElement: {self.id}" | ||
|
||
class Meta: | ||
ordering = ["first_created"] | ||
table = "conversation_elements" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
from fastapi.routing import APIRouter | ||
|
||
from fia_api.web.api import dummy, echo, monitoring, redis, user | ||
from fia_api.web.api import dummy, echo, monitoring, redis, teacher, user | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(monitoring.router) | ||
api_router.include_router(echo.router, prefix="/echo", tags=["echo"]) | ||
api_router.include_router(dummy.router, prefix="/dummy", tags=["dummy"]) | ||
api_router.include_router(redis.router, prefix="/redis", tags=["redis"]) | ||
api_router.include_router(user.router, prefix="/user", tags=["user"]) | ||
api_router.include_router(teacher.router, prefix="/teacher", tags=["teacher"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""Teacher model API.""" | ||
from fia_api.web.api.teacher.views import router | ||
|
||
__all__ = ["router"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import List | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
# OpenAI Related Response Models: | ||
class Mistake(BaseModel): | ||
"""Represents a "Mistake" that the model found in the user input.""" | ||
|
||
mistake_text: str | ||
fixed_text: str | ||
explanation: str | ||
|
||
|
||
class TranslatedWords(BaseModel): | ||
"""Represents a "Translation" that the model was requested to make.""" | ||
|
||
word: str | ||
translated_word: str | ||
|
||
|
||
class TeacherResponse(BaseModel): | ||
"""Represents the entire response from the "Teacher".""" | ||
|
||
translated_words: List[TranslatedWords] | ||
mistakes: List[Mistake] | ||
conversation_response: str | ||
|
||
|
||
class TeacherConverseRequest(BaseModel): | ||
"""Request object for calls to the Teacher to continue a conversation.""" | ||
|
||
# If conversation_id is "new", then start a new conversation. | ||
conversation_id: str | ||
message: str | ||
|
||
|
||
class TeacherConverseResponse(BaseModel): | ||
"""Response from a call to the teacher/converse endpoint.""" | ||
|
||
conversation_id: str | ||
response: TeacherResponse | ||
|
||
|
||
class TeacherConverseEndRequest(BaseModel): | ||
"""Response from a call to the teacher/end_converse endpoint.""" | ||
|
||
conversation_id: str | ||
|
||
|
||
class TeacherConverseEndResponse(BaseModel): | ||
"""Response from a call to the teacher/end_converse endpoint.""" | ||
|
||
message: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# noqa: WPS462 | ||
import json | ||
import uuid | ||
from typing import Dict, List | ||
|
||
import openai | ||
|
||
from fia_api.db.models.conversation_model import ( | ||
ConversationElementModel, | ||
ConversationElementRole, | ||
) | ||
from fia_api.settings import settings | ||
from fia_api.web.api.teacher.schema import TeacherConverseResponse, TeacherResponse | ||
|
||
openai.api_key = settings.openai_api_key | ||
|
||
|
||
async def get_messages_from_conversation_id( | ||
conversation_id: str, | ||
) -> List[Dict[str, str]]: | ||
""" | ||
Given a conversation_id, return a list of dicts ready to pass to OpenAI. | ||
:param conversation_id: str ID of the conversation | ||
:return: List of dicts of shape {"role": EnumValue, "content": "message"} | ||
""" | ||
raw_conversation = await ConversationElementModel.filter( | ||
conversation_id=uuid.UUID(conversation_id), | ||
).values() | ||
|
||
return [ | ||
{ | ||
"role": conversation_element["role"].value, | ||
"content": conversation_element["content"], | ||
} | ||
for conversation_element in raw_conversation | ||
] | ||
|
||
|
||
async def get_response(conversation_id: str, message: str) -> TeacherConverseResponse: | ||
""" | ||
Converse with OpenAI. | ||
Given the conversation ID, and a new message to add to it, store the | ||
message, get the response, store that, and return it. | ||
:param conversation_id: String ID representing the conversation. | ||
:param message: String message the user wants to send. | ||
:return: TeacherResponse | ||
""" | ||
await ConversationElementModel.create( | ||
conversation_id=uuid.UUID(conversation_id), | ||
role=ConversationElementRole.USER, | ||
content=message, | ||
) | ||
|
||
chat_response = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo-0613", | ||
messages=await get_messages_from_conversation_id(conversation_id), | ||
functions=[ | ||
{ | ||
"name": "get_answer_for_user_query", | ||
"description": "Get user language learning mistakes and a sentence to continue the conversation", # noqa: E501 | ||
"parameters": TeacherResponse.schema(), | ||
}, | ||
], | ||
function_call={"name": "get_answer_for_user_query"}, | ||
) | ||
|
||
teacher_response = chat_response.choices[ # noqa: WPS219 | ||
0 | ||
].message.function_call.arguments | ||
|
||
await ConversationElementModel.create( | ||
conversation_id=uuid.UUID(conversation_id), | ||
role=ConversationElementRole.SYSTEM, | ||
content=teacher_response, | ||
) | ||
|
||
# TODO: Store the token usage in the DB... | ||
|
||
return TeacherConverseResponse( | ||
conversation_id=conversation_id, | ||
response=json.loads(teacher_response), | ||
) | ||
|
||
|
||
async def initialize_conversation(message: str) -> TeacherConverseResponse: | ||
""" | ||
Starts the conversation. | ||
Set up the DB with the initial conversation prompt and return the new | ||
conversation ID, along with the first response from the model. | ||
:param message: The message to start the conversation with. | ||
:returns: TeacherConverseResponse of the teacher's first reply. | ||
""" | ||
conversation_id = uuid.uuid4() | ||
|
||
await ConversationElementModel.create( | ||
conversation_id=conversation_id, | ||
role=ConversationElementRole.SYSTEM, | ||
content=settings.prompts["p3"], | ||
) | ||
|
||
return await get_response(str(conversation_id), message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from fastapi import APIRouter, Depends | ||
|
||
from fia_api.web.api.teacher.schema import ( | ||
TeacherConverseRequest, | ||
TeacherConverseResponse, | ||
) | ||
from fia_api.web.api.teacher.utils import get_response, initialize_conversation | ||
from fia_api.web.api.user.schema import AuthenticatedUser | ||
from fia_api.web.api.user.utils import get_current_user | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/converse", response_model=TeacherConverseResponse) | ||
async def converse( | ||
converse_request: TeacherConverseRequest, | ||
user: AuthenticatedUser = Depends(get_current_user), | ||
) -> TeacherConverseResponse: | ||
""" | ||
Starts or continues a conversation with the Teacher. | ||
:param converse_request: The request object. | ||
:param user: The AuthenticatedUser making the request. | ||
:returns: TeacherConverseResponse of mistakes and conversation. | ||
""" | ||
if converse_request.conversation_id == "new": | ||
return await initialize_conversation(converse_request.message) | ||
|
||
return await get_response( | ||
converse_request.conversation_id, | ||
converse_request.message, | ||
) |
Oops, something went wrong.