Skip to content

Commit

Permalink
OpenAI Integration (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
conor-f committed Sep 5, 2023
1 parent 268d0b2 commit 36bb362
Show file tree
Hide file tree
Showing 11 changed files with 640 additions and 6 deletions.
1 change: 1 addition & 0 deletions fia_api/db/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"fia_api.db.models.dummy_model",
"fia_api.db.models.user_model",
"fia_api.db.models.user_details_model",
"fia_api.db.models.conversation_model",
] # noqa: WPS407

TORTOISE_CONFIG = { # noqa: WPS407
Expand Down
20 changes: 20 additions & 0 deletions fia_api/db/migrations/models/5_20230905144348_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from tortoise import BaseDBAsyncClient


async def upgrade(db: BaseDBAsyncClient) -> str:
return """
CREATE TABLE IF NOT EXISTS "conversation_elements" (
"id" SERIAL NOT NULL PRIMARY KEY,
"last_modified" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"first_created" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"conversation_id" UUID NOT NULL,
"role" VARCHAR(9) NOT NULL,
"content" TEXT NOT NULL
);
COMMENT ON COLUMN "conversation_elements"."role" IS 'USER: user\nSYSTEM: system\nASSISTANT: assistant';
COMMENT ON TABLE "conversation_elements" IS 'Model for a snippet of a Conversation.';"""


async def downgrade(db: BaseDBAsyncClient) -> str:
return """
DROP TABLE IF EXISTS "conversation_elements";"""
28 changes: 28 additions & 0 deletions fia_api/db/models/conversation_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from enum import Enum

from tortoise import fields

from fia_api.db.models.fia_base_model import FiaBaseModel


class ConversationElementRole(Enum):
"""Enum for role the message content is from."""

USER = "user"
SYSTEM = "system"
ASSISTANT = "assistant"


class ConversationElementModel(FiaBaseModel):
"""Model for a snippet of a Conversation."""

conversation_id = fields.data.UUIDField(null=False, required=True)
role = fields.data.CharEnumField(ConversationElementRole, required=True)
content = fields.data.TextField(null=False, required=True)

def __str__(self) -> str:
return f"ConversationElement: {self.id}"

class Meta:
ordering = ["first_created"]
table = "conversation_elements"
35 changes: 34 additions & 1 deletion fia_api/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# flake8: noqa

# Ignoring whole file as Flake8 _hates_ the prompts dict with line length and
# multiline strings.

import enum
from pathlib import Path
from tempfile import gettempdir
from typing import Optional
from typing import Dict, Optional

from pydantic_settings import BaseSettings, SettingsConfigDict
from yarl import URL
Expand Down Expand Up @@ -60,6 +65,34 @@ class Settings(BaseSettings):
jwt_secret_key: str = "jwt_secret_key"
jwt_refresh_secret_key: str = "jwt_refresh_secret_key"

openai_api_key: str = "INVALID_OPENAI_API_KEY"

prompts: Dict[str, Dict[str, str]] = {
"p1": {
"role": "system",
"content": """You are an expert German language teacher. You hold basic conversations in German with users. You actively engage with the conversation and keep a pleasant tone. You use a simple vocabulary that the user can understand. If they don't understand you, use simpler words. If they understand you easily, use more complex words. Your response is in the following JSON object:
{
"mistakes": A list of JSON objects. There is one object for each mistake
made in the users message. Each object has an English language explanation
and shows the part of the sentence the mistake was in. If there were no grammar mistakes, the list is empty.
"fluency": A score from 0-100 of how natural sounding the users message was.
"conversation_response": A string in the German language to continue the conversation with the user.
}
You must respond to every message in this exact structure. You must not respond in any other way.""",
},
"p2": {
"role": "system",
"content": """You are a German language teacher analyzing sentences. You always respond in a JSON object. The JSON object has the following members: mistakes, response. mistakes is a list of every grammer/spelling/vocabulary mistake the user made. response is the German language response to the users message.""",
},
# BEST SO FAR!
"p3": {
"role": "system",
"content": """You are a German language teacher. You correct any English words in a users message. You also explain any spelling or grammar mistakes they make in English. You are having a conversation with them. Don't translate every word, only the words that the user typed in English. Always try to continue the conversation.""",
},
}

@property
def db_url(self) -> URL:
"""
Expand Down
3 changes: 2 additions & 1 deletion fia_api/web/api/router.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi.routing import APIRouter

from fia_api.web.api import dummy, echo, monitoring, redis, user
from fia_api.web.api import dummy, echo, monitoring, redis, teacher, user

api_router = APIRouter()
api_router.include_router(monitoring.router)
api_router.include_router(echo.router, prefix="/echo", tags=["echo"])
api_router.include_router(dummy.router, prefix="/dummy", tags=["dummy"])
api_router.include_router(redis.router, prefix="/redis", tags=["redis"])
api_router.include_router(user.router, prefix="/user", tags=["user"])
api_router.include_router(teacher.router, prefix="/teacher", tags=["teacher"])
4 changes: 4 additions & 0 deletions fia_api/web/api/teacher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Teacher model API."""
from fia_api.web.api.teacher.views import router

__all__ = ["router"]
54 changes: 54 additions & 0 deletions fia_api/web/api/teacher/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import List

from pydantic import BaseModel


# OpenAI Related Response Models:
class Mistake(BaseModel):
"""Represents a "Mistake" that the model found in the user input."""

mistake_text: str
fixed_text: str
explanation: str


class TranslatedWords(BaseModel):
"""Represents a "Translation" that the model was requested to make."""

word: str
translated_word: str


class TeacherResponse(BaseModel):
"""Represents the entire response from the "Teacher"."""

translated_words: List[TranslatedWords]
mistakes: List[Mistake]
conversation_response: str


class TeacherConverseRequest(BaseModel):
"""Request object for calls to the Teacher to continue a conversation."""

# If conversation_id is "new", then start a new conversation.
conversation_id: str
message: str


class TeacherConverseResponse(BaseModel):
"""Response from a call to the teacher/converse endpoint."""

conversation_id: str
response: TeacherResponse


class TeacherConverseEndRequest(BaseModel):
"""Response from a call to the teacher/end_converse endpoint."""

conversation_id: str


class TeacherConverseEndResponse(BaseModel):
"""Response from a call to the teacher/end_converse endpoint."""

message: str
106 changes: 106 additions & 0 deletions fia_api/web/api/teacher/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# noqa: WPS462
import json
import uuid
from typing import Dict, List

import openai

from fia_api.db.models.conversation_model import (
ConversationElementModel,
ConversationElementRole,
)
from fia_api.settings import settings
from fia_api.web.api.teacher.schema import TeacherConverseResponse, TeacherResponse

openai.api_key = settings.openai_api_key


async def get_messages_from_conversation_id(
conversation_id: str,
) -> List[Dict[str, str]]:
"""
Given a conversation_id, return a list of dicts ready to pass to OpenAI.
:param conversation_id: str ID of the conversation
:return: List of dicts of shape {"role": EnumValue, "content": "message"}
"""
raw_conversation = await ConversationElementModel.filter(
conversation_id=uuid.UUID(conversation_id),
).values()

return [
{
"role": conversation_element["role"].value,
"content": conversation_element["content"],
}
for conversation_element in raw_conversation
]


async def get_response(conversation_id: str, message: str) -> TeacherConverseResponse:
"""
Converse with OpenAI.
Given the conversation ID, and a new message to add to it, store the
message, get the response, store that, and return it.
:param conversation_id: String ID representing the conversation.
:param message: String message the user wants to send.
:return: TeacherResponse
"""
await ConversationElementModel.create(
conversation_id=uuid.UUID(conversation_id),
role=ConversationElementRole.USER,
content=message,
)

chat_response = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
messages=await get_messages_from_conversation_id(conversation_id),
functions=[
{
"name": "get_answer_for_user_query",
"description": "Get user language learning mistakes and a sentence to continue the conversation", # noqa: E501
"parameters": TeacherResponse.schema(),
},
],
function_call={"name": "get_answer_for_user_query"},
)

teacher_response = chat_response.choices[ # noqa: WPS219
0
].message.function_call.arguments

await ConversationElementModel.create(
conversation_id=uuid.UUID(conversation_id),
role=ConversationElementRole.SYSTEM,
content=teacher_response,
)

# TODO: Store the token usage in the DB...

return TeacherConverseResponse(
conversation_id=conversation_id,
response=json.loads(teacher_response),
)


async def initialize_conversation(message: str) -> TeacherConverseResponse:
"""
Starts the conversation.
Set up the DB with the initial conversation prompt and return the new
conversation ID, along with the first response from the model.
:param message: The message to start the conversation with.
:returns: TeacherConverseResponse of the teacher's first reply.
"""
conversation_id = uuid.uuid4()

await ConversationElementModel.create(
conversation_id=conversation_id,
role=ConversationElementRole.SYSTEM,
content=settings.prompts["p3"],
)

return await get_response(str(conversation_id), message)
32 changes: 32 additions & 0 deletions fia_api/web/api/teacher/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fastapi import APIRouter, Depends

from fia_api.web.api.teacher.schema import (
TeacherConverseRequest,
TeacherConverseResponse,
)
from fia_api.web.api.teacher.utils import get_response, initialize_conversation
from fia_api.web.api.user.schema import AuthenticatedUser
from fia_api.web.api.user.utils import get_current_user

router = APIRouter()


@router.post("/converse", response_model=TeacherConverseResponse)
async def converse(
converse_request: TeacherConverseRequest,
user: AuthenticatedUser = Depends(get_current_user),
) -> TeacherConverseResponse:
"""
Starts or continues a conversation with the Teacher.
:param converse_request: The request object.
:param user: The AuthenticatedUser making the request.
:returns: TeacherConverseResponse of mistakes and conversation.
"""
if converse_request.conversation_id == "new":
return await initialize_conversation(converse_request.message)

return await get_response(
converse_request.conversation_id,
converse_request.message,
)
Loading

0 comments on commit 36bb362

Please sign in to comment.