Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/rag optimize #538

Merged
merged 2 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tqdm import tqdm
from langchain.docstore.document import Document
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.llm import LLMChain
from bisheng_langchain.retrievers import EnsembleRetriever
from bisheng_langchain.vectorstores import ElasticKeywordsSearch, Milvus
from bisheng_langchain.rag.init_retrievers import (
Expand Down Expand Up @@ -69,10 +70,16 @@ def __init__(self, yaml_path) -> None:
)

# es
if self.params['elasticsearch'].get('extract_key_by_llm', False):
extract_key_prompt = import_class(f'bisheng_langchain.rag.prompts.EXTRACT_KEY_PROMPT')
llm_chain = LLMChain(llm=self.llm, prompt=extract_key_prompt)
else:
llm_chain = None
self.keyword_store = ElasticKeywordsSearch(
index_name='default_es',
elasticsearch_url=self.params['elasticsearch']['url'],
ssl_verify=self.params['elasticsearch']['ssl_verify'],
llm_chain=llm_chain
)

# init retriever
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
data:
origin_file_path: '/home/public/rag_benchmark_finance_report'
question: '/home/public/rag_benchmark_finance_report/finance_report_data_100_single.xlsx'
save_answer: '/home/public/rag_benchmark_finance_report/finance_report_data_100_single_qwen1.5_72b_20chunk_chunk_size_1000_with_source_title.xlsx'
save_answer: '/home/public/rag_benchmark_finance_report/finance_report_data_100_single_command-r-plus_20chunk_chunk_size_1000_with_source_title_overlap100.xlsx'

milvus:
host: '110.16.193.170'
Expand All @@ -13,6 +13,7 @@ elasticsearch:
ssl_verify:
basic_auth: ["elastic", "oSGL-zVvZ5P3Tm7qkDLC"]
drop_old: True
extract_key_by_llm: False

embedding:
type: 'OpenAIEmbeddings'
Expand All @@ -27,26 +28,34 @@ embedding:
# openai_proxy: ''
# temperature: 0.0

chat_llm:
type: 'ChatCohere'
model: 'command-r-plus'
cohere_api_key: ''
max_tokens: 1000
temperature: 0.01

# chat_llm:
# type: 'ChatCohere'
# model: 'command-r-plus'
# cohere_api_key: ''
# max_tokens: 1000
# type: 'ChatQWen'
# model_name: 'qwen1.5-110b-chat'
# api_key: ''
# temperature: 0.01

chat_llm:
type: 'ChatQWen'
model_name: 'qwen1.5-72b-chat'
api_key: ''
temperature: 0.01
# chat_llm:
# type: 'ChatOpenAI'
# model: 'qwen1.5-110b-chat'
# openai_api_base: 'http://60.31.21.42:12511/v1'
# openai_api_key : "Z9b8x3V7C2n0Q5T"
# openai_proxy: ''
# temperature: 0.01

loader:
type: 'ElemUnstructuredLoader'
unstructured_api_url: 'https://bisheng.dataelem.com/api/v1/etl4llm/predict'

retriever:
type: 'EnsembleRetriever' # 不动
suffix: 'benchmark_caibao_1000_source_title'
suffix: 'benchmark_caibao_1000_source_title_overlap100'
add_aux_info: True
retrievers:
- type: 'KeywordRetriever'
Expand All @@ -55,7 +64,7 @@ retriever:
# type: 'ElemCharacterTextSplitter'
type: 'RecursiveCharacterTextSplitter'
chunk_size: 1000
chunk_overlap: 0
chunk_overlap: 100
separators: ["\n\n"]
retrieval:
search_type: 'similarity'
Expand All @@ -67,7 +76,7 @@ retriever:
# type: 'ElemCharacterTextSplitter'
type: 'RecursiveCharacterTextSplitter'
chunk_size: 1000
chunk_overlap: 0
chunk_overlap: 100
separators: ["\n\n"]
retrieval:
search_type: 'similarity'
Expand All @@ -86,7 +95,7 @@ post_retrieval:

generate:
with_retrieval: True
max_content: 30000
max_content: 100000
chain_type: 'stuff'
# prompt_type: 'BASE_PROMPT'
prompt_type: 'CHAT_PROMPT'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _get_relevant_documents(
index_name=collection_name,
elasticsearch_url=self.keyword_store.elasticsearch_url,
ssl_verify=self.keyword_store.ssl_verify,
llm_chain=self.keyword_store.llm_chain
)
if self.search_type == 'similarity':
result = self.keyword_store.similarity_search(query, **self.search_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _get_relevant_documents(
index_name=collection_name,
elasticsearch_url=self.keyword_store.elasticsearch_url,
ssl_verify=self.keyword_store.ssl_verify,
llm_chain=self.keyword_store.llm_chain
)
self.vector_store = self.vector_store.__class__(
collection_name=collection_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .prompt import BASE_PROMPT, CHAT_PROMPT, CHAT_PROMPT_GENERAL
from .extract_key_prompt import EXTRACT_KEY_PROMPT

__all__ = [
'BASE_PROMPT',
'CHAT_PROMPT',
'CHAT_PROMPT_GENERAL',
'EXTRACT_KEY_PROMPT',
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from langchain.prompts.prompt import PromptTemplate


EXTRACT_KEY_PROMPT = PromptTemplate(
input_variables=['question'],
template="""分析给定Question,提取Question中包含的KeyWords,输出列表形式

Examples:
Question: 达梦公司在过去三年中的流动比率如下:2021年:3.74倍;2020年:2.82倍;2019年:2.05倍。
KeyWords: ['过去三年', '流动比率', '2021', '3.74', '2020', '2.82', '2019', '2.05']

----------------
Question: {question}
KeyWords: """,
)

# EXTRACT_KEY_PROMPT = PromptTemplate(
# input_variables=['question'],
# template="""分析给定Question,提取Question中包含的KeyWords,输出列表形式

# Examples:
# Question: 能否根据2020年金宇生物技术股份有限公司的年报,给我简要介绍一下报告期内公司的社会责任工作情况?
# KeyWords: ['报告期', '社会责任', '工作情况']

# Question: 请根据江化微2019年的年报,简要介绍报告期内公司主要销售客户的客户集中度情况,并结合同行业情况进行分析。
# KeyWords: ['报告期', '主要', '销售客户', '客户集中度', '同行业', '分析']

# Question: 请问,在苏州迈为科技股份有限公司2019年的年报中,现金流的情况是否发生了重大变化?若发生,导致重大变化的原因是什么?
# KeyWords: ['现金流', '重大变化', '原因']

# ----------------
# Question: {question}
# KeyWords: """,
# )
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def similarity_search_with_score(self,
# llm or jiaba extract keywords
if self.llm_chain:
keywords_str = self.llm_chain.run(query)
print('keywords_str:', keywords_str)
print('llm search keywords:', keywords_str)
try:
keywords = eval(keywords_str)
if not isinstance(keywords, list):
Expand All @@ -238,6 +238,7 @@ def similarity_search_with_score(self,
keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False)
else:
keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False)
print('jieba search keywords:', keywords)
match_query = {'bool': {must_or_should: []}}
for key in keywords:
match_query['bool'][must_or_should].append({query_strategy: {'text': key}})
Expand Down