Skip to content

Commit

Permalink
Merge pull request #92 from nekochans/feature/issue91/add-function-ca…
Browse files Browse the repository at this point in the history
…lling

Function callingで回答を生成するように実装
  • Loading branch information
keitakn committed Jan 27, 2024
2 parents 405e618 + 1caa356 commit 763d316
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 8 deletions.
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ services:
- "5002:5000"
environment:
OPENAI_API_KEY: ${OPENAI_API_KEY}
OPEN_WEATHER_API_KEY: ${OPEN_WEATHER_API_KEY}
BASIC_AUTH_USERNAME: ${BASIC_AUTH_USERNAME}
BASIC_AUTH_PASSWORD: ${BASIC_AUTH_PASSWORD}
DB_HOST: ${DB_HOST}
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mypy = "^1.5.1"
asyncstdlib = "^3.10.9"
pytest-env = "^1.1.1"
pytest-xdist = "^3.5.0"
httpx = "^0.26.0"

[build-system]
requires = ["poetry-core"]
Expand Down
133 changes: 129 additions & 4 deletions src/infrastructure/repository/openai/openai_cat_message_repository.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import os
from typing import AsyncGenerator, cast, List
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
import math
import httpx
import json
from typing import AsyncGenerator, cast, List, TypedDict
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionChunk,
ChatCompletionFunctionMessageParam,
completion_create_params,
)
from domain.repository.cat_message_repository_interface import (
CatMessageRepositoryInterface,
GenerateMessageForGuestUserDto,
GenerateMessageForGuestUserResult,
)


class FetchCurrentWeatherResponse(TypedDict):
city_name: str
description: str
temperature: int


class OpenAiCatMessageRepository(CatMessageRepositoryInterface):
def __init__(self) -> None:
self.OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
self.OPEN_WEATHER_API_KEY = os.environ["OPEN_WEATHER_API_KEY"]
self.client = AsyncOpenAI(api_key=self.OPENAI_API_KEY)

# TODO: 型は合っているのに型チェックエラーが出る mypy が AsyncGenerator に対応していない可能性がある
Expand All @@ -22,16 +37,126 @@ async def generate_message_for_guest_user( # type: ignore
messages = cast(List[ChatCompletionMessageParam], dto.get("chat_messages"))
user = str(dto.get("user_id"))

functions = [
{
"name": "fetch_current_weather",
"description": "指定された都市の現在の天気を取得する。(日本の都市の天気しか取得出来ない)",
"parameters": {
"type": "object",
"properties": {
"city_name": {
"type": "string",
"description": "英語表記の日本の都市名",
}
},
"required": ["city_name"],
},
}
]
function_calling_params = cast(
List[completion_create_params.Function], functions
)

response = await self.client.chat.completions.create(
model="gpt-3.5-turbo-1106",
messages=messages,
stream=True,
temperature=0.7,
user=user,
functions=function_calling_params,
function_call="auto",
)

ai_response_id = ""
function_calling_arguments: dict[str, str] = {
"name": "",
"arguments": "",
}

async for chunk in response:
function_call = chunk.choices[0].delta.function_call

if function_call:
if function_call.name is not None and function_call.name != "":
function_calling_arguments["name"] = function_call.name
if (
function_call.arguments is not None
and function_call.arguments != ""
):
function_calling_arguments["arguments"] += function_call.arguments
continue

if chunk.choices[0].finish_reason == "function_call":
if function_calling_arguments["name"] == "fetch_current_weather":
arguments = (
function_calling_arguments["arguments"]
if function_calling_arguments["arguments"] is not None
else ""
)
city_name = json.loads(arguments)["city_name"]
function_response = await self._fetch_current_weather(city_name)

function_result_message: ChatCompletionFunctionMessageParam = {
"role": "function",
"name": "fetch_current_weather",
"content": json.dumps(function_response, ensure_ascii=False),
}

messages.append(function_result_message)
response = await self.client.chat.completions.create(
model="gpt-3.5-turbo-1106",
messages=messages,
stream=True,
temperature=0.7,
user=user,
)

async for generated_response in self._extract_chat_chunks(response):
yield generated_response
continue

async for generated_response in self._extract_chat_chunks(response):
yield generated_response

async def _fetch_current_weather(
self, city_name: str = "Tokyo"
) -> FetchCurrentWeatherResponse:
async with httpx.AsyncClient() as client:
geocoding_response = await client.get(
"http://api.openweathermap.org/geo/1.0/direct",
params={
"q": city_name + ",jp",
"limit": 1,
"appid": self.OPEN_WEATHER_API_KEY,
},
)
geocoding_list = geocoding_response.json()
geocoding = geocoding_list[0]
lat, lon = geocoding["lat"], geocoding["lon"]

current_weather_response = await client.get(
"https://api.openweathermap.org/data/2.5/weather",
params={
"lat": lat,
"lon": lon,
"units": "metric",
"lang": "ja",
"appid": self.OPEN_WEATHER_API_KEY,
},
)
current_weather = current_weather_response.json()

return {
"city_name": city_name,
"description": current_weather["weather"][0]["description"],
"temperature": math.floor(current_weather["main"]["temp"]),
}

@staticmethod
async def _extract_chat_chunks(
async_stream: AsyncStream[ChatCompletionChunk],
) -> AsyncGenerator[GenerateMessageForGuestUserResult, None]:
ai_response_id = ""
async for chunk in async_stream:
chunk_message: str = (
chunk.choices[0].delta.content
if chunk.choices[0].delta.content is not None
Expand Down

0 comments on commit 763d316

Please sign in to comment.