Skip to content

Commit

Permalink
chore: 去掉 bot 示例中的 pickle 依赖 (baidu#9640)
Browse files Browse the repository at this point in the history
  • Loading branch information
nwind committed Feb 22, 2024
1 parent 7e03085 commit dd34bf9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions scripts/bot/.gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
db
__pycache__
text.pickle
embedding.pickle
text.json
embedding.json
.env
m3e-base
flagged
14 changes: 7 additions & 7 deletions scripts/bot/gen_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import glob
import uuid
import pickle
import json
from embedding import get_embedding
from split_markdown import split_markdown
from vector_store import get_client
Expand All @@ -21,11 +21,11 @@
embedding_cache = {}

embedding_cache_file = os.path.join(
os.path.dirname(__file__), 'embedding.pickle')
os.path.dirname(__file__), 'embedding.json')

if os.path.exists(embedding_cache_file):
with open(embedding_cache_file, 'rb') as f:
embedding_cache = pickle.load(f)
embedding_cache = json.load(f)


def get_embedding_with_cache(text):
Expand Down Expand Up @@ -65,8 +65,8 @@ def get_embedding_with_cache(text):
)


with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'wb') as f:
pickle.dump(text_blocks_by_id, f, pickle.HIGHEST_PROTOCOL)
with open(os.path.join(os.path.dirname(__file__), 'text.json'), 'w') as f:
json.dump(text_blocks_by_id, f)

with open(embedding_cache_file, 'wb') as f:
pickle.dump(embedding_cache, f, pickle.HIGHEST_PROTOCOL)
with open(embedding_cache_file, 'w') as f:
json.dump(embedding_cache, f)
6 changes: 3 additions & 3 deletions scripts/bot/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from embedding import get_embedding
import gradio as gr
import os
import pickle
import json
from llm.wenxin import Wenxin, ModelName
from dotenv import load_dotenv
load_dotenv()
Expand All @@ -15,8 +15,8 @@
wenxin = Wenxin()

text_blocks_by_id = {}
with open(os.path.join(os.path.dirname(__file__), 'text.pickle'), 'rb') as f:
text_blocks_by_id = pickle.load(f)
with open(os.path.join(os.path.dirname(__file__), 'text.json'), 'rb') as f:
text_blocks_by_id = json.load(f)


def get_prompt(context, query):
Expand Down

0 comments on commit dd34bf9

Please sign in to comment.