Skip to content

Commit

Permalink
switch to partner lib for weaviate (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda committed Jul 15, 2024
1 parent 23463ad commit a588739
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 233 deletions.
14 changes: 7 additions & 7 deletions backend/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import weaviate
from langchain_anthropic import ChatAnthropic
from langchain_cohere import ChatCohere
from langchain_community.vectorstores import Weaviate
from langchain_core.documents import Document
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
Expand All @@ -24,6 +23,7 @@
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_weaviate import WeaviateVectorStore
from langgraph.graph import END, StateGraph, add_messages

from backend.constants import WEAVIATE_DOCS_INDEX_NAME
Expand Down Expand Up @@ -184,18 +184,18 @@ class AgentState(TypedDict):


def get_retriever() -> BaseRetriever:
weaviate_client = weaviate.Client(
url=os.environ["WEAVIATE_URL"],
auth_client_secret=weaviate.AuthApiKey(
api_key=os.environ.get("WEAVIATE_API_KEY", "not_provided")
weaviate_client = weaviate.connect_to_wcs(
cluster_url=os.environ["WEAVIATE_URL"],
auth_credentials=weaviate.classes.init.Auth.api_key(
os.environ.get("WEAVIATE_API_KEY", "not_provided")
),
skip_init_checks=True,
)
weaviate_client = Weaviate(
weaviate_client = WeaviateVectorStore(
client=weaviate_client,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=get_embeddings_model(),
by_text=False,
attributes=["source", "title"],
)
return weaviate_client.as_retriever(search_kwargs=dict(k=6))
Expand Down
18 changes: 11 additions & 7 deletions backend/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from langchain.indexes import SQLRecordManager, index
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.utils.html import PREFIXES_TO_IGNORE_REGEX, SUFFIXES_TO_IGNORE_REGEX
from langchain_community.vectorstores import Weaviate
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings
from langchain_weaviate import WeaviateVectorStore

from backend.constants import WEAVIATE_DOCS_INDEX_NAME
from backend.parser import langchain_docs_extractor
Expand Down Expand Up @@ -103,16 +103,16 @@ def ingest_docs():
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200)
embedding = get_embeddings_model()

client = weaviate.Client(
url=WEAVIATE_URL,
auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY),
client = weaviate.connect_to_wcs(
cluster_url=WEAVIATE_URL,
auth_credentials=weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY),
skip_init_checks=True,
)
vectorstore = Weaviate(
vectorstore = WeaviateVectorStore(
client=client,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=embedding,
by_text=False,
attributes=["source", "title"],
)

Expand Down Expand Up @@ -152,7 +152,11 @@ def ingest_docs():
)

logger.info(f"Indexing stats: {indexing_stats}")
num_vecs = client.query.aggregate(WEAVIATE_DOCS_INDEX_NAME).with_meta_count().do()
num_vecs = (
client.collections.get(WEAVIATE_DOCS_INDEX_NAME)
.aggregate.over_all()
.total_count
)
logger.info(
f"LangChain now has this many vectors: {num_vecs}",
)
Expand Down
Loading

0 comments on commit a588739

Please sign in to comment.