Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
c0sogi committed May 25, 2023
1 parent c56955a commit 1a83a71
Show file tree
Hide file tree
Showing 16 changed files with 612 additions and 819 deletions.
33 changes: 16 additions & 17 deletions app/common/app_settings.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
from fastapi import FastAPI, Depends
from fastapi import Depends, FastAPI
from fastapi.staticfiles import StaticFiles
from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette_admin.contrib.sqla.admin import Admin
from starlette_admin.contrib.sqla.view import ModelView
from starlette_admin.views import DropDown, Link

from app.auth.admin import MyAuthProvider
from app.common.config import JWT_SECRET, Config
from app.database.connection import db, cache
from app.database.connection import cache, db
from app.database.schemas.auth import ApiKeys, ApiWhiteLists, Users
from app.dependencies import USER_DEPENDENCY, api_service_dependency
from app.middlewares.token_validator import access_control
from app.middlewares.trusted_hosts import TrustedHostMiddleware
from app.routers import index, auth, services, users, websocket
from app.dependencies import (
api_service_dependency,
user_dependency,
)
from app.utils.logger import api_logger
from app.routers import auth, index, services, users, websocket
from app.shared import Shared
from app.utils.chat.cache_manager import CacheManager
from app.utils.js_initializer import js_url_initializer
from app import dependencies
from starlette_admin.contrib.sqla.admin import Admin
from starlette_admin.contrib.sqla.view import ModelView
from starlette_admin.views import DropDown, Link

from app.utils.logger import api_logger
from app.viewmodels.admin import ApiKeyAdminView, UserAdminView


Expand Down Expand Up @@ -101,7 +98,7 @@ def create_app(config: Config) -> FastAPI:
users.router,
prefix="/api",
tags=["Users"],
dependencies=[Depends(user_dependency)],
dependencies=[Depends(USER_DEPENDENCY)],
)

@new_app.on_event("startup")
Expand All @@ -118,18 +115,20 @@ async def startup():
else:
api_logger.critical("Redis CACHE connection failed!")
try:
import uvloop # type: ignore
import asyncio

import uvloop # type: ignore

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
api_logger.critical("uvloop installed!")
except ImportError:
api_logger.critical("uvloop not installed!")

@new_app.on_event("shutdown")
async def shutdown():
dependencies.process_pool_executor.shutdown(cancel_futures=True, wait=True)
# await CacheManager.delete_user(f"testaccount@{HOST_MAIN}")
Shared().process_manager.shutdown()
Shared().process_pool_executor.shutdown()
await db.close()
await cache.close()
api_logger.critical("DB & CACHE connection closed!")
Expand Down
66 changes: 51 additions & 15 deletions app/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from re import Pattern, compile
from dotenv import load_dotenv
from urllib import parse

load_dotenv()

Expand Down Expand Up @@ -33,6 +34,7 @@ def __call__(cls, *args, **kwargs):
MAX_API_KEY: int = 3
MAX_API_WHITELIST: int = 10
BASE_DIR: Path = Path(__file__).parents[2]
EMBEDDING_VECTOR_DIMENSION: int = 1536

# MySQL Variables
MYSQL_ROOT_PASSWORD: str = environ["MYSQL_ROOT_PASSWORD"]
Expand Down Expand Up @@ -113,22 +115,59 @@ class Config(metaclass=SingletonMetaClass):
trusted_hosts: list[str] = field(default_factory=lambda: ["*"])
allowed_sites: list[str] = field(default_factory=lambda: ["*"])

def __post_init__(self):
if not DOCKER_MODE:
self.port = 8001
self.mysql_host = "localhost"
self.redis_host = "localhost"
self.mysql_root_url = self.database_url_format.format(
dialect="mysql",
driver="pymysql",
user="root",
password=parse.quote(self.mysql_root_password),
host=self.mysql_host,
port=self.mysql_port,
database=self.mysql_database,
)
self.mysql_url = self.database_url_format.format(
dialect="mysql",
driver="aiomysql",
user=self.mysql_user,
password=parse.quote(self.mysql_password),
host=self.mysql_host,
port=self.mysql_port,
database=self.mysql_database,
)
self.redis_url = self.redis_url_format.format(
username="",
password=self.redis_password,
host=self.redis_host,
port=self.redis_port,
db=self.redis_database,
)

@staticmethod
def get(
option: str | None = None,
) -> LocalConfig | ProdConfig | TestConfig:
config_key = option if option is not None else API_ENV
config_key = "test" if environ.get("PYTEST_RUNNING") is not None else config_key
_config = {
"prod": ProdConfig,
"local": LocalConfig,
"test": TestConfig,
}[config_key]()
if not DOCKER_MODE:
_config.port = 8001
_config.mysql_host = "localhost"
_config.redis_host = "localhost"
return _config
if environ.get("PYTEST_RUNNING") is not None:
return TestConfig()
else:
if option is not None:
return {
"prod": ProdConfig,
"local": LocalConfig,
"test": TestConfig,
}[option]()
else:
if API_ENV is not None:
return {
"prod": ProdConfig,
"local": LocalConfig,
"test": TestConfig,
}[API_ENV]()
else:
return LocalConfig()


@dataclass
Expand Down Expand Up @@ -176,6 +215,3 @@ class LoggingConfig:

config = Config.get()
logging_config = LoggingConfig()

if __name__ == "__main__":
print(BASE_DIR / "app")
83 changes: 22 additions & 61 deletions app/database/connection.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,27 @@
from collections.abc import Iterable
from asyncio import current_task
from typing import AsyncGenerator, Callable, Optional, Type, Any
from urllib import parse
from collections.abc import Iterable
from typing import Any, AsyncGenerator, Callable, Optional, Type

from langchain.embeddings.base import Embeddings
from redis.asyncio import Redis as AsyncRedisType
from sqlalchemy import (
Result,
ScalarResult,
Select,
Delete,
TextClause,
Update,
create_engine,
text,
)
from sqlalchemy.engine.base import Engine, Connection
from sqlalchemy import Delete, Result, ScalarResult, Select, TextClause, Update, create_engine, text
from sqlalchemy.engine.base import Connection, Engine
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
AsyncEngine,
AsyncSession,
async_scoped_session,
async_sessionmaker,
create_async_engine,
AsyncSession,
AsyncEngine,
)
from sqlalchemy_utils import database_exists, create_database
from sqlalchemy_utils import create_database, database_exists

from app.common.config import Config, SingletonMetaClass, logging_config
from app.shared import Shared
from app.errors.api_exceptions import Responses_500
from app.utils.langchain.redis_vectorstore import Redis as RedisVectorStore
from app.utils.logger import CustomLogger
from app.common.config import logging_config, Config, SingletonMetaClass
from . import Base, DeclarativeMeta

import openai
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from app.utils.langchain.redis_vectorstore import Redis as RedisVectorStore
from app.common.config import OPENAI_API_KEY
from . import Base, DeclarativeMeta


class MySQL(metaclass=SingletonMetaClass):
Expand Down Expand Up @@ -157,28 +147,11 @@ def start(self, config: Config) -> None:
f"Current DB connection of {type(config).__name__}: "
+ f"{config.mysql_host}/{config.mysql_database}@{config.mysql_user}"
)
root_url = config.database_url_format.format(
dialect="mysql",
driver="pymysql",
user="root",
password=parse.quote(config.mysql_root_password),
host=config.mysql_host,
port=config.mysql_port,
database=config.mysql_database,
)
database_url = config.database_url_format.format(
dialect="mysql",
driver="aiomysql",
user=config.mysql_user,
password=parse.quote(config.mysql_password),
host=config.mysql_host,
port=config.mysql_port,
database=config.mysql_database,
)
if not database_exists(root_url):
create_database(root_url)

self.root_engine = create_engine(root_url, echo=config.db_echo)
if not database_exists(config.mysql_root_url):
create_database(config.mysql_root_url)

self.root_engine = create_engine(config.mysql_root_url, echo=config.db_echo)
with self.root_engine.connect() as conn:
if not MySQL.is_user_exists(config.mysql_user, engine_or_conn=conn):
MySQL.create_user(config.mysql_user, config.mysql_password, "%", engine_or_conn=conn)
Expand All @@ -197,7 +170,7 @@ def start(self, config: Config) -> None:
self.root_engine.dispose()
self.root_engine = None
self.engine = create_async_engine(
database_url,
config.mysql_url,
echo=config.db_echo,
pool_recycle=config.db_pool_recycle,
pool_pre_ping=True,
Expand Down Expand Up @@ -377,25 +350,13 @@ def start(
content_key: str = "content",
metadata_key: str = "metadata",
vector_key: str = "content_vector",
vector_dimension: int = 1536,
openai_api_key: str | None = OPENAI_API_KEY,
) -> None:
if self.is_initiated:
return
self.is_test_mode = True if config.test_mode else False
redis_url = config.redis_url_format.format(
username="",
password=config.redis_password,
host=config.redis_host,
port=config.redis_port,
db=config.redis_database,
)
embeddings: Embeddings = OpenAIEmbeddings(
client=openai.Embedding,
openai_api_key=openai_api_key,
)
embeddings: Embeddings = Shared().openai_embeddings
self._vectorstore = RedisVectorStore( # type: ignore
redis_url=redis_url,
redis_url=config.redis_url,
embedding_function=embeddings.embed_query,
content_key=content_key,
metadata_key=metadata_key,
Expand Down
6 changes: 1 addition & 5 deletions app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from concurrent.futures import ProcessPoolExecutor
from fastapi import Header, Query
from fastapi.security import APIKeyHeader
from multiprocessing import Manager


def api_service_dependency(secret: str = Header(...), key: str = Query(...), timestamp: str = Query(...)):
... # do some validation or processing with the headers


process_pool_executor = ProcessPoolExecutor()
process_manager = Manager()
user_dependency = APIKeyHeader(name="Authorization", auto_error=False)
USER_DEPENDENCY = APIKeyHeader(name="Authorization", auto_error=False)
2 changes: 1 addition & 1 deletion app/models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class OpenAIModel(LLMModel):
class LLMModels(Enum): # gpt models for openai api
gpt_3_5_turbo = OpenAIModel(
name="gpt-3.5-turbo",
max_total_tokens=4096,
max_total_tokens=2048,
max_tokens_per_request=2048,
token_margin=8,
tokenizer=OpenAITokenizer("gpt-3.5-turbo"),
Expand Down
4 changes: 2 additions & 2 deletions app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from app.common.config import TOKEN_EXPIRE_HOURS
from app.database.crud.users import is_email_exist, register_new_user
from app.database.schemas.auth import Users
from app.dependencies import user_dependency
from app.dependencies import USER_DEPENDENCY
from app.errors.api_exceptions import Responses_400, Responses_404
from app.utils.auth.register_validation import (
is_email_length_in_range,
Expand Down Expand Up @@ -81,7 +81,7 @@ async def register(

@router.delete("/register", status_code=status.HTTP_204_NO_CONTENT)
async def unregister(
authorization: str = Security(user_dependency),
authorization: str = Security(USER_DEPENDENCY),
):
registered_user: UserToken = UserToken(**token_decode(authorization))
if registered_user.email is not None:
Expand Down
17 changes: 17 additions & 0 deletions app/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
from multiprocessing import Manager
from multiprocessing.managers import SyncManager

import openai
from app.common.config import OPENAI_API_KEY, SingletonMetaClass
from langchain.embeddings import OpenAIEmbeddings


@dataclass
class Shared(metaclass=SingletonMetaClass):
process_manager: SyncManager = field(default_factory=Manager)
process_pool_executor: ProcessPoolExecutor = field(default_factory=ProcessPoolExecutor)
openai_embeddings: OpenAIEmbeddings = field(
default_factory=lambda: OpenAIEmbeddings(client=openai.Embedding, openai_api_key=OPENAI_API_KEY)
)
11 changes: 10 additions & 1 deletion app/utils/chat/buffer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import asyncio
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from fastapi import WebSocket
from app.database.schemas.auth import Users

from app.database.schemas.auth import Users
from app.models.chat_models import MessageHistory, UserChatContext, UserChatProfile

if TYPE_CHECKING:
from app.models.llms import LLMModels


@dataclass
class BufferedUserContext:
Expand Down Expand Up @@ -45,6 +50,10 @@ def buffer_size(self) -> int:
def current_chat_room_id(self) -> str:
return self._current_ctxt.chat_room_id

@property
def current_llm_model(self) -> "LLMModels":
return self._current_ctxt.llm_model

@property
def current_chat_room_name(self) -> str:
return self._current_ctxt.chat_room_name
Expand Down
Loading

0 comments on commit 1a83a71

Please sign in to comment.