diff --git a/fia_api/settings.py b/fia_api/settings.py index d40fa24..c7e97bd 100644 --- a/fia_api/settings.py +++ b/fia_api/settings.py @@ -67,6 +67,9 @@ class Settings(BaseSettings): openai_api_key: str = "INVALID_OPENAI_API_KEY" + get_learning_moments_prompt: str = """You are an expert German language teacher who works with native English speakers to help them learn German. They give you a message and you explain each mistake in their message. You give them "Learning Moments" which they can review and learn from.""" + conversation_continuation_prompt: str = """You are a native German speaker who is helping someone to learn to speak German. They are a beginner and want to try have a conversation only in German with you. Sometimes they make spelling/grammar mistakes, but you always try to continue on the conversation while only sometimes explaining their mistakes to them. You are friendly and ask questions to direct the conversation to help the user learn. You are allowed to use English if the user asks you what a word means. You speak in very simple, short sentences.""" + prompts: Dict[str, Dict[str, str]] = { "p1": { "role": "system", diff --git a/fia_api/tests/test_conversations.py b/fia_api/tests/test_conversations.py index 80ff045..cd8cefc 100644 --- a/fia_api/tests/test_conversations.py +++ b/fia_api/tests/test_conversations.py @@ -5,6 +5,7 @@ import pytest from fastapi import FastAPI from httpx import AsyncClient +from loguru import logger from pytest_mock import MockerFixture @@ -39,6 +40,57 @@ class OpenAIAPIResponse: usage: Dict[str, int] +def get_mocked_openai_response(*args, **kwargs) -> OpenAIAPIResponse: # type: ignore + """ + Return the mocked OpenAI API response based on the input. + + :param args: All args passed to OpenAI + :param kwargs: All kwargs passed to OpenAI + :returns: OpenAIAPIReponse + """ + learning_moments_api_response = OpenAIAPIResponse( + choices=[ + OpenAIAPIChoices( + message=OpenAIAPIMessage( + role="assistant", + function_call=OpenAIAPIFunctionCall( + name="get_learning_moments", + arguments='{\n "learning_moments": [\n {\n "moment": {\n "incorrect_section": "Hallo",\n "corrected_section": "Hallo,",\n "explanation": "In German, a comma is often used after greetings like \'Hallo\' or \'Guten Tag\'."\n }\n },\n {\n "moment": {\n "incorrect_section": "Wie Geht\'s?",\n "corrected_section": "Wie geht es dir?",\n "explanation": "The correct way to ask \'How are you?\' in German is \'Wie geht es dir?\'"\n }\n }\n ]\n}', # noqa: E501 + ), + ), + ), + ], + usage={ + "prompt_tokens": 181, + "completion_tokens": 114, + "total_tokens": 295, + }, + ) + chat_continuation_api_response = OpenAIAPIResponse( + choices=[ + OpenAIAPIChoices( + message=OpenAIAPIMessage( + role="assistant", + function_call=OpenAIAPIFunctionCall( + name="get_conversation_response", + arguments='{\n"message": "Mir geht es gut, danke! Wie geht es dir?"\n}', # noqa: E501 + ), + ), + ), + ], + usage={ + "prompt_tokens": 181, + "completion_tokens": 114, + "total_tokens": 295, + }, + ) + + if kwargs["functions"][0]["name"] == "get_learning_moments": + return learning_moments_api_response + + return chat_continuation_api_response + + async def get_access_token( fastapi_app: FastAPI, client: AsyncClient, @@ -101,6 +153,7 @@ async def test_conversations( list_conversations_url = fastapi_app.url_path_for("list_user_conversations") get_conversation_url = fastapi_app.url_path_for("get_user_conversation") converse_url = fastapi_app.url_path_for("converse") + get_flashcards_url = fastapi_app.url_path_for("get_flashcards") # No conversations by default: response = await client.get( @@ -109,28 +162,17 @@ async def test_conversations( ) assert not response.json()["conversations"] - # Begin conversation: - api_response = OpenAIAPIResponse( - choices=[ - OpenAIAPIChoices( - message=OpenAIAPIMessage( - role="assistant", - function_call=OpenAIAPIFunctionCall( - name="get_answer_for_user_query", - arguments='{\n "translated_words": [\n {\n "word": "Hallo",\n "translated_word": "Hello"\n },\n {\n "word": "Wie",\n "translated_word": "How"\n },\n {\n "word": "Geht",\n "translated_word": "is going"\n },\n {\n "word": "s",\n "translated_word": "it"\n }\n ],\n "mistakes": [],\n "conversation_response": "Mir geht es gut, danke. Wie kann ich Ihnen helfen?"\n}', # noqa: E501 - ), - ), - ), - ], - usage={ - "prompt_tokens": 181, - "completion_tokens": 114, - "total_tokens": 295, - }, + # No flashcards by default: + response = await client.get( + get_flashcards_url, + headers=auth_headers, ) + assert not response.json()["flashcards"] + + # Begin conversation: mocker.patch( - "fia_api.web.api.teacher.utils.get_openai_response", - return_value=api_response, + "fia_api.web.api.teacher.utils.openai.ChatCompletion.create", + side_effect=get_mocked_openai_response, ) response = await client.post( converse_url, @@ -142,11 +184,12 @@ async def test_conversations( ) conversation_id = response.json()["conversation_id"] - conversation = response.json()["conversation"] + conversation_response = response.json()["conversation_response"] assert conversation_id != "new" - assert len(conversation) == 1 - assert conversation[0]["conversation_element"]["role"] == "teacher" + assert len(response.json()["learning_moments"]) > 0 # noqa: WPS507 + assert isinstance(conversation_response, str) + assert len(conversation_response) > 0 # noqa: WPS507 # Now one conversation response = await client.get( @@ -167,8 +210,9 @@ async def test_conversations( ) assert conversation_id == response.json()["conversation_id"] - assert len(conversation) == 1 - assert conversation[0]["conversation_element"]["role"] == "teacher" + assert len(response.json()["learning_moments"]) > 0 # noqa: WPS507 + assert isinstance(conversation_response, str) + assert len(conversation_response) > 0 # noqa: WPS507 # Now get conversation: response = await client.get( @@ -178,4 +222,12 @@ async def test_conversations( "conversation_id": conversation_id, }, ) + logger.error(response.json()) assert len(response.json()["conversation"]) == 4 + + # Now we have flashcards + response = await client.get( + get_flashcards_url, + headers=auth_headers, + ) + assert len(response.json()["flashcards"]) == 4 diff --git a/fia_api/web/api/teacher/schema.py b/fia_api/web/api/teacher/schema.py index 860b7e1..2ea13a4 100644 --- a/fia_api/web/api/teacher/schema.py +++ b/fia_api/web/api/teacher/schema.py @@ -1,30 +1,94 @@ -from typing import List, Union +from typing import List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field # OpenAI Related Response Models: class Mistake(BaseModel): - """Represents a "Mistake" that the model found in the user input.""" - - mistake_text: str - fixed_text: str - explanation: str - - -class TranslatedWords(BaseModel): - """Represents a "Translation" that the model was requested to make.""" - - word: str - translated_word: str + """A single Mistake a user made in their message.""" + + incorrect_section: str = Field( + description=( + "The section of the sentence of the user message the grammar " + "mistake is in" + ), + ) + corrected_section: str = Field( + description="The corrected section of the sentence in German", + ) + explanation: str = Field( + description=( + "The English language explanation of why this section of the " + "sentence is incorrect. Give details such as if it is using the " + "wrong gender/suffix, if the verb conjugation is wrong, etc. If " + "the sentence is correct, but there is a better way to phrase it, " + "explain this too." + ), + ) + + +class Translation(BaseModel): + """A word or phrase the user wants translated.""" + + phrase: str = Field( + description=( + 'The word or phrase the user wants translated. e.g. "Book", ' + "'gerne', \"lesen\"" + ), + ) + translated_phrase: str = Field( + description=( + "The translation of the word or phrase in context of the " + "sentence. If it is a noun, include the correct gender. " + "e.g. Das Buch, Die Hande, etc" + ), + ) + + +class LearningMoment(BaseModel): + """A moment in a conversation a user can learn from.""" + + moment: Union[Mistake, Translation] = Field( + description=( + "A single language learning mistake found in a section of the " + "users message" + ), + ) + + +# This is returned by the model looking for mistakes in the user sentence. +class LearningMoments(BaseModel): + """A list of individual LearningMoment objects.""" + + learning_moments: List[LearningMoment] = Field( + description=( + "A list of language learning mistakes in the users message. " + "There should be one LearningMoment per individual mistake in " + "the sentence." + ), + ) + + +# This is returned by the model trying to continue on the conversation with the +# user in an educational way. +class ConversationContinuation(BaseModel): + """Basic wrapper to supply description data to OpenAI.""" + + message: str = Field( + description=( + "This is the response to to users message. You should always " + "try respond in German. If the user doesn't understand, then try " + "to use even more simple German in your response until you can " + "only use English. Responding in English is a last resort." + ), + ) class TeacherResponse(BaseModel): """Represents the entire response from the "Teacher".""" - translated_words: List[TranslatedWords] - mistakes: List[Mistake] - conversation_response: str + learning_moments: LearningMoments + conversation_response: ConversationContinuation class TeacherConverseRequest(BaseModel): @@ -48,31 +112,12 @@ class UserConversationList(BaseModel): conversations: List[ConversationSnippet] -class ConversationLine(BaseModel): - """A single line of a conversation.""" - - # If str, then this is just a user response. - line: Union[TeacherResponse, str] - - -class UserConversationElement(BaseModel): - """The user part of the conversation.""" - - role: str = "user" - message: str - - -class TeacherConversationElement(BaseModel): - """The teacher response.""" - - role: str = "teacher" - response: TeacherResponse - - class ConversationElement(BaseModel): """A single part of a conversation. Either from the user or system.""" - conversation_element: Union[TeacherConversationElement, UserConversationElement] + role: str + message: str + learning_moments: Optional[LearningMoments] = None class ConversationResponse(BaseModel): @@ -80,3 +125,11 @@ class ConversationResponse(BaseModel): conversation_id: str conversation: List[ConversationElement] + + +class ConverseResponse(BaseModel): + """Response from the Converse endpoint.""" + + conversation_id: str + learning_moments: LearningMoments + conversation_response: str diff --git a/fia_api/web/api/teacher/utils.py b/fia_api/web/api/teacher/utils.py index 6af04a7..5679f3f 100644 --- a/fia_api/web/api/teacher/utils.py +++ b/fia_api/web/api/teacher/utils.py @@ -1,7 +1,7 @@ # noqa: WPS462 import json import uuid -from typing import Dict, List +from typing import Any, Dict, List import openai from loguru import logger @@ -14,12 +14,40 @@ from fia_api.db.models.user_conversation_model import UserConversationModel from fia_api.db.models.user_model import UserModel from fia_api.settings import settings -from fia_api.web.api.teacher.schema import ConversationResponse, TeacherResponse -from fia_api.web.api.user.utils import format_conversation_for_response +from fia_api.web.api.flashcards.utils import create_flashcard +from fia_api.web.api.teacher.schema import ( + ConversationContinuation, + ConverseResponse, + LearningMoments, + Mistake, + Translation, +) openai.api_key = settings.openai_api_key +async def store_token_usage( + conversation_id: str, + openai_response: Any, +) -> None: # type: ignore + """ + Store the token usage for an OpenAI request. + + :param conversation_id: String to store the usage under + :param openai_response: The messy openAI datatype + """ + token_usage_model = await TokenUsageModel.get( + conversation_id=uuid.UUID(conversation_id), + ) + + token_usage_model.prompt_token_usage += openai_response.usage["prompt_tokens"] + token_usage_model.completion_token_usage += openai_response.usage[ + "completion_tokens" + ] + + await token_usage_model.save() + + async def get_messages_from_conversation_id( conversation_id: str, ) -> List[Dict[str, str]]: @@ -42,29 +70,124 @@ async def get_messages_from_conversation_id( ] -# Ignoring type as I don't really know what the OpenAI API returns... -async def get_openai_response(conversation_id: str): # type: ignore +async def get_learning_moments_from_message( + message: str, + conversation_id: str, +) -> LearningMoments: """ - Wraps the OpenAI API call. Mostly to make mocking easier. + Get LearningMoments from a user message. - :param conversation_id: String conversation_id - :return: OpenAI Chat Response object + :param message: String message from the user to look for mistakes in. + :param conversation_id: Store the token usage in the conversation. + :returns: LearningMoments """ - return openai.ChatCompletion.create( + openai_response = openai.ChatCompletion.create( + model="gpt-3.5-turbo-0613", + messages=[ + { + "role": "assistant", + "content": settings.get_learning_moments_prompt, + }, + { + "role": "user", + "content": message, + }, + ], + functions=[ + { + "name": "get_learning_moments", + "description": "List all of the mistakes in the user's message and any words in the user message that they would like translated.", # noqa: E501 + "parameters": LearningMoments.schema(), + }, + ], + function_call={"name": "get_learning_moments"}, + ) + + await store_token_usage(conversation_id, openai_response) + + return LearningMoments( + **json.loads( + openai_response.choices[0].message.function_call.arguments, # noqa: WPS219 + ), + ) + + +async def get_conversation_continuation( + conversation_id: str, +) -> ConversationContinuation: + """ + Continue the conversation with the user based on the context. + + The conversation in the DB must be updated with the most recent user + message. + + :param conversation_id: String conversation to continue on. + :returns: ConversationContinuation + """ + openai_response = openai.ChatCompletion.create( model="gpt-3.5-turbo-0613", messages=await get_messages_from_conversation_id(conversation_id), functions=[ { - "name": "get_answer_for_user_query", - "description": "Get user language learning mistakes and a sentence to continue the conversation", # noqa: E501 - "parameters": TeacherResponse.schema(), + "name": "get_conversation_response", + "description": "Get the conversational response to the user's message.", + "parameters": ConversationContinuation.schema(), }, ], - function_call={"name": "get_answer_for_user_query"}, + function_call={"name": "get_conversation_response"}, ) + await store_token_usage(conversation_id, openai_response) -async def get_response(conversation_id: str, message: str) -> ConversationResponse: + return ConversationContinuation( + **json.loads( + openai_response.choices[0].message.function_call.arguments, # noqa: WPS219 + ), + ) + + +async def create_flashcards_from_learning_moments( + learning_moments: LearningMoments, + user: UserModel, + conversation_id: str, +) -> None: + """ + Store each learning moment as a flashcard. + + :param learning_moments: LearningMoments to store as flashcards. + :param user: UserModel to associate with the flashcards. + :param conversation_id: String conversation ID for context. + """ + for learning_moment in learning_moments.learning_moments: + parsed_learning_moment = learning_moment.moment + + if isinstance(parsed_learning_moment, Mistake): + await create_flashcard( + user.username, + parsed_learning_moment.incorrect_section, + parsed_learning_moment.corrected_section + + "\n\n" + + parsed_learning_moment.explanation, + conversation_id, + ) + elif isinstance(parsed_learning_moment, Translation): + await create_flashcard( + user.username, + parsed_learning_moment.phrase, + parsed_learning_moment.translated_phrase, + conversation_id, + both_sides=True, + ) + else: + logger.error("Some weirdness going on....") + logger.error(learning_moment) + + +async def get_response( + conversation_id: str, + message: str, + user: UserModel, +) -> ConverseResponse: """ Converse with OpenAI. @@ -73,7 +196,8 @@ async def get_response(conversation_id: str, message: str) -> ConversationRespon :param conversation_id: String ID representing the conversation. :param message: String message the user wants to send. - :return: ConversationResponse + :param user: UserModel, needed to store flashcards. + :return: ConverseResponse """ await ConversationElementModel.create( conversation_id=uuid.UUID(conversation_id), @@ -81,41 +205,37 @@ async def get_response(conversation_id: str, message: str) -> ConversationRespon content=message, ) - chat_response = await get_openai_response(conversation_id) - logger.warning("-------------------------") - logger.warning(chat_response) - logger.warning("-------------------------") - - # Do this JSON dance to have it serialize correctly. - teacher_response = json.dumps( - json.loads( - chat_response.choices[0].message.function_call.arguments, # noqa: WPS219 - ), + learning_moments = await get_learning_moments_from_message( + message, + conversation_id, + ) + await create_flashcards_from_learning_moments( + learning_moments, + user, + conversation_id, ) + conversation_continuation = await get_conversation_continuation(conversation_id) + + # TODO: Store the learning moments. + logger.info(learning_moments) await ConversationElementModel.create( conversation_id=uuid.UUID(conversation_id), role=ConversationElementRole.SYSTEM, - content=teacher_response, + content=conversation_continuation.message, ) - token_usage_model = await TokenUsageModel.get( - conversation_id=uuid.UUID(conversation_id), - ) - token_usage_model.prompt_token_usage += chat_response.usage["prompt_tokens"] - token_usage_model.completion_token_usage += chat_response.usage["completion_tokens"] - await token_usage_model.save() - - return await format_conversation_for_response( - conversation_id, - last=True, + return ConverseResponse( + conversation_id=conversation_id, + learning_moments=learning_moments, + conversation_response=conversation_continuation.message, ) async def initialize_conversation( user: UserModel, message: str, -) -> ConversationResponse: +) -> ConverseResponse: """ Starts the conversation. @@ -130,8 +250,8 @@ async def initialize_conversation( await ConversationElementModel.create( conversation_id=conversation_id, - role=ConversationElementRole.SYSTEM, - content=settings.prompts["p2"], + role=ConversationElementRole.ASSISTANT, + content=settings.conversation_continuation_prompt, ) await UserConversationModel.create( @@ -141,4 +261,4 @@ async def initialize_conversation( await TokenUsageModel.create(conversation_id=conversation_id) - return await get_response(str(conversation_id), message) + return await get_response(str(conversation_id), message, user) diff --git a/fia_api/web/api/teacher/views.py b/fia_api/web/api/teacher/views.py index 0fcadc1..00c8c21 100644 --- a/fia_api/web/api/teacher/views.py +++ b/fia_api/web/api/teacher/views.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from fia_api.db.models.user_model import UserModel -from fia_api.web.api.teacher.schema import ConversationResponse, TeacherConverseRequest +from fia_api.web.api.teacher.schema import ConverseResponse, TeacherConverseRequest from fia_api.web.api.teacher.utils import get_response, initialize_conversation from fia_api.web.api.user.schema import AuthenticatedUser from fia_api.web.api.user.utils import get_current_user @@ -9,17 +9,17 @@ router = APIRouter() -@router.post("/converse", response_model=ConversationResponse) +@router.post("/converse", response_model=ConverseResponse) async def converse( converse_request: TeacherConverseRequest, user: AuthenticatedUser = Depends(get_current_user), -) -> ConversationResponse: +) -> ConverseResponse: """ Starts or continues a conversation with the Teacher. :param converse_request: The request object. :param user: The AuthenticatedUser making the request. - :returns: ConversationResponse of mistakes and conversation. + :returns: ConverseResponse of mistakes and conversation. """ if converse_request.conversation_id == "new": return await initialize_conversation( @@ -30,4 +30,5 @@ async def converse( return await get_response( converse_request.conversation_id, converse_request.message, + await UserModel.get(username=user.username), ) diff --git a/fia_api/web/api/user/utils.py b/fia_api/web/api/user/utils.py index bbd29ad..e6b42f4 100644 --- a/fia_api/web/api/user/utils.py +++ b/fia_api/web/api/user/utils.py @@ -1,4 +1,3 @@ -import json import uuid from datetime import datetime, timedelta from typing import Any, Dict, Union @@ -16,13 +15,7 @@ ) from fia_api.db.models.user_model import UserModel from fia_api.settings import settings -from fia_api.web.api.teacher.schema import ( - ConversationElement, - ConversationResponse, - TeacherConversationElement, - TeacherResponse, - UserConversationElement, -) +from fia_api.web.api.teacher.schema import ConversationElement, ConversationResponse from fia_api.web.api.user.schema import AuthenticatedUser, TokenPayload ACCESS_TOKEN_EXPIRY_MINUTES = 30 @@ -132,23 +125,26 @@ async def get_current_user(token: str = Depends(reuseable_oauth)) -> Authenticat async def format_conversation_element( - conversation_element: Dict[str, str], -) -> Union[TeacherConversationElement, UserConversationElement]: + conversation_element: Dict[Any, Any], +) -> ConversationElement: """ - Return the correct object. + Formats a conversation element like dict into an actual object. - :param conversation_element: A dict of the current conversation element. - :returns: A sensible representation of the conversation element. + :param conversation_element: dict of mostly str -> str + :return: ConversationElement """ - # Ignoring type here as MyPy doesn't co-operate with Tortoise ORM for enums. - if conversation_element["role"] == ConversationElementRole.SYSTEM: # type: ignore - return TeacherConversationElement( - response=TeacherResponse( - **json.loads(conversation_element["content"]), - ), + parsed_element = { + "role": conversation_element["role"].value, + "message": conversation_element["content"], + } + + if "learning_moments" in conversation_element: + parsed_element["learning_moments"] = conversation_element.get( + "learning_moments", + None, ) - return UserConversationElement(message=conversation_element["content"]) + return ConversationElement(**parsed_element) async def format_conversation_for_response( @@ -163,27 +159,24 @@ async def format_conversation_for_response( of the whole conversation. :returns: ConversationResponse """ - raw_conversation = await ConversationElementModel.filter( - conversation_id=uuid.UUID(conversation_id), - ).values() + raw_conversation = ( + await ConversationElementModel.filter( + conversation_id=uuid.UUID(conversation_id), + ) + .exclude( + role=ConversationElementRole.ASSISTANT, + ) + .values() + ) if last: raw_conversation = [raw_conversation[-1]] conversation_list = [] for conversation_element in raw_conversation: - try: - conversation_list.append( - ConversationElement( - conversation_element=await format_conversation_element( - conversation_element, - ), - ), - ) - except Exception as ex: - # TODO: This is expected in the case of the initial assistant - # message. - logger.debug(ex) + conversation_list.append( + await format_conversation_element(conversation_element), + ) return ConversationResponse( conversation_id=conversation_id,