Skip to content

Commit

Permalink
feat: support add and delete from MySQLVectorStore (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Mar 27, 2024
1 parent a1c9411 commit ce45617
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/langchain_google_cloud_sql_mysql/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def init_vectorstore_table(
table_name (str): The MySQL database table name.
vector_size (int): Vector size for the embedding model to be used.
content_column (str): Name of the column to store document content.
Deafult: `page_content`.
Default: `page_content`.
embedding_column (str) : Name of the column to store vector embeddings.
Default: `embedding`.
metadata_columns (List[Column]): A list of Columns to create for custom
Expand Down
26 changes: 26 additions & 0 deletions src/langchain_google_cloud_sql_mysql/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,32 @@ def add_texts(
)
return ids

def add_documents(
self,
documents: List[Document],
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
ids = self.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
return ids

def delete(
self,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> bool:
if not ids:
return False

id_list = ", ".join([f"'{id}'" for id in ids])
query = (
f"DELETE FROM `{self.table_name}` WHERE `{self.id_column}` in ({id_list})"
)
self.engine._execute(query)
return True

@classmethod
def from_texts( # type: ignore[override]
cls: Type[MySQLVectorStore],
Expand Down
56 changes: 56 additions & 0 deletions tests/integration/test_mysql_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,30 @@ def test_add_texts_edge_cases(self, engine, vs):
assert len(results) == 3
engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`")

def test_add_docs(self, engine, vs):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
vs.add_documents(docs, ids=ids)
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
assert len(results) == 3
engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`")

def test_add_embedding(self, engine, vs):
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
vs._add_embeddings(texts, embeddings, metadatas, ids)
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
assert len(results) == 3
engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`")

def test_delete(self, engine, vs):
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
vs.add_texts(texts, ids=ids)
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
assert len(results) == 3
# delete an ID
vs.delete([ids[0]])
results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`")
assert len(results) == 2

def test_add_texts_custom(self, engine, vs_custom):
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
vs_custom.add_texts(texts, ids=ids)
Expand All @@ -172,11 +189,50 @@ def test_add_texts_custom(self, engine, vs_custom):
assert len(results) == 6
engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`")

def test_add_docs_custom(self, engine, vs_custom):
ids = [str(uuid.uuid4()) for i in range(len(texts))]
docs = [
Document(
page_content=texts[i],
metadata={"page": str(i), "source": "google.com"},
)
for i in range(len(texts))
]
vs_custom.add_documents(docs, ids=ids)

results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
content = [result["mycontent"] for result in results]
assert len(results) == 3
assert "foo" in content
assert "bar" in content
assert "baz" in content
assert results[0]["myembedding"]
pages = [result["page"] for result in results]
assert "0" in pages
assert "1" in pages
assert "2" in pages
assert results[0]["source"] == "google.com"
engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`")

def test_add_embedding_custom(self, engine, vs_custom):
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
vs_custom._add_embeddings(texts, embeddings, metadatas, ids)
results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
assert len(results) == 3
engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`")

def test_delete_custom(self, engine, vs_custom):
ids = [str(uuid.uuid4()) for _ in range(len(texts))]
vs_custom.add_texts(texts, ids=ids)
results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
content = [result["mycontent"] for result in results]
assert len(results) == 3
assert "foo" in content
# delete an ID
vs_custom.delete([ids[0]])
results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`")
content = [result["mycontent"] for result in results]
assert len(results) == 2
assert "foo" not in content

# Need tests for store metadata=False

0 comments on commit ce45617

Please sign in to comment.