Skip to content

Commit

Permalink
change: Redis vectorstore -> Qdrant vectorstore
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed May 31, 2023
1 parent 5b2d56f commit 6f5d5e8
Show file tree
Hide file tree
Showing 19 changed files with 820 additions and 310 deletions.
6 changes: 5 additions & 1 deletion .env-sample
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ GOOGLE_TRANSLATE_API_KEY ="OPTIONAL_FOR_TRANSTLATION"
GOOGLE_TRANSLATE_OAUTH_ID="OPTIONAL_FOR_TRANSTLATION"
GOOGLE_TRANSLATE_OAUTH_SECRET="OPTIONAL_FOR_TRANSTLATION"
RAPIDAPI_KEY="OPTIONAL_FOR_TRANSLATION"
CUSTOM_TRANSLATE_URL="OPTIONAL_FOR_TRANSLATION"
CUSTOM_TRANSLATE_URL="OPTIONAL_FOR_TRANSLATION"
SUMMARIZE_FOR_CHAT=True
SUMMARIZATION_THRESHOLD=512
EMBEDDING_TOKEN_CHUNK_SIZE=512
EMBEDDING_TOKEN_CHUNK_OVERLAP=128
59 changes: 36 additions & 23 deletions app/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from os import environ
from pathlib import Path
from re import Pattern, compile
from typing import Optional
from aiohttp import ClientTimeout
from dotenv import load_dotenv
from urllib import parse
Expand Down Expand Up @@ -35,7 +36,6 @@ def __call__(cls, *args, **kwargs):
MAX_API_KEY: int = 3
MAX_API_WHITELIST: int = 10
BASE_DIR: Path = Path(__file__).parents[2]
EMBEDDING_VECTOR_DIMENSION: int = 1536

# MySQL Variables
MYSQL_ROOT_PASSWORD: str = environ["MYSQL_ROOT_PASSWORD"]
Expand All @@ -57,24 +57,29 @@ def __call__(cls, *args, **kwargs):


# Optional Service Variables
EMBEDDING_VECTOR_DIMENSION: int = 1536
EMBEDDING_TOKEN_CHUNK_SIZE: int = int(environ.get("EMBEDDING_TOKEN_CHUNK_SIZE", 512))
EMBEDDING_TOKEN_CHUNK_OVERLAP: int = int(environ.get("EMBEDDING_TOKEN_CHUNK_OVERLAP", 128))
SUMMARIZE_FOR_CHAT: bool = environ.get("SUMMARIZE_FOR_CHAT", "True").lower() == "true"
SUMMARIZATION_THRESHOLD: int = int(environ.get("SUMMARIZATION_THRESHOLD", 512))
DEFAULT_LLM_MODEL: str = environ.get("DEFAULT_LLM_MODEL", "gpt_3_5_turbo")
OPENAI_API_KEY: str | None = environ.get("OPENAI_API_KEY")
RAPID_API_KEY: str | None = environ.get("RAPID_API_KEY")
GOOGLE_TRANSLATE_API_KEY: str | None = environ.get("GOOGLE_TRANSLATE_API_KEY")
PAPAGO_CLIENT_ID: str | None = environ.get("PAPAGO_CLIENT_ID")
PAPAGO_CLIENT_SECRET: str | None = environ.get("PAPAGO_CLIENT_SECRET")
CUSTOM_TRANSLATE_URL: str | None = environ.get("CUSTOM_TRANSLATE_URL")
AWS_ACCESS_KEY: str | None = environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY: str | None = environ.get("AWS_SECRET_KEY")
AWS_AUTHORIZED_EMAIL: str | None = environ.get("AWS_AUTHORIZED_EMAIL")
SAMPLE_JWT_TOKEN: str | None = environ.get("SAMPLE_JWT_TOKEN")
SAMPLE_ACCESS_KEY: str | None = environ.get("SAMPLE_ACCESS_KEY")
SAMPLE_SECRET_KEY: str | None = environ.get("SAMPLE_SECRET_KEY")
KAKAO_RESTAPI_TOKEN: str | None = environ.get("KAKAO_RESTAPI_TOKEN")
WEATHERBIT_API_KEY: str | None = environ.get("WEATHERBIT_API_KEY")
KAKAO_IMAGE_URL: str | None = (
"http://k.kakaocdn.net/dn/wwWjr/btrYVhCnZDF/2bgXDJth2LyIajIjILhLK0/kakaolink40_original.png"
)
OPENAI_API_KEY: Optional[str] = environ.get("OPENAI_API_KEY")
RAPID_API_KEY: Optional[str] = environ.get("RAPID_API_KEY")
GOOGLE_TRANSLATE_API_KEY: Optional[str] = environ.get("GOOGLE_TRANSLATE_API_KEY")
PAPAGO_CLIENT_ID: Optional[str] = environ.get("PAPAGO_CLIENT_ID")
PAPAGO_CLIENT_SECRET: Optional[str] = environ.get("PAPAGO_CLIENT_SECRET")
CUSTOM_TRANSLATE_URL: Optional[str] = environ.get("CUSTOM_TRANSLATE_URL")
AWS_ACCESS_KEY: Optional[str] = environ.get("AWS_ACCESS_KEY")
AWS_SECRET_KEY: Optional[str] = environ.get("AWS_SECRET_KEY")
AWS_AUTHORIZED_EMAIL: Optional[str] = environ.get("AWS_AUTHORIZED_EMAIL")
SAMPLE_JWT_TOKEN: Optional[str] = environ.get("SAMPLE_JWT_TOKEN")
SAMPLE_ACCESS_KEY: Optional[str] = environ.get("SAMPLE_ACCESS_KEY")
SAMPLE_SECRET_KEY: Optional[str] = environ.get("SAMPLE_SECRET_KEY")
KAKAO_RESTAPI_TOKEN: Optional[str] = environ.get("KAKAO_RESTAPI_TOKEN")
WEATHERBIT_API_KEY: Optional[str] = environ.get("WEATHERBIT_API_KEY")
KAKAO_IMAGE_URL: Optional[
str
] = "http://k.kakaocdn.net/dn/wwWjr/btrYVhCnZDF/2bgXDJth2LyIajIjILhLK0/kakaolink40_original.png"

"""
400 Bad Request
Expand Down Expand Up @@ -113,6 +118,10 @@ class Config(metaclass=SingletonMetaClass):
redis_port: int = REDIS_PORT
redis_database: int = REDIS_DATABASE
redis_password: str = REDIS_PASSWORD
qdrant_host: str = "vectorstore"
qdrant_port: int = 6333
qdrant_grpc_port: int = 6334
shared_vectorestore_name: str = "SharedCollection"
trusted_hosts: list[str] = field(default_factory=lambda: ["*"])
allowed_sites: list[str] = field(default_factory=lambda: ["*"])

Expand All @@ -121,6 +130,7 @@ def __post_init__(self):
self.port = 8001
self.mysql_host = "localhost"
self.redis_host = "localhost"
self.qdrant_host = "localhost"
self.mysql_root_url = self.database_url_format.format(
dialect="mysql",
driver="pymysql",
Expand Down Expand Up @@ -149,7 +159,7 @@ def __post_init__(self):

@staticmethod
def get(
option: str | None = None,
option: Optional[str] = None,
) -> LocalConfig | ProdConfig | TestConfig:
if environ.get("PYTEST_RUNNING") is not None:
return TestConfig()
Expand Down Expand Up @@ -202,15 +212,16 @@ class TestConfig(Config):
mysql_database: str = MYSQL_TEST_DATABASE
mysql_host: str = "localhost"
redis_host: str = "localhost"
qdrant_host: str = "localhost"
port: int = 8001


@dataclass
class LoggingConfig:
logger_level: int = logging.DEBUG
console_log_level: int = logging.INFO
file_log_level: int | None = logging.DEBUG
file_log_name: str | None = "./logs/debug.log"
file_log_level: Optional[int] = logging.DEBUG
file_log_name: Optional[str] = "./logs/debug.log"
logging_format: str = "[%(asctime)s] %(name)s:%(levelname)s - %(message)s"


Expand All @@ -223,8 +234,10 @@ class ChatConfig:
api_regex_pattern: Pattern = compile(r"data:\s*({.+?})\n\n")
extra_token_margin: int = 512 # number of tokens to remove when tokens exceed token limit
continue_message: str = "...[CONTINUED]" # message to append when tokens exceed token limit
summarize_for_chat: bool = True # whether to summarize chat messages
summarization_threshold: int = 512 # token threshold for summarization. if message tokens exceed this, summarize
summarize_for_chat: bool = SUMMARIZE_FOR_CHAT # whether to summarize chat messages
summarization_threshold: int = (
SUMMARIZATION_THRESHOLD # token threshold for summarization. if message tokens exceed this, summarize
)
summarization_openai_model: str = "gpt-3.5-turbo"
summarization_token_limit: int = 2048 # token limit for summarization
summarization_token_overlap: int = 100 # number of tokens to overlap between chunks
Expand Down
51 changes: 25 additions & 26 deletions app/database/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from collections.abc import Iterable
from typing import Any, AsyncGenerator, Callable, Optional, Type

from langchain.embeddings.base import Embeddings
from redis.asyncio import Redis as AsyncRedisType
from qdrant_client import QdrantClient
from redis.asyncio import Redis, from_url
from sqlalchemy import Delete, Result, ScalarResult, Select, TextClause, Update, create_engine, text
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.ext.asyncio import (
Expand All @@ -16,9 +16,9 @@
from sqlalchemy_utils import create_database, database_exists

from app.common.config import Config, SingletonMetaClass, logging_config
from app.shared import Shared
from app.errors.api_exceptions import Responses_500
from app.utils.langchain.redis_vectorstore import Redis as RedisVectorStore
from app.shared import Shared
from app.utils.langchain.qdrant_vectorstore import Qdrant
from app.utils.logger import CustomLogger

from . import Base, DeclarativeMeta
Expand Down Expand Up @@ -338,50 +338,49 @@ async def scalars__one_or_none(
return (await self.run_in_session(self._scalars)(session, stmt=stmt)).one_or_none()


class RedisFactory(metaclass=SingletonMetaClass):
class CacheFactory(metaclass=SingletonMetaClass):
def __init__(self):
self._vectorstore: RedisVectorStore | None = None
self._vectorstore: Optional[Qdrant] = None
self.is_test_mode: bool = False
self.is_initiated: bool = False

def start(
self,
config: Config,
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
) -> None:
if self.is_initiated:
return
self.is_test_mode = True if config.test_mode else False
embeddings: Embeddings = Shared().openai_embeddings
self._vectorstore = RedisVectorStore( # type: ignore
redis_url=config.redis_url,
embedding_function=embeddings.embed_query,
content_key=content_key,
metadata_key=metadata_key,
vector_key=vector_key,
is_async=True,
self._redis = from_url(url=config.redis_url)
self._vectorstore = Qdrant(
client=QdrantClient(
host=config.qdrant_host,
port=config.qdrant_port,
grpc_port=config.qdrant_grpc_port,
prefer_grpc=True,
),
collection_name=config.shared_vectorestore_name,
embeddings=Shared().openai_embeddings,
)
self.is_initiated = True

async def close(self) -> None:
if self._vectorstore is not None:
assert isinstance(self._vectorstore.client, AsyncRedisType)
await self._vectorstore.client.close()
if self._redis is not None:
assert isinstance(self._redis, Redis)
await self._redis.close()
self.is_initiated = False

@property
def redis(self) -> AsyncRedisType:
def redis(self) -> Redis:
try:
assert self._vectorstore is not None
assert isinstance(self._vectorstore.client, AsyncRedisType)
assert self._redis is not None
assert isinstance(self._redis, Redis)
except AssertionError:
raise Responses_500.cache_not_initialized
return self._vectorstore.client
return self._redis

@property
def vectorstore(self) -> RedisVectorStore:
def vectorstore(self) -> Qdrant:
try:
assert self._vectorstore is not None
except AssertionError:
Expand All @@ -390,4 +389,4 @@ def vectorstore(self) -> RedisVectorStore:


db: SQLAlchemy = SQLAlchemy()
cache: RedisFactory = RedisFactory()
cache: CacheFactory = CacheFactory()
5 changes: 3 additions & 2 deletions app/models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ class LLMModels(Enum):
token_margin=8,
tokenizer=LlamaTokenizer("timdettmers/guanaco-65b-merged"), # timdettmers/guanaco-13b
model_path="./llama_models/ggml/guanaco-13B.ggmlv3.q5_1.bin",
description=DESCRIPTION_TMPL2,
user_chat_roles=UserChatRoles(
user="Instruction",
ai="Response",
user="Human",
ai="Assistant",
system="System",
),
)
Expand Down
3 changes: 2 additions & 1 deletion app/utils/chat/buffer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
from typing import Any, Awaitable, Callable, Optional

from fastapi import WebSocket

Expand Down Expand Up @@ -45,6 +45,7 @@ class BufferedUserContext:
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
done: asyncio.Event = field(default_factory=asyncio.Event)
task_list: list[asyncio.Task[Any]] = field(default_factory=list) # =
last_user_message: Optional[str] = None
_sorted_ctxts: ContextList = field(init=False)
_current_ctxt: UserChatContext = field(init=False)

Expand Down
38 changes: 25 additions & 13 deletions app/utils/chat/chat_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from fastapi import WebSocket
from fastapi.concurrency import run_in_threadpool

from app.common.constants import CODEX_PROMPT, QUERY_TMPL1, REDEX_PROMPT, CHAT_TURN_TMPL1
from app.common.config import config
from app.common.constants import CHAT_TURN_TMPL1, CODEX_PROMPT, QUERY_TMPL1, REDEX_PROMPT
from app.database.schemas.auth import UserStatus
from app.errors.api_exceptions import InternalServerError
from app.models.chat_models import ChatRoles, LLMModels, MessageHistory, UserChatContext
from app.shared import Shared
from app.utils.api.translate import Translator
from app.utils.chat.buffer import BufferedUserContext
from app.utils.chat.cache_manager import CacheManager
from app.utils.chat.message_handler import MessageHandler
Expand Down Expand Up @@ -580,19 +582,27 @@ async def query(query: str, /, buffer: BufferedUserContext, **kwargs) -> Tuple[s
return query, ResponseType.REPEAT_COMMAND

k: int = 3
if kwargs.get("translate", False):
query = await Translator.translate(text=query, src_lang="ko", trg_lang="en")
await SendToWebsocket.message(
websocket=buffer.websocket,
msg=f"## 번역된 질문\n\n{query}\n\n## 생성된 답변\n\n",
chat_room_id=buffer.current_chat_room_id,
finish=False,
model_name=buffer.current_user_chat_context.llm_model.value.name,
)
found_text_and_score: list[
list[Tuple[Document, float]]
] = await VectorStoreManager.asimilarity_search_multiple_index_with_score(
queries=[query], index_names=[buffer.user_id, ""], k=k
Tuple[Document, float]
] = await VectorStoreManager.asimilarity_search_multiple_collections_with_score(
query=query, collection_names=[buffer.user_id, config.shared_vectorestore_name], k=k
) # lower score is the better!
print(found_text_and_score)

if len(found_text_and_score[0]) > 0:
found_text: str = "\n\n".join([document.page_content for document, _ in found_text_and_score[0]])
if len(found_text_and_score) > 0:
found_text: str = "\n\n".join([document.page_content for document, _ in found_text_and_score])
context_and_query: str = QUERY_TMPL1.format(question=query, context=found_text)
await MessageHandler.user(
msg=context_and_query,
translate=kwargs.get("translate", False),
translate=False,
buffer=buffer,
)
await MessageHandler.ai(
Expand Down Expand Up @@ -621,15 +631,15 @@ async def query(query: str, /, buffer: BufferedUserContext, **kwargs) -> Tuple[s
async def embed(text_to_embed: str, /, buffer: BufferedUserContext) -> str:
"""Embed the text and save its vectors in the redis vectorstore.\n
/embed <text_to_embed>"""
await VectorStoreManager.create_documents(text=text_to_embed, index_name=buffer.user_id)
await VectorStoreManager.create_documents(text=text_to_embed, collection_name=buffer.user_id)
return "Embedding successful!"

@staticmethod
@CommandResponse.send_message_and_stop
async def share(text_to_embed: str, /) -> str:
"""Embed the text and save its vectors in the redis vectorstore. This index is shared for everyone.\n
/share <text_to_embed>"""
await VectorStoreManager.create_documents(text=text_to_embed, index_name="")
await VectorStoreManager.create_documents(text=text_to_embed, collection_name=config.shared_vectorestore_name)
return "Embedding successful! This data will be shared for everyone."

@staticmethod
Expand All @@ -638,10 +648,12 @@ async def drop(buffer: BufferedUserContext) -> str:
"""Drop the index from the redis vectorstore.\n
/drop"""
dropped_index: list[str] = []
if await VectorStoreManager.drop_index(index_name=buffer.user_id):
if await VectorStoreManager.delete_collection(collection_name=buffer.user_id):
dropped_index.append(buffer.user_id)
if buffer.user.status is UserStatus.admin and await VectorStoreManager.drop_index(index_name=""):
dropped_index.append("shared")
if buffer.user.status is UserStatus.admin and await VectorStoreManager.delete_collection(
collection_name=config.shared_vectorestore_name,
):
dropped_index.append(config.shared_vectorestore_name)
if not dropped_index:
return "No index dropped."
return f"Index dropped: {', '.join(dropped_index)}"
Expand Down
2 changes: 1 addition & 1 deletion app/utils/chat/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_generator() -> Iterator[Any]:
else:
real_max_tokens = max_tokens
if real_max_tokens <= 0:
raise ChatLengthException()
raise ChatLengthException(msg=content_buffer)
return llm_client.create_completion( # type: ignore
prompt=prompt,
suffix=llm.suffix,
Expand Down
7 changes: 4 additions & 3 deletions app/utils/chat/stream_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ async def harvest_done_tasks(buffer: BufferedUserContext) -> None:
api_logger.exception(f"Some error occurred while running update tasks: {result}")
except Exception as e:
api_logger.exception(f"Unexpected error occurred while running update tasks: {e}")

buffer.task_list = [task for task in buffer.task_list if task not in harvested_tasks]
finally:
buffer.task_list = [task for task in buffer.task_list if task not in harvested_tasks]


class ChatStreamManager:
Expand Down Expand Up @@ -173,7 +173,7 @@ async def _websocket_receiver(buffer: BufferedUserContext) -> None:
elif received_bytes is not None:
await buffer.queue.put(
await VectorStoreManager.embed_file_to_vectorstore(
file=received_bytes, filename=filename, index_name=buffer.current_user_chat_context.user_id
file=received_bytes, filename=filename, collection_name=buffer.current_user_chat_context.user_id
)
)

Expand Down Expand Up @@ -213,6 +213,7 @@ async def _websocket_sender(cls, buffer: BufferedUserContext) -> None:
buffer=buffer,
)
else:
buffer.last_user_message = item.msg
await MessageHandler.user(
msg=item.msg,
translate=item.translate,
Expand Down
Loading

1 comment on commit 6f5d5e8

@Torhamilton
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you thank you!. I still have not updated my local install. waiting for a new release so I can test.

Please sign in to comment.