Skip to content

Commit

Permalink
updated deps
Browse files Browse the repository at this point in the history
  • Loading branch information
flepied committed Jun 7, 2024
1 parent d7be79f commit 990ea04
Show file tree
Hide file tree
Showing 4 changed files with 2,162 additions and 1,938 deletions.
22 changes: 8 additions & 14 deletions lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@
import time

import chromadb
from langchain.chains import VectorDBQAWithSourcesChain
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.indexes.vectorstore import VectorStoreIndexWrapper
from langchain_community.embeddings import HuggingFaceEmbeddings

# pylint: disable=no-name-in-module
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings

# pylint: disable=no-name-in-module
from langchain_openai import OpenAI
Expand Down Expand Up @@ -120,10 +118,9 @@ def __init__(self):
"Initialize the agent"
self.vectorstore = get_vectorstore()
self.llm = OpenAI(temperature=0)
self.chain = VectorDBQAWithSourcesChain.from_llm(
self.chain = RetrievalQAWithSourcesChain.from_llm(
llm=self.llm,
vectorstore=self.vectorstore,
return_source_documents=True,
retriever=self.vectorstore.as_retriever(),
)
try:
with open("documents-desc.txt", "r", encoding="UTF-8") as desc_file:
Expand All @@ -134,9 +131,10 @@ def __init__(self):
def question(self, user_question):
"Ask a question and format the answer for text"
response = self._get_response(user_question)
print(f"{response=}", file=sys.stderr)
if (
response["sources"] not in ("None.", "N/A", "I don't know.")
and len(response["source_documents"]) > 0
and len(response["sources"]) > 0
):
sources = "- " + "\n- ".join(self._filter_file(self._get_sources(response)))
return f"{response['answer']}\nSources:\n{sources}"
Expand All @@ -147,7 +145,7 @@ def html_question(self, user_question):
response = self._get_response(user_question)
if (
response["sources"] not in ("None.", "N/A", "I don't know.")
and len(response["source_documents"]) > 0
and len(response["sources"]) > 0
):
sources = "- " + "\n- ".join(
[
Expand Down Expand Up @@ -228,11 +226,7 @@ def _regular_question(self, user_question):

def _get_sources(self, resp):
"Get the url instead of the chunk sources"
ret = []
for doc in resp["source_documents"]:
if doc.metadata["url"] not in ret:
ret.append(doc.metadata["url"])
return ret
return resp["sources"].split(", ")

def _build_filter(self, metadata):
"Build the filter for the vector store from the metadata"
Expand Down
Loading

0 comments on commit 990ea04

Please sign in to comment.