Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MySQLVectorStore initialization methods #52

Merged
merged 8 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ availableSecrets:
env: "DB_PASSWORD"

substitutions:
_INSTANCE_ID: test-instance
_INSTANCE_ID: mysql-vector
_REGION: us-central1
_DB_NAME: test
_VERSION: "3.8"
Expand Down
5 changes: 4 additions & 1 deletion src/langchain_google_cloud_sql_mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# limitations under the License.

from .chat_message_history import MySQLChatMessageHistory
from .engine import MySQLEngine
from .engine import Column, MySQLEngine
from .loader import MySQLDocumentSaver, MySQLLoader
from .vectorstore import MySQLVectorStore
from .version import __version__

__all__ = [
"Column",
"MySQLChatMessageHistory",
"MySQLDocumentSaver",
"MySQLEngine",
"MySQLLoader",
"MySQLVectorStore",
"__version__",
]
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MySQLChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Cloud SQL MySQL database.

Args:
engine (MySQLEngine): SQLAlchemy connection pool engine for managing
engine (MySQLEngine): Connection pool engine for managing
connections to Cloud SQL for MySQL.
session_id (str): Arbitrary key that is used to store the messages
of a single chat session.
Expand Down
77 changes: 77 additions & 0 deletions src/langchain_google_cloud_sql_mysql/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@

USER_AGENT = "langchain-google-cloud-sql-mysql-python/" + __version__

from dataclasses import dataclass


@dataclass
class Column:
name: str
data_type: str
nullable: bool = True

def __post_init__(self):
if not isinstance(self.name, str):
raise ValueError("Column name must be type string")
if not isinstance(self.data_type, str):
raise ValueError("Column data_type must be type string")


def _get_iam_principal_email(
credentials: google.auth.credentials.Credentials,
Expand Down Expand Up @@ -206,6 +221,20 @@ def connect(self) -> sqlalchemy.engine.Connection:
"""
return self.engine.connect()

def _execute(self, query: str, params: Optional[dict] = None) -> None:
"""Execute a SQL query."""
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(query), params)
conn.commit()

def _fetch(self, query: str, params: Optional[dict] = None):
"""Fetch results from a SQL query."""
with self.engine.connect() as conn:
result = conn.execute(sqlalchemy.text(query), params)
result_map = result.mappings()
result_fetch = result_map.fetchall()
return result_fetch

def init_chat_history_table(self, table_name: str) -> None:
"""Create table with schema required for MySQLChatMessageHistory class.

Expand Down Expand Up @@ -293,3 +322,51 @@ def _load_document_table(self, table_name: str) -> sqlalchemy.Table:
metadata = sqlalchemy.MetaData()
sqlalchemy.MetaData.reflect(metadata, bind=self.engine, only=[table_name])
return metadata.tables[table_name]

def init_vectorstore_table(
self,
table_name: str,
vector_size: int,
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[Column] = [],
metadata_json_column: str = "langchain_metadata",
id_column: str = "langchain_id",
overwrite_existing: bool = False,
store_metadata: bool = True,
) -> None:
"""
Create a table for saving of vectors to be used with MySQLVectorStore.

Args:
table_name (str): The MySQL database table name.
vector_size (int): Vector size for the embedding model to be used.
content_column (str): Name of the column to store document content.
Deafult: `page_content`.
embedding_column (str) : Name of the column to store vector embeddings.
Default: `embedding`.
metadata_columns (List[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: `langchain_metadata`. Optional.
id_column (str): Name of the column to store ids.
Default: `langchain_id`. Optional,
overwrite_existing (bool): Whether to drop existing table. Default: False.
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
store_metadata (bool): Whether to store metadata in the table.
Default: True.
"""
query = f"""CREATE TABLE `{table_name}`(
`{id_column}` CHAR(36) PRIMARY KEY,
`{content_column}` TEXT NOT NULL,
`{embedding_column}` vector({vector_size}) USING VARBINARY NOT NULL"""
for column in metadata_columns:
nullable = "NOT NULL" if not column.nullable else ""
query += f",\n`{column.name}` {column.data_type} {nullable}"
if store_metadata:
query += f""",\n`{metadata_json_column}` JSON"""
query += "\n);"

with self.engine.connect() as conn:
if overwrite_existing:
conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`"))
conn.execute(sqlalchemy.text(query))
22 changes: 22 additions & 0 deletions src/langchain_google_cloud_sql_mysql/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from dataclasses import dataclass


@dataclass
class QueryOptions(ABC):
def to_string(self) -> str:
raise NotImplementedError("to_string method must be implemented by subclass")
Loading
Loading