In [1]:
# pip install -U transformers[agents] jupyter ipywidgets
# pip install -U langchain-community==0.2.1 langchain-core==0.2.1 sentence-transformers faiss-cpu

In [2]:
import datasets
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings

knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")

source_docs = [
    Document(
        page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}
    ) for doc in knowledge_base
]
docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000]

all_sources = list(set([doc.metadata["source"] for doc in docs_processed]))
print(all_sources)


embedding_model = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
vectordb = FAISS.from_documents(
    documents=docs_processed,
    embedding=embedding_model
)

['datasets', 'hf-endpoints-documentation', 'hub-docs', 'course', 'evaluate', 'deep-rl-class', 'diffusers', 'transformers', 'pytorch-image-models', 'datasets-server', 'gradio', 'optimum', 'peft', 'blog']




In [3]:
import json
from transformers.agents import Tool
from langchain_core.vectorstores import VectorStore

class RetrieverTool(Tool):
    name = "retriever"
    description = "Retrieves some documents from the knowledge base that have the closest embeddings to the input query."
    inputs = {
        "query": {
            "type": "text",
            "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
        },
        "source": {
            "type": "text", 
            "description": ""
        },
    }
    output_type = "text"
    
    def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs):
        super().__init__(**kwargs)
        self.vectordb = vectordb
        self.inputs["source"]["description"] = (
            f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched."
          )

    def forward(self, query: str, source: str = None) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        if source:
            if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list
                source = [source]
            source = json.loads(str(source).replace("'", '"'))

        docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3)

        if len(docs) == 0:
            return "No documents found with this filtering. Try removing the source filter."
        return "Retrieved documents:\n\n" + "\n===Document===\n".join(
            [doc.page_content for doc in docs]
        )


In [4]:
from transformers.agents import ReactJsonAgent, HfEngine

agent = ReactJsonAgent(
    tools=[RetrieverTool(vectordb, all_sources)],
    llm_engine=HfEngine("http://127.0.0.1:8087")
)

In [5]:
agent_output = agent.run("Please show me a LORA finetuning script")
print(f"Final output:\n{agent_output}")

[37;1mPlease show me a LORA finetuning script[0m
[33;1mCalling tool: 'retriever' with arguments: {'query': 'LORA finetuning script', 'source': "['transformers', 'hf-endpoints-documentation']"}[0m
[33;1mCalling tool: 'retriever' with arguments: {'query': 'LORA finetuning script'}[0m
[33;1mCalling tool: 'final_answer' with arguments: {'answer': 'https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/text_to_image_lora.py'}[0m


Final output:
https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/text_to_image_lora.py
