Skip to content

Commit

Permalink
避免configs对torch的依赖;webui自动从configs获取api地址(close #1319) (#1328)
Browse files Browse the repository at this point in the history
  • Loading branch information
liunux4odoo committed Aug 31, 2023
1 parent 215bc25 commit 26a9237
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
21 changes: 16 additions & 5 deletions configs/model_config.py.example
@@ -1,13 +1,25 @@
import os
import logging
import torch
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)


# 分布式部署时,不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"


# 在以下字典中修改属性值,以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径
Expand All @@ -33,7 +45,7 @@ embedding_model_dict = {
EMBEDDING_MODEL = "m3e-base"

# Embedding 模型运行设备
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
EMBEDDING_DEVICE = default_device()


llm_model_dict = {
Expand Down Expand Up @@ -76,15 +88,14 @@ llm_model_dict = {
},
}


# LLM 名称
LLM_MODEL = "chatglm2-6b"

# 历史对话轮数
HISTORY_LEN = 3

# LLM 运行设备
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
LLM_DEVICE = default_device()

# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
Expand Down Expand Up @@ -166,4 +177,4 @@ BING_SUBSCRIPTION_KEY = ""
# 是否开启中文标题加强,以及标题增强的相关配置
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
ZH_TITLE_ENHANCE = False
ZH_TITLE_ENHANCE = False
4 changes: 3 additions & 1 deletion webui.py
Expand Up @@ -10,8 +10,10 @@
from webui_pages import *
import os
from configs import VERSION
from server.utils import api_address

api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)

api = ApiRequest(base_url=api_address())

if __name__ == "__main__":
st.set_page_config(
Expand Down

0 comments on commit 26a9237

Please sign in to comment.