Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache refinement #67

Merged
merged 10 commits into from
Mar 11, 2023
Merged
4 changes: 3 additions & 1 deletion src/BingService.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import requests
import yaml

from Util import setup_logger, get_project_root
from Util import setup_logger, get_project_root, storage_cached
from text_extract.html.beautiful_soup import BeautifulSoupSvc
from text_extract.html.trafilatura import TrafilaturaSvc

Expand All @@ -21,6 +21,7 @@ def __init__(self, config):
elif extract_svc == 'beautifulsoup':
self.txt_extract_svc = BeautifulSoupSvc()

@storage_cached('bing_search_website', 'query')
def call_bing_search_api(self, query: str) -> pd.DataFrame:
logger.info("BingService.call_bing_search_api. query: " + query)
subscription_key = self.config.get('bing_search').get('subscription_key')
Expand Down Expand Up @@ -81,6 +82,7 @@ def call_one_url(self, website_tuple):
logger.info(f" receive sentences: {len(sentences)}")
return sentences, name, url, url_id, snippet

@storage_cached('bing_search_website_content', 'website_df')
def call_urls_and_extract_sentences_concurrent(self, website_df):
logger.info(f"BingService.call_urls_and_extract_sentences_async. website_df.shape: {website_df.shape}")
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
Expand Down
4 changes: 3 additions & 1 deletion src/LLMService.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import yaml

from Util import setup_logger, get_project_root
from Util import setup_logger, get_project_root, storage_cached

logger = setup_logger('LLMService')

Expand Down Expand Up @@ -103,6 +103,7 @@ def __init__(self, config):
raise Exception("OpenAI API key is not set.")
openai.api_key = open_api_key

@storage_cached('openai', 'prompt')
def call_api(self, prompt: str):
openai_api_config = self.config.get('openai_api')
model = openai_api_config.get('model')
Expand Down Expand Up @@ -143,6 +144,7 @@ def __init__(self, config):
openai.api_key = goose_api_key
openai.api_base = config.get('goose_ai_api').get('api_base')

@storage_cached('gooseai', 'prompt')
def call_api(self, prompt: str):
logger.info(f"GooseAIService.call_openai_api. len(prompt): {len(prompt)}")
goose_api_config = self.config.get('goose_ai_api')
Expand Down
40 changes: 9 additions & 31 deletions src/SearchGPTService.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from FrontendService import FrontendService
from LLMService import LLMServiceFactory
from SemanticSearchService import BatchOpenAISemanticSearchService
from Util import setup_logger, post_process_gpt_input_text_df, check_result_cache_exists, load_result_from_cache, save_result_cache, check_max_number_of_cache, get_project_root
from Util import setup_logger, post_process_gpt_input_text_df, get_project_root, storage_cached
from text_extract.doc import support_doc_type, doc_extract_svc_map
from text_extract.doc.abc_doc_extract import AbstractDocExtractSvc

Expand Down Expand Up @@ -60,23 +60,9 @@ def _prompt(self, search_text, text_df, cache_path=None):
gpt_input_text_df = semantic_search_service.search_related_source(text_df, search_text)
gpt_input_text_df = post_process_gpt_input_text_df(gpt_input_text_df, self.config.get('openai_api').get('prompt').get('prompt_length_limit'))

llm_service_provider = self.config.get('llm_service').get('provider')
# check if llm result is cached and load if exists
if self.config.get('cache').get('is_enable_cache') and check_result_cache_exists(cache_path, search_text, llm_service_provider):
logger.info(f"SemanticSearchService.load_result_from_cache. search_text: {search_text}, cache_path: {cache_path}")
cache = load_result_from_cache(cache_path, search_text, llm_service_provider)
prompt, response_text = cache['prompt'], cache['response_text']
else:
llm_service = LLMServiceFactory.create_llm_service(self.config)
prompt = llm_service.get_prompt_v3(search_text, gpt_input_text_df)
response_text = llm_service.call_api(prompt)

llm_config = self.config.get(f'{llm_service_provider}_api').copy()
llm_config.pop('api_key') # delete api_key to avoid saving it to .cache
save_result_cache(cache_path, search_text, llm_service_provider, prompt=prompt, response_text=response_text, config=llm_config)

# check whether the number of cache exceeds the limit
check_max_number_of_cache(cache_path, self.config.get('cache').get('max_number_of_cache'))
llm_service = LLMServiceFactory.create_llm_service(self.config)
prompt = llm_service.get_prompt_v3(search_text, gpt_input_text_df)
response_text = llm_service.call_api(prompt=prompt)

frontend_service = FrontendService(self.config, response_text, gpt_input_text_df)
source_text, data_json = frontend_service.get_data_json(response_text, gpt_input_text_df)
Expand All @@ -94,23 +80,14 @@ def _prompt(self, search_text, text_df, cache_path=None):

def _extract_bing_text_df(self, search_text, cache_path):
# BingSearch using search_text
# check if bing search result is cached and load if exists
bing_text_df = None
if not self.config['search_option']['is_use_source'] or not self.config['search_option']['is_enable_bing_search']:
return bing_text_df

if self.config.get('cache').get('is_enable_cache') and check_result_cache_exists(cache_path, search_text, 'bing_search'):
logger.info(f"BingService.load_result_from_cache. search_text: {search_text}, cache_path: {cache_path}")
cache = load_result_from_cache(cache_path, search_text, 'bing_search')
bing_text_df = cache['bing_text_df']
else:
bing_service = BingService(self.config)
website_df = bing_service.call_bing_search_api(search_text)
bing_text_df = bing_service.call_urls_and_extract_sentences_concurrent(website_df)

bing_search_config = self.config.get('bing_search').copy()
bing_search_config.pop('subscription_key') # delete api_key from config to avoid saving it to .cache
save_result_cache(cache_path, search_text, 'bing_search', bing_text_df=bing_text_df, config=bing_search_config)
bing_service = BingService(self.config)
website_df = bing_service.call_bing_search_api(query=search_text)
bing_text_df = bing_service.call_urls_and_extract_sentences_concurrent(website_df=website_df)

return bing_text_df

def _extract_doc_text_df(self, bing_text_df):
Expand Down Expand Up @@ -143,6 +120,7 @@ def _extract_doc_text_df(self, bing_text_df):
doc_text_df = pd.DataFrame(doc_sentence_list)
return doc_text_df

@storage_cached('web', 'search_text')
def query_and_get_answer(self, search_text):
cache_path = Path(self.config.get('cache').get('path'))
# TODO: strategy pattern to support different text sources (e.g. PDF later)
Expand Down
76 changes: 60 additions & 16 deletions src/Util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
import pickle
import re
import shutil
from copy import deepcopy
from functools import wraps
from hashlib import md5
from pathlib import Path

Expand Down Expand Up @@ -37,41 +38,84 @@ def post_process_gpt_input_text_df(gpt_input_text_df, prompt_length_limit):
return gpt_input_text_df


def save_result_cache(path: Path, search_text: str, cache_type: str = 'bing_search', **kwargs):
cache_dir = path / md5(search_text.encode()).hexdigest()

def save_result_cache(path: Path, hash: str, type: str, **kwargs):
cache_dir = path / type
os.makedirs(cache_dir, exist_ok=True)
path = Path(cache_dir, f'{cache_type}.pickle')
path = Path(cache_dir, f'{hash}.pickle')
with open(path, 'wb') as f:
pickle.dump(kwargs, f)


def load_result_from_cache(path: Path, search_text: str, cache_type: str = 'bing_search'):
path = path / f'{md5(search_text.encode()).hexdigest()}' / f'{cache_type}.pickle'
def load_result_from_cache(path: Path, hash: str, type: str):
path = path / type / f'{hash}.pickle'
with open(path, 'rb') as f:
return pickle.load(f)


def check_result_cache_exists(path: Path, search_text: str, cache_type: str = 'bing_search') -> bool:
path = path / f'{md5(search_text.encode()).hexdigest()}' / f'{cache_type}.pickle'
if os.path.exists(path):
return True
else:
return False
def check_result_cache_exists(path: Path, hash: str, type: str) -> bool:
path = path / type / f'{hash}.pickle'
return True if os.path.exists(path) else False


def check_max_number_of_cache(path: Path, max_number_of_cache: int = 10):
if len(os.listdir(path)) >= max_number_of_cache:
def check_max_number_of_cache(path: Path, type: str, max_number_of_cache: int = 10):
path = path / type
if len(os.listdir(path)) > max_number_of_cache:
ctime_list = [(os.path.getctime(path / file), file) for file in os.listdir(path)]
oldest_file = sorted(ctime_list)[0][1]
shutil.rmtree(path / oldest_file)
os.remove(path / oldest_file)


def split_sentences_from_paragraph(text):
sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s", text)
return sentences


def remove_api_keys(d):
key_to_remove = ['api_key', 'subscription_key']
temp_key_list = []
for key, value in d.items():
if key in key_to_remove:
temp_key_list += [key]
if isinstance(value, dict):
remove_api_keys(value)

for key in temp_key_list:
d.pop(key)
return d


def storage_cached(cache_type: str, cache_hash_key_name: str):
def storage_cache_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
assert getattr(args[0], 'config'), 'storage_cached is only applicable to class method with config attribute'
assert cache_hash_key_name in kwargs, f'Target method does not have {cache_hash_key_name} keyword argument'

config = getattr(args[0], 'config')
if config.get('cache').get('is_enable').get(cache_type):
hash_key = str(kwargs[cache_hash_key_name])

cache_path = Path(get_project_root(), config.get('cache').get('path'))
cache_hash = md5(str(config).encode() + hash_key.encode()).hexdigest()

if check_result_cache_exists(cache_path, cache_hash, cache_type):
result = load_result_from_cache(cache_path, cache_hash, cache_type)['result']
else:
result = func(*args, **kwargs)
config_for_cache = deepcopy(config)
config_for_cache = remove_api_keys(config_for_cache) # remove api keys
save_result_cache(cache_path, cache_hash, cache_type, result=result, config=config_for_cache)

check_max_number_of_cache(cache_path, cache_type, config.get('cache').get('max_number_of_cache'))
else:
result = func(*args, **kwargs)

return result

return wrapper

return storage_cache_decorator

if __name__ == '__main__':
text = "There are many things you can do to learn how to run faster, Mr. Wan, such as incorporating speed workouts into your running schedule, running hills, counting your strides, and adjusting your running form. Lean forward when you run and push off firmly with each foot. Pump your arms actively and keep your elbows bent at a 90-degree angle. Try to run every day, and gradually increase the distance you run for long-distance runs. Make sure you rest at least one day per week to allow your body to recover. Avoid running with excess gear that could slow you down."
sentences = split_sentences_from_paragraph(text)
Expand Down
9 changes: 7 additions & 2 deletions src/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ goose_ai_api:
model: gpt-neo-20b
max_tokens: 100
cache: # .cache result for efficiency and consistency
is_enable_cache: false
is_enable:
web: false
bing_search_website: false
bing_search_website_content: false
openai: false
gooseai: false
path: .cache
max_number_of_cache: 0
max_number_of_cache: 50