Skip to content
This repository has been archived by the owner on Jul 18, 2024. It is now read-only.

Commit

Permalink
Merge pull request #523 from yao531441/rag
Browse files Browse the repository at this point in the history
[v1.2][ISSUE-306]Add search tool for rag.
  • Loading branch information
xuechendi committed Jan 29, 2024
2 parents d059bbe + 4ebb38c commit 2db6849
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 17 deletions.
57 changes: 57 additions & 0 deletions RecDP/pyrecdp/primitives/llmutils/search_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from faiss import IndexFlatL2
from langchain_community.docstore import InMemoryDocstore
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.utilities.google_search import GoogleSearchAPIWrapper
from langchain_community.vectorstores.faiss import FAISS
from langchain_core.tools import Tool

from pyrecdp.LLM import TextPipeline
from pyrecdp.primitives.operations import UrlLoader, RAGTextFix, DocumentSplit, LengthFilter, GoogleSearchTool, \
DocumentIngestion


def db_similarity_search(query, db, k=4):
docs = db.similarity_search(query=query, k=k)
return docs

def get_search_results(query):
search = GoogleSearchTool(query=query)

text_splitter = "RecursiveCharacterTextSplitter"
splitter_chunk_size = 500
text_splitter_args = {
"chunk_size": splitter_chunk_size,
"chunk_overlap": 0,
"separators": ["\n\n", "\n", " ", ""],
}
embedding_model_name = 'sentence-transformers/all-MiniLM-L6-v2'
pipeline = TextPipeline()
ops = [search]

ops.extend(
[
UrlLoader(text_key='url', text_to_markdown=False),
RAGTextFix(re_sentence=True),
DocumentSplit(text_splitter=text_splitter, text_splitter_args=text_splitter_args),
LengthFilter(),
DocumentIngestion(
vector_store='FAISS',
vector_store_args={"in_memory": True, "index": 'search'},
embeddings='HuggingFaceEmbeddings',
embeddings_args={"model_name": embedding_model_name},
return_db_handler=True
)
]
)
pipeline.add_operations(ops)
db = pipeline.execute()
return db


if __name__ == '__main__':
query = "chatgpt latest version?"
db = get_search_results(query)
res = db_similarity_search(query, db)
for line in res:
print(line)
print("_" * 40)
1 change: 1 addition & 0 deletions RecDP/pyrecdp/primitives/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,6 @@
from .table_summary import TableSummary
from .text_spell_correct import TextSpellCorrect
from .text_contraction_remove import TextContractionRemove
from .search_tool import GoogleSearchTool
except Exception as e:
pass
25 changes: 15 additions & 10 deletions RecDP/pyrecdp/primitives/operations/doc_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(self, urls: List[str], save_dir: str = None, model='small',
self.model_name = model
self.num_cpus = num_cpus
os.system("apt-get -qq -y install ffmpeg")
check_availability_and_install(['langchain', 'pytube', 'openai-whisper', 'youtube-transcript-api'])
check_availability_and_install(['langchain', 'pytube', 'openai-whisper', 'youtube-transcript-api', 'yt_dlp'])

def process_rayds(self, ds=None):
import ray
Expand Down Expand Up @@ -240,7 +240,7 @@ def process_spark(self, spark, spark_df=None):
class UrlLoader(TextReader):
def __init__(
self,
urls: Union[str, List[str]],
urls: Union[str, List[str]] = None,
max_depth: Optional[int] = 1,
use_async: Optional[bool] = None,
extractor: Optional[Callable[[str], str]] = None,
Expand All @@ -254,6 +254,7 @@ def __init__(
text_to_markdown: bool = True,
requirements=None,
num_cpus: Optional[int] = None,
text_key: str = None,
) -> None:
"""Initialize with URL to crawl and any subdirectories to exclude.
Expand All @@ -279,6 +280,7 @@ def __init__(
check_response_status: If True, check HTTP response status and skip
URLs with error responses (400-599).
num_cpus: The number of CPUs to reserve for each parallel url read worker.
text_key: text key to process.
"""
if requirements is None:
requirements = ['bs4', 'markdownify', 'langchain']
Expand All @@ -296,28 +298,31 @@ def __init__(
'check_response_status': check_response_status,
}
settings = self.loader_kwargs.copy()
settings.update({'urls': urls, 'text_to_markdown': text_to_markdown, 'num_cpus': num_cpus})
settings.update({'urls': urls, 'text_to_markdown': text_to_markdown, 'num_cpus': num_cpus, 'text_key': text_key})
self.text_to_markdown = text_to_markdown
self.text_key = text_key
super().__init__(settings, requirements=requirements)
self.support_spark = True
self.support_ray = True
self.num_cpus = num_cpus

if isinstance(urls, str):
urls = [urls]

self.urls = set(urls)
if urls:
if isinstance(urls, str):
urls = [urls]
self.urls = set(urls)

def process_rayds(self, ds=None):
import ray
urls_ds = ray.data.from_items([{'url': url} for url in self.urls])
if self.text_key:
urls_ds = ds.select_columns(['url'])
else:
urls_ds = ray.data.from_items([{'url': url} for url in self.urls])

from pyrecdp.primitives.document.reader import read_from_url
self.cache = urls_ds.flat_map(
lambda record: read_from_url(record['url'], self.text_to_markdown, **self.loader_kwargs),
num_cpus=self.num_cpus)

if ds is not None:
if ds is not None and not self.text_key:
self.cache = self.union_ray_ds(ds, self.cache)
return self.cache

Expand Down
94 changes: 94 additions & 0 deletions RecDP/pyrecdp/primitives/operations/search_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Copyright 2024 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from .base import LLMOPERATORS

from langchain_core.tools import Tool
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.utilities.google_search import GoogleSearchAPIWrapper
from langchain_community.vectorstores.faiss import FAISS

import datetime

from .text_reader import TextReader


def get_search_tool(search_class, search_num):
def top5_results(query):
return search_class.results(query, search_num)

search_tool = Tool(
name="Search Tool",
description="Search Web for recent results.",
func=top5_results,
)

return search_tool


def generate_search_query(query, llm=None):
prompt_temp = ("You are tasked with generating web search queries. "
+ "Give me an appropriate query to answer my question for google search. "
+ "Answer with only the query. Today is {current_date}, Query: {query}")
prompt = prompt_temp.format(current_date=str(datetime.date.today()), query=query)
# TODO generate by llm:
return query


def content_similarity_search(query, texts, k=4,
embedding_model_name='sentence-transformers/all-MiniLM-L6-v2'):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
db = FAISS.from_texts(texts, embeddings)
docs = db.similarity_search(query=query, k=k)
return docs


class SearchTool(TextReader):
def __init__(self, query, search_num=5):
settings = {'search_num': search_num, 'query': query}
super().__init__(settings)
self.search_num = search_num
self.query = query
self.search_tool = None

self.support_spark = False
self.support_ray = True

def process_rayds(self, ds=None):
import ray
self.cache = ray.data.from_items(self.get_result_urls())
return self.cache

def get_result_urls(self):
search_keywords = generate_search_query(self.query)
res = self.search_tool.run(search_keywords)
if res:
result_urls = [{'url': x['link']} for x in res]
return result_urls
else:
return None


LLMOPERATORS.register(SearchTool)


class GoogleSearchTool(SearchTool):
def __init__(self, query, search_num=5):
super().__init__(query, search_num)
self.search_tool = get_search_tool(GoogleSearchAPIWrapper(), search_num=search_num)


LLMOPERATORS.register(GoogleSearchTool)
6 changes: 5 additions & 1 deletion RecDP/pyrecdp/primitives/operations/text_ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def do_persist(self, ds: Dataset):
check_availability_and_install(["langchain", "faiss-cpu"])

db = self.vector_store_args["db_handler"]
in_memory = self.vector_store_args.get("in_memory", False)
index_name = self.vector_store_args.get("index", "index")

rows = ds.iter_rows() if isinstance(ds, Dataset) else ds.collect()
Expand All @@ -140,11 +141,14 @@ def do_persist(self, ds: Dataset):
if db is not None:
db.add_embeddings(text_embeddings)
return db
embeddings = create_embeddings(self.embeddings, self.embeddings_args)
if in_memory:
db = FAISS.from_embeddings(text_embeddings, embedding=embeddings)
return db

if "output_dir" not in self.vector_store_args:
raise ValueError(f"You must have `output_dir` option specify for FAAIS vector store")
faiss_folder_path = self.vector_store_args["output_dir"]
embeddings = create_embeddings(self.embeddings, self.embeddings_args)
if not self.override and os.path.exists(os.path.join(faiss_folder_path, index_name + ".faiss")):
db = FAISS.load_local(faiss_folder_path, embeddings, index_name)
db.add_embeddings(text_embeddings)
Expand Down
2 changes: 1 addition & 1 deletion RecDP/pyrecdp/primitives/operations/text_pii_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, text_key='text', inplace=True, model_root_path="", debug_mode
entity_types=None):
settings = {'text_key': text_key, 'inplace': inplace, 'model_root_path': model_root_path,
'debug_mode': debug_mode, 'entity_types': entity_types}
requirements = ['torch', 'transformers', 'Faker', "phonenumbers"]
requirements = ['torch', 'transformers', 'Faker', "phonenumbers", 'gibberish_detector']
super().__init__(settings, requirements)
from pyrecdp.primitives.llmutils.pii.detect.utils import PIIEntityType
self.text_key = text_key
Expand Down
5 changes: 3 additions & 2 deletions RecDP/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def __init__(self):
self.version = find_version()
self.files_to_include: list = []
self.install_requires: list = [
"scikit-learn",
"scikit-learn==1.3.2",
"psutil",
"tqdm",
"pyyaml",
"pandas",
"pandas==2.1.4",
"numpy",
"pyarrow",
"ipywidgets",
Expand All @@ -48,6 +48,7 @@ def __init__(self):
"typer>=0.6.1",
"scipy==1.10.1",
"tabulate==0.9.0",
"grpcio",
]
self.extras: dict = {}
self.extras['autofe'] = list_requirements("pyrecdp/autofe/requirements.txt")
Expand Down
6 changes: 3 additions & 3 deletions RecDP/tests/test_llmutils_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def test_filter_by_profanity_ray(self):
def test_filter_by_url_ray(self):
pass
# Ray version not supported yet
op = URLFilter()
with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
ctx.show(op.process_rayds(ctx.ds))
# op = URLFilter()
# with RayContext("tests/data/llm_data/tiny_c4_sample.jsonl") as ctx:
# ctx.show(op.process_rayds(ctx.ds))

def test_filter_by_alphanumeric_ray(self):
pass
Expand Down

0 comments on commit 2db6849

Please sign in to comment.