Skip to content

Commit

Permalink
improved web browsing feature
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed Jun 12, 2023
1 parent e2d3b92 commit c351e1a
Show file tree
Hide file tree
Showing 46 changed files with 53,622 additions and 51,651 deletions.
45 changes: 36 additions & 9 deletions app/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# flake8: noqa

from re import compile, Pattern

from langchain import PromptTemplate


class QueryTemplates:
CONTEXT_QUESTION__DEFAULT = PromptTemplate(
template=(
"Context information is below. \n"
"Context information is below. You can utilize this information to answer the question.\n"
"---------------------\n"
"{context}"
"\n---------------------\n"
Expand Down Expand Up @@ -98,19 +100,43 @@ class SummarizationTemplates:
)


class WebSearchTemplates:
QUERY__DEFAULT: PromptTemplate = PromptTemplate(
class QueryBasedSearchTemplates:
QUERY__JSONIFY_WEB_BROWSING: PromptTemplate = PromptTemplate(
template=(
"You are a Web search API bot that performs a web search for a user's question. F"
"ollow the rules below to output a response.\n- Output the query to search the web"
' for USER\'S QUESTION in the form {"query": QUERY_TO_SEARCH}.\n- QUERY_TO_SEARCH i'
"s a set of words within 10 words.\n- Your response must be in JSON format, starti"
"ng with { and ending with }.\n- Output a generalized query to return sufficiently"
" relevant results when searching the web.\n- If a suitable search query does not "
'exist, output {"query": null} - don\'t be afraid to output null!\n\n'
"```USER'S QUESTION\n{{query}}\n```"
),
input_variables=["query"],
template_format="jinja2",
)
QUERY__JSONIFY_VECTORSTORE: PromptTemplate = PromptTemplate(
template=(
"You are a Web search API bot that only outputs JSON. If you decide that you need"
" to search a USER'S QUESTION surrounded by the following triple backticks, you m"
'ust use {"query": QUERY_TO_SEARCH} in JSON format. If you don\'t need to search t'
'he web, output something like {"query": null}. Keep your QUERY_TO_SEARCH as a se'
"t of words rather than a sentence. The JSON must contain only the query.\n\nUSER'S"
" QUESTION\n```\n{{query}}\n```"
"You are a Search API bot performing a vector similarity-based search for a user'"
"s question. Follow the rules below to output a response.\n- Output the query to s"
'earch the web for USER\'S QUESTION in a format like {"query": QUERY_TO_SEARCH}.\n-'
" QUERY_TO_SEARCH creates a hypothetical answer to facilitate searching in the Ve"
"ctor database.\n- Your response must be in JSON format, starting with { and endin"
'g with }.\n- If a suitable search query does not exist, output {"query": NULL} - '
"don't be afraid to output NULL!"
"```USER'S QUESTION\n{{query}}\n```"
),
input_variables=["query"],
template_format="jinja2",
)
QUERY__SUMMARIZE_QUERY: PromptTemplate = PromptTemplate(
template=(
"Summarize the user's question in 10 words or less.\n\n"
" ```USER'S QUESTION\n{query}\n```"
),
input_variables=["query"],
template_format="f-string",
)


class SystemPrompts:
Expand Down Expand Up @@ -342,6 +368,7 @@ class SystemPrompts:
},
]

JSON_PATTERN: Pattern = compile(r"\{.*\}")

if __name__ == "__main__":

Expand Down
Binary file added app/contents/browsing_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed app/contents/browsing_demo.png
Binary file not shown.
Binary file removed app/contents/chat_demo.gif
Binary file not shown.
Binary file removed app/contents/embed_demo.png
Binary file not shown.
Binary file removed app/contents/llama_demo.gif
Binary file not shown.
Binary file removed app/contents/stop_generation_demo.gif
Binary file not shown.
Binary file removed app/contents/ui_demo.gif
Binary file not shown.
Binary file added app/contents/upload_demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 22 additions & 8 deletions app/database/crud/api_keys.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Tuple
from sqlalchemy import select, func, exists
from app.viewmodels.base_models import AddApiKey
from app.models.base_models import AddApiKey
from app.errors.api_exceptions import (
Responses_400,
Responses_404,
Expand All @@ -24,9 +24,15 @@ async def create_api_key(
if api_key_count is not None and api_key_count >= MAX_API_KEY:
raise Responses_400.max_key_count_exceed
while True:
new_api_key: ApiKeys = generate_new_api_key(user_id=user_id, additional_key_info=additional_key_info)
is_api_key_duplicate_stmt = select(exists().where(ApiKeys.access_key == new_api_key.access_key))
is_api_key_duplicate: bool | None = await transaction.scalar(is_api_key_duplicate_stmt)
new_api_key: ApiKeys = generate_new_api_key(
user_id=user_id, additional_key_info=additional_key_info
)
is_api_key_duplicate_stmt = select(
exists().where(ApiKeys.access_key == new_api_key.access_key)
)
is_api_key_duplicate: bool | None = await transaction.scalar(
is_api_key_duplicate_stmt
)
if not is_api_key_duplicate:
break
transaction.add(new_api_key)
Expand All @@ -43,7 +49,9 @@ async def get_api_key_owner(access_key: str) -> Users:
if db.session is None:
raise Responses_500.database_not_initialized
async with db.session() as transaction:
matched_api_key: Optional[ApiKeys] = await transaction.scalar(select(ApiKeys).filter_by(access_key=access_key))
matched_api_key: Optional[ApiKeys] = await transaction.scalar(
select(ApiKeys).filter_by(access_key=access_key)
)
if matched_api_key is None:
raise Responses_404.not_found_access_key
owner: Users = await Users.first_filtered_by(id=matched_api_key.user_id) # type: ignore
Expand All @@ -56,10 +64,14 @@ async def get_api_key_and_owner(access_key: str) -> Tuple[ApiKeys, Users]:
if db.session is None:
raise Responses_500.database_not_initialized
async with db.session() as transaction:
matched_api_key: Optional[ApiKeys] = await transaction.scalar(select(ApiKeys).filter_by(access_key=access_key))
matched_api_key: Optional[ApiKeys] = await transaction.scalar(
select(ApiKeys).filter_by(access_key=access_key)
)
if matched_api_key is None:
raise Responses_404.not_found_access_key
api_key_owner: Optional[Users] = await transaction.scalar(select(Users).filter_by(id=matched_api_key.user_id))
api_key_owner: Optional[Users] = await transaction.scalar(
select(Users).filter_by(id=matched_api_key.user_id)
)
if api_key_owner is None:
raise Responses_404.not_found_user
return matched_api_key, api_key_owner
Expand Down Expand Up @@ -94,7 +106,9 @@ async def delete_api_key(
raise Responses_500.database_not_initialized
async with db.session() as transaction:
matched_api_key: Optional[ApiKeys] = await transaction.scalar(
select(ApiKeys).filter_by(id=access_key_id, user_id=user_id, access_key=access_key)
select(ApiKeys).filter_by(
id=access_key_id, user_id=user_id, access_key=access_key
)
)
if matched_api_key is None:
raise Responses_404.not_found_api_key
Expand Down
2 changes: 1 addition & 1 deletion app/middlewares/token_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
InternalServerError,
exception_handler,
)
from app.viewmodels.base_models import UserToken
from app.models.base_models import UserToken
from app.utils.auth.token import token_decode
from app.utils.date_utils import UTC
from app.utils.logger import api_logger
Expand Down
34 changes: 17 additions & 17 deletions app/viewmodels/base_models.py → app/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,27 @@ class MessageOk(BaseModel):
class UserToken(BaseModel):
id: int
status: UserStatus
email: str | None = None
name: str | None = None
email: Optional[str] = None
name: Optional[str] = None

class Config:
orm_mode = True


class UserMe(BaseModel):
id: int
email: str | None = None
name: str | None = None
phone_number: str | None = None
profile_img: str | None = None
sns_type: str | None = None
email: Optional[str] = None
name: Optional[str] = None
phone_number: Optional[str] = None
profile_img: Optional[str] = None
sns_type: Optional[str] = None

class Config:
orm_mode = True


class AddApiKey(BaseModel):
user_memo: str | None = None
user_memo: Optional[str] = None

class Config:
orm_mode = True
Expand Down Expand Up @@ -104,28 +104,29 @@ class Config:


class MessageToWebsocket(BaseModel):
msg: str | None
msg: Optional[str]
finish: bool
chat_room_id: Optional[str] = None
is_user: bool
init: bool = False
model_name: Optional[str] = None
uuid: Optional[str] = None
wait_next_query: Optional[bool] = None

class Config:
orm_mode = True


class MessageFromWebsocket(BaseModel):
msg: str
translate: bool
translate: Optional[str] = None
chat_room_id: str


class CreateChatRoom(BaseModel): # stub
chat_room_type: str
name: str
description: str | None = None
description: Optional[str] = None
user_id: int

class Config:
Expand Down Expand Up @@ -153,12 +154,11 @@ class Config:


class InitMessage(BaseModel):
previous_chats: list[dict] | None = None
chat_rooms: list[dict[str, str]] | None = None
models: list[str] | None = None
selected_model: str | None = None
tokens: int | None = None
wait_next_query: bool
previous_chats: Optional[list[dict]] = None
chat_rooms: Optional[list[dict[str, str]]] = None
models: Optional[list[str]] = None
selected_model: Optional[str] = None
tokens: Optional[int] = None


class StreamProgress(BaseModel):
Expand Down
53 changes: 50 additions & 3 deletions app/models/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional, Union
from functools import wraps
from inspect import iscoroutinefunction
from typing import Any, Callable, Optional, Tuple, Union
from uuid import uuid4

from orjson import dumps as orjson_dumps
from orjson import loads as orjson_loads
from app.common.config import DEFAULT_LLM_MODEL

from app.common.config import DEFAULT_LLM_MODEL
from app.models.base_models import UserChatRoles
from app.models.llms import LLMModels
from app.utils.date_utils import UTC
from app.viewmodels.base_models import UserChatRoles


class ChatRoles(str, Enum):
Expand Down Expand Up @@ -285,3 +287,48 @@ def clear_tokens(self, tokens_to_remove: int) -> int:
self.user_message_histories.pop(0)
self.ai_message_histories.pop(0)
return deleted_histories


class ResponseType(str, Enum):
SEND_MESSAGE_AND_STOP = "send_message_and_stop"
SEND_MESSAGE_AND_KEEP_GOING = "send_message_and_keep_going"
HANDLE_USER = "handle_user"
HANDLE_AI = "handle_ai"
HANDLE_BOTH = "handle_both"
DO_NOTHING = "do_nothing"
REPEAT_COMMAND = "repeat_command"


class command_response:
@staticmethod
def _wrapper(enum_type: ResponseType) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Tuple[Any, ResponseType]:
result = func(*args, **kwargs)
return (result, enum_type)

@wraps(func)
async def async_wrapper(
*args: Any, **kwargs: Any
) -> Tuple[Any, ResponseType]:
result = await func(*args, **kwargs)
return (result, enum_type)

return async_wrapper if iscoroutinefunction(func) else sync_wrapper

return decorator

send_message_and_stop = _wrapper(ResponseType.SEND_MESSAGE_AND_STOP)
send_message_and_keep_going = _wrapper(ResponseType.SEND_MESSAGE_AND_KEEP_GOING)
handle_user = _wrapper(ResponseType.HANDLE_USER)
handle_ai = _wrapper(ResponseType.HANDLE_AI)
handle_both = _wrapper(ResponseType.HANDLE_BOTH)
do_nothing = _wrapper(ResponseType.DO_NOTHING)
repeat_command = _wrapper(ResponseType.REPEAT_COMMAND)


class ChainStatus(str, Enum):
BEGIN = "begin"
END = "end"
ERROR = "error"
11 changes: 10 additions & 1 deletion app/models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.common.constants import ChatTurnTemplates, DescriptionTemplates

from app.models.llm_tokenizers import BaseTokenizer, LlamaTokenizer, OpenAITokenizer
from app.viewmodels.base_models import UserChatRoles
from app.models.base_models import UserChatRoles


@dataclass
Expand Down Expand Up @@ -341,6 +341,15 @@ class LLMModels(Enum):
model_path="./llama_models/ggml/airoboros-13b-gpt4.ggmlv3.q5_1.bin",
prefix_template=DescriptionTemplates.USER_AI__SHORT,
)
selfee_7b = LlamaCppModel(
name="selfee-7B-GGML",
max_total_tokens=4096, # context tokens (n_ctx)
max_tokens_per_request=2048, # The maximum number of tokens to generate.
token_margin=8,
tokenizer=LlamaTokenizer("kaist-ai/selfee-7b-delta"),
model_path="./llama_models/ggml/selfee-7B.ggmlv3.q4_1.bin",
prefix_template=DescriptionTemplates.USER_AI__SHORT,
)

@classmethod
def find_model_by_name(cls, name: str) -> LLMModel | None:
Expand Down
2 changes: 1 addition & 1 deletion app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from app.utils.auth.token import create_access_token, token_decode
from app.utils.chat.cache_manager import CacheManager
from app.viewmodels.base_models import SnsType, Token, UserRegister, UserToken
from app.models.base_models import SnsType, Token, UserRegister, UserToken

router = APIRouter(prefix="/auth")

Expand Down
2 changes: 1 addition & 1 deletion app/routers/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from app.errors.api_exceptions import Responses_400
from app.utils.api.weather import fetch_weather_data
from app.utils.encoding_utils import encode_from_utf8
from app.viewmodels.base_models import KakaoMsgBody, MessageOk, SendEmail
from app.models.base_models import KakaoMsgBody, MessageOk, SendEmail

router = APIRouter(prefix="/services")
router.redirect_slashes = False
Expand Down
2 changes: 1 addition & 1 deletion app/routers/user_services.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional
from fastapi import APIRouter
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from app.viewmodels.base_models import MessageOk
from app.models.base_models import MessageOk


router = APIRouter(prefix="/user-services")
Expand Down
Loading

0 comments on commit c351e1a

Please sign in to comment.