diff --git a/app/gpt.py b/app/gpt.py index 415467fc..53be3384 100644 --- a/app/gpt.py +++ b/app/gpt.py @@ -6,7 +6,7 @@ import uuid import openai from pathlib import Path -from llama_index import ServiceContext, GPTVectorStoreIndex, LLMPredictor, RssReader, SimpleDirectoryReader +from llama_index import ServiceContext, GPTVectorStoreIndex, LLMPredictor, RssReader, SimpleDirectoryReader, StorageContext, load_index_from_storage from llama_index.readers.schema.base import Document from langchain.chat_models import ChatOpenAI from azure.cognitiveservices.speech import SpeechConfig, SpeechSynthesizer, ResultReason, CancellationReason, SpeechSynthesisOutputFormat @@ -21,14 +21,17 @@ SPEECH_REGION = os.environ.get('SPEECH_REGION') openai.api_key = OPENAI_API_KEY +index_cache_web_dir = Path('/tmp/myGPTReader/cache_web/') +index_cache_file_dir = Path('/data/myGPTReader/file/') +index_cache_voice_dir = Path('/tmp/myGPTReader/voice/') + llm_predictor = LLMPredictor(llm=ChatOpenAI( temperature=0, model_name="gpt-3.5-turbo")) service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor) -index_cache_web_dir = Path('/tmp/myGPTReader/cache_web/') -index_cache_voice_dir = Path('/tmp/myGPTReader/voice/') -index_cache_file_dir = Path('/data/myGPTReader/file/') +web_storage_context = StorageContext.from_defaults(persist_dir=str(index_cache_web_dir)) +file_storage_context = StorageContext.from_defaults(persist_dir=str(index_cache_file_dir)) if not index_cache_web_dir.is_dir(): index_cache_web_dir.mkdir(parents=True, exist_ok=True) @@ -81,27 +84,25 @@ def get_documents_from_urls(urls): return documents def get_index_from_web_cache(name): - web_cache_file = index_cache_web_dir / name - if not web_cache_file.is_file(): + try: + index = load_index_from_storage(web_storage_context, index_id=name) + except Exception as e: + logging.error(e) return None - index = GPTVectorStoreIndex.load_from_disk(web_cache_file) - logging.info( - f"=====> Get index from web cache: {web_cache_file}") return index def get_index_from_file_cache(name): - file_cache_file = index_cache_file_dir / name - if not file_cache_file.is_file(): + try: + index = load_index_from_storage(file_storage_context, index_id=name) + except Exception as e: + logging.error(e) return None - index = GPTVectorStoreIndex.load_from_disk(file_cache_file) - logging.info( - f"=====> Get index from file cache: {file_cache_file}") return index def get_index_name_from_file(file: str): file_md5_with_extension = str(Path(file).relative_to(index_cache_file_dir).name) file_md5 = file_md5_with_extension.split('.')[0] - return file_md5 + '.json' + return file_md5 def get_answer_from_chatGPT(messages): dialog_messages = format_dialog_messages(messages) @@ -127,9 +128,10 @@ def get_answer_from_llama_web(messages, urls): documents = get_documents_from_urls(combained_urls) logging.info(documents) index = GPTVectorStoreIndex.from_documents(documents, service_context=service_context) + index.set_index_id(index_file_name) + index.storage_context.persist(persist_dir=str(index_cache_web_dir)) logging.info( f"=====> Save index to disk path: {index_cache_web_dir / index_file_name}") - index.save_to_disk(index_cache_web_dir / index_file_name) prompt = get_prompt_template(lang_code) logging.info('=====> Use llama web with chatGPT to answer!') logging.info('=====> dialog_messages') @@ -150,9 +152,10 @@ def get_answer_from_llama_file(messages, file): logging.info(f"=====> Build index from file!") documents = SimpleDirectoryReader(input_files=[file]).load_data() index = GPTVectorStoreIndex.from_documents(documents, service_context=service_context) + index.set_index_id(index_name) + index.storage_context.persist(persist_dir=str(index_cache_file_dir)) logging.info( f"=====> Save index to disk path: {index_cache_file_dir / index_name}") - index.save_to_disk(index_cache_file_dir / index_name) prompt = get_prompt_template(lang_code) logging.info('=====> Use llama file with chatGPT to answer!') logging.info('=====> dialog_messages')