Skip to content

Commit

Permalink
Split Models and Add Flashcard Creation (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
conor-f committed Sep 10, 2023
1 parent 48a67cc commit 685ccfd
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 143 deletions.
3 changes: 3 additions & 0 deletions fia_api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class Settings(BaseSettings):

openai_api_key: str = "INVALID_OPENAI_API_KEY"

get_learning_moments_prompt: str = """You are an expert German language teacher who works with native English speakers to help them learn German. They give you a message and you explain each mistake in their message. You give them "Learning Moments" which they can review and learn from."""
conversation_continuation_prompt: str = """You are a native German speaker who is helping someone to learn to speak German. They are a beginner and want to try have a conversation only in German with you. Sometimes they make spelling/grammar mistakes, but you always try to continue on the conversation while only sometimes explaining their mistakes to them. You are friendly and ask questions to direct the conversation to help the user learn. You are allowed to use English if the user asks you what a word means. You speak in very simple, short sentences."""

prompts: Dict[str, Dict[str, str]] = {
"p1": {
"role": "system",
Expand Down
102 changes: 77 additions & 25 deletions fia_api/tests/test_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from loguru import logger
from pytest_mock import MockerFixture


Expand Down Expand Up @@ -39,6 +40,57 @@ class OpenAIAPIResponse:
usage: Dict[str, int]


def get_mocked_openai_response(*args, **kwargs) -> OpenAIAPIResponse: # type: ignore
"""
Return the mocked OpenAI API response based on the input.
:param args: All args passed to OpenAI
:param kwargs: All kwargs passed to OpenAI
:returns: OpenAIAPIReponse
"""
learning_moments_api_response = OpenAIAPIResponse(
choices=[
OpenAIAPIChoices(
message=OpenAIAPIMessage(
role="assistant",
function_call=OpenAIAPIFunctionCall(
name="get_learning_moments",
arguments='{\n "learning_moments": [\n {\n "moment": {\n "incorrect_section": "Hallo",\n "corrected_section": "Hallo,",\n "explanation": "In German, a comma is often used after greetings like \'Hallo\' or \'Guten Tag\'."\n }\n },\n {\n "moment": {\n "incorrect_section": "Wie Geht\'s?",\n "corrected_section": "Wie geht es dir?",\n "explanation": "The correct way to ask \'How are you?\' in German is \'Wie geht es dir?\'"\n }\n }\n ]\n}', # noqa: E501
),
),
),
],
usage={
"prompt_tokens": 181,
"completion_tokens": 114,
"total_tokens": 295,
},
)
chat_continuation_api_response = OpenAIAPIResponse(
choices=[
OpenAIAPIChoices(
message=OpenAIAPIMessage(
role="assistant",
function_call=OpenAIAPIFunctionCall(
name="get_conversation_response",
arguments='{\n"message": "Mir geht es gut, danke! Wie geht es dir?"\n}', # noqa: E501
),
),
),
],
usage={
"prompt_tokens": 181,
"completion_tokens": 114,
"total_tokens": 295,
},
)

if kwargs["functions"][0]["name"] == "get_learning_moments":
return learning_moments_api_response

return chat_continuation_api_response


async def get_access_token(
fastapi_app: FastAPI,
client: AsyncClient,
Expand Down Expand Up @@ -101,6 +153,7 @@ async def test_conversations(
list_conversations_url = fastapi_app.url_path_for("list_user_conversations")
get_conversation_url = fastapi_app.url_path_for("get_user_conversation")
converse_url = fastapi_app.url_path_for("converse")
get_flashcards_url = fastapi_app.url_path_for("get_flashcards")

# No conversations by default:
response = await client.get(
Expand All @@ -109,28 +162,17 @@ async def test_conversations(
)
assert not response.json()["conversations"]

# Begin conversation:
api_response = OpenAIAPIResponse(
choices=[
OpenAIAPIChoices(
message=OpenAIAPIMessage(
role="assistant",
function_call=OpenAIAPIFunctionCall(
name="get_answer_for_user_query",
arguments='{\n "translated_words": [\n {\n "word": "Hallo",\n "translated_word": "Hello"\n },\n {\n "word": "Wie",\n "translated_word": "How"\n },\n {\n "word": "Geht",\n "translated_word": "is going"\n },\n {\n "word": "s",\n "translated_word": "it"\n }\n ],\n "mistakes": [],\n "conversation_response": "Mir geht es gut, danke. Wie kann ich Ihnen helfen?"\n}', # noqa: E501
),
),
),
],
usage={
"prompt_tokens": 181,
"completion_tokens": 114,
"total_tokens": 295,
},
# No flashcards by default:
response = await client.get(
get_flashcards_url,
headers=auth_headers,
)
assert not response.json()["flashcards"]

# Begin conversation:
mocker.patch(
"fia_api.web.api.teacher.utils.get_openai_response",
return_value=api_response,
"fia_api.web.api.teacher.utils.openai.ChatCompletion.create",
side_effect=get_mocked_openai_response,
)
response = await client.post(
converse_url,
Expand All @@ -142,11 +184,12 @@ async def test_conversations(
)

conversation_id = response.json()["conversation_id"]
conversation = response.json()["conversation"]
conversation_response = response.json()["conversation_response"]

assert conversation_id != "new"
assert len(conversation) == 1
assert conversation[0]["conversation_element"]["role"] == "teacher"
assert len(response.json()["learning_moments"]) > 0 # noqa: WPS507
assert isinstance(conversation_response, str)
assert len(conversation_response) > 0 # noqa: WPS507

# Now one conversation
response = await client.get(
Expand All @@ -167,8 +210,9 @@ async def test_conversations(
)

assert conversation_id == response.json()["conversation_id"]
assert len(conversation) == 1
assert conversation[0]["conversation_element"]["role"] == "teacher"
assert len(response.json()["learning_moments"]) > 0 # noqa: WPS507
assert isinstance(conversation_response, str)
assert len(conversation_response) > 0 # noqa: WPS507

# Now get conversation:
response = await client.get(
Expand All @@ -178,4 +222,12 @@ async def test_conversations(
"conversation_id": conversation_id,
},
)
logger.error(response.json())
assert len(response.json()["conversation"]) == 4

# Now we have flashcards
response = await client.get(
get_flashcards_url,
headers=auth_headers,
)
assert len(response.json()["flashcards"]) == 4
131 changes: 92 additions & 39 deletions fia_api/web/api/teacher/schema.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,94 @@
from typing import List, Union
from typing import List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field


# OpenAI Related Response Models:
class Mistake(BaseModel):
"""Represents a "Mistake" that the model found in the user input."""

mistake_text: str
fixed_text: str
explanation: str


class TranslatedWords(BaseModel):
"""Represents a "Translation" that the model was requested to make."""

word: str
translated_word: str
"""A single Mistake a user made in their message."""

incorrect_section: str = Field(
description=(
"The section of the sentence of the user message the grammar "
"mistake is in"
),
)
corrected_section: str = Field(
description="The corrected section of the sentence in German",
)
explanation: str = Field(
description=(
"The English language explanation of why this section of the "
"sentence is incorrect. Give details such as if it is using the "
"wrong gender/suffix, if the verb conjugation is wrong, etc. If "
"the sentence is correct, but there is a better way to phrase it, "
"explain this too."
),
)


class Translation(BaseModel):
"""A word or phrase the user wants translated."""

phrase: str = Field(
description=(
'The word or phrase the user wants translated. e.g. "Book", '
"'gerne', \"lesen\""
),
)
translated_phrase: str = Field(
description=(
"The translation of the word or phrase in context of the "
"sentence. If it is a noun, include the correct gender. "
"e.g. Das Buch, Die Hande, etc"
),
)


class LearningMoment(BaseModel):
"""A moment in a conversation a user can learn from."""

moment: Union[Mistake, Translation] = Field(
description=(
"A single language learning mistake found in a section of the "
"users message"
),
)


# This is returned by the model looking for mistakes in the user sentence.
class LearningMoments(BaseModel):
"""A list of individual LearningMoment objects."""

learning_moments: List[LearningMoment] = Field(
description=(
"A list of language learning mistakes in the users message. "
"There should be one LearningMoment per individual mistake in "
"the sentence."
),
)


# This is returned by the model trying to continue on the conversation with the
# user in an educational way.
class ConversationContinuation(BaseModel):
"""Basic wrapper to supply description data to OpenAI."""

message: str = Field(
description=(
"This is the response to to users message. You should always "
"try respond in German. If the user doesn't understand, then try "
"to use even more simple German in your response until you can "
"only use English. Responding in English is a last resort."
),
)


class TeacherResponse(BaseModel):
"""Represents the entire response from the "Teacher"."""

translated_words: List[TranslatedWords]
mistakes: List[Mistake]
conversation_response: str
learning_moments: LearningMoments
conversation_response: ConversationContinuation


class TeacherConverseRequest(BaseModel):
Expand All @@ -48,35 +112,24 @@ class UserConversationList(BaseModel):
conversations: List[ConversationSnippet]


class ConversationLine(BaseModel):
"""A single line of a conversation."""

# If str, then this is just a user response.
line: Union[TeacherResponse, str]


class UserConversationElement(BaseModel):
"""The user part of the conversation."""

role: str = "user"
message: str


class TeacherConversationElement(BaseModel):
"""The teacher response."""

role: str = "teacher"
response: TeacherResponse


class ConversationElement(BaseModel):
"""A single part of a conversation. Either from the user or system."""

conversation_element: Union[TeacherConversationElement, UserConversationElement]
role: str
message: str
learning_moments: Optional[LearningMoments] = None


class ConversationResponse(BaseModel):
"""A conversation a user had."""

conversation_id: str
conversation: List[ConversationElement]


class ConverseResponse(BaseModel):
"""Response from the Converse endpoint."""

conversation_id: str
learning_moments: LearningMoments
conversation_response: str
Loading

0 comments on commit 685ccfd

Please sign in to comment.