diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index 61e569b..fae720a 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -43,7 +43,7 @@ availableSecrets: env: "DB_PASSWORD" substitutions: - _INSTANCE_ID: test-instance + _INSTANCE_ID: mysql-vector _REGION: us-central1 _DB_NAME: test _VERSION: "3.8" diff --git a/src/langchain_google_cloud_sql_mysql/__init__.py b/src/langchain_google_cloud_sql_mysql/__init__.py index 72d5e3c..29d2540 100644 --- a/src/langchain_google_cloud_sql_mysql/__init__.py +++ b/src/langchain_google_cloud_sql_mysql/__init__.py @@ -13,14 +13,17 @@ # limitations under the License. from .chat_message_history import MySQLChatMessageHistory -from .engine import MySQLEngine +from .engine import Column, MySQLEngine from .loader import MySQLDocumentSaver, MySQLLoader +from .vectorstore import MySQLVectorStore from .version import __version__ __all__ = [ + "Column", "MySQLChatMessageHistory", "MySQLDocumentSaver", "MySQLEngine", "MySQLLoader", + "MySQLVectorStore", "__version__", ] diff --git a/src/langchain_google_cloud_sql_mysql/chat_message_history.py b/src/langchain_google_cloud_sql_mysql/chat_message_history.py index c51e607..56584b6 100644 --- a/src/langchain_google_cloud_sql_mysql/chat_message_history.py +++ b/src/langchain_google_cloud_sql_mysql/chat_message_history.py @@ -25,7 +25,7 @@ class MySQLChatMessageHistory(BaseChatMessageHistory): """Chat message history stored in a Cloud SQL MySQL database. Args: - engine (MySQLEngine): SQLAlchemy connection pool engine for managing + engine (MySQLEngine): Connection pool engine for managing connections to Cloud SQL for MySQL. session_id (str): Arbitrary key that is used to store the messages of a single chat session. diff --git a/src/langchain_google_cloud_sql_mysql/engine.py b/src/langchain_google_cloud_sql_mysql/engine.py index 27a590c..63a53a6 100644 --- a/src/langchain_google_cloud_sql_mysql/engine.py +++ b/src/langchain_google_cloud_sql_mysql/engine.py @@ -31,6 +31,21 @@ USER_AGENT = "langchain-google-cloud-sql-mysql-python/" + __version__ +from dataclasses import dataclass + + +@dataclass +class Column: + name: str + data_type: str + nullable: bool = True + + def __post_init__(self): + if not isinstance(self.name, str): + raise ValueError("Column name must be type string") + if not isinstance(self.data_type, str): + raise ValueError("Column data_type must be type string") + def _get_iam_principal_email( credentials: google.auth.credentials.Credentials, @@ -206,6 +221,20 @@ def connect(self) -> sqlalchemy.engine.Connection: """ return self.engine.connect() + def _execute(self, query: str, params: Optional[dict] = None) -> None: + """Execute a SQL query.""" + with self.engine.connect() as conn: + conn.execute(sqlalchemy.text(query), params) + conn.commit() + + def _fetch(self, query: str, params: Optional[dict] = None): + """Fetch results from a SQL query.""" + with self.engine.connect() as conn: + result = conn.execute(sqlalchemy.text(query), params) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + def init_chat_history_table(self, table_name: str) -> None: """Create table with schema required for MySQLChatMessageHistory class. @@ -293,3 +322,51 @@ def _load_document_table(self, table_name: str) -> sqlalchemy.Table: metadata = sqlalchemy.MetaData() sqlalchemy.MetaData.reflect(metadata, bind=self.engine, only=[table_name]) return metadata.tables[table_name] + + def init_vectorstore_table( + self, + table_name: str, + vector_size: int, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[Column] = [], + metadata_json_column: str = "langchain_metadata", + id_column: str = "langchain_id", + overwrite_existing: bool = False, + store_metadata: bool = True, + ) -> None: + """ + Create a table for saving of vectors to be used with MySQLVectorStore. + + Args: + table_name (str): The MySQL database table name. + vector_size (int): Vector size for the embedding model to be used. + content_column (str): Name of the column to store document content. + Deafult: `page_content`. + embedding_column (str) : Name of the column to store vector embeddings. + Default: `embedding`. + metadata_columns (List[Column]): A list of Columns to create for custom + metadata. Default: []. Optional. + metadata_json_column (str): The column to store extra metadata in JSON format. + Default: `langchain_metadata`. Optional. + id_column (str): Name of the column to store ids. + Default: `langchain_id`. Optional, + overwrite_existing (bool): Whether to drop existing table. Default: False. + store_metadata (bool): Whether to store metadata in the table. + Default: True. + """ + query = f"""CREATE TABLE `{table_name}`( + `{id_column}` CHAR(36) PRIMARY KEY, + `{content_column}` TEXT NOT NULL, + `{embedding_column}` vector({vector_size}) USING VARBINARY NOT NULL""" + for column in metadata_columns: + nullable = "NOT NULL" if not column.nullable else "" + query += f",\n`{column.name}` {column.data_type} {nullable}" + if store_metadata: + query += f""",\n`{metadata_json_column}` JSON""" + query += "\n);" + + with self.engine.connect() as conn: + if overwrite_existing: + conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`")) + conn.execute(sqlalchemy.text(query)) diff --git a/src/langchain_google_cloud_sql_mysql/indexes.py b/src/langchain_google_cloud_sql_mysql/indexes.py new file mode 100644 index 0000000..d038abb --- /dev/null +++ b/src/langchain_google_cloud_sql_mysql/indexes.py @@ -0,0 +1,22 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from dataclasses import dataclass + + +@dataclass +class QueryOptions(ABC): + def to_string(self) -> str: + raise NotImplementedError("to_string method must be implemented by subclass") diff --git a/src/langchain_google_cloud_sql_mysql/vectorstore.py b/src/langchain_google_cloud_sql_mysql/vectorstore.py new file mode 100644 index 0000000..602848e --- /dev/null +++ b/src/langchain_google_cloud_sql_mysql/vectorstore.py @@ -0,0 +1,257 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Remove below import when minimum supported Python version is 3.10 +from __future__ import annotations + +import json +from typing import Any, Iterable, List, Optional, Type + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VectorStore + +from .engine import MySQLEngine +from .indexes import QueryOptions + + +class MySQLVectorStore(VectorStore): + def __init__( + self, + engine: MySQLEngine, + embedding_service: Embeddings, + table_name: str, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: Optional[str] = "langchain_metadata", + query_options: Optional[QueryOptions] = None, + ): + """Constructor for MySQLVectorStore. + Args: + engine (MySQLEngine): Connection pool engine for managing + connections to Cloud SQL for MySQL database. + embedding_service (Embeddings): Text embedding model to use. + table_name (str): Name of an existing table or table to be created. + content_column (str): Column that represent a Document's + page_content. Defaults to "content". + embedding_column (str): Column for embedding vectors. The embedding + is generated from the document value. Defaults to "embedding". + metadata_columns (List[str]): Column(s) that represent a document's metadata. + ignore_metadata_columns (List[str]): Column(s) to ignore in + pre-existing tables for a document's metadata. Can not be used + with metadata_columns. Defaults to None. + id_column (str): Column that represents the Document's id. + Defaults to "langchain_id". + metadata_json_column (str): Column to store metadata as JSON. + Defaults to "langchain_metadata". + """ + if metadata_columns and ignore_metadata_columns: + raise ValueError( + "Can not use both metadata_columns and ignore_metadata_columns." + ) + # Get field type information + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'" + + results = engine._fetch(stmt) + columns = {} + for field in results: + columns[field["COLUMN_NAME"]] = field["DATA_TYPE"] + + # Check columns + if id_column not in columns: + raise ValueError(f"Id column, {id_column}, does not exist.") + if content_column not in columns: + raise ValueError(f"Content column, {content_column}, does not exist.") + content_type = columns[content_column] + if content_type != "text" and "char" not in content_type: + raise ValueError( + f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." + ) + if embedding_column not in columns: + raise ValueError(f"Embedding column, {embedding_column}, does not exist.") + if columns[embedding_column] != "varbinary": + raise ValueError( + f"Embedding column, {embedding_column}, is not type Vector (varbinary)." + ) + + metadata_json_column = ( + None if metadata_json_column not in columns else metadata_json_column + ) + + # If using metadata_columns check to make sure column exists + for column in metadata_columns: + if column not in columns: + raise ValueError(f"Metadata column, {column}, does not exist.") + + # If using ignore_metadata_columns, filter out known columns and set known metadata columns + all_columns = columns + if ignore_metadata_columns: + for column in ignore_metadata_columns: + del all_columns[column] + + del all_columns[id_column] + del all_columns[content_column] + del all_columns[embedding_column] + metadata_columns = [key for key, _ in all_columns.keys()] + + # set all class attributes + self.engine = engine + self.embedding_service = embedding_service + self.table_name = table_name + self.content_column = content_column + self.embedding_column = embedding_column + self.metadata_columns = metadata_columns + self.id_column = id_column + self.metadata_json_column = metadata_json_column + self.query_options = query_options + + @property + def embeddings(self) -> Embeddings: + return self.embedding_service + + def _add_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + if not ids: + ids = ["NULL" for _ in texts] + if not metadatas: + metadatas = [{} for _ in texts] + # Insert embeddings + for id, content, embedding, metadata in zip(ids, texts, embeddings, metadatas): + metadata_col_names = ( + ", " + ", ".join(self.metadata_columns) + if len(self.metadata_columns) > 0 + else "" + ) + insert_stmt = f"INSERT INTO `{self.table_name}`(`{self.id_column}`, `{self.content_column}`, `{self.embedding_column}`{metadata_col_names}" + values = {"id": id, "content": content, "embedding": str(embedding)} + values_stmt = "VALUES (:id, :content, string_to_vector(:embedding)" + + # Add metadata + extra = metadata + for metadata_column in self.metadata_columns: + if metadata_column in metadata: + values_stmt += f", :{metadata_column}" + values[metadata_column] = metadata[metadata_column] + del extra[metadata_column] + else: + values_stmt += ",null" + + # Add JSON column and/or close statement + insert_stmt += ( + f", {self.metadata_json_column})" if self.metadata_json_column else ")" + ) + if self.metadata_json_column: + values_stmt += ", :extra)" + values["extra"] = json.dumps(extra) + else: + values_stmt += ")" + + query = insert_stmt + values_stmt + self.engine._execute(query, values) + + return ids + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + embeddings = self.embedding_service.embed_documents(list(texts)) + ids = self._add_embeddings( + texts, embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + return ids + + @classmethod + def from_texts( # type: ignore[override] + cls: Type[MySQLVectorStore], + texts: List[str], + embedding: Embeddings, + engine: MySQLEngine, + table_name: str, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: str = "langchain_metadata", + **kwargs: Any, + ): + vs = cls( + engine=engine, + embedding_service=embedding, + table_name=table_name, + content_column=content_column, + embedding_column=embedding_column, + metadata_columns=metadata_columns, + ignore_metadata_columns=ignore_metadata_columns, + id_column=id_column, + metadata_json_column=metadata_json_column, + ) + vs.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + return vs + + @classmethod + def from_documents( # type: ignore[override] + cls: Type[MySQLVectorStore], + documents: List[Document], + embedding: Embeddings, + engine: MySQLEngine, + table_name: str, + ids: Optional[List[str]] = None, + content_column: str = "content", + embedding_column: str = "embedding", + metadata_columns: List[str] = [], + ignore_metadata_columns: Optional[List[str]] = None, + id_column: str = "langchain_id", + metadata_json_column: str = "langchain_metadata", + **kwargs: Any, + ) -> MySQLVectorStore: + vs = cls( + engine=engine, + embedding_service=embedding, + table_name=table_name, + content_column=content_column, + embedding_column=embedding_column, + metadata_columns=metadata_columns, + ignore_metadata_columns=ignore_metadata_columns, + id_column=id_column, + metadata_json_column=metadata_json_column, + ) + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + vs.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs) + return vs + + def similarity_search( + self, + query: str, + k: Optional[int] = None, + filter: Optional[str] = None, + **kwargs: Any, + ): + raise NotImplementedError diff --git a/tests/integration/test_mysql_vectorstore.py b/tests/integration/test_mysql_vectorstore.py new file mode 100644 index 0000000..a362cea --- /dev/null +++ b/tests/integration/test_mysql_vectorstore.py @@ -0,0 +1,182 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +from langchain_community.embeddings import DeterministicFakeEmbedding +from langchain_core.documents import Document + +from langchain_google_cloud_sql_mysql import Column, MySQLEngine, MySQLVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "test-table-custom" + str(uuid.uuid4()) +VECTOR_SIZE = 768 + +embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) + +texts = ["foo", "bar", "baz"] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +docs = [ + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) +] + +embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +class TestVectorStore: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DB_NAME", "database name on cloud sql instance") + + @pytest.fixture(scope="class") + def engine(self, db_project, db_region, db_instance, db_name): + engine = MySQLEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + + yield engine + + @pytest.fixture(scope="function") + def vs(self, engine): + engine.init_vectorstore_table( + DEFAULT_TABLE, + VECTOR_SIZE, + overwrite_existing=True, + ) + + vs = MySQLVectorStore( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ) + yield vs + engine._execute(f"DROP TABLE IF EXISTS `{DEFAULT_TABLE}`") + + @pytest.fixture(scope="function") + def vs_custom(self, engine): + engine.init_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + metadata_json_column="mymeta", + overwrite_existing=True, + ) + + vs = MySQLVectorStore( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + metadata_json_column="mymeta", + ) + yield vs + engine._execute(f"DROP TABLE IF EXISTS `{CUSTOM_TABLE}`") + + def test_post_init(self, engine): + with pytest.raises(ValueError): + MySQLVectorStore( + engine, + embedding_service=embeddings_service, + table_name=CUSTOM_TABLE, + id_column="myid", + content_column="noname", + embedding_column="myembedding", + metadata_columns=["page", "source"], + metadata_json_column="mymeta", + ) + + def test_add_texts(self, engine, vs): + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + vs.add_texts(texts, ids=ids) + results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`") + assert len(results) == 3 + + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + vs.add_texts(texts, metadatas, ids) + results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`") + assert len(results) == 6 + engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`") + + def test_add_texts_edge_cases(self, engine, vs): + texts = ["Taylor's", '"Swift"', "best-friend"] + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + vs.add_texts(texts, ids=ids) + results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`") + assert len(results) == 3 + engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`") + + def test_add_embedding(self, engine, vs): + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + vs._add_embeddings(texts, embeddings, metadatas, ids) + results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`") + assert len(results) == 3 + engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`") + + def test_add_texts_custom(self, engine, vs_custom): + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + vs_custom.add_texts(texts, ids=ids) + results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`") + content = [result["mycontent"] for result in results] + assert len(results) == 3 + assert "foo" in content + assert "bar" in content + assert "baz" in content + assert results[0]["myembedding"] + assert results[0]["page"] is None + assert results[0]["source"] is None + + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + vs_custom.add_texts(texts, metadatas, ids) + results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`") + assert len(results) == 6 + engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`") + + def test_add_embedding_custom(self, engine, vs_custom): + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + vs_custom._add_embeddings(texts, embeddings, metadatas, ids) + results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`") + assert len(results) == 3 + engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`") + + # Need tests for store metadata=False diff --git a/tests/integration/test_mysql_vectorstore_from_methods.py b/tests/integration/test_mysql_vectorstore_from_methods.py new file mode 100644 index 0000000..c165559 --- /dev/null +++ b/tests/integration/test_mysql_vectorstore_from_methods.py @@ -0,0 +1,169 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +from langchain_community.embeddings import DeterministicFakeEmbedding +from langchain_core.documents import Document + +from langchain_google_cloud_sql_mysql import Column, MySQLEngine, MySQLVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") +CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") +VECTOR_SIZE = 768 + + +embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) + +texts = ["foo", "bar", "baz"] +metadatas = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] +docs = [ + Document(page_content=texts[i], metadata=metadatas[i]) for i in range(len(texts)) +] + +embeddings = [embeddings_service.embed_query(texts[i]) for i in range(len(texts))] + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +class TestVectorStoreFromMethods: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DB_NAME", "database name on cloud sql instance") + + @pytest.fixture + def engine(self, db_project, db_region, db_instance, db_name): + engine = MySQLEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + engine.init_vectorstore_table(DEFAULT_TABLE, VECTOR_SIZE) + engine.init_vectorstore_table( + CUSTOM_TABLE, + VECTOR_SIZE, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=[Column("page", "TEXT"), Column("source", "TEXT")], + store_metadata=False, + ) + yield engine + engine._execute(f"DROP TABLE IF EXISTS `{DEFAULT_TABLE}`") + engine._execute(f"DROP TABLE IF EXISTS `{CUSTOM_TABLE}`") + + def test_from_texts(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + MySQLVectorStore.from_texts( + texts, + embeddings_service, + engine, + DEFAULT_TABLE, + metadatas=metadatas, + ids=ids, + ) + results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`") + assert len(results) == 3 + engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`") + + def test_from_docs(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + MySQLVectorStore.from_documents( + docs, + embeddings_service, + engine, + DEFAULT_TABLE, + ids=ids, + ) + results = engine._fetch(f"SELECT * FROM `{DEFAULT_TABLE}`") + assert len(results) == 3 + engine._execute(f"TRUNCATE TABLE `{DEFAULT_TABLE}`") + + def test_from_texts_custom(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + MySQLVectorStore.from_texts( + texts, + embeddings_service, + engine, + CUSTOM_TABLE, + ids=ids, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ) + results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`") + content = [result["mycontent"] for result in results] + assert len(results) == 3 + assert "foo" in content + assert "bar" in content + assert "baz" in content + assert results[0]["myembedding"] + assert results[0]["page"] is None + assert results[0]["source"] is None + + def test_from_docs_custom(self, engine): + ids = [str(uuid.uuid4()) for i in range(len(texts))] + docs = [ + Document( + page_content=texts[i], + metadata={"page": str(i), "source": "google.com"}, + ) + for i in range(len(texts)) + ] + MySQLVectorStore.from_documents( + docs, + embeddings_service, + engine, + CUSTOM_TABLE, + ids=ids, + id_column="myid", + content_column="mycontent", + embedding_column="myembedding", + metadata_columns=["page", "source"], + ) + + results = engine._fetch(f"SELECT * FROM `{CUSTOM_TABLE}`") + content = [result["mycontent"] for result in results] + assert len(results) == 3 + assert "foo" in content + assert "bar" in content + assert "baz" in content + assert results[0]["myembedding"] + pages = [result["page"] for result in results] + assert "0" in pages + assert "1" in pages + assert "2" in pages + assert results[0]["source"] == "google.com" + engine._execute(f"TRUNCATE TABLE `{CUSTOM_TABLE}`")